make router call telemetry

This commit is contained in:
Dinesh Yeduguru 2025-02-04 15:42:33 -08:00
parent a72cdafac0
commit 38f1337afa
5 changed files with 50 additions and 31 deletions

View file

@ -182,7 +182,9 @@ async def resolve_impls(
module="llama_stack.distribution.routers",
routing_table_api=info.routing_table_api,
api_dependencies=[info.routing_table_api],
deps__=([info.routing_table_api.value]),
# Add telemetry as an optional dependency to all auto-routed providers
optional_api_dependencies=[Api.telemetry],
deps__=([info.routing_table_api.value, Api.telemetry.value]),
),
)
}

View file

@ -45,7 +45,7 @@ async def get_routing_table_impl(
return impl
async def get_auto_router_impl(api: Api, routing_table: RoutingTable, _deps) -> Any:
async def get_auto_router_impl(api: Api, routing_table: RoutingTable, _deps: Dict[str, Any]) -> Any:
from .routers import (
DatasetIORouter,
EvalRouter,
@ -65,9 +65,18 @@ async def get_auto_router_impl(api: Api, routing_table: RoutingTable, _deps) ->
"eval": EvalRouter,
"tool_runtime": ToolRuntimeRouter,
}
api_to_deps = {
"inference": [Api.telemetry],
}
if api.value not in api_to_routers:
raise ValueError(f"API {api.value} not found in router map")
impl = api_to_routers[api.value](routing_table)
deps = []
for dep in api_to_deps.get(api.value, []):
if dep not in _deps:
raise ValueError(f"Dependency {dep} not found in _deps")
deps.append(_deps[dep])
impl = api_to_routers[api.value](routing_table, *deps)
await impl.initialize()
return impl

View file

@ -4,6 +4,7 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import time
from typing import Any, AsyncGenerator, Dict, List, Optional
from llama_models.llama3.api.chat_format import ChatFormat
@ -45,7 +46,7 @@ from llama_stack.apis.scoring import (
ScoringFnParams,
)
from llama_stack.apis.shields import Shield
from llama_stack.apis.telemetry import TokenUsage
from llama_stack.apis.telemetry import MetricEvent, Telemetry, TokenUsage
from llama_stack.apis.tools import (
RAGDocument,
RAGQueryConfig,
@ -57,6 +58,7 @@ from llama_stack.apis.tools import (
from llama_stack.apis.vector_io import Chunk, QueryChunksResponse, VectorIO
from llama_stack.providers.datatypes import RoutingTable
from llama_stack.providers.utils.inference.prompt_adapter import get_default_tool_prompt_format
from llama_stack.providers.utils.telemetry.tracing import get_current_span
class VectorIORouter(VectorIO):
@ -113,8 +115,10 @@ class InferenceRouter(Inference):
def __init__(
self,
routing_table: RoutingTable,
telemetry: Telemetry,
) -> None:
self.routing_table = routing_table
self.telemetry = telemetry
self.tokenizer = Tokenizer.get_instance()
self.formatter = ChatFormat(self.tokenizer)
@ -217,6 +221,31 @@ class InferenceRouter(Inference):
total_tokens=total_tokens,
)
)
# Log token usage metrics
metrics = [
("prompt_tokens", prompt_tokens),
("completion_tokens", completion_tokens),
("total_tokens", total_tokens),
]
span = get_current_span()
if span:
breakpoint()
for metric_name, value in metrics:
await self.telemetry.log_event(
MetricEvent(
trace_id=span.trace_id,
span_id=span.span_id,
metric=metric_name,
value=value,
timestamp=time.time(),
unit="tokens",
attributes={
"model_id": model_id,
"provider_id": model.provider_id,
},
)
)
return response
async def completion(

View file

@ -20,7 +20,7 @@ class TelemetrySink(str, Enum):
class TelemetryConfig(BaseModel):
otel_endpoint: str = Field(
default="http://localhost:4318/v1/traces",
default="http://localhost:4318",
description="The OpenTelemetry collector endpoint URL",
)
service_name: str = Field(

View file

@ -6,6 +6,7 @@
import threading
from typing import Any, Dict, List, Optional
from urllib.parse import urljoin
from opentelemetry import metrics, trace
from opentelemetry.exporter.otlp.proto.http.metric_exporter import OTLPMetricExporter
@ -92,13 +93,13 @@ class TelemetryAdapter(TelemetryDatasetMixin, Telemetry):
_TRACER_PROVIDER = provider
if TelemetrySink.OTEL in self.config.sinks:
otlp_exporter = OTLPSpanExporter(
endpoint=self.config.otel_endpoint,
endpoint=urljoin(self.config.otel_endpoint, "v1/traces"),
)
span_processor = BatchSpanProcessor(otlp_exporter)
trace.get_tracer_provider().add_span_processor(span_processor)
metric_reader = PeriodicExportingMetricReader(
OTLPMetricExporter(
endpoint=self.config.otel_endpoint,
endpoint=urljoin(self.config.otel_endpoint, "v1/metrics"),
)
)
metric_provider = MeterProvider(resource=resource, metric_readers=[metric_reader])
@ -161,31 +162,9 @@ class TelemetryAdapter(TelemetryDatasetMixin, Telemetry):
)
return _GLOBAL_STORAGE["counters"][name]
def _get_or_create_gauge(self, name: str, unit: str) -> metrics.ObservableGauge:
if name not in _GLOBAL_STORAGE["gauges"]:
_GLOBAL_STORAGE["gauges"][name] = self.meter.create_gauge(
name=name,
unit=unit,
description=f"Gauge for {name}",
)
return _GLOBAL_STORAGE["gauges"][name]
def _log_metric(self, event: MetricEvent) -> None:
if isinstance(event.value, int):
counter = self._get_or_create_counter(event.metric, event.unit)
counter.add(event.value, attributes=event.attributes)
elif isinstance(event.value, float):
up_down_counter = self._get_or_create_up_down_counter(event.metric, event.unit)
up_down_counter.add(event.value, attributes=event.attributes)
def _get_or_create_up_down_counter(self, name: str, unit: str) -> metrics.UpDownCounter:
if name not in _GLOBAL_STORAGE["up_down_counters"]:
_GLOBAL_STORAGE["up_down_counters"][name] = self.meter.create_up_down_counter(
name=name,
unit=unit,
description=f"UpDownCounter for {name}",
)
return _GLOBAL_STORAGE["up_down_counters"][name]
counter = self._get_or_create_counter(event.metric, event.unit)
counter.add(event.value, attributes=event.attributes)
def _log_structured(self, event: StructuredLogEvent, ttl_seconds: int) -> None:
with self._lock: