forked from phoenix-oss/llama-stack-mirror
refactor: restructure resolver logic and improve type safety (#1323)
# What does this PR do? - Modularized `resolve_impls` by extracting helper functions for validation, sorting, and instantiation. - Improved readability by introducing `validate_and_prepare_providers`, `sort_providers_by_dependency`, and `instantiate_providers`. - Enhanced type safety with explicit type hints (`Tuple`, `Dict`, `Set`, etc.). - Fixed potential issues with provider module imports and added error handling. - Updated `pyproject.toml` to enforce type checking on `resolver.py` using `mypy`. Signed-off-by: Sébastien Han <seb@redhat.com> - [//]: # (If resolving an issue, uncomment and update the line below) [//]: # (Closes #[issue-number]) ## Test Plan Run the server. [//]: # (## Documentation) Signed-off-by: Sébastien Han <seb@redhat.com>
This commit is contained in:
parent
cae6c00d8a
commit
f86154dff5
2 changed files with 119 additions and 84 deletions
|
@ -5,7 +5,7 @@
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
import importlib
|
import importlib
|
||||||
import inspect
|
import inspect
|
||||||
from typing import Any, Dict, List, Set
|
from typing import Any, Dict, List, Set, Tuple
|
||||||
|
|
||||||
from llama_stack import logcat
|
from llama_stack import logcat
|
||||||
from llama_stack.apis.agents import Agents
|
from llama_stack.apis.agents import Agents
|
||||||
|
@ -102,21 +102,79 @@ class ProviderWithSpec(Provider):
|
||||||
ProviderRegistry = Dict[Api, Dict[str, ProviderSpec]]
|
ProviderRegistry = Dict[Api, Dict[str, ProviderSpec]]
|
||||||
|
|
||||||
|
|
||||||
# TODO: this code is not very straightforward to follow and needs one more round of refactoring
|
|
||||||
async def resolve_impls(
|
async def resolve_impls(
|
||||||
run_config: StackRunConfig,
|
run_config: StackRunConfig,
|
||||||
provider_registry: ProviderRegistry,
|
provider_registry: ProviderRegistry,
|
||||||
dist_registry: DistributionRegistry,
|
dist_registry: DistributionRegistry,
|
||||||
) -> Dict[Api, Any]:
|
) -> Dict[Api, Any]:
|
||||||
"""
|
"""
|
||||||
Does two things:
|
Resolves provider implementations by:
|
||||||
- flatmaps, sorts and resolves the providers in dependency order
|
1. Validating and organizing providers.
|
||||||
- for each API, produces either a (local, passthrough or router) implementation
|
2. Sorting them in dependency order.
|
||||||
|
3. Instantiating them with required dependencies.
|
||||||
"""
|
"""
|
||||||
routing_table_apis = {x.routing_table_api for x in builtin_automatically_routed_apis()}
|
routing_table_apis = {x.routing_table_api for x in builtin_automatically_routed_apis()}
|
||||||
router_apis = {x.router_api for x in builtin_automatically_routed_apis()}
|
router_apis = {x.router_api for x in builtin_automatically_routed_apis()}
|
||||||
|
|
||||||
providers_with_specs = {}
|
providers_with_specs = validate_and_prepare_providers(
|
||||||
|
run_config, provider_registry, routing_table_apis, router_apis
|
||||||
|
)
|
||||||
|
|
||||||
|
apis_to_serve = run_config.apis or set(
|
||||||
|
list(providers_with_specs.keys()) + [x.value for x in routing_table_apis] + [x.value for x in router_apis]
|
||||||
|
)
|
||||||
|
|
||||||
|
providers_with_specs.update(specs_for_autorouted_apis(apis_to_serve))
|
||||||
|
|
||||||
|
sorted_providers = sort_providers_by_deps(providers_with_specs, run_config)
|
||||||
|
|
||||||
|
return await instantiate_providers(sorted_providers, router_apis, dist_registry)
|
||||||
|
|
||||||
|
|
||||||
|
def specs_for_autorouted_apis(apis_to_serve: List[str] | Set[str]) -> Dict[str, Dict[str, ProviderWithSpec]]:
|
||||||
|
"""Generates specifications for automatically routed APIs."""
|
||||||
|
specs = {}
|
||||||
|
for info in builtin_automatically_routed_apis():
|
||||||
|
if info.router_api.value not in apis_to_serve:
|
||||||
|
continue
|
||||||
|
|
||||||
|
specs[info.routing_table_api.value] = {
|
||||||
|
"__builtin__": ProviderWithSpec(
|
||||||
|
provider_id="__routing_table__",
|
||||||
|
provider_type="__routing_table__",
|
||||||
|
config={},
|
||||||
|
spec=RoutingTableProviderSpec(
|
||||||
|
api=info.routing_table_api,
|
||||||
|
router_api=info.router_api,
|
||||||
|
module="llama_stack.distribution.routers",
|
||||||
|
api_dependencies=[],
|
||||||
|
deps__=[f"inner-{info.router_api.value}"],
|
||||||
|
),
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
specs[info.router_api.value] = {
|
||||||
|
"__builtin__": ProviderWithSpec(
|
||||||
|
provider_id="__autorouted__",
|
||||||
|
provider_type="__autorouted__",
|
||||||
|
config={},
|
||||||
|
spec=AutoRoutedProviderSpec(
|
||||||
|
api=info.router_api,
|
||||||
|
module="llama_stack.distribution.routers",
|
||||||
|
routing_table_api=info.routing_table_api,
|
||||||
|
api_dependencies=[info.routing_table_api],
|
||||||
|
deps__=[info.routing_table_api.value],
|
||||||
|
),
|
||||||
|
)
|
||||||
|
}
|
||||||
|
return specs
|
||||||
|
|
||||||
|
|
||||||
|
def validate_and_prepare_providers(
|
||||||
|
run_config: StackRunConfig, provider_registry: ProviderRegistry, routing_table_apis: Set[Api], router_apis: Set[Api]
|
||||||
|
) -> Dict[str, Dict[str, ProviderWithSpec]]:
|
||||||
|
"""Validates providers, handles deprecations, and organizes them into a spec dictionary."""
|
||||||
|
providers_with_specs: Dict[str, Dict[str, ProviderWithSpec]] = {}
|
||||||
|
|
||||||
for api_str, providers in run_config.providers.items():
|
for api_str, providers in run_config.providers.items():
|
||||||
api = Api(api_str)
|
api = Api(api_str)
|
||||||
|
@ -129,6 +187,20 @@ async def resolve_impls(
|
||||||
logcat.warning("core", f"Provider `{provider.provider_type}` for API `{api}` is disabled")
|
logcat.warning("core", f"Provider `{provider.provider_type}` for API `{api}` is disabled")
|
||||||
continue
|
continue
|
||||||
|
|
||||||
|
validate_provider(provider, api, provider_registry)
|
||||||
|
p = provider_registry[api][provider.provider_type]
|
||||||
|
p.deps__ = [a.value for a in p.api_dependencies] + [a.value for a in p.optional_api_dependencies]
|
||||||
|
spec = ProviderWithSpec(spec=p, **provider.model_dump())
|
||||||
|
specs[provider.provider_id] = spec
|
||||||
|
|
||||||
|
key = api_str if api not in router_apis else f"inner-{api_str}"
|
||||||
|
providers_with_specs[key] = specs
|
||||||
|
|
||||||
|
return providers_with_specs
|
||||||
|
|
||||||
|
|
||||||
|
def validate_provider(provider: Provider, api: Api, provider_registry: ProviderRegistry):
|
||||||
|
"""Validates if the provider is allowed and handles deprecations."""
|
||||||
if provider.provider_type not in provider_registry[api]:
|
if provider.provider_type not in provider_registry[api]:
|
||||||
raise ValueError(f"Provider `{provider.provider_type}` is not available for API `{api}`")
|
raise ValueError(f"Provider `{provider.provider_type}` is not available for API `{api}`")
|
||||||
|
|
||||||
|
@ -136,61 +208,22 @@ async def resolve_impls(
|
||||||
if p.deprecation_error:
|
if p.deprecation_error:
|
||||||
logcat.error("core", p.deprecation_error)
|
logcat.error("core", p.deprecation_error)
|
||||||
raise InvalidProviderError(p.deprecation_error)
|
raise InvalidProviderError(p.deprecation_error)
|
||||||
|
|
||||||
elif p.deprecation_warning:
|
elif p.deprecation_warning:
|
||||||
logcat.warning(
|
logcat.warning(
|
||||||
"core",
|
"core",
|
||||||
f"Provider `{provider.provider_type}` for API `{api}` is deprecated and will be removed in a future release: {p.deprecation_warning}",
|
f"Provider `{provider.provider_type}` for API `{api}` is deprecated and will be removed in a future release: {p.deprecation_warning}",
|
||||||
)
|
)
|
||||||
p.deps__ = [a.value for a in p.api_dependencies] + [a.value for a in p.optional_api_dependencies]
|
|
||||||
spec = ProviderWithSpec(
|
|
||||||
spec=p,
|
|
||||||
**(provider.model_dump()),
|
|
||||||
)
|
|
||||||
specs[provider.provider_id] = spec
|
|
||||||
|
|
||||||
key = api_str if api not in router_apis else f"inner-{api_str}"
|
|
||||||
providers_with_specs[key] = specs
|
|
||||||
|
|
||||||
apis_to_serve = run_config.apis or set(
|
def sort_providers_by_deps(
|
||||||
list(providers_with_specs.keys()) + [x.value for x in routing_table_apis] + [x.value for x in router_apis]
|
providers_with_specs: Dict[str, Dict[str, ProviderWithSpec]], run_config: StackRunConfig
|
||||||
|
) -> List[Tuple[str, ProviderWithSpec]]:
|
||||||
|
"""Sorts providers based on their dependencies."""
|
||||||
|
sorted_providers: List[Tuple[str, ProviderWithSpec]] = topological_sort(
|
||||||
|
{k: list(v.values()) for k, v in providers_with_specs.items()}
|
||||||
)
|
)
|
||||||
|
|
||||||
for info in builtin_automatically_routed_apis():
|
# Append built-in "inspect" provider
|
||||||
if info.router_api.value not in apis_to_serve:
|
|
||||||
continue
|
|
||||||
|
|
||||||
providers_with_specs[info.routing_table_api.value] = {
|
|
||||||
"__builtin__": ProviderWithSpec(
|
|
||||||
provider_id="__routing_table__",
|
|
||||||
provider_type="__routing_table__",
|
|
||||||
config={},
|
|
||||||
spec=RoutingTableProviderSpec(
|
|
||||||
api=info.routing_table_api,
|
|
||||||
router_api=info.router_api,
|
|
||||||
module="llama_stack.distribution.routers",
|
|
||||||
api_dependencies=[],
|
|
||||||
deps__=([f"inner-{info.router_api.value}"]),
|
|
||||||
),
|
|
||||||
)
|
|
||||||
}
|
|
||||||
|
|
||||||
providers_with_specs[info.router_api.value] = {
|
|
||||||
"__builtin__": ProviderWithSpec(
|
|
||||||
provider_id="__autorouted__",
|
|
||||||
provider_type="__autorouted__",
|
|
||||||
config={},
|
|
||||||
spec=AutoRoutedProviderSpec(
|
|
||||||
api=info.router_api,
|
|
||||||
module="llama_stack.distribution.routers",
|
|
||||||
routing_table_api=info.routing_table_api,
|
|
||||||
api_dependencies=[info.routing_table_api],
|
|
||||||
deps__=([info.routing_table_api.value]),
|
|
||||||
),
|
|
||||||
)
|
|
||||||
}
|
|
||||||
|
|
||||||
sorted_providers = topological_sort({k: v.values() for k, v in providers_with_specs.items()})
|
|
||||||
apis = [x[1].spec.api for x in sorted_providers]
|
apis = [x[1].spec.api for x in sorted_providers]
|
||||||
sorted_providers.append(
|
sorted_providers.append(
|
||||||
(
|
(
|
||||||
|
@ -198,16 +231,14 @@ async def resolve_impls(
|
||||||
ProviderWithSpec(
|
ProviderWithSpec(
|
||||||
provider_id="__builtin__",
|
provider_id="__builtin__",
|
||||||
provider_type="__builtin__",
|
provider_type="__builtin__",
|
||||||
config={
|
config={"run_config": run_config.model_dump()},
|
||||||
"run_config": run_config.dict(),
|
|
||||||
},
|
|
||||||
spec=InlineProviderSpec(
|
spec=InlineProviderSpec(
|
||||||
api=Api.inspect,
|
api=Api.inspect,
|
||||||
provider_type="__builtin__",
|
provider_type="__builtin__",
|
||||||
config_class="llama_stack.distribution.inspect.DistributionInspectConfig",
|
config_class="llama_stack.distribution.inspect.DistributionInspectConfig",
|
||||||
module="llama_stack.distribution.inspect",
|
module="llama_stack.distribution.inspect",
|
||||||
api_dependencies=apis,
|
api_dependencies=apis,
|
||||||
deps__=([x.value for x in apis]),
|
deps__=[x.value for x in apis],
|
||||||
),
|
),
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
@ -217,9 +248,15 @@ async def resolve_impls(
|
||||||
for api_str, provider in sorted_providers:
|
for api_str, provider in sorted_providers:
|
||||||
logcat.debug("core", f" {api_str} => {provider.provider_id}")
|
logcat.debug("core", f" {api_str} => {provider.provider_id}")
|
||||||
logcat.debug("core", "")
|
logcat.debug("core", "")
|
||||||
|
return sorted_providers
|
||||||
|
|
||||||
impls = {}
|
|
||||||
inner_impls_by_provider_id = {f"inner-{x.value}": {} for x in router_apis}
|
async def instantiate_providers(
|
||||||
|
sorted_providers: List[Tuple[str, ProviderWithSpec]], router_apis: Set[Api], dist_registry: DistributionRegistry
|
||||||
|
) -> Dict:
|
||||||
|
"""Instantiates providers asynchronously while managing dependencies."""
|
||||||
|
impls: Dict[Api, Any] = {}
|
||||||
|
inner_impls_by_provider_id: Dict[str, Dict[str, Any]] = {f"inner-{x.value}": {} for x in router_apis}
|
||||||
for api_str, provider in sorted_providers:
|
for api_str, provider in sorted_providers:
|
||||||
deps = {a: impls[a] for a in provider.spec.api_dependencies}
|
deps = {a: impls[a] for a in provider.spec.api_dependencies}
|
||||||
for a in provider.spec.optional_api_dependencies:
|
for a in provider.spec.optional_api_dependencies:
|
||||||
|
@ -230,14 +267,9 @@ async def resolve_impls(
|
||||||
if isinstance(provider.spec, RoutingTableProviderSpec):
|
if isinstance(provider.spec, RoutingTableProviderSpec):
|
||||||
inner_impls = inner_impls_by_provider_id[f"inner-{provider.spec.router_api.value}"]
|
inner_impls = inner_impls_by_provider_id[f"inner-{provider.spec.router_api.value}"]
|
||||||
|
|
||||||
impl = await instantiate_provider(
|
impl = await instantiate_provider(provider, deps, inner_impls, dist_registry)
|
||||||
provider,
|
|
||||||
deps,
|
if api_str.startswith("inner-"):
|
||||||
inner_impls,
|
|
||||||
dist_registry,
|
|
||||||
)
|
|
||||||
# TODO: ugh slightly redesign this shady looking code
|
|
||||||
if "inner-" in api_str:
|
|
||||||
inner_impls_by_provider_id[api_str][provider.provider_id] = impl
|
inner_impls_by_provider_id[api_str][provider.provider_id] = impl
|
||||||
else:
|
else:
|
||||||
api = Api(api_str)
|
api = Api(api_str)
|
||||||
|
@ -248,7 +280,7 @@ async def resolve_impls(
|
||||||
|
|
||||||
def topological_sort(
|
def topological_sort(
|
||||||
providers_with_specs: Dict[str, List[ProviderWithSpec]],
|
providers_with_specs: Dict[str, List[ProviderWithSpec]],
|
||||||
) -> List[ProviderWithSpec]:
|
) -> List[Tuple[str, ProviderWithSpec]]:
|
||||||
def dfs(kv, visited: Set[str], stack: List[str]):
|
def dfs(kv, visited: Set[str], stack: List[str]):
|
||||||
api_str, providers = kv
|
api_str, providers = kv
|
||||||
visited.add(api_str)
|
visited.add(api_str)
|
||||||
|
@ -264,8 +296,8 @@ def topological_sort(
|
||||||
|
|
||||||
stack.append(api_str)
|
stack.append(api_str)
|
||||||
|
|
||||||
visited = set()
|
visited: Set[str] = set()
|
||||||
stack = []
|
stack: List[str] = []
|
||||||
|
|
||||||
for api_str, providers in providers_with_specs.items():
|
for api_str, providers in providers_with_specs.items():
|
||||||
if api_str not in visited:
|
if api_str not in visited:
|
||||||
|
@ -275,13 +307,14 @@ def topological_sort(
|
||||||
for api_str in stack:
|
for api_str in stack:
|
||||||
for provider in providers_with_specs[api_str]:
|
for provider in providers_with_specs[api_str]:
|
||||||
flattened.append((api_str, provider))
|
flattened.append((api_str, provider))
|
||||||
|
|
||||||
return flattened
|
return flattened
|
||||||
|
|
||||||
|
|
||||||
# returns a class implementing the protocol corresponding to the Api
|
# returns a class implementing the protocol corresponding to the Api
|
||||||
async def instantiate_provider(
|
async def instantiate_provider(
|
||||||
provider: ProviderWithSpec,
|
provider: ProviderWithSpec,
|
||||||
deps: Dict[str, Any],
|
deps: Dict[Api, Any],
|
||||||
inner_impls: Dict[str, Any],
|
inner_impls: Dict[str, Any],
|
||||||
dist_registry: DistributionRegistry,
|
dist_registry: DistributionRegistry,
|
||||||
):
|
):
|
||||||
|
@ -289,8 +322,10 @@ async def instantiate_provider(
|
||||||
additional_protocols = additional_protocols_map()
|
additional_protocols = additional_protocols_map()
|
||||||
|
|
||||||
provider_spec = provider.spec
|
provider_spec = provider.spec
|
||||||
module = importlib.import_module(provider_spec.module)
|
if not hasattr(provider_spec, "module"):
|
||||||
|
raise AttributeError(f"ProviderSpec of type {type(provider_spec)} does not have a 'module' attribute")
|
||||||
|
|
||||||
|
module = importlib.import_module(provider_spec.module)
|
||||||
args = []
|
args = []
|
||||||
if isinstance(provider_spec, RemoteProviderSpec):
|
if isinstance(provider_spec, RemoteProviderSpec):
|
||||||
config_type = instantiate_class_type(provider_spec.config_class)
|
config_type = instantiate_class_type(provider_spec.config_class)
|
||||||
|
|
|
@ -75,11 +75,7 @@ docs = [
|
||||||
"sphinxcontrib.mermaid",
|
"sphinxcontrib.mermaid",
|
||||||
"tomli",
|
"tomli",
|
||||||
]
|
]
|
||||||
codegen = [
|
codegen = ["rich", "pydantic", "jinja2"]
|
||||||
"rich",
|
|
||||||
"pydantic",
|
|
||||||
"jinja2",
|
|
||||||
]
|
|
||||||
|
|
||||||
[project.urls]
|
[project.urls]
|
||||||
Homepage = "https://github.com/meta-llama/llama-stack"
|
Homepage = "https://github.com/meta-llama/llama-stack"
|
||||||
|
@ -163,3 +159,7 @@ exclude = [
|
||||||
# packages that lack typing annotations, do not have stubs, or are unavailable.
|
# packages that lack typing annotations, do not have stubs, or are unavailable.
|
||||||
module = ["yaml", "fire"]
|
module = ["yaml", "fire"]
|
||||||
ignore_missing_imports = true
|
ignore_missing_imports = true
|
||||||
|
|
||||||
|
[[tool.mypy.overrides]]
|
||||||
|
module = "llama_stack.distribution.resolver"
|
||||||
|
follow_imports = "normal" # This will force type checking on this module
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue