refactor & cleanup

This commit is contained in:
Xi Yan 2024-09-19 22:11:59 -07:00
parent 5d3c02d0fb
commit bce79617bf
3 changed files with 13 additions and 10 deletions

View file

@ -48,7 +48,10 @@ from typing_extensions import Annotated
from llama_stack.distribution.datatypes import * # noqa: F403
from llama_stack.distribution.distribution import api_endpoints, api_providers
from llama_stack.distribution.utils.dynamic import instantiate_provider
from llama_stack.distribution.utils.dynamic import (
instantiate_provider,
is_models_router_provider,
)
def is_async_iterator_type(typ):
@ -272,10 +275,6 @@ def topological_sort(providers: List[ProviderSpec]) -> List[ProviderSpec]:
return [by_id[x] for x in stack]
def snake_to_camel(snake_str):
return "".join(word.capitalize() for word in snake_str.split("_"))
async def resolve_impls(
provider_map: Dict[str, ProviderMapEntry],
) -> Dict[Api, Any]:
@ -297,7 +296,7 @@ async def resolve_impls(
f"Unknown provider `{provider_id}` is not available for API `{api}`"
)
specs[api] = providers[item.provider_id]
elif isinstance(item, str) and item == "models-router":
elif is_models_router_provider(item):
specs[api] = RouterProviderSpec(
api=api,
module=f"llama_stack.providers.routers.{api.value.lower()}",
@ -323,8 +322,6 @@ async def resolve_impls(
sorted_specs = topological_sort(specs.values())
cprint(f"sorted_specs={sorted_specs}", "red")
impls = {}
for spec in sorted_specs:
api = spec.api

View file

@ -16,6 +16,10 @@ def instantiate_class_type(fully_qualified_name):
return getattr(module, class_name)
def is_models_router_provider(item: Any):
return isinstance(item, str) and item == "models-router"
# returns a class implementing the protocol corresponding to the Api
async def instantiate_provider(
provider_spec: ProviderSpec,
@ -51,7 +55,7 @@ async def instantiate_provider(
config = None
args = [inner_impls, deps]
elif isinstance(provider_config, str) and provider_config == "models-router":
elif is_models_router_provider(provider_config):
config = None
assert len(deps) == 1 and Api.models in deps
args = [deps[Api.models]]

View file

@ -55,7 +55,9 @@ class InferenceRouterImpl(Inference):
core_model_id = model_spec.llama_model_metadata.core_model_id
self.model2providers[core_model_id.value] = impl
cprint(self.model2providers, "blue")
cprint(f"Initialized inference router: ", "blue")
for m, p in self.model2providers.items():
cprint(f"- {m} serving using {p}", "blue")
async def shutdown(self) -> None:
for p in self.model2providers.values():