diff --git a/llama_stack/apis/models/models.py b/llama_stack/apis/models/models.py index d3aa64292..a447d0999 100644 --- a/llama_stack/apis/models/models.py +++ b/llama_stack/apis/models/models.py @@ -1,11 +1,12 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -from typing import Dict, List, Optional, Protocol +from typing import Any, Dict, List, Optional, Protocol -from llama_models.llama3.api.datatypes import * # noqa: F403 +from llama_models.llama3.api.datatypes import Model from llama_models.schema_utils import json_schema_type, webmethod +from llama_stack.distribution.datatypes import GenericProviderConfig from pydantic import BaseModel, Field @@ -14,17 +15,14 @@ class ModelSpec(BaseModel): llama_model_metadata: Model = Field( description="All metadatas associated with llama model (defined in llama_models.models.sku_list). " ) - providers_spec: Dict[str, Any] = Field( - default_factory=dict, - description="Map of API to the concrete provider specs. E.g. {}".format( - { - "inference": { - "provider_type": "remote::8080", - "url": "localhost::5555", - "api_token": "hf_xxx", - }, - } - ), + provider_id: str = Field( + description="API provider that is serving this model (e.g. meta-reference, local)", + ) + api: str = Field( + description="API that this model is serving (e.g. inference / safety)", + ) + provider_config: Dict[str, Any] = Field( + description="API provider config used for serving this model to the API provider `provider_id`" ) diff --git a/llama_stack/providers/impls/builtin/models/models.py b/llama_stack/providers/impls/builtin/models/models.py index 50da40cb0..66afb118f 100644 --- a/llama_stack/providers/impls/builtin/models/models.py +++ b/llama_stack/providers/impls/builtin/models/models.py @@ -18,16 +18,6 @@ from termcolor import cprint from .config import BuiltinImplConfig -DUMMY_MODELS_SPEC_1 = ModelSpec( - llama_model_metadata=resolve_model("Llama-Guard-3-8B"), - providers_spec={"safety": {"provider_type": "meta-reference"}}, -) - -DUMMY_MODELS_SPEC_2 = ModelSpec( - llama_model_metadata=resolve_model("Meta-Llama3.1-8B-Instruct"), - providers_spec={"inference": {"provider_type": "meta-reference"}}, -) - class BuiltinModelsImpl(Models): def __init__( @@ -35,19 +25,21 @@ class BuiltinModelsImpl(Models): config: BuiltinImplConfig, ) -> None: self.config = config - - self.models = { - x.llama_model_metadata.core_model_id.value: x - for x in [DUMMY_MODELS_SPEC_1, DUMMY_MODELS_SPEC_2] - } - cprint(self.config, "red") + self.models = { + entry.core_model_id: ModelSpec( + llama_model_metadata=resolve_model(entry.core_model_id), + provider_id=entry.provider_id, + api=entry.api, + provider_config=entry.config, + ) + for entry in self.config.models_config + } async def initialize(self) -> None: pass async def list_models(self) -> ModelsListResponse: - print(self.config, "hihihi") return ModelsListResponse(models_list=list(self.models.values())) async def get_model(self, core_model_id: str) -> ModelsGetResponse: