Allow passing provider_registry to resolve_impls()

This commit is contained in:
Ashwin Bharambe 2024-10-28 11:58:16 -07:00
parent 8a3b64d1be
commit b7d2b83d55
3 changed files with 13 additions and 12 deletions

View file

@ -25,10 +25,7 @@ from llama_stack.apis.scoring import Scoring
from llama_stack.apis.scoring_functions import ScoringFunctions
from llama_stack.apis.shields import Shields
from llama_stack.apis.telemetry import Telemetry
from llama_stack.distribution.distribution import (
builtin_automatically_routed_apis,
get_provider_registry,
)
from llama_stack.distribution.distribution import builtin_automatically_routed_apis
from llama_stack.distribution.utils.dynamic import instantiate_class_type
@ -65,14 +62,14 @@ class ProviderWithSpec(Provider):
# TODO: this code is not very straightforward to follow and needs one more round of refactoring
async def resolve_impls(run_config: StackRunConfig) -> Dict[Api, Any]:
async def resolve_impls(
run_config: StackRunConfig, provider_registry: Dict[Api, Dict[str, ProviderSpec]]
) -> Dict[Api, Any]:
"""
Does two things:
- flatmaps, sorts and resolves the providers in dependency order
- for each API, produces either a (local, passthrough or router) implementation
"""
all_api_providers = get_provider_registry()
routing_table_apis = set(
x.routing_table_api for x in builtin_automatically_routed_apis()
)
@ -89,12 +86,12 @@ async def resolve_impls(run_config: StackRunConfig) -> Dict[Api, Any]:
specs = {}
for provider in providers:
if provider.provider_type not in all_api_providers[api]:
if provider.provider_type not in provider_registry[api]:
raise ValueError(
f"Provider `{provider.provider_type}` is not available for API `{api}`"
)
p = all_api_providers[api][provider.provider_type]
p = provider_registry[api][provider.provider_type]
p.deps__ = [a.value for a in p.api_dependencies]
spec = ProviderWithSpec(
spec=p,

View file

@ -26,7 +26,10 @@ from pydantic import BaseModel, ValidationError
from termcolor import cprint
from typing_extensions import Annotated
from llama_stack.distribution.distribution import builtin_automatically_routed_apis
from llama_stack.distribution.distribution import (
builtin_automatically_routed_apis,
get_provider_registry,
)
from llama_stack.providers.utils.telemetry.tracing import (
end_trace,
@ -276,7 +279,7 @@ def main(
app = FastAPI()
impls = asyncio.run(resolve_impls(config))
impls = asyncio.run(resolve_impls(config, get_provider_registry()))
if Api.telemetry in impls:
setup_logger(impls[Api.telemetry])

View file

@ -13,6 +13,7 @@ import yaml
from llama_stack.distribution.datatypes import * # noqa: F403
from llama_stack.distribution.configure import parse_and_maybe_upgrade_config
from llama_stack.distribution.distribution import get_provider_registry
from llama_stack.distribution.request_headers import set_request_provider_data
from llama_stack.distribution.resolver import resolve_impls
@ -36,7 +37,7 @@ async def resolve_impls_for_test(api: Api, deps: List[Api] = None):
providers=chosen,
)
run_config = parse_and_maybe_upgrade_config(run_config)
impls = await resolve_impls(run_config)
impls = await resolve_impls(run_config, get_provider_registry())
if "provider_data" in config_dict:
provider_id = chosen[api.value][0].provider_id