From 7071c46422efcb88bffdfd291a134308d367a107 Mon Sep 17 00:00:00 2001 From: Xi Yan Date: Thu, 19 Sep 2024 21:33:35 -0700 Subject: [PATCH] instantiate inference models --- llama_stack/distribution/server/server.py | 6 +-- llama_stack/distribution/utils/dynamic.py | 3 +- .../impls/builtin/models/__init__.py | 2 - .../providers/impls/builtin/models/models.py | 1 - .../providers/routers/inference/__init__.py | 4 +- .../providers/routers/inference/inference.py | 44 ++++++++++++++----- 6 files changed, 40 insertions(+), 20 deletions(-) diff --git a/llama_stack/distribution/server/server.py b/llama_stack/distribution/server/server.py index 5cf299bff..4149d19cc 100644 --- a/llama_stack/distribution/server/server.py +++ b/llama_stack/distribution/server/server.py @@ -321,12 +321,10 @@ async def resolve_impls( inner_specs=inner_specs, ) - for k, v in specs.items(): - cprint(k, "blue") - cprint(v, "blue") - sorted_specs = topological_sort(specs.values()) + cprint(f"sorted_specs={sorted_specs}", "red") + impls = {} for spec in sorted_specs: api = spec.api diff --git a/llama_stack/distribution/utils/dynamic.py b/llama_stack/distribution/utils/dynamic.py index 048a418d4..b2e29b3e7 100644 --- a/llama_stack/distribution/utils/dynamic.py +++ b/llama_stack/distribution/utils/dynamic.py @@ -53,7 +53,8 @@ async def instantiate_provider( args = [inner_impls, deps] elif isinstance(provider_config, str) and provider_config == "models-router": config = None - args = [[], deps] + assert len(deps) == 1 and Api.models in deps + args = [deps[Api.models]] else: raise ValueError(f"provider_config {provider_config} is not valid") else: diff --git a/llama_stack/providers/impls/builtin/models/__init__.py b/llama_stack/providers/impls/builtin/models/__init__.py index bb06e828f..439f2be61 100644 --- a/llama_stack/providers/impls/builtin/models/__init__.py +++ b/llama_stack/providers/impls/builtin/models/__init__.py @@ -18,8 +18,6 @@ async def get_provider_impl(config: BuiltinImplConfig, deps: Dict[Api, ProviderS config, BuiltinImplConfig ), f"Unexpected config type: {type(config)}" - print(config) - impl = BuiltinModelsImpl(config) await impl.initialize() return impl diff --git a/llama_stack/providers/impls/builtin/models/models.py b/llama_stack/providers/impls/builtin/models/models.py index 66afb118f..9a399c67d 100644 --- a/llama_stack/providers/impls/builtin/models/models.py +++ b/llama_stack/providers/impls/builtin/models/models.py @@ -25,7 +25,6 @@ class BuiltinModelsImpl(Models): config: BuiltinImplConfig, ) -> None: self.config = config - cprint(self.config, "red") self.models = { entry.core_model_id: ModelSpec( llama_model_metadata=resolve_model(entry.core_model_id), diff --git a/llama_stack/providers/routers/inference/__init__.py b/llama_stack/providers/routers/inference/__init__.py index c6619ffc9..c0f0498b7 100644 --- a/llama_stack/providers/routers/inference/__init__.py +++ b/llama_stack/providers/routers/inference/__init__.py @@ -9,9 +9,9 @@ from typing import Any, List, Tuple from llama_stack.distribution.datatypes import Api -async def get_router_impl(inner_impls: List[Tuple[str, Any]], deps: List[Api]): +async def get_router_impl(models_api: Api): from .inference import InferenceRouterImpl - impl = InferenceRouterImpl(inner_impls, deps) + impl = InferenceRouterImpl(models_api) await impl.initialize() return impl diff --git a/llama_stack/providers/routers/inference/inference.py b/llama_stack/providers/routers/inference/inference.py index f029892b0..dbf2f3952 100644 --- a/llama_stack/providers/routers/inference/inference.py +++ b/llama_stack/providers/routers/inference/inference.py @@ -8,7 +8,14 @@ from typing import Any, AsyncGenerator, Dict, List, Tuple from llama_stack.distribution.datatypes import Api from llama_stack.apis.inference import * # noqa: F403 +from llama_stack.apis.models import Models + +from llama_stack.distribution.datatypes import GenericProviderConfig +from llama_stack.distribution.distribution import api_providers +from llama_stack.distribution.utils.dynamic import instantiate_provider +from llama_stack.providers.impls.builtin.models.models import BuiltinModelsImpl from llama_stack.providers.registry.inference import available_providers +from termcolor import cprint class InferenceRouterImpl(Inference): @@ -16,19 +23,36 @@ class InferenceRouterImpl(Inference): def __init__( self, - inner_impls: List[Tuple[str, Any]], - deps: List[Api], + models_api: Models, ) -> None: - self.inner_impls = inner_impls - self.deps = deps - print("INIT INFERENCE ROUTER!") - - # self.providers = {} - # for routing_key, provider_impl in inner_impls: - # self.providers[routing_key] = provider_impl + # map of model_id to provider impl + self.providers = {} + self.models_api = models_api async def initialize(self) -> None: - pass + inference_providers = api_providers()[Api.inference] + + models_list_response = await self.models_api.list_models() + for model_spec in models_list_response.models_list: + + if model_spec.api != Api.inference.value: + continue + + if model_spec.provider_id not in inference_providers: + raise ValueError( + f"provider_id {model_spec.provider_id} is not available for inference. Please check run.yaml config spec to define a valid provider" + ) + impl = await instantiate_provider( + inference_providers[model_spec.provider_id], + deps=[], + provider_config=GenericProviderConfig( + provider_id=model_spec.provider_id, + config=model_spec.provider_config, + ), + ) + cprint(f"impl={impl}", "blue") + # look up and initialize provider implementations for each model + core_model_id = model_spec.llama_model_metadata.core_model_id async def shutdown(self) -> None: pass