refactor & cleanup

This commit is contained in:
Xi Yan 2024-09-19 22:11:59 -07:00
parent 5d3c02d0fb
commit bce79617bf
3 changed files with 13 additions and 10 deletions

View file

@ -48,7 +48,10 @@ from typing_extensions import Annotated
from llama_stack.distribution.datatypes import * # noqa: F403
from llama_stack.distribution.distribution import api_endpoints, api_providers
from llama_stack.distribution.utils.dynamic import instantiate_provider
from llama_stack.distribution.utils.dynamic import (
instantiate_provider,
is_models_router_provider,
)
def is_async_iterator_type(typ):
@ -272,10 +275,6 @@ def topological_sort(providers: List[ProviderSpec]) -> List[ProviderSpec]:
return [by_id[x] for x in stack]
def snake_to_camel(snake_str):
return "".join(word.capitalize() for word in snake_str.split("_"))
async def resolve_impls(
provider_map: Dict[str, ProviderMapEntry],
) -> Dict[Api, Any]:
@ -297,7 +296,7 @@ async def resolve_impls(
f"Unknown provider `{provider_id}` is not available for API `{api}`"
)
specs[api] = providers[item.provider_id]
elif isinstance(item, str) and item == "models-router":
elif is_models_router_provider(item):
specs[api] = RouterProviderSpec(
api=api,
module=f"llama_stack.providers.routers.{api.value.lower()}",
@ -323,8 +322,6 @@ async def resolve_impls(
sorted_specs = topological_sort(specs.values())
cprint(f"sorted_specs={sorted_specs}", "red")
impls = {}
for spec in sorted_specs:
api = spec.api