From eb60f04f86a1b9da20e64f6d466f8445416a0c95 Mon Sep 17 00:00:00 2001 From: Ashwin Bharambe Date: Fri, 17 Jan 2025 15:26:53 -0800 Subject: [PATCH] optional api dependencies (#793) Co-authored-by: Dinesh Yeduguru --- llama_stack/distribution/resolver.py | 9 +++++++-- llama_stack/providers/datatypes.py | 3 +++ .../inline/telemetry/meta_reference/telemetry.py | 2 +- llama_stack/providers/registry/telemetry.py | 2 +- llama_stack/providers/utils/telemetry/dataset_mixin.py | 3 +++ 5 files changed, 15 insertions(+), 4 deletions(-) diff --git a/llama_stack/distribution/resolver.py b/llama_stack/distribution/resolver.py index d7e947a46..204555b16 100644 --- a/llama_stack/distribution/resolver.py +++ b/llama_stack/distribution/resolver.py @@ -145,7 +145,9 @@ async def resolve_impls( log.warning( f"Provider `{provider.provider_type}` for API `{api}` is deprecated and will be removed in a future release: {p.deprecation_warning}", ) - p.deps__ = [a.value for a in p.api_dependencies] + p.deps__ = [a.value for a in p.api_dependencies] + [ + a.value for a in p.optional_api_dependencies + ] spec = ProviderWithSpec( spec=p, **(provider.model_dump()), @@ -229,6 +231,9 @@ async def resolve_impls( inner_impls_by_provider_id = {f"inner-{x.value}": {} for x in router_apis} for api_str, provider in sorted_providers: deps = {a: impls[a] for a in provider.spec.api_dependencies} + for a in provider.spec.optional_api_dependencies: + if a in impls: + deps[a] = impls[a] inner_impls = {} if isinstance(provider.spec, RoutingTableProviderSpec): @@ -265,7 +270,7 @@ def topological_sort( deps.append(dep) for dep in deps: - if dep not in visited: + if dep not in visited and dep in providers_with_specs: dfs((dep, providers_with_specs[dep]), visited, stack) stack.append(api_str) diff --git a/llama_stack/providers/datatypes.py b/llama_stack/providers/datatypes.py index ce0c9f52e..3e64a62a1 100644 --- a/llama_stack/providers/datatypes.py +++ b/llama_stack/providers/datatypes.py @@ -96,6 +96,9 @@ class ProviderSpec(BaseModel): default_factory=list, description="Higher-level API surfaces may depend on other providers to provide their functionality", ) + optional_api_dependencies: List[Api] = Field( + default_factory=list, + ) deprecation_warning: Optional[str] = Field( default=None, description="If this provider is deprecated, specify the warning message here", diff --git a/llama_stack/providers/inline/telemetry/meta_reference/telemetry.py b/llama_stack/providers/inline/telemetry/meta_reference/telemetry.py index 4875f8cf0..aeeed1ac0 100644 --- a/llama_stack/providers/inline/telemetry/meta_reference/telemetry.py +++ b/llama_stack/providers/inline/telemetry/meta_reference/telemetry.py @@ -72,7 +72,7 @@ def is_tracing_enabled(tracer): class TelemetryAdapter(TelemetryDatasetMixin, Telemetry): def __init__(self, config: TelemetryConfig, deps: Dict[str, Any]) -> None: self.config = config - self.datasetio_api = deps[Api.datasetio] + self.datasetio_api = deps.get(Api.datasetio) resource = Resource.create( { diff --git a/llama_stack/providers/registry/telemetry.py b/llama_stack/providers/registry/telemetry.py index ba7e2f806..f3b41374c 100644 --- a/llama_stack/providers/registry/telemetry.py +++ b/llama_stack/providers/registry/telemetry.py @@ -24,7 +24,7 @@ def available_providers() -> List[ProviderSpec]: "opentelemetry-sdk", "opentelemetry-exporter-otlp-proto-http", ], - api_dependencies=[Api.datasetio], + optional_api_dependencies=[Api.datasetio], module="llama_stack.providers.inline.telemetry.meta_reference", config_class="llama_stack.providers.inline.telemetry.meta_reference.config.TelemetryConfig", ), diff --git a/llama_stack/providers/utils/telemetry/dataset_mixin.py b/llama_stack/providers/utils/telemetry/dataset_mixin.py index 6806f39aa..a2bfdcb87 100644 --- a/llama_stack/providers/utils/telemetry/dataset_mixin.py +++ b/llama_stack/providers/utils/telemetry/dataset_mixin.py @@ -22,6 +22,9 @@ class TelemetryDatasetMixin: dataset_id: str, max_depth: Optional[int] = None, ) -> None: + if self.datasetio_api is None: + raise RuntimeError("DatasetIO API not available") + spans = await self.query_spans( attribute_filters=attribute_filters, attributes_to_return=attributes_to_save,