refactor: restructure resolver logic and improve type safety (#1323)

# What does this PR do?

- Modularized `resolve_impls` by extracting helper functions for
validation, sorting, and instantiation.
- Improved readability by introducing `validate_and_prepare_providers`,
`sort_providers_by_dependency`, and `instantiate_providers`.
- Enhanced type safety with explicit type hints (`Tuple`, `Dict`, `Set`,
etc.).
- Fixed potential issues with provider module imports and added error
handling.
- Updated `pyproject.toml` to enforce type checking on `resolver.py`
using `mypy`.

Signed-off-by: Sébastien Han <seb@redhat.com>

- [//]: # (If resolving an issue, uncomment and update the line below)
[//]: # (Closes #[issue-number])

## Test Plan

Run the server.

[//]: # (## Documentation)

Signed-off-by: Sébastien Han <seb@redhat.com>
This commit is contained in:
Sébastien Han 2025-03-03 19:45:12 +01:00 committed by GitHub
parent cae6c00d8a
commit f86154dff5
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
2 changed files with 119 additions and 84 deletions

View file

@ -5,7 +5,7 @@
# the root directory of this source tree. # the root directory of this source tree.
import importlib import importlib
import inspect import inspect
from typing import Any, Dict, List, Set from typing import Any, Dict, List, Set, Tuple
from llama_stack import logcat from llama_stack import logcat
from llama_stack.apis.agents import Agents from llama_stack.apis.agents import Agents
@ -102,21 +102,79 @@ class ProviderWithSpec(Provider):
ProviderRegistry = Dict[Api, Dict[str, ProviderSpec]] ProviderRegistry = Dict[Api, Dict[str, ProviderSpec]]
# TODO: this code is not very straightforward to follow and needs one more round of refactoring
async def resolve_impls( async def resolve_impls(
run_config: StackRunConfig, run_config: StackRunConfig,
provider_registry: ProviderRegistry, provider_registry: ProviderRegistry,
dist_registry: DistributionRegistry, dist_registry: DistributionRegistry,
) -> Dict[Api, Any]: ) -> Dict[Api, Any]:
""" """
Does two things: Resolves provider implementations by:
- flatmaps, sorts and resolves the providers in dependency order 1. Validating and organizing providers.
- for each API, produces either a (local, passthrough or router) implementation 2. Sorting them in dependency order.
3. Instantiating them with required dependencies.
""" """
routing_table_apis = {x.routing_table_api for x in builtin_automatically_routed_apis()} routing_table_apis = {x.routing_table_api for x in builtin_automatically_routed_apis()}
router_apis = {x.router_api for x in builtin_automatically_routed_apis()} router_apis = {x.router_api for x in builtin_automatically_routed_apis()}
providers_with_specs = {} providers_with_specs = validate_and_prepare_providers(
run_config, provider_registry, routing_table_apis, router_apis
)
apis_to_serve = run_config.apis or set(
list(providers_with_specs.keys()) + [x.value for x in routing_table_apis] + [x.value for x in router_apis]
)
providers_with_specs.update(specs_for_autorouted_apis(apis_to_serve))
sorted_providers = sort_providers_by_deps(providers_with_specs, run_config)
return await instantiate_providers(sorted_providers, router_apis, dist_registry)
def specs_for_autorouted_apis(apis_to_serve: List[str] | Set[str]) -> Dict[str, Dict[str, ProviderWithSpec]]:
"""Generates specifications for automatically routed APIs."""
specs = {}
for info in builtin_automatically_routed_apis():
if info.router_api.value not in apis_to_serve:
continue
specs[info.routing_table_api.value] = {
"__builtin__": ProviderWithSpec(
provider_id="__routing_table__",
provider_type="__routing_table__",
config={},
spec=RoutingTableProviderSpec(
api=info.routing_table_api,
router_api=info.router_api,
module="llama_stack.distribution.routers",
api_dependencies=[],
deps__=[f"inner-{info.router_api.value}"],
),
)
}
specs[info.router_api.value] = {
"__builtin__": ProviderWithSpec(
provider_id="__autorouted__",
provider_type="__autorouted__",
config={},
spec=AutoRoutedProviderSpec(
api=info.router_api,
module="llama_stack.distribution.routers",
routing_table_api=info.routing_table_api,
api_dependencies=[info.routing_table_api],
deps__=[info.routing_table_api.value],
),
)
}
return specs
def validate_and_prepare_providers(
run_config: StackRunConfig, provider_registry: ProviderRegistry, routing_table_apis: Set[Api], router_apis: Set[Api]
) -> Dict[str, Dict[str, ProviderWithSpec]]:
"""Validates providers, handles deprecations, and organizes them into a spec dictionary."""
providers_with_specs: Dict[str, Dict[str, ProviderWithSpec]] = {}
for api_str, providers in run_config.providers.items(): for api_str, providers in run_config.providers.items():
api = Api(api_str) api = Api(api_str)
@ -129,6 +187,20 @@ async def resolve_impls(
logcat.warning("core", f"Provider `{provider.provider_type}` for API `{api}` is disabled") logcat.warning("core", f"Provider `{provider.provider_type}` for API `{api}` is disabled")
continue continue
validate_provider(provider, api, provider_registry)
p = provider_registry[api][provider.provider_type]
p.deps__ = [a.value for a in p.api_dependencies] + [a.value for a in p.optional_api_dependencies]
spec = ProviderWithSpec(spec=p, **provider.model_dump())
specs[provider.provider_id] = spec
key = api_str if api not in router_apis else f"inner-{api_str}"
providers_with_specs[key] = specs
return providers_with_specs
def validate_provider(provider: Provider, api: Api, provider_registry: ProviderRegistry):
"""Validates if the provider is allowed and handles deprecations."""
if provider.provider_type not in provider_registry[api]: if provider.provider_type not in provider_registry[api]:
raise ValueError(f"Provider `{provider.provider_type}` is not available for API `{api}`") raise ValueError(f"Provider `{provider.provider_type}` is not available for API `{api}`")
@ -136,61 +208,22 @@ async def resolve_impls(
if p.deprecation_error: if p.deprecation_error:
logcat.error("core", p.deprecation_error) logcat.error("core", p.deprecation_error)
raise InvalidProviderError(p.deprecation_error) raise InvalidProviderError(p.deprecation_error)
elif p.deprecation_warning: elif p.deprecation_warning:
logcat.warning( logcat.warning(
"core", "core",
f"Provider `{provider.provider_type}` for API `{api}` is deprecated and will be removed in a future release: {p.deprecation_warning}", f"Provider `{provider.provider_type}` for API `{api}` is deprecated and will be removed in a future release: {p.deprecation_warning}",
) )
p.deps__ = [a.value for a in p.api_dependencies] + [a.value for a in p.optional_api_dependencies]
spec = ProviderWithSpec(
spec=p,
**(provider.model_dump()),
)
specs[provider.provider_id] = spec
key = api_str if api not in router_apis else f"inner-{api_str}"
providers_with_specs[key] = specs
apis_to_serve = run_config.apis or set( def sort_providers_by_deps(
list(providers_with_specs.keys()) + [x.value for x in routing_table_apis] + [x.value for x in router_apis] providers_with_specs: Dict[str, Dict[str, ProviderWithSpec]], run_config: StackRunConfig
) -> List[Tuple[str, ProviderWithSpec]]:
"""Sorts providers based on their dependencies."""
sorted_providers: List[Tuple[str, ProviderWithSpec]] = topological_sort(
{k: list(v.values()) for k, v in providers_with_specs.items()}
) )
for info in builtin_automatically_routed_apis(): # Append built-in "inspect" provider
if info.router_api.value not in apis_to_serve:
continue
providers_with_specs[info.routing_table_api.value] = {
"__builtin__": ProviderWithSpec(
provider_id="__routing_table__",
provider_type="__routing_table__",
config={},
spec=RoutingTableProviderSpec(
api=info.routing_table_api,
router_api=info.router_api,
module="llama_stack.distribution.routers",
api_dependencies=[],
deps__=([f"inner-{info.router_api.value}"]),
),
)
}
providers_with_specs[info.router_api.value] = {
"__builtin__": ProviderWithSpec(
provider_id="__autorouted__",
provider_type="__autorouted__",
config={},
spec=AutoRoutedProviderSpec(
api=info.router_api,
module="llama_stack.distribution.routers",
routing_table_api=info.routing_table_api,
api_dependencies=[info.routing_table_api],
deps__=([info.routing_table_api.value]),
),
)
}
sorted_providers = topological_sort({k: v.values() for k, v in providers_with_specs.items()})
apis = [x[1].spec.api for x in sorted_providers] apis = [x[1].spec.api for x in sorted_providers]
sorted_providers.append( sorted_providers.append(
( (
@ -198,16 +231,14 @@ async def resolve_impls(
ProviderWithSpec( ProviderWithSpec(
provider_id="__builtin__", provider_id="__builtin__",
provider_type="__builtin__", provider_type="__builtin__",
config={ config={"run_config": run_config.model_dump()},
"run_config": run_config.dict(),
},
spec=InlineProviderSpec( spec=InlineProviderSpec(
api=Api.inspect, api=Api.inspect,
provider_type="__builtin__", provider_type="__builtin__",
config_class="llama_stack.distribution.inspect.DistributionInspectConfig", config_class="llama_stack.distribution.inspect.DistributionInspectConfig",
module="llama_stack.distribution.inspect", module="llama_stack.distribution.inspect",
api_dependencies=apis, api_dependencies=apis,
deps__=([x.value for x in apis]), deps__=[x.value for x in apis],
), ),
), ),
) )
@ -217,9 +248,15 @@ async def resolve_impls(
for api_str, provider in sorted_providers: for api_str, provider in sorted_providers:
logcat.debug("core", f" {api_str} => {provider.provider_id}") logcat.debug("core", f" {api_str} => {provider.provider_id}")
logcat.debug("core", "") logcat.debug("core", "")
return sorted_providers
impls = {}
inner_impls_by_provider_id = {f"inner-{x.value}": {} for x in router_apis} async def instantiate_providers(
sorted_providers: List[Tuple[str, ProviderWithSpec]], router_apis: Set[Api], dist_registry: DistributionRegistry
) -> Dict:
"""Instantiates providers asynchronously while managing dependencies."""
impls: Dict[Api, Any] = {}
inner_impls_by_provider_id: Dict[str, Dict[str, Any]] = {f"inner-{x.value}": {} for x in router_apis}
for api_str, provider in sorted_providers: for api_str, provider in sorted_providers:
deps = {a: impls[a] for a in provider.spec.api_dependencies} deps = {a: impls[a] for a in provider.spec.api_dependencies}
for a in provider.spec.optional_api_dependencies: for a in provider.spec.optional_api_dependencies:
@ -230,14 +267,9 @@ async def resolve_impls(
if isinstance(provider.spec, RoutingTableProviderSpec): if isinstance(provider.spec, RoutingTableProviderSpec):
inner_impls = inner_impls_by_provider_id[f"inner-{provider.spec.router_api.value}"] inner_impls = inner_impls_by_provider_id[f"inner-{provider.spec.router_api.value}"]
impl = await instantiate_provider( impl = await instantiate_provider(provider, deps, inner_impls, dist_registry)
provider,
deps, if api_str.startswith("inner-"):
inner_impls,
dist_registry,
)
# TODO: ugh slightly redesign this shady looking code
if "inner-" in api_str:
inner_impls_by_provider_id[api_str][provider.provider_id] = impl inner_impls_by_provider_id[api_str][provider.provider_id] = impl
else: else:
api = Api(api_str) api = Api(api_str)
@ -248,7 +280,7 @@ async def resolve_impls(
def topological_sort( def topological_sort(
providers_with_specs: Dict[str, List[ProviderWithSpec]], providers_with_specs: Dict[str, List[ProviderWithSpec]],
) -> List[ProviderWithSpec]: ) -> List[Tuple[str, ProviderWithSpec]]:
def dfs(kv, visited: Set[str], stack: List[str]): def dfs(kv, visited: Set[str], stack: List[str]):
api_str, providers = kv api_str, providers = kv
visited.add(api_str) visited.add(api_str)
@ -264,8 +296,8 @@ def topological_sort(
stack.append(api_str) stack.append(api_str)
visited = set() visited: Set[str] = set()
stack = [] stack: List[str] = []
for api_str, providers in providers_with_specs.items(): for api_str, providers in providers_with_specs.items():
if api_str not in visited: if api_str not in visited:
@ -275,13 +307,14 @@ def topological_sort(
for api_str in stack: for api_str in stack:
for provider in providers_with_specs[api_str]: for provider in providers_with_specs[api_str]:
flattened.append((api_str, provider)) flattened.append((api_str, provider))
return flattened return flattened
# returns a class implementing the protocol corresponding to the Api # returns a class implementing the protocol corresponding to the Api
async def instantiate_provider( async def instantiate_provider(
provider: ProviderWithSpec, provider: ProviderWithSpec,
deps: Dict[str, Any], deps: Dict[Api, Any],
inner_impls: Dict[str, Any], inner_impls: Dict[str, Any],
dist_registry: DistributionRegistry, dist_registry: DistributionRegistry,
): ):
@ -289,8 +322,10 @@ async def instantiate_provider(
additional_protocols = additional_protocols_map() additional_protocols = additional_protocols_map()
provider_spec = provider.spec provider_spec = provider.spec
module = importlib.import_module(provider_spec.module) if not hasattr(provider_spec, "module"):
raise AttributeError(f"ProviderSpec of type {type(provider_spec)} does not have a 'module' attribute")
module = importlib.import_module(provider_spec.module)
args = [] args = []
if isinstance(provider_spec, RemoteProviderSpec): if isinstance(provider_spec, RemoteProviderSpec):
config_type = instantiate_class_type(provider_spec.config_class) config_type = instantiate_class_type(provider_spec.config_class)

View file

@ -75,11 +75,7 @@ docs = [
"sphinxcontrib.mermaid", "sphinxcontrib.mermaid",
"tomli", "tomli",
] ]
codegen = [ codegen = ["rich", "pydantic", "jinja2"]
"rich",
"pydantic",
"jinja2",
]
[project.urls] [project.urls]
Homepage = "https://github.com/meta-llama/llama-stack" Homepage = "https://github.com/meta-llama/llama-stack"
@ -163,3 +159,7 @@ exclude = [
# packages that lack typing annotations, do not have stubs, or are unavailable. # packages that lack typing annotations, do not have stubs, or are unavailable.
module = ["yaml", "fire"] module = ["yaml", "fire"]
ignore_missing_imports = true ignore_missing_imports = true
[[tool.mypy.overrides]]
module = "llama_stack.distribution.resolver"
follow_imports = "normal" # This will force type checking on this module