refactor: restructure resolver logic and improve type safety (#1323)

# What does this PR do?

- Modularized `resolve_impls` by extracting helper functions for
validation, sorting, and instantiation.
- Improved readability by introducing `validate_and_prepare_providers`,
`sort_providers_by_dependency`, and `instantiate_providers`.
- Enhanced type safety with explicit type hints (`Tuple`, `Dict`, `Set`,
etc.).
- Fixed potential issues with provider module imports and added error
handling.
- Updated `pyproject.toml` to enforce type checking on `resolver.py`
using `mypy`.

Signed-off-by: Sébastien Han <seb@redhat.com>

- [//]: # (If resolving an issue, uncomment and update the line below)
[//]: # (Closes #[issue-number])

## Test Plan

Run the server.

[//]: # (## Documentation)

Signed-off-by: Sébastien Han <seb@redhat.com>
This commit is contained in:
Sébastien Han 2025-03-03 19:45:12 +01:00 committed by GitHub
parent cae6c00d8a
commit f86154dff5
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
2 changed files with 119 additions and 84 deletions

View file

@ -5,7 +5,7 @@
# the root directory of this source tree.
import importlib
import inspect
from typing import Any, Dict, List, Set
from typing import Any, Dict, List, Set, Tuple
from llama_stack import logcat
from llama_stack.apis.agents import Agents
@ -102,21 +102,79 @@ class ProviderWithSpec(Provider):
ProviderRegistry = Dict[Api, Dict[str, ProviderSpec]]
# TODO: this code is not very straightforward to follow and needs one more round of refactoring
async def resolve_impls(
run_config: StackRunConfig,
provider_registry: ProviderRegistry,
dist_registry: DistributionRegistry,
) -> Dict[Api, Any]:
"""
Does two things:
- flatmaps, sorts and resolves the providers in dependency order
- for each API, produces either a (local, passthrough or router) implementation
Resolves provider implementations by:
1. Validating and organizing providers.
2. Sorting them in dependency order.
3. Instantiating them with required dependencies.
"""
routing_table_apis = {x.routing_table_api for x in builtin_automatically_routed_apis()}
router_apis = {x.router_api for x in builtin_automatically_routed_apis()}
providers_with_specs = {}
providers_with_specs = validate_and_prepare_providers(
run_config, provider_registry, routing_table_apis, router_apis
)
apis_to_serve = run_config.apis or set(
list(providers_with_specs.keys()) + [x.value for x in routing_table_apis] + [x.value for x in router_apis]
)
providers_with_specs.update(specs_for_autorouted_apis(apis_to_serve))
sorted_providers = sort_providers_by_deps(providers_with_specs, run_config)
return await instantiate_providers(sorted_providers, router_apis, dist_registry)
def specs_for_autorouted_apis(apis_to_serve: List[str] | Set[str]) -> Dict[str, Dict[str, ProviderWithSpec]]:
"""Generates specifications for automatically routed APIs."""
specs = {}
for info in builtin_automatically_routed_apis():
if info.router_api.value not in apis_to_serve:
continue
specs[info.routing_table_api.value] = {
"__builtin__": ProviderWithSpec(
provider_id="__routing_table__",
provider_type="__routing_table__",
config={},
spec=RoutingTableProviderSpec(
api=info.routing_table_api,
router_api=info.router_api,
module="llama_stack.distribution.routers",
api_dependencies=[],
deps__=[f"inner-{info.router_api.value}"],
),
)
}
specs[info.router_api.value] = {
"__builtin__": ProviderWithSpec(
provider_id="__autorouted__",
provider_type="__autorouted__",
config={},
spec=AutoRoutedProviderSpec(
api=info.router_api,
module="llama_stack.distribution.routers",
routing_table_api=info.routing_table_api,
api_dependencies=[info.routing_table_api],
deps__=[info.routing_table_api.value],
),
)
}
return specs
def validate_and_prepare_providers(
run_config: StackRunConfig, provider_registry: ProviderRegistry, routing_table_apis: Set[Api], router_apis: Set[Api]
) -> Dict[str, Dict[str, ProviderWithSpec]]:
"""Validates providers, handles deprecations, and organizes them into a spec dictionary."""
providers_with_specs: Dict[str, Dict[str, ProviderWithSpec]] = {}
for api_str, providers in run_config.providers.items():
api = Api(api_str)
@ -129,68 +187,43 @@ async def resolve_impls(
logcat.warning("core", f"Provider `{provider.provider_type}` for API `{api}` is disabled")
continue
if provider.provider_type not in provider_registry[api]:
raise ValueError(f"Provider `{provider.provider_type}` is not available for API `{api}`")
validate_provider(provider, api, provider_registry)
p = provider_registry[api][provider.provider_type]
if p.deprecation_error:
logcat.error("core", p.deprecation_error)
raise InvalidProviderError(p.deprecation_error)
elif p.deprecation_warning:
logcat.warning(
"core",
f"Provider `{provider.provider_type}` for API `{api}` is deprecated and will be removed in a future release: {p.deprecation_warning}",
)
p.deps__ = [a.value for a in p.api_dependencies] + [a.value for a in p.optional_api_dependencies]
spec = ProviderWithSpec(
spec=p,
**(provider.model_dump()),
)
spec = ProviderWithSpec(spec=p, **provider.model_dump())
specs[provider.provider_id] = spec
key = api_str if api not in router_apis else f"inner-{api_str}"
providers_with_specs[key] = specs
apis_to_serve = run_config.apis or set(
list(providers_with_specs.keys()) + [x.value for x in routing_table_apis] + [x.value for x in router_apis]
return providers_with_specs
def validate_provider(provider: Provider, api: Api, provider_registry: ProviderRegistry):
"""Validates if the provider is allowed and handles deprecations."""
if provider.provider_type not in provider_registry[api]:
raise ValueError(f"Provider `{provider.provider_type}` is not available for API `{api}`")
p = provider_registry[api][provider.provider_type]
if p.deprecation_error:
logcat.error("core", p.deprecation_error)
raise InvalidProviderError(p.deprecation_error)
elif p.deprecation_warning:
logcat.warning(
"core",
f"Provider `{provider.provider_type}` for API `{api}` is deprecated and will be removed in a future release: {p.deprecation_warning}",
)
def sort_providers_by_deps(
providers_with_specs: Dict[str, Dict[str, ProviderWithSpec]], run_config: StackRunConfig
) -> List[Tuple[str, ProviderWithSpec]]:
"""Sorts providers based on their dependencies."""
sorted_providers: List[Tuple[str, ProviderWithSpec]] = topological_sort(
{k: list(v.values()) for k, v in providers_with_specs.items()}
)
for info in builtin_automatically_routed_apis():
if info.router_api.value not in apis_to_serve:
continue
providers_with_specs[info.routing_table_api.value] = {
"__builtin__": ProviderWithSpec(
provider_id="__routing_table__",
provider_type="__routing_table__",
config={},
spec=RoutingTableProviderSpec(
api=info.routing_table_api,
router_api=info.router_api,
module="llama_stack.distribution.routers",
api_dependencies=[],
deps__=([f"inner-{info.router_api.value}"]),
),
)
}
providers_with_specs[info.router_api.value] = {
"__builtin__": ProviderWithSpec(
provider_id="__autorouted__",
provider_type="__autorouted__",
config={},
spec=AutoRoutedProviderSpec(
api=info.router_api,
module="llama_stack.distribution.routers",
routing_table_api=info.routing_table_api,
api_dependencies=[info.routing_table_api],
deps__=([info.routing_table_api.value]),
),
)
}
sorted_providers = topological_sort({k: v.values() for k, v in providers_with_specs.items()})
# Append built-in "inspect" provider
apis = [x[1].spec.api for x in sorted_providers]
sorted_providers.append(
(
@ -198,16 +231,14 @@ async def resolve_impls(
ProviderWithSpec(
provider_id="__builtin__",
provider_type="__builtin__",
config={
"run_config": run_config.dict(),
},
config={"run_config": run_config.model_dump()},
spec=InlineProviderSpec(
api=Api.inspect,
provider_type="__builtin__",
config_class="llama_stack.distribution.inspect.DistributionInspectConfig",
module="llama_stack.distribution.inspect",
api_dependencies=apis,
deps__=([x.value for x in apis]),
deps__=[x.value for x in apis],
),
),
)
@ -216,10 +247,16 @@ async def resolve_impls(
logcat.debug("core", f"Resolved {len(sorted_providers)} providers")
for api_str, provider in sorted_providers:
logcat.debug("core", f" {api_str} => {provider.provider_id}")
logcat.debug("core", "")
logcat.debug("core", "")
return sorted_providers
impls = {}
inner_impls_by_provider_id = {f"inner-{x.value}": {} for x in router_apis}
async def instantiate_providers(
sorted_providers: List[Tuple[str, ProviderWithSpec]], router_apis: Set[Api], dist_registry: DistributionRegistry
) -> Dict:
"""Instantiates providers asynchronously while managing dependencies."""
impls: Dict[Api, Any] = {}
inner_impls_by_provider_id: Dict[str, Dict[str, Any]] = {f"inner-{x.value}": {} for x in router_apis}
for api_str, provider in sorted_providers:
deps = {a: impls[a] for a in provider.spec.api_dependencies}
for a in provider.spec.optional_api_dependencies:
@ -230,14 +267,9 @@ async def resolve_impls(
if isinstance(provider.spec, RoutingTableProviderSpec):
inner_impls = inner_impls_by_provider_id[f"inner-{provider.spec.router_api.value}"]
impl = await instantiate_provider(
provider,
deps,
inner_impls,
dist_registry,
)
# TODO: ugh slightly redesign this shady looking code
if "inner-" in api_str:
impl = await instantiate_provider(provider, deps, inner_impls, dist_registry)
if api_str.startswith("inner-"):
inner_impls_by_provider_id[api_str][provider.provider_id] = impl
else:
api = Api(api_str)
@ -248,7 +280,7 @@ async def resolve_impls(
def topological_sort(
providers_with_specs: Dict[str, List[ProviderWithSpec]],
) -> List[ProviderWithSpec]:
) -> List[Tuple[str, ProviderWithSpec]]:
def dfs(kv, visited: Set[str], stack: List[str]):
api_str, providers = kv
visited.add(api_str)
@ -264,8 +296,8 @@ def topological_sort(
stack.append(api_str)
visited = set()
stack = []
visited: Set[str] = set()
stack: List[str] = []
for api_str, providers in providers_with_specs.items():
if api_str not in visited:
@ -275,13 +307,14 @@ def topological_sort(
for api_str in stack:
for provider in providers_with_specs[api_str]:
flattened.append((api_str, provider))
return flattened
# returns a class implementing the protocol corresponding to the Api
async def instantiate_provider(
provider: ProviderWithSpec,
deps: Dict[str, Any],
deps: Dict[Api, Any],
inner_impls: Dict[str, Any],
dist_registry: DistributionRegistry,
):
@ -289,8 +322,10 @@ async def instantiate_provider(
additional_protocols = additional_protocols_map()
provider_spec = provider.spec
module = importlib.import_module(provider_spec.module)
if not hasattr(provider_spec, "module"):
raise AttributeError(f"ProviderSpec of type {type(provider_spec)} does not have a 'module' attribute")
module = importlib.import_module(provider_spec.module)
args = []
if isinstance(provider_spec, RemoteProviderSpec):
config_type = instantiate_class_type(provider_spec.config_class)

View file

@ -75,11 +75,7 @@ docs = [
"sphinxcontrib.mermaid",
"tomli",
]
codegen = [
"rich",
"pydantic",
"jinja2",
]
codegen = ["rich", "pydantic", "jinja2"]
[project.urls]
Homepage = "https://github.com/meta-llama/llama-stack"
@ -163,3 +159,7 @@ exclude = [
# packages that lack typing annotations, do not have stubs, or are unavailable.
module = ["yaml", "fire"]
ignore_missing_imports = true
[[tool.mypy.overrides]]
module = "llama_stack.distribution.resolver"
follow_imports = "normal" # This will force type checking on this module