mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-07-30 07:39:38 +00:00
refactor & cleanup
This commit is contained in:
parent
5d3c02d0fb
commit
bce79617bf
3 changed files with 13 additions and 10 deletions
|
@ -48,7 +48,10 @@ from typing_extensions import Annotated
|
||||||
from llama_stack.distribution.datatypes import * # noqa: F403
|
from llama_stack.distribution.datatypes import * # noqa: F403
|
||||||
|
|
||||||
from llama_stack.distribution.distribution import api_endpoints, api_providers
|
from llama_stack.distribution.distribution import api_endpoints, api_providers
|
||||||
from llama_stack.distribution.utils.dynamic import instantiate_provider
|
from llama_stack.distribution.utils.dynamic import (
|
||||||
|
instantiate_provider,
|
||||||
|
is_models_router_provider,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def is_async_iterator_type(typ):
|
def is_async_iterator_type(typ):
|
||||||
|
@ -272,10 +275,6 @@ def topological_sort(providers: List[ProviderSpec]) -> List[ProviderSpec]:
|
||||||
return [by_id[x] for x in stack]
|
return [by_id[x] for x in stack]
|
||||||
|
|
||||||
|
|
||||||
def snake_to_camel(snake_str):
|
|
||||||
return "".join(word.capitalize() for word in snake_str.split("_"))
|
|
||||||
|
|
||||||
|
|
||||||
async def resolve_impls(
|
async def resolve_impls(
|
||||||
provider_map: Dict[str, ProviderMapEntry],
|
provider_map: Dict[str, ProviderMapEntry],
|
||||||
) -> Dict[Api, Any]:
|
) -> Dict[Api, Any]:
|
||||||
|
@ -297,7 +296,7 @@ async def resolve_impls(
|
||||||
f"Unknown provider `{provider_id}` is not available for API `{api}`"
|
f"Unknown provider `{provider_id}` is not available for API `{api}`"
|
||||||
)
|
)
|
||||||
specs[api] = providers[item.provider_id]
|
specs[api] = providers[item.provider_id]
|
||||||
elif isinstance(item, str) and item == "models-router":
|
elif is_models_router_provider(item):
|
||||||
specs[api] = RouterProviderSpec(
|
specs[api] = RouterProviderSpec(
|
||||||
api=api,
|
api=api,
|
||||||
module=f"llama_stack.providers.routers.{api.value.lower()}",
|
module=f"llama_stack.providers.routers.{api.value.lower()}",
|
||||||
|
@ -323,8 +322,6 @@ async def resolve_impls(
|
||||||
|
|
||||||
sorted_specs = topological_sort(specs.values())
|
sorted_specs = topological_sort(specs.values())
|
||||||
|
|
||||||
cprint(f"sorted_specs={sorted_specs}", "red")
|
|
||||||
|
|
||||||
impls = {}
|
impls = {}
|
||||||
for spec in sorted_specs:
|
for spec in sorted_specs:
|
||||||
api = spec.api
|
api = spec.api
|
||||||
|
|
|
@ -16,6 +16,10 @@ def instantiate_class_type(fully_qualified_name):
|
||||||
return getattr(module, class_name)
|
return getattr(module, class_name)
|
||||||
|
|
||||||
|
|
||||||
|
def is_models_router_provider(item: Any):
|
||||||
|
return isinstance(item, str) and item == "models-router"
|
||||||
|
|
||||||
|
|
||||||
# returns a class implementing the protocol corresponding to the Api
|
# returns a class implementing the protocol corresponding to the Api
|
||||||
async def instantiate_provider(
|
async def instantiate_provider(
|
||||||
provider_spec: ProviderSpec,
|
provider_spec: ProviderSpec,
|
||||||
|
@ -51,7 +55,7 @@ async def instantiate_provider(
|
||||||
|
|
||||||
config = None
|
config = None
|
||||||
args = [inner_impls, deps]
|
args = [inner_impls, deps]
|
||||||
elif isinstance(provider_config, str) and provider_config == "models-router":
|
elif is_models_router_provider(provider_config):
|
||||||
config = None
|
config = None
|
||||||
assert len(deps) == 1 and Api.models in deps
|
assert len(deps) == 1 and Api.models in deps
|
||||||
args = [deps[Api.models]]
|
args = [deps[Api.models]]
|
||||||
|
|
|
@ -55,7 +55,9 @@ class InferenceRouterImpl(Inference):
|
||||||
core_model_id = model_spec.llama_model_metadata.core_model_id
|
core_model_id = model_spec.llama_model_metadata.core_model_id
|
||||||
self.model2providers[core_model_id.value] = impl
|
self.model2providers[core_model_id.value] = impl
|
||||||
|
|
||||||
cprint(self.model2providers, "blue")
|
cprint(f"Initialized inference router: ", "blue")
|
||||||
|
for m, p in self.model2providers.items():
|
||||||
|
cprint(f"- {m} serving using {p}", "blue")
|
||||||
|
|
||||||
async def shutdown(self) -> None:
|
async def shutdown(self) -> None:
|
||||||
for p in self.model2providers.values():
|
for p in self.model2providers.values():
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue