forked from phoenix-oss/llama-stack-mirror
refactor: restructure resolver logic and improve type safety (#1323)
# What does this PR do? - Modularized `resolve_impls` by extracting helper functions for validation, sorting, and instantiation. - Improved readability by introducing `validate_and_prepare_providers`, `sort_providers_by_dependency`, and `instantiate_providers`. - Enhanced type safety with explicit type hints (`Tuple`, `Dict`, `Set`, etc.). - Fixed potential issues with provider module imports and added error handling. - Updated `pyproject.toml` to enforce type checking on `resolver.py` using `mypy`. Signed-off-by: Sébastien Han <seb@redhat.com> - [//]: # (If resolving an issue, uncomment and update the line below) [//]: # (Closes #[issue-number]) ## Test Plan Run the server. [//]: # (## Documentation) Signed-off-by: Sébastien Han <seb@redhat.com>
This commit is contained in:
parent
cae6c00d8a
commit
f86154dff5
2 changed files with 119 additions and 84 deletions
|
@ -5,7 +5,7 @@
|
|||
# the root directory of this source tree.
|
||||
import importlib
|
||||
import inspect
|
||||
from typing import Any, Dict, List, Set
|
||||
from typing import Any, Dict, List, Set, Tuple
|
||||
|
||||
from llama_stack import logcat
|
||||
from llama_stack.apis.agents import Agents
|
||||
|
@ -102,21 +102,79 @@ class ProviderWithSpec(Provider):
|
|||
ProviderRegistry = Dict[Api, Dict[str, ProviderSpec]]
|
||||
|
||||
|
||||
# TODO: this code is not very straightforward to follow and needs one more round of refactoring
|
||||
async def resolve_impls(
|
||||
run_config: StackRunConfig,
|
||||
provider_registry: ProviderRegistry,
|
||||
dist_registry: DistributionRegistry,
|
||||
) -> Dict[Api, Any]:
|
||||
"""
|
||||
Does two things:
|
||||
- flatmaps, sorts and resolves the providers in dependency order
|
||||
- for each API, produces either a (local, passthrough or router) implementation
|
||||
Resolves provider implementations by:
|
||||
1. Validating and organizing providers.
|
||||
2. Sorting them in dependency order.
|
||||
3. Instantiating them with required dependencies.
|
||||
"""
|
||||
routing_table_apis = {x.routing_table_api for x in builtin_automatically_routed_apis()}
|
||||
router_apis = {x.router_api for x in builtin_automatically_routed_apis()}
|
||||
|
||||
providers_with_specs = {}
|
||||
providers_with_specs = validate_and_prepare_providers(
|
||||
run_config, provider_registry, routing_table_apis, router_apis
|
||||
)
|
||||
|
||||
apis_to_serve = run_config.apis or set(
|
||||
list(providers_with_specs.keys()) + [x.value for x in routing_table_apis] + [x.value for x in router_apis]
|
||||
)
|
||||
|
||||
providers_with_specs.update(specs_for_autorouted_apis(apis_to_serve))
|
||||
|
||||
sorted_providers = sort_providers_by_deps(providers_with_specs, run_config)
|
||||
|
||||
return await instantiate_providers(sorted_providers, router_apis, dist_registry)
|
||||
|
||||
|
||||
def specs_for_autorouted_apis(apis_to_serve: List[str] | Set[str]) -> Dict[str, Dict[str, ProviderWithSpec]]:
|
||||
"""Generates specifications for automatically routed APIs."""
|
||||
specs = {}
|
||||
for info in builtin_automatically_routed_apis():
|
||||
if info.router_api.value not in apis_to_serve:
|
||||
continue
|
||||
|
||||
specs[info.routing_table_api.value] = {
|
||||
"__builtin__": ProviderWithSpec(
|
||||
provider_id="__routing_table__",
|
||||
provider_type="__routing_table__",
|
||||
config={},
|
||||
spec=RoutingTableProviderSpec(
|
||||
api=info.routing_table_api,
|
||||
router_api=info.router_api,
|
||||
module="llama_stack.distribution.routers",
|
||||
api_dependencies=[],
|
||||
deps__=[f"inner-{info.router_api.value}"],
|
||||
),
|
||||
)
|
||||
}
|
||||
|
||||
specs[info.router_api.value] = {
|
||||
"__builtin__": ProviderWithSpec(
|
||||
provider_id="__autorouted__",
|
||||
provider_type="__autorouted__",
|
||||
config={},
|
||||
spec=AutoRoutedProviderSpec(
|
||||
api=info.router_api,
|
||||
module="llama_stack.distribution.routers",
|
||||
routing_table_api=info.routing_table_api,
|
||||
api_dependencies=[info.routing_table_api],
|
||||
deps__=[info.routing_table_api.value],
|
||||
),
|
||||
)
|
||||
}
|
||||
return specs
|
||||
|
||||
|
||||
def validate_and_prepare_providers(
|
||||
run_config: StackRunConfig, provider_registry: ProviderRegistry, routing_table_apis: Set[Api], router_apis: Set[Api]
|
||||
) -> Dict[str, Dict[str, ProviderWithSpec]]:
|
||||
"""Validates providers, handles deprecations, and organizes them into a spec dictionary."""
|
||||
providers_with_specs: Dict[str, Dict[str, ProviderWithSpec]] = {}
|
||||
|
||||
for api_str, providers in run_config.providers.items():
|
||||
api = Api(api_str)
|
||||
|
@ -129,68 +187,43 @@ async def resolve_impls(
|
|||
logcat.warning("core", f"Provider `{provider.provider_type}` for API `{api}` is disabled")
|
||||
continue
|
||||
|
||||
if provider.provider_type not in provider_registry[api]:
|
||||
raise ValueError(f"Provider `{provider.provider_type}` is not available for API `{api}`")
|
||||
|
||||
validate_provider(provider, api, provider_registry)
|
||||
p = provider_registry[api][provider.provider_type]
|
||||
if p.deprecation_error:
|
||||
logcat.error("core", p.deprecation_error)
|
||||
raise InvalidProviderError(p.deprecation_error)
|
||||
|
||||
elif p.deprecation_warning:
|
||||
logcat.warning(
|
||||
"core",
|
||||
f"Provider `{provider.provider_type}` for API `{api}` is deprecated and will be removed in a future release: {p.deprecation_warning}",
|
||||
)
|
||||
p.deps__ = [a.value for a in p.api_dependencies] + [a.value for a in p.optional_api_dependencies]
|
||||
spec = ProviderWithSpec(
|
||||
spec=p,
|
||||
**(provider.model_dump()),
|
||||
)
|
||||
spec = ProviderWithSpec(spec=p, **provider.model_dump())
|
||||
specs[provider.provider_id] = spec
|
||||
|
||||
key = api_str if api not in router_apis else f"inner-{api_str}"
|
||||
providers_with_specs[key] = specs
|
||||
|
||||
apis_to_serve = run_config.apis or set(
|
||||
list(providers_with_specs.keys()) + [x.value for x in routing_table_apis] + [x.value for x in router_apis]
|
||||
return providers_with_specs
|
||||
|
||||
|
||||
def validate_provider(provider: Provider, api: Api, provider_registry: ProviderRegistry):
|
||||
"""Validates if the provider is allowed and handles deprecations."""
|
||||
if provider.provider_type not in provider_registry[api]:
|
||||
raise ValueError(f"Provider `{provider.provider_type}` is not available for API `{api}`")
|
||||
|
||||
p = provider_registry[api][provider.provider_type]
|
||||
if p.deprecation_error:
|
||||
logcat.error("core", p.deprecation_error)
|
||||
raise InvalidProviderError(p.deprecation_error)
|
||||
elif p.deprecation_warning:
|
||||
logcat.warning(
|
||||
"core",
|
||||
f"Provider `{provider.provider_type}` for API `{api}` is deprecated and will be removed in a future release: {p.deprecation_warning}",
|
||||
)
|
||||
|
||||
|
||||
def sort_providers_by_deps(
|
||||
providers_with_specs: Dict[str, Dict[str, ProviderWithSpec]], run_config: StackRunConfig
|
||||
) -> List[Tuple[str, ProviderWithSpec]]:
|
||||
"""Sorts providers based on their dependencies."""
|
||||
sorted_providers: List[Tuple[str, ProviderWithSpec]] = topological_sort(
|
||||
{k: list(v.values()) for k, v in providers_with_specs.items()}
|
||||
)
|
||||
|
||||
for info in builtin_automatically_routed_apis():
|
||||
if info.router_api.value not in apis_to_serve:
|
||||
continue
|
||||
|
||||
providers_with_specs[info.routing_table_api.value] = {
|
||||
"__builtin__": ProviderWithSpec(
|
||||
provider_id="__routing_table__",
|
||||
provider_type="__routing_table__",
|
||||
config={},
|
||||
spec=RoutingTableProviderSpec(
|
||||
api=info.routing_table_api,
|
||||
router_api=info.router_api,
|
||||
module="llama_stack.distribution.routers",
|
||||
api_dependencies=[],
|
||||
deps__=([f"inner-{info.router_api.value}"]),
|
||||
),
|
||||
)
|
||||
}
|
||||
|
||||
providers_with_specs[info.router_api.value] = {
|
||||
"__builtin__": ProviderWithSpec(
|
||||
provider_id="__autorouted__",
|
||||
provider_type="__autorouted__",
|
||||
config={},
|
||||
spec=AutoRoutedProviderSpec(
|
||||
api=info.router_api,
|
||||
module="llama_stack.distribution.routers",
|
||||
routing_table_api=info.routing_table_api,
|
||||
api_dependencies=[info.routing_table_api],
|
||||
deps__=([info.routing_table_api.value]),
|
||||
),
|
||||
)
|
||||
}
|
||||
|
||||
sorted_providers = topological_sort({k: v.values() for k, v in providers_with_specs.items()})
|
||||
# Append built-in "inspect" provider
|
||||
apis = [x[1].spec.api for x in sorted_providers]
|
||||
sorted_providers.append(
|
||||
(
|
||||
|
@ -198,16 +231,14 @@ async def resolve_impls(
|
|||
ProviderWithSpec(
|
||||
provider_id="__builtin__",
|
||||
provider_type="__builtin__",
|
||||
config={
|
||||
"run_config": run_config.dict(),
|
||||
},
|
||||
config={"run_config": run_config.model_dump()},
|
||||
spec=InlineProviderSpec(
|
||||
api=Api.inspect,
|
||||
provider_type="__builtin__",
|
||||
config_class="llama_stack.distribution.inspect.DistributionInspectConfig",
|
||||
module="llama_stack.distribution.inspect",
|
||||
api_dependencies=apis,
|
||||
deps__=([x.value for x in apis]),
|
||||
deps__=[x.value for x in apis],
|
||||
),
|
||||
),
|
||||
)
|
||||
|
@ -216,10 +247,16 @@ async def resolve_impls(
|
|||
logcat.debug("core", f"Resolved {len(sorted_providers)} providers")
|
||||
for api_str, provider in sorted_providers:
|
||||
logcat.debug("core", f" {api_str} => {provider.provider_id}")
|
||||
logcat.debug("core", "")
|
||||
logcat.debug("core", "")
|
||||
return sorted_providers
|
||||
|
||||
impls = {}
|
||||
inner_impls_by_provider_id = {f"inner-{x.value}": {} for x in router_apis}
|
||||
|
||||
async def instantiate_providers(
|
||||
sorted_providers: List[Tuple[str, ProviderWithSpec]], router_apis: Set[Api], dist_registry: DistributionRegistry
|
||||
) -> Dict:
|
||||
"""Instantiates providers asynchronously while managing dependencies."""
|
||||
impls: Dict[Api, Any] = {}
|
||||
inner_impls_by_provider_id: Dict[str, Dict[str, Any]] = {f"inner-{x.value}": {} for x in router_apis}
|
||||
for api_str, provider in sorted_providers:
|
||||
deps = {a: impls[a] for a in provider.spec.api_dependencies}
|
||||
for a in provider.spec.optional_api_dependencies:
|
||||
|
@ -230,14 +267,9 @@ async def resolve_impls(
|
|||
if isinstance(provider.spec, RoutingTableProviderSpec):
|
||||
inner_impls = inner_impls_by_provider_id[f"inner-{provider.spec.router_api.value}"]
|
||||
|
||||
impl = await instantiate_provider(
|
||||
provider,
|
||||
deps,
|
||||
inner_impls,
|
||||
dist_registry,
|
||||
)
|
||||
# TODO: ugh slightly redesign this shady looking code
|
||||
if "inner-" in api_str:
|
||||
impl = await instantiate_provider(provider, deps, inner_impls, dist_registry)
|
||||
|
||||
if api_str.startswith("inner-"):
|
||||
inner_impls_by_provider_id[api_str][provider.provider_id] = impl
|
||||
else:
|
||||
api = Api(api_str)
|
||||
|
@ -248,7 +280,7 @@ async def resolve_impls(
|
|||
|
||||
def topological_sort(
|
||||
providers_with_specs: Dict[str, List[ProviderWithSpec]],
|
||||
) -> List[ProviderWithSpec]:
|
||||
) -> List[Tuple[str, ProviderWithSpec]]:
|
||||
def dfs(kv, visited: Set[str], stack: List[str]):
|
||||
api_str, providers = kv
|
||||
visited.add(api_str)
|
||||
|
@ -264,8 +296,8 @@ def topological_sort(
|
|||
|
||||
stack.append(api_str)
|
||||
|
||||
visited = set()
|
||||
stack = []
|
||||
visited: Set[str] = set()
|
||||
stack: List[str] = []
|
||||
|
||||
for api_str, providers in providers_with_specs.items():
|
||||
if api_str not in visited:
|
||||
|
@ -275,13 +307,14 @@ def topological_sort(
|
|||
for api_str in stack:
|
||||
for provider in providers_with_specs[api_str]:
|
||||
flattened.append((api_str, provider))
|
||||
|
||||
return flattened
|
||||
|
||||
|
||||
# returns a class implementing the protocol corresponding to the Api
|
||||
async def instantiate_provider(
|
||||
provider: ProviderWithSpec,
|
||||
deps: Dict[str, Any],
|
||||
deps: Dict[Api, Any],
|
||||
inner_impls: Dict[str, Any],
|
||||
dist_registry: DistributionRegistry,
|
||||
):
|
||||
|
@ -289,8 +322,10 @@ async def instantiate_provider(
|
|||
additional_protocols = additional_protocols_map()
|
||||
|
||||
provider_spec = provider.spec
|
||||
module = importlib.import_module(provider_spec.module)
|
||||
if not hasattr(provider_spec, "module"):
|
||||
raise AttributeError(f"ProviderSpec of type {type(provider_spec)} does not have a 'module' attribute")
|
||||
|
||||
module = importlib.import_module(provider_spec.module)
|
||||
args = []
|
||||
if isinstance(provider_spec, RemoteProviderSpec):
|
||||
config_type = instantiate_class_type(provider_spec.config_class)
|
||||
|
|
|
@ -75,11 +75,7 @@ docs = [
|
|||
"sphinxcontrib.mermaid",
|
||||
"tomli",
|
||||
]
|
||||
codegen = [
|
||||
"rich",
|
||||
"pydantic",
|
||||
"jinja2",
|
||||
]
|
||||
codegen = ["rich", "pydantic", "jinja2"]
|
||||
|
||||
[project.urls]
|
||||
Homepage = "https://github.com/meta-llama/llama-stack"
|
||||
|
@ -163,3 +159,7 @@ exclude = [
|
|||
# packages that lack typing annotations, do not have stubs, or are unavailable.
|
||||
module = ["yaml", "fire"]
|
||||
ignore_missing_imports = true
|
||||
|
||||
[[tool.mypy.overrides]]
|
||||
module = "llama_stack.distribution.resolver"
|
||||
follow_imports = "normal" # This will force type checking on this module
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue