address feedback

This commit is contained in:
Dinesh Yeduguru 2024-11-08 16:11:53 -08:00
parent 772e23e29e
commit d6a9a17828
4 changed files with 2 additions and 19 deletions

View file

@ -15,9 +15,6 @@ from llama_stack.apis.resource import Resource, ResourceType
@json_schema_type
class Model(Resource):
type: Literal[ResourceType.model.value] = ResourceType.model.value
llama_model: str = Field(
description="Pointer to the underlying core Llama family model. Each model served by Llama Stack must have a core Llama model.",
)
metadata: Dict[str, Any] = Field(
default_factory=dict,
description="Any additional metadata for this model",
@ -38,6 +35,5 @@ class Models(Protocol):
model_id: str,
provider_model_id: Optional[str] = None,
provider_id: Optional[str] = None,
llama_model: Optional[str] = None,
metadata: Optional[Dict[str, Any]] = None,
) -> Model: ...

View file

@ -76,11 +76,10 @@ class InferenceRouter(Inference):
model_id: str,
provider_model_id: Optional[str] = None,
provider_id: Optional[str] = None,
llama_model: Optional[str] = None,
metadata: Optional[Dict[str, Any]] = None,
) -> None:
await self.routing_table.register_model(
model_id, provider_model_id, provider_id, llama_model, metadata
model_id, provider_model_id, provider_id, metadata
)
async def chat_completion(

View file

@ -207,7 +207,6 @@ class ModelsRoutingTable(CommonRoutingTableImpl, Models):
model_id: str,
provider_model_id: Optional[str] = None,
provider_id: Optional[str] = None,
llama_model: Optional[str] = None,
metadata: Optional[Dict[str, Any]] = None,
) -> Model:
if provider_model_id is None:
@ -218,17 +217,14 @@ class ModelsRoutingTable(CommonRoutingTableImpl, Models):
provider_id = list(self.impls_by_provider_id.keys())[0]
else:
raise ValueError(
"No provider specified and multiple providers available. Please specify a provider_id."
"No provider specified and multiple providers available. Please specify a provider_id. Available providers: {self.impls_by_provider_id.keys()}"
)
if metadata is None:
metadata = {}
if llama_model is None:
llama_model = model_id
model = Model(
identifier=model_id,
provider_resource_id=provider_model_id,
provider_id=provider_id,
llama_model=llama_model,
metadata=metadata,
)
await self.register_object(model)

View file

@ -88,14 +88,6 @@ class VLLMInferenceImpl(Inference, ModelsProtocolPrivate):
"You cannot dynamically add a model to a running vllm instance"
)
async def list_models(self) -> List[Model]:
return [
Model(
identifier=self.config.model,
llama_model=self.config.model,
)
]
def _sampling_params(self, sampling_params: SamplingParams) -> VLLMSamplingParams:
if sampling_params is None:
return VLLMSamplingParams(max_tokens=self.config.max_tokens)