diff --git a/llama_stack/distribution/resolver.py b/llama_stack/distribution/resolver.py index b3a24bb7b..acbbee09d 100644 --- a/llama_stack/distribution/resolver.py +++ b/llama_stack/distribution/resolver.py @@ -5,7 +5,7 @@ # the root directory of this source tree. import importlib import inspect -from typing import Any, Dict, List, Set +from typing import Any, Dict, List, Set, Tuple from llama_stack import logcat from llama_stack.apis.agents import Agents @@ -102,21 +102,79 @@ class ProviderWithSpec(Provider): ProviderRegistry = Dict[Api, Dict[str, ProviderSpec]] -# TODO: this code is not very straightforward to follow and needs one more round of refactoring async def resolve_impls( run_config: StackRunConfig, provider_registry: ProviderRegistry, dist_registry: DistributionRegistry, ) -> Dict[Api, Any]: """ - Does two things: - - flatmaps, sorts and resolves the providers in dependency order - - for each API, produces either a (local, passthrough or router) implementation + Resolves provider implementations by: + 1. Validating and organizing providers. + 2. Sorting them in dependency order. + 3. Instantiating them with required dependencies. """ routing_table_apis = {x.routing_table_api for x in builtin_automatically_routed_apis()} router_apis = {x.router_api for x in builtin_automatically_routed_apis()} - providers_with_specs = {} + providers_with_specs = validate_and_prepare_providers( + run_config, provider_registry, routing_table_apis, router_apis + ) + + apis_to_serve = run_config.apis or set( + list(providers_with_specs.keys()) + [x.value for x in routing_table_apis] + [x.value for x in router_apis] + ) + + providers_with_specs.update(specs_for_autorouted_apis(apis_to_serve)) + + sorted_providers = sort_providers_by_deps(providers_with_specs, run_config) + + return await instantiate_providers(sorted_providers, router_apis, dist_registry) + + +def specs_for_autorouted_apis(apis_to_serve: List[str] | Set[str]) -> Dict[str, Dict[str, ProviderWithSpec]]: + """Generates specifications for automatically routed APIs.""" + specs = {} + for info in builtin_automatically_routed_apis(): + if info.router_api.value not in apis_to_serve: + continue + + specs[info.routing_table_api.value] = { + "__builtin__": ProviderWithSpec( + provider_id="__routing_table__", + provider_type="__routing_table__", + config={}, + spec=RoutingTableProviderSpec( + api=info.routing_table_api, + router_api=info.router_api, + module="llama_stack.distribution.routers", + api_dependencies=[], + deps__=[f"inner-{info.router_api.value}"], + ), + ) + } + + specs[info.router_api.value] = { + "__builtin__": ProviderWithSpec( + provider_id="__autorouted__", + provider_type="__autorouted__", + config={}, + spec=AutoRoutedProviderSpec( + api=info.router_api, + module="llama_stack.distribution.routers", + routing_table_api=info.routing_table_api, + api_dependencies=[info.routing_table_api], + deps__=[info.routing_table_api.value], + ), + ) + } + return specs + + +def validate_and_prepare_providers( + run_config: StackRunConfig, provider_registry: ProviderRegistry, routing_table_apis: Set[Api], router_apis: Set[Api] +) -> Dict[str, Dict[str, ProviderWithSpec]]: + """Validates providers, handles deprecations, and organizes them into a spec dictionary.""" + providers_with_specs: Dict[str, Dict[str, ProviderWithSpec]] = {} for api_str, providers in run_config.providers.items(): api = Api(api_str) @@ -129,68 +187,43 @@ async def resolve_impls( logcat.warning("core", f"Provider `{provider.provider_type}` for API `{api}` is disabled") continue - if provider.provider_type not in provider_registry[api]: - raise ValueError(f"Provider `{provider.provider_type}` is not available for API `{api}`") - + validate_provider(provider, api, provider_registry) p = provider_registry[api][provider.provider_type] - if p.deprecation_error: - logcat.error("core", p.deprecation_error) - raise InvalidProviderError(p.deprecation_error) - - elif p.deprecation_warning: - logcat.warning( - "core", - f"Provider `{provider.provider_type}` for API `{api}` is deprecated and will be removed in a future release: {p.deprecation_warning}", - ) p.deps__ = [a.value for a in p.api_dependencies] + [a.value for a in p.optional_api_dependencies] - spec = ProviderWithSpec( - spec=p, - **(provider.model_dump()), - ) + spec = ProviderWithSpec(spec=p, **provider.model_dump()) specs[provider.provider_id] = spec key = api_str if api not in router_apis else f"inner-{api_str}" providers_with_specs[key] = specs - apis_to_serve = run_config.apis or set( - list(providers_with_specs.keys()) + [x.value for x in routing_table_apis] + [x.value for x in router_apis] + return providers_with_specs + + +def validate_provider(provider: Provider, api: Api, provider_registry: ProviderRegistry): + """Validates if the provider is allowed and handles deprecations.""" + if provider.provider_type not in provider_registry[api]: + raise ValueError(f"Provider `{provider.provider_type}` is not available for API `{api}`") + + p = provider_registry[api][provider.provider_type] + if p.deprecation_error: + logcat.error("core", p.deprecation_error) + raise InvalidProviderError(p.deprecation_error) + elif p.deprecation_warning: + logcat.warning( + "core", + f"Provider `{provider.provider_type}` for API `{api}` is deprecated and will be removed in a future release: {p.deprecation_warning}", + ) + + +def sort_providers_by_deps( + providers_with_specs: Dict[str, Dict[str, ProviderWithSpec]], run_config: StackRunConfig +) -> List[Tuple[str, ProviderWithSpec]]: + """Sorts providers based on their dependencies.""" + sorted_providers: List[Tuple[str, ProviderWithSpec]] = topological_sort( + {k: list(v.values()) for k, v in providers_with_specs.items()} ) - for info in builtin_automatically_routed_apis(): - if info.router_api.value not in apis_to_serve: - continue - - providers_with_specs[info.routing_table_api.value] = { - "__builtin__": ProviderWithSpec( - provider_id="__routing_table__", - provider_type="__routing_table__", - config={}, - spec=RoutingTableProviderSpec( - api=info.routing_table_api, - router_api=info.router_api, - module="llama_stack.distribution.routers", - api_dependencies=[], - deps__=([f"inner-{info.router_api.value}"]), - ), - ) - } - - providers_with_specs[info.router_api.value] = { - "__builtin__": ProviderWithSpec( - provider_id="__autorouted__", - provider_type="__autorouted__", - config={}, - spec=AutoRoutedProviderSpec( - api=info.router_api, - module="llama_stack.distribution.routers", - routing_table_api=info.routing_table_api, - api_dependencies=[info.routing_table_api], - deps__=([info.routing_table_api.value]), - ), - ) - } - - sorted_providers = topological_sort({k: v.values() for k, v in providers_with_specs.items()}) + # Append built-in "inspect" provider apis = [x[1].spec.api for x in sorted_providers] sorted_providers.append( ( @@ -198,16 +231,14 @@ async def resolve_impls( ProviderWithSpec( provider_id="__builtin__", provider_type="__builtin__", - config={ - "run_config": run_config.dict(), - }, + config={"run_config": run_config.model_dump()}, spec=InlineProviderSpec( api=Api.inspect, provider_type="__builtin__", config_class="llama_stack.distribution.inspect.DistributionInspectConfig", module="llama_stack.distribution.inspect", api_dependencies=apis, - deps__=([x.value for x in apis]), + deps__=[x.value for x in apis], ), ), ) @@ -216,10 +247,16 @@ async def resolve_impls( logcat.debug("core", f"Resolved {len(sorted_providers)} providers") for api_str, provider in sorted_providers: logcat.debug("core", f" {api_str} => {provider.provider_id}") - logcat.debug("core", "") + logcat.debug("core", "") + return sorted_providers - impls = {} - inner_impls_by_provider_id = {f"inner-{x.value}": {} for x in router_apis} + +async def instantiate_providers( + sorted_providers: List[Tuple[str, ProviderWithSpec]], router_apis: Set[Api], dist_registry: DistributionRegistry +) -> Dict: + """Instantiates providers asynchronously while managing dependencies.""" + impls: Dict[Api, Any] = {} + inner_impls_by_provider_id: Dict[str, Dict[str, Any]] = {f"inner-{x.value}": {} for x in router_apis} for api_str, provider in sorted_providers: deps = {a: impls[a] for a in provider.spec.api_dependencies} for a in provider.spec.optional_api_dependencies: @@ -230,14 +267,9 @@ async def resolve_impls( if isinstance(provider.spec, RoutingTableProviderSpec): inner_impls = inner_impls_by_provider_id[f"inner-{provider.spec.router_api.value}"] - impl = await instantiate_provider( - provider, - deps, - inner_impls, - dist_registry, - ) - # TODO: ugh slightly redesign this shady looking code - if "inner-" in api_str: + impl = await instantiate_provider(provider, deps, inner_impls, dist_registry) + + if api_str.startswith("inner-"): inner_impls_by_provider_id[api_str][provider.provider_id] = impl else: api = Api(api_str) @@ -248,7 +280,7 @@ async def resolve_impls( def topological_sort( providers_with_specs: Dict[str, List[ProviderWithSpec]], -) -> List[ProviderWithSpec]: +) -> List[Tuple[str, ProviderWithSpec]]: def dfs(kv, visited: Set[str], stack: List[str]): api_str, providers = kv visited.add(api_str) @@ -264,8 +296,8 @@ def topological_sort( stack.append(api_str) - visited = set() - stack = [] + visited: Set[str] = set() + stack: List[str] = [] for api_str, providers in providers_with_specs.items(): if api_str not in visited: @@ -275,13 +307,14 @@ def topological_sort( for api_str in stack: for provider in providers_with_specs[api_str]: flattened.append((api_str, provider)) + return flattened # returns a class implementing the protocol corresponding to the Api async def instantiate_provider( provider: ProviderWithSpec, - deps: Dict[str, Any], + deps: Dict[Api, Any], inner_impls: Dict[str, Any], dist_registry: DistributionRegistry, ): @@ -289,8 +322,10 @@ async def instantiate_provider( additional_protocols = additional_protocols_map() provider_spec = provider.spec - module = importlib.import_module(provider_spec.module) + if not hasattr(provider_spec, "module"): + raise AttributeError(f"ProviderSpec of type {type(provider_spec)} does not have a 'module' attribute") + module = importlib.import_module(provider_spec.module) args = [] if isinstance(provider_spec, RemoteProviderSpec): config_type = instantiate_class_type(provider_spec.config_class) diff --git a/pyproject.toml b/pyproject.toml index 730af5888..8b604e6d3 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -75,11 +75,7 @@ docs = [ "sphinxcontrib.mermaid", "tomli", ] -codegen = [ - "rich", - "pydantic", - "jinja2", -] +codegen = ["rich", "pydantic", "jinja2"] [project.urls] Homepage = "https://github.com/meta-llama/llama-stack" @@ -163,3 +159,7 @@ exclude = [ # packages that lack typing annotations, do not have stubs, or are unavailable. module = ["yaml", "fire"] ignore_missing_imports = true + +[[tool.mypy.overrides]] +module = "llama_stack.distribution.resolver" +follow_imports = "normal" # This will force type checking on this module