Allow passing provider_registry to resolve_impls()

This commit is contained in:
Ashwin Bharambe 2024-10-28 11:58:16 -07:00
parent 8a3b64d1be
commit b7d2b83d55
3 changed files with 13 additions and 12 deletions

View file

@ -25,10 +25,7 @@ from llama_stack.apis.scoring import Scoring
from llama_stack.apis.scoring_functions import ScoringFunctions
from llama_stack.apis.shields import Shields
from llama_stack.apis.telemetry import Telemetry
from llama_stack.distribution.distribution import (
builtin_automatically_routed_apis,
get_provider_registry,
)
from llama_stack.distribution.distribution import builtin_automatically_routed_apis
from llama_stack.distribution.utils.dynamic import instantiate_class_type
@ -65,14 +62,14 @@ class ProviderWithSpec(Provider):
# TODO: this code is not very straightforward to follow and needs one more round of refactoring
async def resolve_impls(run_config: StackRunConfig) -> Dict[Api, Any]:
async def resolve_impls(
run_config: StackRunConfig, provider_registry: Dict[Api, Dict[str, ProviderSpec]]
) -> Dict[Api, Any]:
"""
Does two things:
- flatmaps, sorts and resolves the providers in dependency order
- for each API, produces either a (local, passthrough or router) implementation
"""
all_api_providers = get_provider_registry()
routing_table_apis = set(
x.routing_table_api for x in builtin_automatically_routed_apis()
)
@ -89,12 +86,12 @@ async def resolve_impls(run_config: StackRunConfig) -> Dict[Api, Any]:
specs = {}
for provider in providers:
if provider.provider_type not in all_api_providers[api]:
if provider.provider_type not in provider_registry[api]:
raise ValueError(
f"Provider `{provider.provider_type}` is not available for API `{api}`"
)
p = all_api_providers[api][provider.provider_type]
p = provider_registry[api][provider.provider_type]
p.deps__ = [a.value for a in p.api_dependencies]
spec = ProviderWithSpec(
spec=p,