optional api dependencies

This commit is contained in:
Ashwin Bharambe 2025-01-16 15:13:42 -08:00 committed by Dinesh Yeduguru
parent 1f60c0286d
commit 65e64f6877
5 changed files with 11 additions and 2 deletions

View file

@ -229,6 +229,9 @@ async def resolve_impls(
inner_impls_by_provider_id = {f"inner-{x.value}": {} for x in router_apis} inner_impls_by_provider_id = {f"inner-{x.value}": {} for x in router_apis}
for api_str, provider in sorted_providers: for api_str, provider in sorted_providers:
deps = {a: impls[a] for a in provider.spec.api_dependencies} deps = {a: impls[a] for a in provider.spec.api_dependencies}
for a in provider.spec.optional_api_dependencies:
if a in impls:
deps[a] = impls[a]
inner_impls = {} inner_impls = {}
if isinstance(provider.spec, RoutingTableProviderSpec): if isinstance(provider.spec, RoutingTableProviderSpec):

View file

@ -96,6 +96,9 @@ class ProviderSpec(BaseModel):
default_factory=list, default_factory=list,
description="Higher-level API surfaces may depend on other providers to provide their functionality", description="Higher-level API surfaces may depend on other providers to provide their functionality",
) )
optional_api_dependencies: List[Api] = Field(
default_factory=list,
)
deprecation_warning: Optional[str] = Field( deprecation_warning: Optional[str] = Field(
default=None, default=None,
description="If this provider is deprecated, specify the warning message here", description="If this provider is deprecated, specify the warning message here",

View file

@ -72,7 +72,7 @@ def is_tracing_enabled(tracer):
class TelemetryAdapter(TelemetryDatasetMixin, Telemetry): class TelemetryAdapter(TelemetryDatasetMixin, Telemetry):
def __init__(self, config: TelemetryConfig, deps: Dict[str, Any]) -> None: def __init__(self, config: TelemetryConfig, deps: Dict[str, Any]) -> None:
self.config = config self.config = config
self.datasetio_api = deps[Api.datasetio] self.datasetio_api = deps.get(Api.datasetio)
resource = Resource.create( resource = Resource.create(
{ {

View file

@ -24,7 +24,7 @@ def available_providers() -> List[ProviderSpec]:
"opentelemetry-sdk", "opentelemetry-sdk",
"opentelemetry-exporter-otlp-proto-http", "opentelemetry-exporter-otlp-proto-http",
], ],
api_dependencies=[Api.datasetio], optional_api_dependencies=[Api.datasetio],
module="llama_stack.providers.inline.telemetry.meta_reference", module="llama_stack.providers.inline.telemetry.meta_reference",
config_class="llama_stack.providers.inline.telemetry.meta_reference.config.TelemetryConfig", config_class="llama_stack.providers.inline.telemetry.meta_reference.config.TelemetryConfig",
), ),

View file

@ -22,6 +22,9 @@ class TelemetryDatasetMixin:
dataset_id: str, dataset_id: str,
max_depth: Optional[int] = None, max_depth: Optional[int] = None,
) -> None: ) -> None:
if self.datasetio_api is None:
raise RuntimeError("DatasetIO API not available")
spans = await self.query_spans( spans = await self.query_spans(
attribute_filters=attribute_filters, attribute_filters=attribute_filters,
attributes_to_return=attributes_to_save, attributes_to_return=attributes_to_save,