diff --git a/llama_stack/providers/impls/builtin/models/__init__.py b/llama_stack/providers/impls/builtin/models/__init__.py index 439f2be61..bb06e828f 100644 --- a/llama_stack/providers/impls/builtin/models/__init__.py +++ b/llama_stack/providers/impls/builtin/models/__init__.py @@ -18,6 +18,8 @@ async def get_provider_impl(config: BuiltinImplConfig, deps: Dict[Api, ProviderS config, BuiltinImplConfig ), f"Unexpected config type: {type(config)}" + print(config) + impl = BuiltinModelsImpl(config) await impl.initialize() return impl diff --git a/llama_stack/providers/impls/builtin/models/config.py b/llama_stack/providers/impls/builtin/models/config.py index b24499d4e..d153d6075 100644 --- a/llama_stack/providers/impls/builtin/models/config.py +++ b/llama_stack/providers/impls/builtin/models/config.py @@ -4,15 +4,21 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -from typing import Optional - -from llama_models.datatypes import ModelFamily +from typing import Any, List, Optional from llama_models.schema_utils import json_schema_type from llama_models.sku_list import all_registered_models, resolve_model +from llama_stack.distribution.datatypes import GenericProviderConfig from pydantic import BaseModel, Field, field_validator @json_schema_type -class BuiltinImplConfig(BaseModel): ... +class ModelConfigProviderEntry(GenericProviderConfig): + api: str + core_model_id: str + + +@json_schema_type +class BuiltinImplConfig(BaseModel): + models_config: List[ModelConfigProviderEntry] diff --git a/llama_stack/providers/impls/builtin/models/models.py b/llama_stack/providers/impls/builtin/models/models.py index 08dea21e0..50da40cb0 100644 --- a/llama_stack/providers/impls/builtin/models/models.py +++ b/llama_stack/providers/impls/builtin/models/models.py @@ -14,6 +14,7 @@ from llama_stack.apis.models import * # noqa: F403 from llama_models.llama3.api.datatypes import * # noqa: F403 from llama_models.datatypes import CoreModelId, Model from llama_models.sku_list import resolve_model +from termcolor import cprint from .config import BuiltinImplConfig @@ -34,21 +35,25 @@ class BuiltinModelsImpl(Models): config: BuiltinImplConfig, ) -> None: self.config = config + self.models = { x.llama_model_metadata.core_model_id.value: x for x in [DUMMY_MODELS_SPEC_1, DUMMY_MODELS_SPEC_2] } + cprint(self.config, "red") + async def initialize(self) -> None: pass async def list_models(self) -> ModelsListResponse: + print(self.config, "hihihi") return ModelsListResponse(models_list=list(self.models.values())) async def get_model(self, core_model_id: str) -> ModelsGetResponse: if core_model_id in self.models: return ModelsGetResponse(core_model_spec=self.models[core_model_id]) - raise ValueError(f"Cannot find {core_model_id} in model registry") + raise RuntimeError(f"Cannot find {core_model_id} in model registry") async def register_model( self, model_id: str, api: str, provider_spec: Dict[str, str]