make router call telemetry

This commit is contained in:
Dinesh Yeduguru 2025-02-04 15:42:33 -08:00
parent a72cdafac0
commit 38f1337afa
5 changed files with 50 additions and 31 deletions

View file

@ -182,7 +182,9 @@ async def resolve_impls(
module="llama_stack.distribution.routers", module="llama_stack.distribution.routers",
routing_table_api=info.routing_table_api, routing_table_api=info.routing_table_api,
api_dependencies=[info.routing_table_api], api_dependencies=[info.routing_table_api],
deps__=([info.routing_table_api.value]), # Add telemetry as an optional dependency to all auto-routed providers
optional_api_dependencies=[Api.telemetry],
deps__=([info.routing_table_api.value, Api.telemetry.value]),
), ),
) )
} }

View file

@ -45,7 +45,7 @@ async def get_routing_table_impl(
return impl return impl
async def get_auto_router_impl(api: Api, routing_table: RoutingTable, _deps) -> Any: async def get_auto_router_impl(api: Api, routing_table: RoutingTable, _deps: Dict[str, Any]) -> Any:
from .routers import ( from .routers import (
DatasetIORouter, DatasetIORouter,
EvalRouter, EvalRouter,
@ -65,9 +65,18 @@ async def get_auto_router_impl(api: Api, routing_table: RoutingTable, _deps) ->
"eval": EvalRouter, "eval": EvalRouter,
"tool_runtime": ToolRuntimeRouter, "tool_runtime": ToolRuntimeRouter,
} }
api_to_deps = {
"inference": [Api.telemetry],
}
if api.value not in api_to_routers: if api.value not in api_to_routers:
raise ValueError(f"API {api.value} not found in router map") raise ValueError(f"API {api.value} not found in router map")
impl = api_to_routers[api.value](routing_table) deps = []
for dep in api_to_deps.get(api.value, []):
if dep not in _deps:
raise ValueError(f"Dependency {dep} not found in _deps")
deps.append(_deps[dep])
impl = api_to_routers[api.value](routing_table, *deps)
await impl.initialize() await impl.initialize()
return impl return impl

View file

@ -4,6 +4,7 @@
# This source code is licensed under the terms described in the LICENSE file in # This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree. # the root directory of this source tree.
import time
from typing import Any, AsyncGenerator, Dict, List, Optional from typing import Any, AsyncGenerator, Dict, List, Optional
from llama_models.llama3.api.chat_format import ChatFormat from llama_models.llama3.api.chat_format import ChatFormat
@ -45,7 +46,7 @@ from llama_stack.apis.scoring import (
ScoringFnParams, ScoringFnParams,
) )
from llama_stack.apis.shields import Shield from llama_stack.apis.shields import Shield
from llama_stack.apis.telemetry import TokenUsage from llama_stack.apis.telemetry import MetricEvent, Telemetry, TokenUsage
from llama_stack.apis.tools import ( from llama_stack.apis.tools import (
RAGDocument, RAGDocument,
RAGQueryConfig, RAGQueryConfig,
@ -57,6 +58,7 @@ from llama_stack.apis.tools import (
from llama_stack.apis.vector_io import Chunk, QueryChunksResponse, VectorIO from llama_stack.apis.vector_io import Chunk, QueryChunksResponse, VectorIO
from llama_stack.providers.datatypes import RoutingTable from llama_stack.providers.datatypes import RoutingTable
from llama_stack.providers.utils.inference.prompt_adapter import get_default_tool_prompt_format from llama_stack.providers.utils.inference.prompt_adapter import get_default_tool_prompt_format
from llama_stack.providers.utils.telemetry.tracing import get_current_span
class VectorIORouter(VectorIO): class VectorIORouter(VectorIO):
@ -113,8 +115,10 @@ class InferenceRouter(Inference):
def __init__( def __init__(
self, self,
routing_table: RoutingTable, routing_table: RoutingTable,
telemetry: Telemetry,
) -> None: ) -> None:
self.routing_table = routing_table self.routing_table = routing_table
self.telemetry = telemetry
self.tokenizer = Tokenizer.get_instance() self.tokenizer = Tokenizer.get_instance()
self.formatter = ChatFormat(self.tokenizer) self.formatter = ChatFormat(self.tokenizer)
@ -217,6 +221,31 @@ class InferenceRouter(Inference):
total_tokens=total_tokens, total_tokens=total_tokens,
) )
) )
# Log token usage metrics
metrics = [
("prompt_tokens", prompt_tokens),
("completion_tokens", completion_tokens),
("total_tokens", total_tokens),
]
span = get_current_span()
if span:
breakpoint()
for metric_name, value in metrics:
await self.telemetry.log_event(
MetricEvent(
trace_id=span.trace_id,
span_id=span.span_id,
metric=metric_name,
value=value,
timestamp=time.time(),
unit="tokens",
attributes={
"model_id": model_id,
"provider_id": model.provider_id,
},
)
)
return response return response
async def completion( async def completion(

View file

@ -20,7 +20,7 @@ class TelemetrySink(str, Enum):
class TelemetryConfig(BaseModel): class TelemetryConfig(BaseModel):
otel_endpoint: str = Field( otel_endpoint: str = Field(
default="http://localhost:4318/v1/traces", default="http://localhost:4318",
description="The OpenTelemetry collector endpoint URL", description="The OpenTelemetry collector endpoint URL",
) )
service_name: str = Field( service_name: str = Field(

View file

@ -6,6 +6,7 @@
import threading import threading
from typing import Any, Dict, List, Optional from typing import Any, Dict, List, Optional
from urllib.parse import urljoin
from opentelemetry import metrics, trace from opentelemetry import metrics, trace
from opentelemetry.exporter.otlp.proto.http.metric_exporter import OTLPMetricExporter from opentelemetry.exporter.otlp.proto.http.metric_exporter import OTLPMetricExporter
@ -92,13 +93,13 @@ class TelemetryAdapter(TelemetryDatasetMixin, Telemetry):
_TRACER_PROVIDER = provider _TRACER_PROVIDER = provider
if TelemetrySink.OTEL in self.config.sinks: if TelemetrySink.OTEL in self.config.sinks:
otlp_exporter = OTLPSpanExporter( otlp_exporter = OTLPSpanExporter(
endpoint=self.config.otel_endpoint, endpoint=urljoin(self.config.otel_endpoint, "v1/traces"),
) )
span_processor = BatchSpanProcessor(otlp_exporter) span_processor = BatchSpanProcessor(otlp_exporter)
trace.get_tracer_provider().add_span_processor(span_processor) trace.get_tracer_provider().add_span_processor(span_processor)
metric_reader = PeriodicExportingMetricReader( metric_reader = PeriodicExportingMetricReader(
OTLPMetricExporter( OTLPMetricExporter(
endpoint=self.config.otel_endpoint, endpoint=urljoin(self.config.otel_endpoint, "v1/metrics"),
) )
) )
metric_provider = MeterProvider(resource=resource, metric_readers=[metric_reader]) metric_provider = MeterProvider(resource=resource, metric_readers=[metric_reader])
@ -161,31 +162,9 @@ class TelemetryAdapter(TelemetryDatasetMixin, Telemetry):
) )
return _GLOBAL_STORAGE["counters"][name] return _GLOBAL_STORAGE["counters"][name]
def _get_or_create_gauge(self, name: str, unit: str) -> metrics.ObservableGauge:
if name not in _GLOBAL_STORAGE["gauges"]:
_GLOBAL_STORAGE["gauges"][name] = self.meter.create_gauge(
name=name,
unit=unit,
description=f"Gauge for {name}",
)
return _GLOBAL_STORAGE["gauges"][name]
def _log_metric(self, event: MetricEvent) -> None: def _log_metric(self, event: MetricEvent) -> None:
if isinstance(event.value, int): counter = self._get_or_create_counter(event.metric, event.unit)
counter = self._get_or_create_counter(event.metric, event.unit) counter.add(event.value, attributes=event.attributes)
counter.add(event.value, attributes=event.attributes)
elif isinstance(event.value, float):
up_down_counter = self._get_or_create_up_down_counter(event.metric, event.unit)
up_down_counter.add(event.value, attributes=event.attributes)
def _get_or_create_up_down_counter(self, name: str, unit: str) -> metrics.UpDownCounter:
if name not in _GLOBAL_STORAGE["up_down_counters"]:
_GLOBAL_STORAGE["up_down_counters"][name] = self.meter.create_up_down_counter(
name=name,
unit=unit,
description=f"UpDownCounter for {name}",
)
return _GLOBAL_STORAGE["up_down_counters"][name]
def _log_structured(self, event: StructuredLogEvent, ttl_seconds: int) -> None: def _log_structured(self, event: StructuredLogEvent, ttl_seconds: int) -> None:
with self._lock: with self._lock: