mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-07-31 16:01:46 +00:00
address feedback
This commit is contained in:
parent
772e23e29e
commit
d6a9a17828
4 changed files with 2 additions and 19 deletions
|
@ -15,9 +15,6 @@ from llama_stack.apis.resource import Resource, ResourceType
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
class Model(Resource):
|
class Model(Resource):
|
||||||
type: Literal[ResourceType.model.value] = ResourceType.model.value
|
type: Literal[ResourceType.model.value] = ResourceType.model.value
|
||||||
llama_model: str = Field(
|
|
||||||
description="Pointer to the underlying core Llama family model. Each model served by Llama Stack must have a core Llama model.",
|
|
||||||
)
|
|
||||||
metadata: Dict[str, Any] = Field(
|
metadata: Dict[str, Any] = Field(
|
||||||
default_factory=dict,
|
default_factory=dict,
|
||||||
description="Any additional metadata for this model",
|
description="Any additional metadata for this model",
|
||||||
|
@ -38,6 +35,5 @@ class Models(Protocol):
|
||||||
model_id: str,
|
model_id: str,
|
||||||
provider_model_id: Optional[str] = None,
|
provider_model_id: Optional[str] = None,
|
||||||
provider_id: Optional[str] = None,
|
provider_id: Optional[str] = None,
|
||||||
llama_model: Optional[str] = None,
|
|
||||||
metadata: Optional[Dict[str, Any]] = None,
|
metadata: Optional[Dict[str, Any]] = None,
|
||||||
) -> Model: ...
|
) -> Model: ...
|
||||||
|
|
|
@ -76,11 +76,10 @@ class InferenceRouter(Inference):
|
||||||
model_id: str,
|
model_id: str,
|
||||||
provider_model_id: Optional[str] = None,
|
provider_model_id: Optional[str] = None,
|
||||||
provider_id: Optional[str] = None,
|
provider_id: Optional[str] = None,
|
||||||
llama_model: Optional[str] = None,
|
|
||||||
metadata: Optional[Dict[str, Any]] = None,
|
metadata: Optional[Dict[str, Any]] = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
await self.routing_table.register_model(
|
await self.routing_table.register_model(
|
||||||
model_id, provider_model_id, provider_id, llama_model, metadata
|
model_id, provider_model_id, provider_id, metadata
|
||||||
)
|
)
|
||||||
|
|
||||||
async def chat_completion(
|
async def chat_completion(
|
||||||
|
|
|
@ -207,7 +207,6 @@ class ModelsRoutingTable(CommonRoutingTableImpl, Models):
|
||||||
model_id: str,
|
model_id: str,
|
||||||
provider_model_id: Optional[str] = None,
|
provider_model_id: Optional[str] = None,
|
||||||
provider_id: Optional[str] = None,
|
provider_id: Optional[str] = None,
|
||||||
llama_model: Optional[str] = None,
|
|
||||||
metadata: Optional[Dict[str, Any]] = None,
|
metadata: Optional[Dict[str, Any]] = None,
|
||||||
) -> Model:
|
) -> Model:
|
||||||
if provider_model_id is None:
|
if provider_model_id is None:
|
||||||
|
@ -218,17 +217,14 @@ class ModelsRoutingTable(CommonRoutingTableImpl, Models):
|
||||||
provider_id = list(self.impls_by_provider_id.keys())[0]
|
provider_id = list(self.impls_by_provider_id.keys())[0]
|
||||||
else:
|
else:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"No provider specified and multiple providers available. Please specify a provider_id."
|
"No provider specified and multiple providers available. Please specify a provider_id. Available providers: {self.impls_by_provider_id.keys()}"
|
||||||
)
|
)
|
||||||
if metadata is None:
|
if metadata is None:
|
||||||
metadata = {}
|
metadata = {}
|
||||||
if llama_model is None:
|
|
||||||
llama_model = model_id
|
|
||||||
model = Model(
|
model = Model(
|
||||||
identifier=model_id,
|
identifier=model_id,
|
||||||
provider_resource_id=provider_model_id,
|
provider_resource_id=provider_model_id,
|
||||||
provider_id=provider_id,
|
provider_id=provider_id,
|
||||||
llama_model=llama_model,
|
|
||||||
metadata=metadata,
|
metadata=metadata,
|
||||||
)
|
)
|
||||||
await self.register_object(model)
|
await self.register_object(model)
|
||||||
|
|
|
@ -88,14 +88,6 @@ class VLLMInferenceImpl(Inference, ModelsProtocolPrivate):
|
||||||
"You cannot dynamically add a model to a running vllm instance"
|
"You cannot dynamically add a model to a running vllm instance"
|
||||||
)
|
)
|
||||||
|
|
||||||
async def list_models(self) -> List[Model]:
|
|
||||||
return [
|
|
||||||
Model(
|
|
||||||
identifier=self.config.model,
|
|
||||||
llama_model=self.config.model,
|
|
||||||
)
|
|
||||||
]
|
|
||||||
|
|
||||||
def _sampling_params(self, sampling_params: SamplingParams) -> VLLMSamplingParams:
|
def _sampling_params(self, sampling_params: SamplingParams) -> VLLMSamplingParams:
|
||||||
if sampling_params is None:
|
if sampling_params is None:
|
||||||
return VLLMSamplingParams(max_tokens=self.config.max_tokens)
|
return VLLMSamplingParams(max_tokens=self.config.max_tokens)
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue