Remove unneeded config parameter

This commit is contained in:
Fred Reiss 2025-01-25 19:00:08 -08:00 committed by Ashwin Bharambe
parent 29ae2552fd
commit 59211067d1

View file

@ -4,20 +4,19 @@
# This source code is licensed under the terms described in the LICENSE file in # This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree. # the root directory of this source tree.
from pydantic import BaseModel, Field, field_validator from pydantic import BaseModel, Field
from llama_stack.providers.utils.inference import supported_inference_models
from llama_stack.schema_utils import json_schema_type from llama_stack.schema_utils import json_schema_type
@json_schema_type @json_schema_type
class VLLMConfig(BaseModel): class VLLMConfig(BaseModel):
"""Configuration for the vLLM inference provider.""" """Configuration for the vLLM inference provider.
Note that the model name is no longer part of this static configuration.
You can bind an instance of this provider to a specific model with the
``models.register()`` API call."""
model: str = Field(
default="Llama3.2-3B-Instruct",
description="Model descriptor from `llama model list`",
)
tensor_parallel_size: int = Field( tensor_parallel_size: int = Field(
default=1, default=1,
description="Number of tensor parallel replicas (number of GPUs to use).", description="Number of tensor parallel replicas (number of GPUs to use).",
@ -26,12 +25,8 @@ class VLLMConfig(BaseModel):
default=4096, default=4096,
description="Maximum number of tokens to generate.", description="Maximum number of tokens to generate.",
) )
max_model_len: int = Field( max_model_len: int = Field(default=4096, description="Maximum context length to use during serving.")
default=4096, description="Maximum context length to use during serving." max_num_seqs: int = Field(default=4, description="Maximum parallel batch size for generation")
)
max_num_seqs: int = Field(
default=4, description="Maximum parallel batch size for generation"
)
enforce_eager: bool = Field( enforce_eager: bool = Field(
default=False, default=False,
description="Whether to use eager mode for inference (otherwise cuda graphs are used).", description="Whether to use eager mode for inference (otherwise cuda graphs are used).",
@ -47,7 +42,6 @@ class VLLMConfig(BaseModel):
@classmethod @classmethod
def sample_run_config(cls): def sample_run_config(cls):
return { return {
"model": "${env.INFERENCE_MODEL:Llama3.2-3B-Instruct}",
"tensor_parallel_size": "${env.TENSOR_PARALLEL_SIZE:1}", "tensor_parallel_size": "${env.TENSOR_PARALLEL_SIZE:1}",
"max_tokens": "${env.MAX_TOKENS:4096}", "max_tokens": "${env.MAX_TOKENS:4096}",
"max_model_len": "${env.MAX_MODEL_LEN:4096}", "max_model_len": "${env.MAX_MODEL_LEN:4096}",
@ -55,15 +49,3 @@ class VLLMConfig(BaseModel):
"enforce_eager": "${env.ENFORCE_EAGER:False}", "enforce_eager": "${env.ENFORCE_EAGER:False}",
"gpu_memory_utilization": "${env.GPU_MEMORY_UTILIZATION:0.3}", "gpu_memory_utilization": "${env.GPU_MEMORY_UTILIZATION:0.3}",
} }
@field_validator("model")
@classmethod
def validate_model(cls, model: str) -> str:
permitted_models = supported_inference_models()
descriptors = [m.descriptor() for m in permitted_models]
repos = [m.huggingface_repo for m in permitted_models]
if model not in (descriptors + repos):
model_list = "\n\t".join(repos)
raise ValueError(f"Unknown model: `{model}`. Choose from [\n\t{model_list}\n]")
return model