From b7d2b83d55e09305094c3cf53008992e7b30a0d1 Mon Sep 17 00:00:00 2001 From: Ashwin Bharambe Date: Mon, 28 Oct 2024 11:58:16 -0700 Subject: [PATCH] Allow passing provider_registry to resolve_impls() --- llama_stack/distribution/resolver.py | 15 ++++++--------- llama_stack/distribution/server/server.py | 7 +++++-- llama_stack/providers/tests/resolver.py | 3 ++- 3 files changed, 13 insertions(+), 12 deletions(-) diff --git a/llama_stack/distribution/resolver.py b/llama_stack/distribution/resolver.py index cfe31a21d..bab807da9 100644 --- a/llama_stack/distribution/resolver.py +++ b/llama_stack/distribution/resolver.py @@ -25,10 +25,7 @@ from llama_stack.apis.scoring import Scoring from llama_stack.apis.scoring_functions import ScoringFunctions from llama_stack.apis.shields import Shields from llama_stack.apis.telemetry import Telemetry -from llama_stack.distribution.distribution import ( - builtin_automatically_routed_apis, - get_provider_registry, -) +from llama_stack.distribution.distribution import builtin_automatically_routed_apis from llama_stack.distribution.utils.dynamic import instantiate_class_type @@ -65,14 +62,14 @@ class ProviderWithSpec(Provider): # TODO: this code is not very straightforward to follow and needs one more round of refactoring -async def resolve_impls(run_config: StackRunConfig) -> Dict[Api, Any]: +async def resolve_impls( + run_config: StackRunConfig, provider_registry: Dict[Api, Dict[str, ProviderSpec]] +) -> Dict[Api, Any]: """ Does two things: - flatmaps, sorts and resolves the providers in dependency order - for each API, produces either a (local, passthrough or router) implementation """ - all_api_providers = get_provider_registry() - routing_table_apis = set( x.routing_table_api for x in builtin_automatically_routed_apis() ) @@ -89,12 +86,12 @@ async def resolve_impls(run_config: StackRunConfig) -> Dict[Api, Any]: specs = {} for provider in providers: - if provider.provider_type not in all_api_providers[api]: + if provider.provider_type not in provider_registry[api]: raise ValueError( f"Provider `{provider.provider_type}` is not available for API `{api}`" ) - p = all_api_providers[api][provider.provider_type] + p = provider_registry[api][provider.provider_type] p.deps__ = [a.value for a in p.api_dependencies] spec = ProviderWithSpec( spec=p, diff --git a/llama_stack/distribution/server/server.py b/llama_stack/distribution/server/server.py index e3d621fd6..b8fe4734e 100644 --- a/llama_stack/distribution/server/server.py +++ b/llama_stack/distribution/server/server.py @@ -26,7 +26,10 @@ from pydantic import BaseModel, ValidationError from termcolor import cprint from typing_extensions import Annotated -from llama_stack.distribution.distribution import builtin_automatically_routed_apis +from llama_stack.distribution.distribution import ( + builtin_automatically_routed_apis, + get_provider_registry, +) from llama_stack.providers.utils.telemetry.tracing import ( end_trace, @@ -276,7 +279,7 @@ def main( app = FastAPI() - impls = asyncio.run(resolve_impls(config)) + impls = asyncio.run(resolve_impls(config, get_provider_registry())) if Api.telemetry in impls: setup_logger(impls[Api.telemetry]) diff --git a/llama_stack/providers/tests/resolver.py b/llama_stack/providers/tests/resolver.py index de672b6dc..f211cc7d3 100644 --- a/llama_stack/providers/tests/resolver.py +++ b/llama_stack/providers/tests/resolver.py @@ -13,6 +13,7 @@ import yaml from llama_stack.distribution.datatypes import * # noqa: F403 from llama_stack.distribution.configure import parse_and_maybe_upgrade_config +from llama_stack.distribution.distribution import get_provider_registry from llama_stack.distribution.request_headers import set_request_provider_data from llama_stack.distribution.resolver import resolve_impls @@ -36,7 +37,7 @@ async def resolve_impls_for_test(api: Api, deps: List[Api] = None): providers=chosen, ) run_config = parse_and_maybe_upgrade_config(run_config) - impls = await resolve_impls(run_config) + impls = await resolve_impls(run_config, get_provider_registry()) if "provider_data" in config_dict: provider_id = chosen[api.value][0].provider_id