optional api dependencies (#793)

Co-authored-by: Dinesh Yeduguru <yvdinesh@gmail.com>
This commit is contained in:
Ashwin Bharambe 2025-01-17 15:26:53 -08:00 committed by GitHub
parent 1f60c0286d
commit eb60f04f86
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
5 changed files with 15 additions and 4 deletions

View file

@ -145,7 +145,9 @@ async def resolve_impls(
log.warning(
f"Provider `{provider.provider_type}` for API `{api}` is deprecated and will be removed in a future release: {p.deprecation_warning}",
)
p.deps__ = [a.value for a in p.api_dependencies]
p.deps__ = [a.value for a in p.api_dependencies] + [
a.value for a in p.optional_api_dependencies
]
spec = ProviderWithSpec(
spec=p,
**(provider.model_dump()),
@ -229,6 +231,9 @@ async def resolve_impls(
inner_impls_by_provider_id = {f"inner-{x.value}": {} for x in router_apis}
for api_str, provider in sorted_providers:
deps = {a: impls[a] for a in provider.spec.api_dependencies}
for a in provider.spec.optional_api_dependencies:
if a in impls:
deps[a] = impls[a]
inner_impls = {}
if isinstance(provider.spec, RoutingTableProviderSpec):
@ -265,7 +270,7 @@ def topological_sort(
deps.append(dep)
for dep in deps:
if dep not in visited:
if dep not in visited and dep in providers_with_specs:
dfs((dep, providers_with_specs[dep]), visited, stack)
stack.append(api_str)

View file

@ -96,6 +96,9 @@ class ProviderSpec(BaseModel):
default_factory=list,
description="Higher-level API surfaces may depend on other providers to provide their functionality",
)
optional_api_dependencies: List[Api] = Field(
default_factory=list,
)
deprecation_warning: Optional[str] = Field(
default=None,
description="If this provider is deprecated, specify the warning message here",

View file

@ -72,7 +72,7 @@ def is_tracing_enabled(tracer):
class TelemetryAdapter(TelemetryDatasetMixin, Telemetry):
def __init__(self, config: TelemetryConfig, deps: Dict[str, Any]) -> None:
self.config = config
self.datasetio_api = deps[Api.datasetio]
self.datasetio_api = deps.get(Api.datasetio)
resource = Resource.create(
{

View file

@ -24,7 +24,7 @@ def available_providers() -> List[ProviderSpec]:
"opentelemetry-sdk",
"opentelemetry-exporter-otlp-proto-http",
],
api_dependencies=[Api.datasetio],
optional_api_dependencies=[Api.datasetio],
module="llama_stack.providers.inline.telemetry.meta_reference",
config_class="llama_stack.providers.inline.telemetry.meta_reference.config.TelemetryConfig",
),

View file

@ -22,6 +22,9 @@ class TelemetryDatasetMixin:
dataset_id: str,
max_depth: Optional[int] = None,
) -> None:
if self.datasetio_api is None:
raise RuntimeError("DatasetIO API not available")
spans = await self.query_spans(
attribute_filters=attribute_filters,
attributes_to_return=attributes_to_save,