diff --git a/llama_stack/distribution/resolver.py b/llama_stack/distribution/resolver.py index 0bc2e774c..100987022 100644 --- a/llama_stack/distribution/resolver.py +++ b/llama_stack/distribution/resolver.py @@ -182,7 +182,9 @@ async def resolve_impls( module="llama_stack.distribution.routers", routing_table_api=info.routing_table_api, api_dependencies=[info.routing_table_api], - deps__=([info.routing_table_api.value]), + # Add telemetry as an optional dependency to all auto-routed providers + optional_api_dependencies=[Api.telemetry], + deps__=([info.routing_table_api.value, Api.telemetry.value]), ), ) } diff --git a/llama_stack/distribution/routers/__init__.py b/llama_stack/distribution/routers/__init__.py index a54f57fb3..6660e180c 100644 --- a/llama_stack/distribution/routers/__init__.py +++ b/llama_stack/distribution/routers/__init__.py @@ -45,7 +45,7 @@ async def get_routing_table_impl( return impl -async def get_auto_router_impl(api: Api, routing_table: RoutingTable, _deps) -> Any: +async def get_auto_router_impl(api: Api, routing_table: RoutingTable, _deps: Dict[str, Any]) -> Any: from .routers import ( DatasetIORouter, EvalRouter, @@ -65,9 +65,18 @@ async def get_auto_router_impl(api: Api, routing_table: RoutingTable, _deps) -> "eval": EvalRouter, "tool_runtime": ToolRuntimeRouter, } + api_to_deps = { + "inference": [Api.telemetry], + } if api.value not in api_to_routers: raise ValueError(f"API {api.value} not found in router map") - impl = api_to_routers[api.value](routing_table) + deps = [] + for dep in api_to_deps.get(api.value, []): + if dep not in _deps: + raise ValueError(f"Dependency {dep} not found in _deps") + deps.append(_deps[dep]) + + impl = api_to_routers[api.value](routing_table, *deps) await impl.initialize() return impl diff --git a/llama_stack/distribution/routers/routers.py b/llama_stack/distribution/routers/routers.py index 9047b1c33..1af5d43f9 100644 --- a/llama_stack/distribution/routers/routers.py +++ b/llama_stack/distribution/routers/routers.py @@ -4,6 +4,7 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. +import time from typing import Any, AsyncGenerator, Dict, List, Optional from llama_models.llama3.api.chat_format import ChatFormat @@ -45,7 +46,7 @@ from llama_stack.apis.scoring import ( ScoringFnParams, ) from llama_stack.apis.shields import Shield -from llama_stack.apis.telemetry import TokenUsage +from llama_stack.apis.telemetry import MetricEvent, Telemetry, TokenUsage from llama_stack.apis.tools import ( RAGDocument, RAGQueryConfig, @@ -57,6 +58,7 @@ from llama_stack.apis.tools import ( from llama_stack.apis.vector_io import Chunk, QueryChunksResponse, VectorIO from llama_stack.providers.datatypes import RoutingTable from llama_stack.providers.utils.inference.prompt_adapter import get_default_tool_prompt_format +from llama_stack.providers.utils.telemetry.tracing import get_current_span class VectorIORouter(VectorIO): @@ -113,8 +115,10 @@ class InferenceRouter(Inference): def __init__( self, routing_table: RoutingTable, + telemetry: Telemetry, ) -> None: self.routing_table = routing_table + self.telemetry = telemetry self.tokenizer = Tokenizer.get_instance() self.formatter = ChatFormat(self.tokenizer) @@ -217,6 +221,31 @@ class InferenceRouter(Inference): total_tokens=total_tokens, ) ) + # Log token usage metrics + metrics = [ + ("prompt_tokens", prompt_tokens), + ("completion_tokens", completion_tokens), + ("total_tokens", total_tokens), + ] + + span = get_current_span() + if span: + breakpoint() + for metric_name, value in metrics: + await self.telemetry.log_event( + MetricEvent( + trace_id=span.trace_id, + span_id=span.span_id, + metric=metric_name, + value=value, + timestamp=time.time(), + unit="tokens", + attributes={ + "model_id": model_id, + "provider_id": model.provider_id, + }, + ) + ) return response async def completion( diff --git a/llama_stack/providers/inline/telemetry/meta_reference/config.py b/llama_stack/providers/inline/telemetry/meta_reference/config.py index f409235d9..97b1c2efc 100644 --- a/llama_stack/providers/inline/telemetry/meta_reference/config.py +++ b/llama_stack/providers/inline/telemetry/meta_reference/config.py @@ -20,7 +20,7 @@ class TelemetrySink(str, Enum): class TelemetryConfig(BaseModel): otel_endpoint: str = Field( - default="http://localhost:4318/v1/traces", + default="http://localhost:4318", description="The OpenTelemetry collector endpoint URL", ) service_name: str = Field( diff --git a/llama_stack/providers/inline/telemetry/meta_reference/telemetry.py b/llama_stack/providers/inline/telemetry/meta_reference/telemetry.py index e713a057f..ef417ac18 100644 --- a/llama_stack/providers/inline/telemetry/meta_reference/telemetry.py +++ b/llama_stack/providers/inline/telemetry/meta_reference/telemetry.py @@ -6,6 +6,7 @@ import threading from typing import Any, Dict, List, Optional +from urllib.parse import urljoin from opentelemetry import metrics, trace from opentelemetry.exporter.otlp.proto.http.metric_exporter import OTLPMetricExporter @@ -92,13 +93,13 @@ class TelemetryAdapter(TelemetryDatasetMixin, Telemetry): _TRACER_PROVIDER = provider if TelemetrySink.OTEL in self.config.sinks: otlp_exporter = OTLPSpanExporter( - endpoint=self.config.otel_endpoint, + endpoint=urljoin(self.config.otel_endpoint, "v1/traces"), ) span_processor = BatchSpanProcessor(otlp_exporter) trace.get_tracer_provider().add_span_processor(span_processor) metric_reader = PeriodicExportingMetricReader( OTLPMetricExporter( - endpoint=self.config.otel_endpoint, + endpoint=urljoin(self.config.otel_endpoint, "v1/metrics"), ) ) metric_provider = MeterProvider(resource=resource, metric_readers=[metric_reader]) @@ -161,31 +162,9 @@ class TelemetryAdapter(TelemetryDatasetMixin, Telemetry): ) return _GLOBAL_STORAGE["counters"][name] - def _get_or_create_gauge(self, name: str, unit: str) -> metrics.ObservableGauge: - if name not in _GLOBAL_STORAGE["gauges"]: - _GLOBAL_STORAGE["gauges"][name] = self.meter.create_gauge( - name=name, - unit=unit, - description=f"Gauge for {name}", - ) - return _GLOBAL_STORAGE["gauges"][name] - def _log_metric(self, event: MetricEvent) -> None: - if isinstance(event.value, int): - counter = self._get_or_create_counter(event.metric, event.unit) - counter.add(event.value, attributes=event.attributes) - elif isinstance(event.value, float): - up_down_counter = self._get_or_create_up_down_counter(event.metric, event.unit) - up_down_counter.add(event.value, attributes=event.attributes) - - def _get_or_create_up_down_counter(self, name: str, unit: str) -> metrics.UpDownCounter: - if name not in _GLOBAL_STORAGE["up_down_counters"]: - _GLOBAL_STORAGE["up_down_counters"][name] = self.meter.create_up_down_counter( - name=name, - unit=unit, - description=f"UpDownCounter for {name}", - ) - return _GLOBAL_STORAGE["up_down_counters"][name] + counter = self._get_or_create_counter(event.metric, event.unit) + counter.add(event.value, attributes=event.attributes) def _log_structured(self, event: StructuredLogEvent, ttl_seconds: int) -> None: with self._lock: