mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-07-29 15:23:51 +00:00
refactor & cleanup
This commit is contained in:
parent
5d3c02d0fb
commit
bce79617bf
3 changed files with 13 additions and 10 deletions
|
@ -48,7 +48,10 @@ from typing_extensions import Annotated
|
|||
from llama_stack.distribution.datatypes import * # noqa: F403
|
||||
|
||||
from llama_stack.distribution.distribution import api_endpoints, api_providers
|
||||
from llama_stack.distribution.utils.dynamic import instantiate_provider
|
||||
from llama_stack.distribution.utils.dynamic import (
|
||||
instantiate_provider,
|
||||
is_models_router_provider,
|
||||
)
|
||||
|
||||
|
||||
def is_async_iterator_type(typ):
|
||||
|
@ -272,10 +275,6 @@ def topological_sort(providers: List[ProviderSpec]) -> List[ProviderSpec]:
|
|||
return [by_id[x] for x in stack]
|
||||
|
||||
|
||||
def snake_to_camel(snake_str):
|
||||
return "".join(word.capitalize() for word in snake_str.split("_"))
|
||||
|
||||
|
||||
async def resolve_impls(
|
||||
provider_map: Dict[str, ProviderMapEntry],
|
||||
) -> Dict[Api, Any]:
|
||||
|
@ -297,7 +296,7 @@ async def resolve_impls(
|
|||
f"Unknown provider `{provider_id}` is not available for API `{api}`"
|
||||
)
|
||||
specs[api] = providers[item.provider_id]
|
||||
elif isinstance(item, str) and item == "models-router":
|
||||
elif is_models_router_provider(item):
|
||||
specs[api] = RouterProviderSpec(
|
||||
api=api,
|
||||
module=f"llama_stack.providers.routers.{api.value.lower()}",
|
||||
|
@ -323,8 +322,6 @@ async def resolve_impls(
|
|||
|
||||
sorted_specs = topological_sort(specs.values())
|
||||
|
||||
cprint(f"sorted_specs={sorted_specs}", "red")
|
||||
|
||||
impls = {}
|
||||
for spec in sorted_specs:
|
||||
api = spec.api
|
||||
|
|
|
@ -16,6 +16,10 @@ def instantiate_class_type(fully_qualified_name):
|
|||
return getattr(module, class_name)
|
||||
|
||||
|
||||
def is_models_router_provider(item: Any):
|
||||
return isinstance(item, str) and item == "models-router"
|
||||
|
||||
|
||||
# returns a class implementing the protocol corresponding to the Api
|
||||
async def instantiate_provider(
|
||||
provider_spec: ProviderSpec,
|
||||
|
@ -51,7 +55,7 @@ async def instantiate_provider(
|
|||
|
||||
config = None
|
||||
args = [inner_impls, deps]
|
||||
elif isinstance(provider_config, str) and provider_config == "models-router":
|
||||
elif is_models_router_provider(provider_config):
|
||||
config = None
|
||||
assert len(deps) == 1 and Api.models in deps
|
||||
args = [deps[Api.models]]
|
||||
|
|
|
@ -55,7 +55,9 @@ class InferenceRouterImpl(Inference):
|
|||
core_model_id = model_spec.llama_model_metadata.core_model_id
|
||||
self.model2providers[core_model_id.value] = impl
|
||||
|
||||
cprint(self.model2providers, "blue")
|
||||
cprint(f"Initialized inference router: ", "blue")
|
||||
for m, p in self.model2providers.items():
|
||||
cprint(f"- {m} serving using {p}", "blue")
|
||||
|
||||
async def shutdown(self) -> None:
|
||||
for p in self.model2providers.values():
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue