From bce79617bf81c3fe27f12974d0c2e846f44c2653 Mon Sep 17 00:00:00 2001 From: Xi Yan Date: Thu, 19 Sep 2024 22:11:59 -0700 Subject: [PATCH] refactor & cleanup --- llama_stack/distribution/server/server.py | 13 +++++-------- llama_stack/distribution/utils/dynamic.py | 6 +++++- .../providers/routers/inference/inference.py | 4 +++- 3 files changed, 13 insertions(+), 10 deletions(-) diff --git a/llama_stack/distribution/server/server.py b/llama_stack/distribution/server/server.py index 4149d19cc..7a2ed874f 100644 --- a/llama_stack/distribution/server/server.py +++ b/llama_stack/distribution/server/server.py @@ -48,7 +48,10 @@ from typing_extensions import Annotated from llama_stack.distribution.datatypes import * # noqa: F403 from llama_stack.distribution.distribution import api_endpoints, api_providers -from llama_stack.distribution.utils.dynamic import instantiate_provider +from llama_stack.distribution.utils.dynamic import ( + instantiate_provider, + is_models_router_provider, +) def is_async_iterator_type(typ): @@ -272,10 +275,6 @@ def topological_sort(providers: List[ProviderSpec]) -> List[ProviderSpec]: return [by_id[x] for x in stack] -def snake_to_camel(snake_str): - return "".join(word.capitalize() for word in snake_str.split("_")) - - async def resolve_impls( provider_map: Dict[str, ProviderMapEntry], ) -> Dict[Api, Any]: @@ -297,7 +296,7 @@ async def resolve_impls( f"Unknown provider `{provider_id}` is not available for API `{api}`" ) specs[api] = providers[item.provider_id] - elif isinstance(item, str) and item == "models-router": + elif is_models_router_provider(item): specs[api] = RouterProviderSpec( api=api, module=f"llama_stack.providers.routers.{api.value.lower()}", @@ -323,8 +322,6 @@ async def resolve_impls( sorted_specs = topological_sort(specs.values()) - cprint(f"sorted_specs={sorted_specs}", "red") - impls = {} for spec in sorted_specs: api = spec.api diff --git a/llama_stack/distribution/utils/dynamic.py b/llama_stack/distribution/utils/dynamic.py index b2e29b3e7..b70a538c2 100644 --- a/llama_stack/distribution/utils/dynamic.py +++ b/llama_stack/distribution/utils/dynamic.py @@ -16,6 +16,10 @@ def instantiate_class_type(fully_qualified_name): return getattr(module, class_name) +def is_models_router_provider(item: Any): + return isinstance(item, str) and item == "models-router" + + # returns a class implementing the protocol corresponding to the Api async def instantiate_provider( provider_spec: ProviderSpec, @@ -51,7 +55,7 @@ async def instantiate_provider( config = None args = [inner_impls, deps] - elif isinstance(provider_config, str) and provider_config == "models-router": + elif is_models_router_provider(provider_config): config = None assert len(deps) == 1 and Api.models in deps args = [deps[Api.models]] diff --git a/llama_stack/providers/routers/inference/inference.py b/llama_stack/providers/routers/inference/inference.py index be3d2e434..a58f67af4 100644 --- a/llama_stack/providers/routers/inference/inference.py +++ b/llama_stack/providers/routers/inference/inference.py @@ -55,7 +55,9 @@ class InferenceRouterImpl(Inference): core_model_id = model_spec.llama_model_metadata.core_model_id self.model2providers[core_model_id.value] = impl - cprint(self.model2providers, "blue") + cprint(f"Initialized inference router: ", "blue") + for m, p in self.model2providers.items(): + cprint(f"- {m} serving using {p}", "blue") async def shutdown(self) -> None: for p in self.model2providers.values():