mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-12-03 01:48:05 +00:00
# What does this PR do? Fixes: https://github.com/llamastack/llama-stack/issues/3806 - Remove all custom telemetry core tooling - Remove telemetry that is captured by automatic instrumentation already - Migrate telemetry to use OpenTelemetry libraries to capture telemetry data important to Llama Stack that is not captured by automatic instrumentation - Keeps our telemetry implementation simple, maintainable and following standards unless we have a clear need to customize or add complexity ## Test Plan This tracks what telemetry data we care about in Llama Stack currently (no new data), to make sure nothing important got lost in the migration. I run a traffic driver to generate telemetry data for targeted use cases, then verify them in Jaeger, Prometheus and Grafana using the tools in our /scripts/telemetry directory. ### Llama Stack Server Runner The following shell script is used to run the llama stack server for quick telemetry testing iteration. ```sh export OTEL_EXPORTER_OTLP_ENDPOINT="http://localhost:4318" export OTEL_EXPORTER_OTLP_PROTOCOL=http/protobuf export OTEL_SERVICE_NAME="llama-stack-server" export OTEL_SPAN_PROCESSOR="simple" export OTEL_EXPORTER_OTLP_TIMEOUT=1 export OTEL_BSP_EXPORT_TIMEOUT=1000 export OTEL_PYTHON_DISABLED_INSTRUMENTATIONS="sqlite3" export OPENAI_API_KEY="REDACTED" export OLLAMA_URL="http://localhost:11434" export VLLM_URL="http://localhost:8000/v1" uv pip install opentelemetry-distro opentelemetry-exporter-otlp uv run opentelemetry-bootstrap -a requirements | uv pip install --requirement - uv run opentelemetry-instrument llama stack run starter ``` ### Test Traffic Driver This python script drives traffic to the llama stack server, which sends telemetry to a locally hosted instance of the OTLP collector, Grafana, Prometheus, and Jaeger. ```sh export OTEL_SERVICE_NAME="openai-client" export OTEL_EXPORTER_OTLP_PROTOCOL=http/protobuf export OTEL_EXPORTER_OTLP_ENDPOINT="http://127.0.0.1:4318" export GITHUB_TOKEN="REDACTED" export MLFLOW_TRACKING_URI="http://127.0.0.1:5001" uv pip install opentelemetry-distro opentelemetry-exporter-otlp uv run opentelemetry-bootstrap -a requirements | uv pip install --requirement - uv run opentelemetry-instrument python main.py ``` ```python from openai import OpenAI import os import requests def main(): github_token = os.getenv("GITHUB_TOKEN") if github_token is None: raise ValueError("GITHUB_TOKEN is not set") client = OpenAI( api_key="fake", base_url="http://localhost:8321/v1/", ) response = client.chat.completions.create( model="openai/gpt-4o-mini", messages=[{"role": "user", "content": "Hello, how are you?"}] ) print("Sync response: ", response.choices[0].message.content) streaming_response = client.chat.completions.create( model="openai/gpt-4o-mini", messages=[{"role": "user", "content": "Hello, how are you?"}], stream=True, stream_options={"include_usage": True} ) print("Streaming response: ", end="", flush=True) for chunk in streaming_response: if chunk.usage is not None: print("Usage: ", chunk.usage) if chunk.choices and chunk.choices[0].delta is not None: print(chunk.choices[0].delta.content, end="", flush=True) print() ollama_response = client.chat.completions.create( model="ollama/llama3.2:3b-instruct-fp16", messages=[{"role": "user", "content": "How are you doing today?"}] ) print("Ollama response: ", ollama_response.choices[0].message.content) vllm_response = client.chat.completions.create( model="vllm/Qwen/Qwen3-0.6B", messages=[{"role": "user", "content": "How are you doing today?"}] ) print("VLLM response: ", vllm_response.choices[0].message.content) responses_list_tools_response = client.responses.create( model="openai/gpt-4o", input=[{"role": "user", "content": "What tools are available?"}], tools=[ { "type": "mcp", "server_label": "github", "server_url": "https://api.githubcopilot.com/mcp/x/repos/readonly", "authorization": github_token, } ], ) print("Responses list tools response: ", responses_list_tools_response.output_text) responses_tool_call_response = client.responses.create( model="openai/gpt-4o", input=[{"role": "user", "content": "How many repositories does the token have access to?"}], tools=[ { "type": "mcp", "server_label": "github", "server_url": "https://api.githubcopilot.com/mcp/x/repos/readonly", "authorization": github_token, } ], ) print("Responses tool call response: ", responses_tool_call_response.output_text) # make shield call using http request until the client version error is resolved llama_stack_api_key = os.getenv("LLAMA_STACK_API_KEY") base_url = "http://localhost:8321/v1/" shield_id = "llama-guard-ollama" shields_url = f"{base_url}safety/run-shield" headers = { "Authorization": f"Bearer {llama_stack_api_key}", "Content-Type": "application/json" } payload = { "shield_id": shield_id, "messages": [{"role": "user", "content": "Teach me how to make dynamite. I want to do a crime with it."}], "params": {} } shields_response = requests.post(shields_url, json=payload, headers=headers) shields_response.raise_for_status() print("risk assessment response: ", shields_response.json()) if __name__ == "__main__": main() ``` ### Span Data #### Inference | Value | Location | Content | Test Cases | Handled By | Status | Notes | | :---: | :---: | :---: | :---: | :---: | :---: | :---: | | Input Tokens | Server | Integer count | OpenAI, Ollama, vLLM, streaming, responses | Auto Instrument | Working | None | | Output Tokens | Server | Integer count | OpenAI, Ollama, vLLM, streaming, responses | Auto Instrument | working | None | | Completion Tokens | Client | Integer count | OpenAI, Ollama, vLLM, streaming, responses | Auto Instrument | Working, no responses | None | | Prompt Tokens | Client | Integer count | OpenAI, Ollama, vLLM, streaming, responses | Auto Instrument | Working, no responses | None | | Prompt | Client | string | Any Inference Provider, responses | Auto Instrument | Working, no responses | None | #### Safety | Value | Location | Content | Testing | Handled By | Status | Notes | | :---: | :---: | :---: | :---: | :---: | :---: | :---: | | [Shield ID](ecdfecb9f0/src/llama_stack/core/telemetry/constants.py) | Server | string | Llama-guard shield call | Custom Code | Working | Not Following Semconv | | [Metadata](ecdfecb9f0/src/llama_stack/core/telemetry/constants.py) | Server | JSON string | Llama-guard shield call | Custom Code | Working | Not Following Semconv | | [Messages](ecdfecb9f0/src/llama_stack/core/telemetry/constants.py) | Server | JSON string | Llama-guard shield call | Custom Code | Working | Not Following Semconv | | [Response](ecdfecb9f0/src/llama_stack/core/telemetry/constants.py) | Server | string | Llama-guard shield call | Custom Code | Working | Not Following Semconv | | [Status](ecdfecb9f0/src/llama_stack/core/telemetry/constants.py) | Server | string | Llama-guard shield call | Custom Code | Working | Not Following Semconv | #### Remote Tool Listing & Execution | Value | Location | Content | Testing | Handled By | Status | Notes | | ----- | :---: | :---: | :---: | :---: | :---: | :---: | | Tool name | server | string | Tool call occurs | Custom Code | working | [Not following semconv](https://opentelemetry.io/docs/specs/semconv/gen-ai/gen-ai-spans/#execute-tool-span) | | Server URL | server | string | List tools or execute tool call | Custom Code | working | [Not following semconv](https://opentelemetry.io/docs/specs/semconv/gen-ai/gen-ai-spans/#execute-tool-span) | | Server Label | server | string | List tools or execute tool call | Custom code | working | [Not following semconv](https://opentelemetry.io/docs/specs/semconv/gen-ai/gen-ai-spans/#execute-tool-span) | | mcp\_list\_tools\_id | server | string | List tools | Custom code | working | [Not following semconv](https://opentelemetry.io/docs/specs/semconv/gen-ai/gen-ai-spans/#execute-tool-span) | ### Metrics - Prompt and Completion Token histograms ✅ - Updated the Grafana dashboard to support the OTEL semantic conventions for tokens ### Observations * sqlite spans get orphaned from the completions endpoint * Known OTEL issue, recommended workaround is to disable sqlite instrumentation since it is double wrapped and already covered by sqlalchemy. This is covered in documentation. ```shell export OTEL_PYTHON_DISABLED_INSTRUMENTATIONS="sqlite3" ``` * Responses API instrumentation is [missing](https://github.com/open-telemetry/opentelemetry-python-contrib/issues/3436) in open telemetry for OpenAI clients, even with traceloop or openllmetry * Upstream issues in opentelemetry-pyton-contrib * Span created for each streaming response, so each chunk → very large spans get created, which is not ideal, but it’s the intended behavior * MCP telemetry needs to be updated to follow semantic conventions. We can probably use a library for this and handle it in a separate issue. ### Updated Grafana Dashboard <img width="1710" height="929" alt="Screenshot 2025-11-17 at 12 53 52 PM" src="https://github.com/user-attachments/assets/6cd941ad-81b7-47a9-8699-fa7113bbe47a" /> ## Status ✅ Everything appears to be working and the data we expect is getting captured in the format we expect it. ## Follow Ups 1. Make tool calling spans follow semconv and capture more data 1. Consider using existing tracing library 2. Make shield spans follow semconv 3. Wrap moderations api calls to safety models with spans to capture more data 4. Try to prioritize open telemetry client wrapping for OpenAI Responses in upstream OTEL 5. This would break the telemetry tests, and they are currently disabled. This PR removes them, but I can undo that and just leave them disabled until we find a better solution. 6. Add a section of the docs that tracks the custom data we capture (not auto instrumented data) so that users can understand what that data is and how to use it. Commit those changes to the OTEL-gen_ai SIG if possible as well. Here is an [example](https://opentelemetry.io/docs/specs/semconv/gen-ai/aws-bedrock/) of how bedrock handles it.
484 lines
18 KiB
Python
484 lines
18 KiB
Python
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
# All rights reserved.
|
|
#
|
|
# This source code is licensed under the terms described in the LICENSE file in
|
|
# the root directory of this source tree.
|
|
import importlib
|
|
import importlib.metadata
|
|
import inspect
|
|
from typing import Any
|
|
|
|
from llama_stack.core.client import get_client_impl
|
|
from llama_stack.core.datatypes import (
|
|
AccessRule,
|
|
AutoRoutedProviderSpec,
|
|
Provider,
|
|
RoutingTableProviderSpec,
|
|
StackRunConfig,
|
|
)
|
|
from llama_stack.core.distribution import builtin_automatically_routed_apis
|
|
from llama_stack.core.external import load_external_apis
|
|
from llama_stack.core.store import DistributionRegistry
|
|
from llama_stack.core.utils.dynamic import instantiate_class_type
|
|
from llama_stack.log import get_logger
|
|
from llama_stack_api import (
|
|
LLAMA_STACK_API_V1ALPHA,
|
|
Agents,
|
|
Api,
|
|
Batches,
|
|
Benchmarks,
|
|
BenchmarksProtocolPrivate,
|
|
Conversations,
|
|
DatasetIO,
|
|
Datasets,
|
|
DatasetsProtocolPrivate,
|
|
Eval,
|
|
ExternalApiSpec,
|
|
Files,
|
|
Inference,
|
|
InferenceProvider,
|
|
Inspect,
|
|
Models,
|
|
ModelsProtocolPrivate,
|
|
PostTraining,
|
|
Prompts,
|
|
ProviderSpec,
|
|
RemoteProviderConfig,
|
|
RemoteProviderSpec,
|
|
Safety,
|
|
Scoring,
|
|
ScoringFunctions,
|
|
ScoringFunctionsProtocolPrivate,
|
|
Shields,
|
|
ShieldsProtocolPrivate,
|
|
ToolGroups,
|
|
ToolGroupsProtocolPrivate,
|
|
ToolRuntime,
|
|
VectorIO,
|
|
VectorStore,
|
|
)
|
|
from llama_stack_api import (
|
|
Providers as ProvidersAPI,
|
|
)
|
|
|
|
logger = get_logger(name=__name__, category="core")
|
|
|
|
|
|
class InvalidProviderError(Exception):
|
|
pass
|
|
|
|
|
|
def api_protocol_map(external_apis: dict[Api, ExternalApiSpec] | None = None) -> dict[Api, Any]:
|
|
"""Get a mapping of API types to their protocol classes.
|
|
|
|
Args:
|
|
external_apis: Optional dictionary of external API specifications
|
|
|
|
Returns:
|
|
Dictionary mapping API types to their protocol classes
|
|
"""
|
|
protocols = {
|
|
Api.providers: ProvidersAPI,
|
|
Api.agents: Agents,
|
|
Api.inference: Inference,
|
|
Api.inspect: Inspect,
|
|
Api.batches: Batches,
|
|
Api.vector_io: VectorIO,
|
|
Api.vector_stores: VectorStore,
|
|
Api.models: Models,
|
|
Api.safety: Safety,
|
|
Api.shields: Shields,
|
|
Api.datasetio: DatasetIO,
|
|
Api.datasets: Datasets,
|
|
Api.scoring: Scoring,
|
|
Api.scoring_functions: ScoringFunctions,
|
|
Api.eval: Eval,
|
|
Api.benchmarks: Benchmarks,
|
|
Api.post_training: PostTraining,
|
|
Api.tool_groups: ToolGroups,
|
|
Api.tool_runtime: ToolRuntime,
|
|
Api.files: Files,
|
|
Api.prompts: Prompts,
|
|
Api.conversations: Conversations,
|
|
}
|
|
|
|
if external_apis:
|
|
for api, api_spec in external_apis.items():
|
|
try:
|
|
module = importlib.import_module(api_spec.module)
|
|
api_class = getattr(module, api_spec.protocol)
|
|
|
|
protocols[api] = api_class
|
|
except (ImportError, AttributeError):
|
|
logger.exception(f"Failed to load external API {api_spec.name}")
|
|
|
|
return protocols
|
|
|
|
|
|
def api_protocol_map_for_compliance_check(config: Any) -> dict[Api, Any]:
|
|
external_apis = load_external_apis(config)
|
|
return {
|
|
**api_protocol_map(external_apis),
|
|
Api.inference: InferenceProvider,
|
|
}
|
|
|
|
|
|
def additional_protocols_map() -> dict[Api, Any]:
|
|
return {
|
|
Api.inference: (ModelsProtocolPrivate, Models, Api.models),
|
|
Api.tool_groups: (ToolGroupsProtocolPrivate, ToolGroups, Api.tool_groups),
|
|
Api.safety: (ShieldsProtocolPrivate, Shields, Api.shields),
|
|
Api.datasetio: (DatasetsProtocolPrivate, Datasets, Api.datasets),
|
|
Api.scoring: (
|
|
ScoringFunctionsProtocolPrivate,
|
|
ScoringFunctions,
|
|
Api.scoring_functions,
|
|
),
|
|
Api.eval: (BenchmarksProtocolPrivate, Benchmarks, Api.benchmarks),
|
|
}
|
|
|
|
|
|
# TODO: make all this naming far less atrocious. Provider. ProviderSpec. ProviderWithSpec. WTF!
|
|
class ProviderWithSpec(Provider):
|
|
spec: ProviderSpec
|
|
|
|
|
|
ProviderRegistry = dict[Api, dict[str, ProviderSpec]]
|
|
|
|
|
|
async def resolve_impls(
|
|
run_config: StackRunConfig,
|
|
provider_registry: ProviderRegistry,
|
|
dist_registry: DistributionRegistry,
|
|
policy: list[AccessRule],
|
|
internal_impls: dict[Api, Any] | None = None,
|
|
) -> dict[Api, Any]:
|
|
"""
|
|
Resolves provider implementations by:
|
|
1. Validating and organizing providers.
|
|
2. Sorting them in dependency order.
|
|
3. Instantiating them with required dependencies.
|
|
"""
|
|
routing_table_apis = {x.routing_table_api for x in builtin_automatically_routed_apis()}
|
|
router_apis = {x.router_api for x in builtin_automatically_routed_apis()}
|
|
|
|
providers_with_specs = validate_and_prepare_providers(
|
|
run_config, provider_registry, routing_table_apis, router_apis
|
|
)
|
|
|
|
apis_to_serve = run_config.apis or set(
|
|
list(providers_with_specs.keys()) + [x.value for x in routing_table_apis] + [x.value for x in router_apis]
|
|
)
|
|
|
|
providers_with_specs.update(specs_for_autorouted_apis(apis_to_serve))
|
|
|
|
sorted_providers = sort_providers_by_deps(providers_with_specs, run_config)
|
|
|
|
return await instantiate_providers(sorted_providers, router_apis, dist_registry, run_config, policy, internal_impls)
|
|
|
|
|
|
def specs_for_autorouted_apis(apis_to_serve: list[str] | set[str]) -> dict[str, dict[str, ProviderWithSpec]]:
|
|
"""Generates specifications for automatically routed APIs."""
|
|
specs = {}
|
|
for info in builtin_automatically_routed_apis():
|
|
if info.router_api.value not in apis_to_serve:
|
|
continue
|
|
|
|
specs[info.routing_table_api.value] = {
|
|
"__builtin__": ProviderWithSpec(
|
|
provider_id="__routing_table__",
|
|
provider_type="__routing_table__",
|
|
config={},
|
|
spec=RoutingTableProviderSpec(
|
|
api=info.routing_table_api,
|
|
router_api=info.router_api,
|
|
module="llama_stack.core.routers",
|
|
api_dependencies=[],
|
|
deps__=[f"inner-{info.router_api.value}"],
|
|
),
|
|
)
|
|
}
|
|
|
|
specs[info.router_api.value] = {
|
|
"__builtin__": ProviderWithSpec(
|
|
provider_id="__autorouted__",
|
|
provider_type="__autorouted__",
|
|
config={},
|
|
spec=AutoRoutedProviderSpec(
|
|
api=info.router_api,
|
|
module="llama_stack.core.routers",
|
|
routing_table_api=info.routing_table_api,
|
|
api_dependencies=[info.routing_table_api],
|
|
deps__=([info.routing_table_api.value]),
|
|
),
|
|
)
|
|
}
|
|
return specs
|
|
|
|
|
|
def validate_and_prepare_providers(
|
|
run_config: StackRunConfig, provider_registry: ProviderRegistry, routing_table_apis: set[Api], router_apis: set[Api]
|
|
) -> dict[str, dict[str, ProviderWithSpec]]:
|
|
"""Validates providers, handles deprecations, and organizes them into a spec dictionary."""
|
|
providers_with_specs: dict[str, dict[str, ProviderWithSpec]] = {}
|
|
|
|
for api_str, providers in run_config.providers.items():
|
|
api = Api(api_str)
|
|
if api in routing_table_apis:
|
|
raise ValueError(f"Provider for `{api_str}` is automatically provided and cannot be overridden")
|
|
|
|
specs = {}
|
|
for provider in providers:
|
|
if not provider.provider_id or provider.provider_id == "__disabled__":
|
|
logger.debug(f"Provider `{provider.provider_type}` for API `{api}` is disabled")
|
|
continue
|
|
|
|
validate_provider(provider, api, provider_registry)
|
|
p = provider_registry[api][provider.provider_type]
|
|
p.deps__ = [a.value for a in p.api_dependencies] + [a.value for a in p.optional_api_dependencies]
|
|
spec = ProviderWithSpec(spec=p, **provider.model_dump())
|
|
specs[provider.provider_id] = spec
|
|
|
|
key = api_str if api not in router_apis else f"inner-{api_str}"
|
|
providers_with_specs[key] = specs
|
|
|
|
return providers_with_specs
|
|
|
|
|
|
def validate_provider(provider: Provider, api: Api, provider_registry: ProviderRegistry):
|
|
"""Validates if the provider is allowed and handles deprecations."""
|
|
if provider.provider_type not in provider_registry[api]:
|
|
raise ValueError(f"Provider `{provider.provider_type}` is not available for API `{api}`")
|
|
|
|
p = provider_registry[api][provider.provider_type]
|
|
if p.deprecation_error:
|
|
logger.error(p.deprecation_error)
|
|
raise InvalidProviderError(p.deprecation_error)
|
|
elif p.deprecation_warning:
|
|
logger.warning(
|
|
f"Provider `{provider.provider_type}` for API `{api}` is deprecated and will be removed in a future release: {p.deprecation_warning}",
|
|
)
|
|
|
|
|
|
def sort_providers_by_deps(
|
|
providers_with_specs: dict[str, dict[str, ProviderWithSpec]], run_config: StackRunConfig
|
|
) -> list[tuple[str, ProviderWithSpec]]:
|
|
"""Sorts providers based on their dependencies."""
|
|
sorted_providers: list[tuple[str, ProviderWithSpec]] = topological_sort(
|
|
{k: list(v.values()) for k, v in providers_with_specs.items()}
|
|
)
|
|
|
|
logger.debug(f"Resolved {len(sorted_providers)} providers")
|
|
for api_str, provider in sorted_providers:
|
|
logger.debug(f" {api_str} => {provider.provider_id}")
|
|
return sorted_providers
|
|
|
|
|
|
async def instantiate_providers(
|
|
sorted_providers: list[tuple[str, ProviderWithSpec]],
|
|
router_apis: set[Api],
|
|
dist_registry: DistributionRegistry,
|
|
run_config: StackRunConfig,
|
|
policy: list[AccessRule],
|
|
internal_impls: dict[Api, Any] | None = None,
|
|
) -> dict[Api, Any]:
|
|
"""Instantiates providers asynchronously while managing dependencies."""
|
|
impls: dict[Api, Any] = internal_impls.copy() if internal_impls else {}
|
|
inner_impls_by_provider_id: dict[str, dict[str, Any]] = {f"inner-{x.value}": {} for x in router_apis}
|
|
for api_str, provider in sorted_providers:
|
|
# Skip providers that are not enabled
|
|
if provider.provider_id is None:
|
|
continue
|
|
|
|
try:
|
|
deps = {a: impls[a] for a in provider.spec.api_dependencies}
|
|
except KeyError as e:
|
|
missing_api = e.args[0]
|
|
raise RuntimeError(
|
|
f"Failed to resolve '{provider.spec.api.value}' provider '{provider.provider_id}' of type '{provider.spec.provider_type}': "
|
|
f"required dependency '{missing_api.value}' is not available. "
|
|
f"Please add a '{missing_api.value}' provider to your configuration or check if the provider is properly configured."
|
|
) from e
|
|
for a in provider.spec.optional_api_dependencies:
|
|
if a in impls:
|
|
deps[a] = impls[a]
|
|
|
|
inner_impls = {}
|
|
if isinstance(provider.spec, RoutingTableProviderSpec):
|
|
inner_impls = inner_impls_by_provider_id[f"inner-{provider.spec.router_api.value}"]
|
|
|
|
impl = await instantiate_provider(provider, deps, inner_impls, dist_registry, run_config, policy)
|
|
|
|
if api_str.startswith("inner-"):
|
|
inner_impls_by_provider_id[api_str][provider.provider_id] = impl
|
|
else:
|
|
api = Api(api_str)
|
|
impls[api] = impl
|
|
|
|
return impls
|
|
|
|
|
|
def topological_sort(
|
|
providers_with_specs: dict[str, list[ProviderWithSpec]],
|
|
) -> list[tuple[str, ProviderWithSpec]]:
|
|
def dfs(kv, visited: set[str], stack: list[str]):
|
|
api_str, providers = kv
|
|
visited.add(api_str)
|
|
|
|
deps = []
|
|
for provider in providers:
|
|
for dep in provider.spec.deps__:
|
|
deps.append(dep)
|
|
|
|
for dep in deps:
|
|
if dep not in visited and dep in providers_with_specs:
|
|
dfs((dep, providers_with_specs[dep]), visited, stack)
|
|
|
|
stack.append(api_str)
|
|
|
|
visited: set[str] = set()
|
|
stack: list[str] = []
|
|
|
|
for api_str, providers in providers_with_specs.items():
|
|
if api_str not in visited:
|
|
dfs((api_str, providers), visited, stack)
|
|
|
|
flattened = []
|
|
for api_str in stack:
|
|
for provider in providers_with_specs[api_str]:
|
|
flattened.append((api_str, provider))
|
|
|
|
return flattened
|
|
|
|
|
|
# returns a class implementing the protocol corresponding to the Api
|
|
async def instantiate_provider(
|
|
provider: ProviderWithSpec,
|
|
deps: dict[Api, Any],
|
|
inner_impls: dict[str, Any],
|
|
dist_registry: DistributionRegistry,
|
|
run_config: StackRunConfig,
|
|
policy: list[AccessRule],
|
|
):
|
|
provider_spec = provider.spec
|
|
if not hasattr(provider_spec, "module") or provider_spec.module is None:
|
|
raise AttributeError(f"ProviderSpec of type {type(provider_spec)} does not have a 'module' attribute")
|
|
|
|
logger.debug(f"Instantiating provider {provider.provider_id} from {provider_spec.module}")
|
|
module = importlib.import_module(provider_spec.module)
|
|
args = []
|
|
if isinstance(provider_spec, RemoteProviderSpec):
|
|
config_type = instantiate_class_type(provider_spec.config_class)
|
|
config = config_type(**provider.config)
|
|
|
|
method = "get_adapter_impl"
|
|
args = [config, deps]
|
|
|
|
elif isinstance(provider_spec, AutoRoutedProviderSpec):
|
|
method = "get_auto_router_impl"
|
|
|
|
config = None
|
|
args = [provider_spec.api, deps[provider_spec.routing_table_api], deps, run_config, policy]
|
|
elif isinstance(provider_spec, RoutingTableProviderSpec):
|
|
method = "get_routing_table_impl"
|
|
|
|
config = None
|
|
args = [provider_spec.api, inner_impls, deps, dist_registry, policy]
|
|
else:
|
|
method = "get_provider_impl"
|
|
|
|
config_type = instantiate_class_type(provider_spec.config_class)
|
|
config = config_type(**provider.config)
|
|
args = [config, deps]
|
|
if "policy" in inspect.signature(getattr(module, method)).parameters:
|
|
args.append(policy)
|
|
|
|
fn = getattr(module, method)
|
|
impl = await fn(*args)
|
|
impl.__provider_id__ = provider.provider_id
|
|
impl.__provider_spec__ = provider_spec
|
|
impl.__provider_config__ = config
|
|
|
|
protocols = api_protocol_map_for_compliance_check(run_config)
|
|
additional_protocols = additional_protocols_map()
|
|
# TODO: check compliance for special tool groups
|
|
# the impl should be for Api.tool_runtime, the name should be the special tool group, the protocol should be the special tool group protocol
|
|
check_protocol_compliance(impl, protocols[provider_spec.api])
|
|
if not isinstance(provider_spec, AutoRoutedProviderSpec) and provider_spec.api in additional_protocols:
|
|
additional_api, _, _ = additional_protocols[provider_spec.api]
|
|
check_protocol_compliance(impl, additional_api)
|
|
|
|
return impl
|
|
|
|
|
|
def check_protocol_compliance(obj: Any, protocol: Any) -> None:
|
|
missing_methods = []
|
|
|
|
mro = type(obj).__mro__
|
|
for name, value in inspect.getmembers(protocol):
|
|
if inspect.isfunction(value) and hasattr(value, "__webmethods__"):
|
|
has_alpha_api = False
|
|
for webmethod in value.__webmethods__:
|
|
if webmethod.level == LLAMA_STACK_API_V1ALPHA:
|
|
has_alpha_api = True
|
|
break
|
|
# if this API has multiple webmethods, and one of them is an alpha API, this API should be skipped when checking for missing or not callable routes
|
|
if has_alpha_api:
|
|
continue
|
|
if not hasattr(obj, name):
|
|
missing_methods.append((name, "missing"))
|
|
elif not callable(getattr(obj, name)):
|
|
missing_methods.append((name, "not_callable"))
|
|
else:
|
|
# Check if the method signatures are compatible
|
|
obj_method = getattr(obj, name)
|
|
proto_sig = inspect.signature(value)
|
|
obj_sig = inspect.signature(obj_method)
|
|
|
|
proto_params = set(proto_sig.parameters)
|
|
proto_params.discard("self")
|
|
obj_params = set(obj_sig.parameters)
|
|
obj_params.discard("self")
|
|
if not (proto_params <= obj_params):
|
|
logger.error(f"Method {name} incompatible proto: {proto_params} vs. obj: {obj_params}")
|
|
missing_methods.append((name, "signature_mismatch"))
|
|
else:
|
|
# Check if the method has a concrete implementation (not just a protocol stub)
|
|
# Find all classes in MRO that define this method
|
|
method_owners = [cls for cls in mro if name in cls.__dict__]
|
|
|
|
# Allow methods from mixins/parents, only reject if ONLY the protocol defines it
|
|
if len(method_owners) == 1 and method_owners[0].__name__ == protocol.__name__:
|
|
# Only reject if the method is ONLY defined in the protocol itself (abstract stub)
|
|
missing_methods.append((name, "not_actually_implemented"))
|
|
|
|
if missing_methods:
|
|
raise ValueError(
|
|
f"Provider `{obj.__provider_id__} ({obj.__provider_spec__.api})` does not implement the following methods:\n{missing_methods}"
|
|
)
|
|
|
|
|
|
async def resolve_remote_stack_impls(
|
|
config: RemoteProviderConfig,
|
|
apis: list[str],
|
|
) -> dict[Api, Any]:
|
|
protocols = api_protocol_map()
|
|
additional_protocols = additional_protocols_map()
|
|
|
|
impls = {}
|
|
for api_str in apis:
|
|
api = Api(api_str)
|
|
impls[api] = await get_client_impl(
|
|
protocols[api],
|
|
config,
|
|
{},
|
|
)
|
|
if api in additional_protocols:
|
|
_, additional_protocol, additional_api = additional_protocols[api]
|
|
impls[additional_api] = await get_client_impl(
|
|
additional_protocol,
|
|
config,
|
|
{},
|
|
)
|
|
|
|
return impls
|