diff --git a/llama_stack/apis/models/models.py b/llama_stack/apis/models/models.py index d50d5a656..0462a6882 100644 --- a/llama_stack/apis/models/models.py +++ b/llama_stack/apis/models/models.py @@ -9,12 +9,12 @@ from typing import Any, Dict, List, Literal, Optional, Protocol, runtime_checkab from llama_models.schema_utils import json_schema_type, webmethod from pydantic import Field -from llama_stack.apis.resource import Resource +from llama_stack.apis.resource import Resource, ResourceType @json_schema_type class Model(Resource): - type: Literal["model"] = "model" + type: Literal[ResourceType.model.value] = ResourceType.model.value llama_model: str = Field( description="Pointer to the underlying core Llama family model. Each model served by Llama Stack must have a core Llama model.", ) @@ -33,4 +33,11 @@ class Models(Protocol): async def get_model(self, identifier: str) -> Optional[Model]: ... @webmethod(route="/models/register", method="POST") - async def register_model(self, model: Model) -> None: ... + async def register_model( + self, + model_id: str, + provider_model_id: Optional[str] = None, + provider_id: Optional[str] = None, + llama_model: Optional[str] = None, + metadata: Optional[Dict[str, Any]] = None, + ) -> Model: ... diff --git a/llama_stack/distribution/routers/routers.py b/llama_stack/distribution/routers/routers.py index a19495f50..364862271 100644 --- a/llama_stack/distribution/routers/routers.py +++ b/llama_stack/distribution/routers/routers.py @@ -4,7 +4,7 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -from typing import Any, AsyncGenerator, Dict, List +from typing import Any, AsyncGenerator, Dict, List, Optional from llama_stack.apis.datasetio.datasetio import DatasetIO from llama_stack.distribution.datatypes import RoutingTable @@ -71,8 +71,17 @@ class InferenceRouter(Inference): async def shutdown(self) -> None: pass - async def register_model(self, model: Model) -> None: - await self.routing_table.register_model(model) + async def register_model( + self, + model_id: str, + provider_model_id: Optional[str] = None, + provider_id: Optional[str] = None, + llama_model: Optional[str] = None, + metadata: Optional[Dict[str, Any]] = None, + ) -> None: + await self.routing_table.register_model( + model_id, provider_model_id, provider_id, llama_model, metadata + ) async def chat_completion( self, diff --git a/llama_stack/distribution/routers/routing_tables.py b/llama_stack/distribution/routers/routing_tables.py index f28896670..7c231deb5 100644 --- a/llama_stack/distribution/routers/routing_tables.py +++ b/llama_stack/distribution/routers/routing_tables.py @@ -202,7 +202,35 @@ class ModelsRoutingTable(CommonRoutingTableImpl, Models): async def get_model(self, identifier: str) -> Optional[Model]: return await self.get_object_by_identifier(identifier) - async def register_model(self, model: Model) -> None: + async def register_model( + self, + model_id: str, + provider_model_id: Optional[str] = None, + provider_id: Optional[str] = None, + llama_model: Optional[str] = None, + metadata: Optional[Dict[str, Any]] = None, + ) -> None: + if provider_model_id is None: + provider_model_id = model_id + if provider_id is None: + # If provider_id not specified, use the only provider if it supports this model + if len(self.impls_by_provider_id) == 1: + provider_id = list(self.impls_by_provider_id.keys())[0] + else: + raise ValueError( + "No provider specified and multiple providers available. Please specify a provider_id." + ) + if metadata is None: + metadata = {} + if llama_model is None: + llama_model = model_id + model = Model( + identifier=model_id, + provider_resource_id=provider_model_id, + provider_id=provider_id, + llama_model=llama_model, + metadata=metadata, + ) await self.register_object(model) diff --git a/llama_stack/providers/tests/inference/fixtures.py b/llama_stack/providers/tests/inference/fixtures.py index 4ffa31eed..b2c6d3a5e 100644 --- a/llama_stack/providers/tests/inference/fixtures.py +++ b/llama_stack/providers/tests/inference/fixtures.py @@ -10,7 +10,6 @@ import pytest import pytest_asyncio from llama_stack.distribution.datatypes import Api, Provider -from llama_stack.providers.datatypes import Model from llama_stack.providers.inline.inference.meta_reference import ( MetaReferenceInferenceConfig, ) @@ -163,11 +162,9 @@ async def inference_stack(request, inference_model): inference_fixture.provider_data, ) - model = Model( - identifier=inference_model, - provider_id=inference_fixture.providers[0].provider_id, + await impls[Api.models].register_model( + model_id=inference_model, + provider_model_id=inference_fixture.providers[0].provider_id, ) - await impls[Api.models].register_model(model) - return (impls[Api.inference], impls[Api.models])