mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-08-12 13:00:39 +00:00
Remove unneeded config parameter
This commit is contained in:
parent
29ae2552fd
commit
59211067d1
1 changed files with 8 additions and 26 deletions
|
@ -4,20 +4,19 @@
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
from pydantic import BaseModel, Field, field_validator
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
from llama_stack.providers.utils.inference import supported_inference_models
|
|
||||||
from llama_stack.schema_utils import json_schema_type
|
from llama_stack.schema_utils import json_schema_type
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
class VLLMConfig(BaseModel):
|
class VLLMConfig(BaseModel):
|
||||||
"""Configuration for the vLLM inference provider."""
|
"""Configuration for the vLLM inference provider.
|
||||||
|
|
||||||
|
Note that the model name is no longer part of this static configuration.
|
||||||
|
You can bind an instance of this provider to a specific model with the
|
||||||
|
``models.register()`` API call."""
|
||||||
|
|
||||||
model: str = Field(
|
|
||||||
default="Llama3.2-3B-Instruct",
|
|
||||||
description="Model descriptor from `llama model list`",
|
|
||||||
)
|
|
||||||
tensor_parallel_size: int = Field(
|
tensor_parallel_size: int = Field(
|
||||||
default=1,
|
default=1,
|
||||||
description="Number of tensor parallel replicas (number of GPUs to use).",
|
description="Number of tensor parallel replicas (number of GPUs to use).",
|
||||||
|
@ -26,12 +25,8 @@ class VLLMConfig(BaseModel):
|
||||||
default=4096,
|
default=4096,
|
||||||
description="Maximum number of tokens to generate.",
|
description="Maximum number of tokens to generate.",
|
||||||
)
|
)
|
||||||
max_model_len: int = Field(
|
max_model_len: int = Field(default=4096, description="Maximum context length to use during serving.")
|
||||||
default=4096, description="Maximum context length to use during serving."
|
max_num_seqs: int = Field(default=4, description="Maximum parallel batch size for generation")
|
||||||
)
|
|
||||||
max_num_seqs: int = Field(
|
|
||||||
default=4, description="Maximum parallel batch size for generation"
|
|
||||||
)
|
|
||||||
enforce_eager: bool = Field(
|
enforce_eager: bool = Field(
|
||||||
default=False,
|
default=False,
|
||||||
description="Whether to use eager mode for inference (otherwise cuda graphs are used).",
|
description="Whether to use eager mode for inference (otherwise cuda graphs are used).",
|
||||||
|
@ -47,7 +42,6 @@ class VLLMConfig(BaseModel):
|
||||||
@classmethod
|
@classmethod
|
||||||
def sample_run_config(cls):
|
def sample_run_config(cls):
|
||||||
return {
|
return {
|
||||||
"model": "${env.INFERENCE_MODEL:Llama3.2-3B-Instruct}",
|
|
||||||
"tensor_parallel_size": "${env.TENSOR_PARALLEL_SIZE:1}",
|
"tensor_parallel_size": "${env.TENSOR_PARALLEL_SIZE:1}",
|
||||||
"max_tokens": "${env.MAX_TOKENS:4096}",
|
"max_tokens": "${env.MAX_TOKENS:4096}",
|
||||||
"max_model_len": "${env.MAX_MODEL_LEN:4096}",
|
"max_model_len": "${env.MAX_MODEL_LEN:4096}",
|
||||||
|
@ -55,15 +49,3 @@ class VLLMConfig(BaseModel):
|
||||||
"enforce_eager": "${env.ENFORCE_EAGER:False}",
|
"enforce_eager": "${env.ENFORCE_EAGER:False}",
|
||||||
"gpu_memory_utilization": "${env.GPU_MEMORY_UTILIZATION:0.3}",
|
"gpu_memory_utilization": "${env.GPU_MEMORY_UTILIZATION:0.3}",
|
||||||
}
|
}
|
||||||
|
|
||||||
@field_validator("model")
|
|
||||||
@classmethod
|
|
||||||
def validate_model(cls, model: str) -> str:
|
|
||||||
permitted_models = supported_inference_models()
|
|
||||||
|
|
||||||
descriptors = [m.descriptor() for m in permitted_models]
|
|
||||||
repos = [m.huggingface_repo for m in permitted_models]
|
|
||||||
if model not in (descriptors + repos):
|
|
||||||
model_list = "\n\t".join(repos)
|
|
||||||
raise ValueError(f"Unknown model: `{model}`. Choose from [\n\t{model_list}\n]")
|
|
||||||
return model
|
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue