mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-07-31 16:01:46 +00:00
address feedback
This commit is contained in:
parent
772e23e29e
commit
d6a9a17828
4 changed files with 2 additions and 19 deletions
|
@ -15,9 +15,6 @@ from llama_stack.apis.resource import Resource, ResourceType
|
|||
@json_schema_type
|
||||
class Model(Resource):
|
||||
type: Literal[ResourceType.model.value] = ResourceType.model.value
|
||||
llama_model: str = Field(
|
||||
description="Pointer to the underlying core Llama family model. Each model served by Llama Stack must have a core Llama model.",
|
||||
)
|
||||
metadata: Dict[str, Any] = Field(
|
||||
default_factory=dict,
|
||||
description="Any additional metadata for this model",
|
||||
|
@ -38,6 +35,5 @@ class Models(Protocol):
|
|||
model_id: str,
|
||||
provider_model_id: Optional[str] = None,
|
||||
provider_id: Optional[str] = None,
|
||||
llama_model: Optional[str] = None,
|
||||
metadata: Optional[Dict[str, Any]] = None,
|
||||
) -> Model: ...
|
||||
|
|
|
@ -76,11 +76,10 @@ class InferenceRouter(Inference):
|
|||
model_id: str,
|
||||
provider_model_id: Optional[str] = None,
|
||||
provider_id: Optional[str] = None,
|
||||
llama_model: Optional[str] = None,
|
||||
metadata: Optional[Dict[str, Any]] = None,
|
||||
) -> None:
|
||||
await self.routing_table.register_model(
|
||||
model_id, provider_model_id, provider_id, llama_model, metadata
|
||||
model_id, provider_model_id, provider_id, metadata
|
||||
)
|
||||
|
||||
async def chat_completion(
|
||||
|
|
|
@ -207,7 +207,6 @@ class ModelsRoutingTable(CommonRoutingTableImpl, Models):
|
|||
model_id: str,
|
||||
provider_model_id: Optional[str] = None,
|
||||
provider_id: Optional[str] = None,
|
||||
llama_model: Optional[str] = None,
|
||||
metadata: Optional[Dict[str, Any]] = None,
|
||||
) -> Model:
|
||||
if provider_model_id is None:
|
||||
|
@ -218,17 +217,14 @@ class ModelsRoutingTable(CommonRoutingTableImpl, Models):
|
|||
provider_id = list(self.impls_by_provider_id.keys())[0]
|
||||
else:
|
||||
raise ValueError(
|
||||
"No provider specified and multiple providers available. Please specify a provider_id."
|
||||
"No provider specified and multiple providers available. Please specify a provider_id. Available providers: {self.impls_by_provider_id.keys()}"
|
||||
)
|
||||
if metadata is None:
|
||||
metadata = {}
|
||||
if llama_model is None:
|
||||
llama_model = model_id
|
||||
model = Model(
|
||||
identifier=model_id,
|
||||
provider_resource_id=provider_model_id,
|
||||
provider_id=provider_id,
|
||||
llama_model=llama_model,
|
||||
metadata=metadata,
|
||||
)
|
||||
await self.register_object(model)
|
||||
|
|
|
@ -88,14 +88,6 @@ class VLLMInferenceImpl(Inference, ModelsProtocolPrivate):
|
|||
"You cannot dynamically add a model to a running vllm instance"
|
||||
)
|
||||
|
||||
async def list_models(self) -> List[Model]:
|
||||
return [
|
||||
Model(
|
||||
identifier=self.config.model,
|
||||
llama_model=self.config.model,
|
||||
)
|
||||
]
|
||||
|
||||
def _sampling_params(self, sampling_params: SamplingParams) -> VLLMSamplingParams:
|
||||
if sampling_params is None:
|
||||
return VLLMSamplingParams(max_tokens=self.config.max_tokens)
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue