llama-stack-mirror/llama_stack/core/resolver.py
Charlie Doern b11bcfde11
Some checks failed
SqlStore Integration Tests / test-postgres (3.12) (push) Failing after 0s
SqlStore Integration Tests / test-postgres (3.13) (push) Failing after 0s
Integration Tests (Replay) / Integration Tests (, , , client=, ) (push) Failing after 3s
Test Llama Stack Build / generate-matrix (push) Successful in 22s
Test llama stack list-deps / show-single-provider (push) Failing after 53s
Test Llama Stack Build / build-single-provider (push) Failing after 3s
Test External Providers Installed via Module / test-external-providers-from-module (venv) (push) Has been skipped
Python Package Build Test / build (3.12) (push) Failing after 18s
Python Package Build Test / build (3.13) (push) Failing after 24s
Test Llama Stack Build / build-ubi9-container-distribution (push) Failing after 26s
Test Llama Stack Build / build-custom-container-distribution (push) Failing after 27s
Unit Tests / unit-tests (3.12) (push) Failing after 26s
Vector IO Integration Tests / test-matrix (push) Failing after 44s
API Conformance Tests / check-schema-compatibility (push) Successful in 52s
Test llama stack list-deps / generate-matrix (push) Successful in 52s
Test Llama Stack Build / build (push) Failing after 29s
Test External API and Providers / test-external (venv) (push) Failing after 53s
Integration Auth Tests / test-matrix (oauth2_token) (push) Failing after 1m2s
Unit Tests / unit-tests (3.13) (push) Failing after 1m30s
Test llama stack list-deps / list-deps-from-config (push) Failing after 1m59s
Test llama stack list-deps / list-deps (push) Failing after 1m10s
UI Tests / ui-tests (22) (push) Successful in 2m26s
Pre-commit / pre-commit (push) Successful in 3m8s
refactor(build): rework CLI commands and build process (1/2) (#2974)
# What does this PR do?

This PR does a few things outlined in #2878 namely:
1. adds `llama stack list-deps` a command which simply takes the build
logic and instead of executing one of the `build_...` scripts, it
displays all of the providers' dependencies using the `module` and `uv`.
2. deprecated `llama stack build` in favor of `llama stack list-deps`
3. updates all tests to use `list-deps` alongside `build`.

PR 2/2 will migrate `llama stack run`'s default behavior to be `llama
stack build --run` and use the new `list-deps` command under the hood
before running the server.

examples of `llama stack list-deps starter`

```
llama stack list-deps starter --format json
{
  "name": "starter",
  "description": "Quick start template for running Llama Stack with several popular providers. This distribution is intended for CPU-only environments.",
  "apis": [
    {
      "api": "inference",
      "provider": "remote::cerebras"
    },
    {
      "api": "inference",
      "provider": "remote::ollama"
    },
    {
      "api": "inference",
      "provider": "remote::vllm"
    },
    {
      "api": "inference",
      "provider": "remote::tgi"
    },
    {
      "api": "inference",
      "provider": "remote::fireworks"
    },
    {
      "api": "inference",
      "provider": "remote::together"
    },
    {
      "api": "inference",
      "provider": "remote::bedrock"
    },
    {
      "api": "inference",
      "provider": "remote::nvidia"
    },
    {
      "api": "inference",
      "provider": "remote::openai"
    },
    {
      "api": "inference",
      "provider": "remote::anthropic"
    },
    {
      "api": "inference",
      "provider": "remote::gemini"
    },
    {
      "api": "inference",
      "provider": "remote::vertexai"
    },
    {
      "api": "inference",
      "provider": "remote::groq"
    },
    {
      "api": "inference",
      "provider": "remote::sambanova"
    },
    {
      "api": "inference",
      "provider": "remote::azure"
    },
    {
      "api": "inference",
      "provider": "inline::sentence-transformers"
    },
    {
      "api": "vector_io",
      "provider": "inline::faiss"
    },
    {
      "api": "vector_io",
      "provider": "inline::sqlite-vec"
    },
    {
      "api": "vector_io",
      "provider": "inline::milvus"
    },
    {
      "api": "vector_io",
      "provider": "remote::chromadb"
    },
    {
      "api": "vector_io",
      "provider": "remote::pgvector"
    },
    {
      "api": "files",
      "provider": "inline::localfs"
    },
    {
      "api": "safety",
      "provider": "inline::llama-guard"
    },
    {
      "api": "safety",
      "provider": "inline::code-scanner"
    },
    {
      "api": "agents",
      "provider": "inline::meta-reference"
    },
    {
      "api": "telemetry",
      "provider": "inline::meta-reference"
    },
    {
      "api": "post_training",
      "provider": "inline::torchtune-cpu"
    },
    {
      "api": "eval",
      "provider": "inline::meta-reference"
    },
    {
      "api": "datasetio",
      "provider": "remote::huggingface"
    },
    {
      "api": "datasetio",
      "provider": "inline::localfs"
    },
    {
      "api": "scoring",
      "provider": "inline::basic"
    },
    {
      "api": "scoring",
      "provider": "inline::llm-as-judge"
    },
    {
      "api": "scoring",
      "provider": "inline::braintrust"
    },
    {
      "api": "tool_runtime",
      "provider": "remote::brave-search"
    },
    {
      "api": "tool_runtime",
      "provider": "remote::tavily-search"
    },
    {
      "api": "tool_runtime",
      "provider": "inline::rag-runtime"
    },
    {
      "api": "tool_runtime",
      "provider": "remote::model-context-protocol"
    },
    {
      "api": "batches",
      "provider": "inline::reference"
    }
  ],
  "pip_dependencies": [
    "pandas",
    "opentelemetry-exporter-otlp-proto-http",
    "matplotlib",
    "opentelemetry-sdk",
    "sentence-transformers",
    "datasets",
    "pymilvus[milvus-lite]>=2.4.10",
    "codeshield",
    "scipy",
    "torchvision",
    "tree_sitter",
    "h11>=0.16.0",
    "aiohttp",
    "pymongo",
    "tqdm",
    "pythainlp",
    "pillow",
    "torch",
    "emoji",
    "grpcio>=1.67.1,<1.71.0",
    "fireworks-ai",
    "langdetect",
    "psycopg2-binary",
    "asyncpg",
    "redis",
    "together",
    "torchao>=0.12.0",
    "openai",
    "sentencepiece",
    "aiosqlite",
    "google-cloud-aiplatform",
    "faiss-cpu",
    "numpy",
    "sqlite-vec",
    "nltk",
    "scikit-learn",
    "mcp>=1.8.1",
    "transformers",
    "boto3",
    "huggingface_hub",
    "ollama",
    "autoevals",
    "sqlalchemy[asyncio]",
    "torchtune>=0.5.0",
    "chromadb-client",
    "pypdf",
    "requests",
    "anthropic",
    "chardet",
    "aiosqlite",
    "fastapi",
    "fire",
    "httpx",
    "uvicorn",
    "opentelemetry-sdk",
    "opentelemetry-exporter-otlp-proto-http"
  ]
}
```

<img width="1500" height="420" alt="Screenshot 2025-10-16 at 5 53 03 PM"
src="https://github.com/user-attachments/assets/765929fb-93e2-44d7-9c3d-8918b70fc721"
/>

---------

Signed-off-by: Charlie Doern <cdoern@redhat.com>
2025-10-17 19:52:14 -07:00

501 lines
19 KiB
Python

# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import importlib
import importlib.metadata
import inspect
from typing import Any
from llama_stack.apis.agents import Agents
from llama_stack.apis.batches import Batches
from llama_stack.apis.benchmarks import Benchmarks
from llama_stack.apis.conversations import Conversations
from llama_stack.apis.datasetio import DatasetIO
from llama_stack.apis.datasets import Datasets
from llama_stack.apis.datatypes import ExternalApiSpec
from llama_stack.apis.eval import Eval
from llama_stack.apis.files import Files
from llama_stack.apis.inference import Inference, InferenceProvider
from llama_stack.apis.inspect import Inspect
from llama_stack.apis.models import Models
from llama_stack.apis.post_training import PostTraining
from llama_stack.apis.prompts import Prompts
from llama_stack.apis.providers import Providers as ProvidersAPI
from llama_stack.apis.safety import Safety
from llama_stack.apis.scoring import Scoring
from llama_stack.apis.scoring_functions import ScoringFunctions
from llama_stack.apis.shields import Shields
from llama_stack.apis.telemetry import Telemetry
from llama_stack.apis.tools import ToolGroups, ToolRuntime
from llama_stack.apis.vector_io import VectorIO
from llama_stack.apis.version import LLAMA_STACK_API_V1ALPHA
from llama_stack.core.client import get_client_impl
from llama_stack.core.datatypes import (
AccessRule,
AutoRoutedProviderSpec,
Provider,
RoutingTableProviderSpec,
StackRunConfig,
)
from llama_stack.core.distribution import builtin_automatically_routed_apis
from llama_stack.core.external import load_external_apis
from llama_stack.core.store import DistributionRegistry
from llama_stack.core.utils.dynamic import instantiate_class_type
from llama_stack.log import get_logger
from llama_stack.providers.datatypes import (
Api,
BenchmarksProtocolPrivate,
DatasetsProtocolPrivate,
InlineProviderSpec,
ModelsProtocolPrivate,
ProviderSpec,
RemoteProviderConfig,
RemoteProviderSpec,
ScoringFunctionsProtocolPrivate,
ShieldsProtocolPrivate,
ToolGroupsProtocolPrivate,
)
logger = get_logger(name=__name__, category="core")
class InvalidProviderError(Exception):
pass
def api_protocol_map(external_apis: dict[Api, ExternalApiSpec] | None = None) -> dict[Api, Any]:
"""Get a mapping of API types to their protocol classes.
Args:
external_apis: Optional dictionary of external API specifications
Returns:
Dictionary mapping API types to their protocol classes
"""
protocols = {
Api.providers: ProvidersAPI,
Api.agents: Agents,
Api.inference: Inference,
Api.inspect: Inspect,
Api.batches: Batches,
Api.vector_io: VectorIO,
Api.models: Models,
Api.safety: Safety,
Api.shields: Shields,
Api.datasetio: DatasetIO,
Api.datasets: Datasets,
Api.scoring: Scoring,
Api.scoring_functions: ScoringFunctions,
Api.eval: Eval,
Api.benchmarks: Benchmarks,
Api.post_training: PostTraining,
Api.tool_groups: ToolGroups,
Api.tool_runtime: ToolRuntime,
Api.files: Files,
Api.prompts: Prompts,
Api.conversations: Conversations,
Api.telemetry: Telemetry,
}
if external_apis:
for api, api_spec in external_apis.items():
try:
module = importlib.import_module(api_spec.module)
api_class = getattr(module, api_spec.protocol)
protocols[api] = api_class
except (ImportError, AttributeError):
logger.exception(f"Failed to load external API {api_spec.name}")
return protocols
def api_protocol_map_for_compliance_check(config: Any) -> dict[Api, Any]:
external_apis = load_external_apis(config)
return {
**api_protocol_map(external_apis),
Api.inference: InferenceProvider,
}
def additional_protocols_map() -> dict[Api, Any]:
return {
Api.inference: (ModelsProtocolPrivate, Models, Api.models),
Api.tool_groups: (ToolGroupsProtocolPrivate, ToolGroups, Api.tool_groups),
Api.safety: (ShieldsProtocolPrivate, Shields, Api.shields),
Api.datasetio: (DatasetsProtocolPrivate, Datasets, Api.datasets),
Api.scoring: (
ScoringFunctionsProtocolPrivate,
ScoringFunctions,
Api.scoring_functions,
),
Api.eval: (BenchmarksProtocolPrivate, Benchmarks, Api.benchmarks),
}
# TODO: make all this naming far less atrocious. Provider. ProviderSpec. ProviderWithSpec. WTF!
class ProviderWithSpec(Provider):
spec: ProviderSpec
ProviderRegistry = dict[Api, dict[str, ProviderSpec]]
async def resolve_impls(
run_config: StackRunConfig,
provider_registry: ProviderRegistry,
dist_registry: DistributionRegistry,
policy: list[AccessRule],
internal_impls: dict[Api, Any] | None = None,
) -> dict[Api, Any]:
"""
Resolves provider implementations by:
1. Validating and organizing providers.
2. Sorting them in dependency order.
3. Instantiating them with required dependencies.
"""
routing_table_apis = {x.routing_table_api for x in builtin_automatically_routed_apis()}
router_apis = {x.router_api for x in builtin_automatically_routed_apis()}
providers_with_specs = validate_and_prepare_providers(
run_config, provider_registry, routing_table_apis, router_apis
)
apis_to_serve = run_config.apis or set(
list(providers_with_specs.keys()) + [x.value for x in routing_table_apis] + [x.value for x in router_apis]
)
providers_with_specs.update(specs_for_autorouted_apis(apis_to_serve))
sorted_providers = sort_providers_by_deps(providers_with_specs, run_config)
return await instantiate_providers(sorted_providers, router_apis, dist_registry, run_config, policy, internal_impls)
def specs_for_autorouted_apis(apis_to_serve: list[str] | set[str]) -> dict[str, dict[str, ProviderWithSpec]]:
"""Generates specifications for automatically routed APIs."""
specs = {}
for info in builtin_automatically_routed_apis():
if info.router_api.value not in apis_to_serve:
continue
specs[info.routing_table_api.value] = {
"__builtin__": ProviderWithSpec(
provider_id="__routing_table__",
provider_type="__routing_table__",
config={},
spec=RoutingTableProviderSpec(
api=info.routing_table_api,
router_api=info.router_api,
module="llama_stack.core.routers",
api_dependencies=[],
deps__=[f"inner-{info.router_api.value}"],
),
)
}
specs[info.router_api.value] = {
"__builtin__": ProviderWithSpec(
provider_id="__autorouted__",
provider_type="__autorouted__",
config={},
spec=AutoRoutedProviderSpec(
api=info.router_api,
module="llama_stack.core.routers",
routing_table_api=info.routing_table_api,
api_dependencies=[info.routing_table_api],
deps__=([info.routing_table_api.value]),
),
)
}
return specs
def validate_and_prepare_providers(
run_config: StackRunConfig, provider_registry: ProviderRegistry, routing_table_apis: set[Api], router_apis: set[Api]
) -> dict[str, dict[str, ProviderWithSpec]]:
"""Validates providers, handles deprecations, and organizes them into a spec dictionary."""
providers_with_specs: dict[str, dict[str, ProviderWithSpec]] = {}
for api_str, providers in run_config.providers.items():
api = Api(api_str)
if api in routing_table_apis:
raise ValueError(f"Provider for `{api_str}` is automatically provided and cannot be overridden")
specs = {}
for provider in providers:
if not provider.provider_id or provider.provider_id == "__disabled__":
logger.debug(f"Provider `{provider.provider_type}` for API `{api}` is disabled")
continue
validate_provider(provider, api, provider_registry)
p = provider_registry[api][provider.provider_type]
p.deps__ = [a.value for a in p.api_dependencies] + [a.value for a in p.optional_api_dependencies]
spec = ProviderWithSpec(spec=p, **provider.model_dump())
specs[provider.provider_id] = spec
key = api_str if api not in router_apis else f"inner-{api_str}"
providers_with_specs[key] = specs
# TODO: remove this logic, telemetry should not have providers.
# if telemetry has been enabled in the config initialize our internal impl
# telemetry is not an external API so it SHOULD NOT be auto-routed.
if run_config.telemetry.enabled:
specs = {}
p = InlineProviderSpec(
api=Api.telemetry,
provider_type="inline::meta-reference",
pip_packages=[],
optional_api_dependencies=[Api.datasetio],
module="llama_stack.providers.inline.telemetry.meta_reference",
config_class="llama_stack.providers.inline.telemetry.meta_reference.config.TelemetryConfig",
description="Meta's reference implementation of telemetry and observability using OpenTelemetry.",
)
spec = ProviderWithSpec(spec=p, provider_type="inline::meta-reference", provider_id="meta-reference")
specs["meta-reference"] = spec
providers_with_specs["telemetry"] = specs
return providers_with_specs
def validate_provider(provider: Provider, api: Api, provider_registry: ProviderRegistry):
"""Validates if the provider is allowed and handles deprecations."""
if provider.provider_type not in provider_registry[api]:
raise ValueError(f"Provider `{provider.provider_type}` is not available for API `{api}`")
p = provider_registry[api][provider.provider_type]
if p.deprecation_error:
logger.error(p.deprecation_error)
raise InvalidProviderError(p.deprecation_error)
elif p.deprecation_warning:
logger.warning(
f"Provider `{provider.provider_type}` for API `{api}` is deprecated and will be removed in a future release: {p.deprecation_warning}",
)
def sort_providers_by_deps(
providers_with_specs: dict[str, dict[str, ProviderWithSpec]], run_config: StackRunConfig
) -> list[tuple[str, ProviderWithSpec]]:
"""Sorts providers based on their dependencies."""
sorted_providers: list[tuple[str, ProviderWithSpec]] = topological_sort(
{k: list(v.values()) for k, v in providers_with_specs.items()}
)
logger.debug(f"Resolved {len(sorted_providers)} providers")
for api_str, provider in sorted_providers:
logger.debug(f" {api_str} => {provider.provider_id}")
return sorted_providers
async def instantiate_providers(
sorted_providers: list[tuple[str, ProviderWithSpec]],
router_apis: set[Api],
dist_registry: DistributionRegistry,
run_config: StackRunConfig,
policy: list[AccessRule],
internal_impls: dict[Api, Any] | None = None,
) -> dict[Api, Any]:
"""Instantiates providers asynchronously while managing dependencies."""
impls: dict[Api, Any] = internal_impls.copy() if internal_impls else {}
inner_impls_by_provider_id: dict[str, dict[str, Any]] = {f"inner-{x.value}": {} for x in router_apis}
for api_str, provider in sorted_providers:
# Skip providers that are not enabled
if provider.provider_id is None:
continue
try:
deps = {a: impls[a] for a in provider.spec.api_dependencies}
except KeyError as e:
missing_api = e.args[0]
raise RuntimeError(
f"Failed to resolve '{provider.spec.api.value}' provider '{provider.provider_id}' of type '{provider.spec.provider_type}': "
f"required dependency '{missing_api.value}' is not available. "
f"Please add a '{missing_api.value}' provider to your configuration or check if the provider is properly configured."
) from e
for a in provider.spec.optional_api_dependencies:
if a in impls:
deps[a] = impls[a]
inner_impls = {}
if isinstance(provider.spec, RoutingTableProviderSpec):
inner_impls = inner_impls_by_provider_id[f"inner-{provider.spec.router_api.value}"]
impl = await instantiate_provider(provider, deps, inner_impls, dist_registry, run_config, policy)
if api_str.startswith("inner-"):
inner_impls_by_provider_id[api_str][provider.provider_id] = impl
else:
api = Api(api_str)
impls[api] = impl
return impls
def topological_sort(
providers_with_specs: dict[str, list[ProviderWithSpec]],
) -> list[tuple[str, ProviderWithSpec]]:
def dfs(kv, visited: set[str], stack: list[str]):
api_str, providers = kv
visited.add(api_str)
deps = []
for provider in providers:
for dep in provider.spec.deps__:
deps.append(dep)
for dep in deps:
if dep not in visited and dep in providers_with_specs:
dfs((dep, providers_with_specs[dep]), visited, stack)
stack.append(api_str)
visited: set[str] = set()
stack: list[str] = []
for api_str, providers in providers_with_specs.items():
if api_str not in visited:
dfs((api_str, providers), visited, stack)
flattened = []
for api_str in stack:
for provider in providers_with_specs[api_str]:
flattened.append((api_str, provider))
return flattened
# returns a class implementing the protocol corresponding to the Api
async def instantiate_provider(
provider: ProviderWithSpec,
deps: dict[Api, Any],
inner_impls: dict[str, Any],
dist_registry: DistributionRegistry,
run_config: StackRunConfig,
policy: list[AccessRule],
):
provider_spec = provider.spec
if not hasattr(provider_spec, "module") or provider_spec.module is None:
raise AttributeError(f"ProviderSpec of type {type(provider_spec)} does not have a 'module' attribute")
logger.debug(f"Instantiating provider {provider.provider_id} from {provider_spec.module}")
module = importlib.import_module(provider_spec.module)
args = []
if isinstance(provider_spec, RemoteProviderSpec):
config_type = instantiate_class_type(provider_spec.config_class)
config = config_type(**provider.config)
method = "get_adapter_impl"
args = [config, deps]
elif isinstance(provider_spec, AutoRoutedProviderSpec):
method = "get_auto_router_impl"
config = None
args = [provider_spec.api, deps[provider_spec.routing_table_api], deps, run_config, policy]
elif isinstance(provider_spec, RoutingTableProviderSpec):
method = "get_routing_table_impl"
config = None
args = [provider_spec.api, inner_impls, deps, dist_registry, policy]
else:
method = "get_provider_impl"
config_type = instantiate_class_type(provider_spec.config_class)
config = config_type(**provider.config)
args = [config, deps]
if "policy" in inspect.signature(getattr(module, method)).parameters:
args.append(policy)
if "telemetry_enabled" in inspect.signature(getattr(module, method)).parameters and run_config.telemetry:
args.append(run_config.telemetry.enabled)
fn = getattr(module, method)
impl = await fn(*args)
impl.__provider_id__ = provider.provider_id
impl.__provider_spec__ = provider_spec
impl.__provider_config__ = config
protocols = api_protocol_map_for_compliance_check(run_config)
additional_protocols = additional_protocols_map()
# TODO: check compliance for special tool groups
# the impl should be for Api.tool_runtime, the name should be the special tool group, the protocol should be the special tool group protocol
check_protocol_compliance(impl, protocols[provider_spec.api])
if not isinstance(provider_spec, AutoRoutedProviderSpec) and provider_spec.api in additional_protocols:
additional_api, _, _ = additional_protocols[provider_spec.api]
check_protocol_compliance(impl, additional_api)
return impl
def check_protocol_compliance(obj: Any, protocol: Any) -> None:
missing_methods = []
mro = type(obj).__mro__
for name, value in inspect.getmembers(protocol):
if inspect.isfunction(value) and hasattr(value, "__webmethods__"):
has_alpha_api = False
for webmethod in value.__webmethods__:
if webmethod.level == LLAMA_STACK_API_V1ALPHA:
has_alpha_api = True
break
# if this API has multiple webmethods, and one of them is an alpha API, this API should be skipped when checking for missing or not callable routes
if has_alpha_api:
continue
if not hasattr(obj, name):
missing_methods.append((name, "missing"))
elif not callable(getattr(obj, name)):
missing_methods.append((name, "not_callable"))
else:
# Check if the method signatures are compatible
obj_method = getattr(obj, name)
proto_sig = inspect.signature(value)
obj_sig = inspect.signature(obj_method)
proto_params = set(proto_sig.parameters)
proto_params.discard("self")
obj_params = set(obj_sig.parameters)
obj_params.discard("self")
if not (proto_params <= obj_params):
logger.error(f"Method {name} incompatible proto: {proto_params} vs. obj: {obj_params}")
missing_methods.append((name, "signature_mismatch"))
else:
# Check if the method has a concrete implementation (not just a protocol stub)
# Find all classes in MRO that define this method
method_owners = [cls for cls in mro if name in cls.__dict__]
# Allow methods from mixins/parents, only reject if ONLY the protocol defines it
if len(method_owners) == 1 and method_owners[0].__name__ == protocol.__name__:
# Only reject if the method is ONLY defined in the protocol itself (abstract stub)
missing_methods.append((name, "not_actually_implemented"))
if missing_methods:
raise ValueError(
f"Provider `{obj.__provider_id__} ({obj.__provider_spec__.api})` does not implement the following methods:\n{missing_methods}"
)
async def resolve_remote_stack_impls(
config: RemoteProviderConfig,
apis: list[str],
) -> dict[Api, Any]:
protocols = api_protocol_map()
additional_protocols = additional_protocols_map()
impls = {}
for api_str in apis:
api = Api(api_str)
impls[api] = await get_client_impl(
protocols[api],
config,
{},
)
if api in additional_protocols:
_, additional_protocol, additional_api = additional_protocols[api]
impls[additional_api] = await get_client_impl(
additional_protocol,
config,
{},
)
return impls