mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-08-12 04:50:39 +00:00
make router call telemetry
This commit is contained in:
parent
a72cdafac0
commit
38f1337afa
5 changed files with 50 additions and 31 deletions
|
@ -182,7 +182,9 @@ async def resolve_impls(
|
|||
module="llama_stack.distribution.routers",
|
||||
routing_table_api=info.routing_table_api,
|
||||
api_dependencies=[info.routing_table_api],
|
||||
deps__=([info.routing_table_api.value]),
|
||||
# Add telemetry as an optional dependency to all auto-routed providers
|
||||
optional_api_dependencies=[Api.telemetry],
|
||||
deps__=([info.routing_table_api.value, Api.telemetry.value]),
|
||||
),
|
||||
)
|
||||
}
|
||||
|
|
|
@ -45,7 +45,7 @@ async def get_routing_table_impl(
|
|||
return impl
|
||||
|
||||
|
||||
async def get_auto_router_impl(api: Api, routing_table: RoutingTable, _deps) -> Any:
|
||||
async def get_auto_router_impl(api: Api, routing_table: RoutingTable, _deps: Dict[str, Any]) -> Any:
|
||||
from .routers import (
|
||||
DatasetIORouter,
|
||||
EvalRouter,
|
||||
|
@ -65,9 +65,18 @@ async def get_auto_router_impl(api: Api, routing_table: RoutingTable, _deps) ->
|
|||
"eval": EvalRouter,
|
||||
"tool_runtime": ToolRuntimeRouter,
|
||||
}
|
||||
api_to_deps = {
|
||||
"inference": [Api.telemetry],
|
||||
}
|
||||
if api.value not in api_to_routers:
|
||||
raise ValueError(f"API {api.value} not found in router map")
|
||||
|
||||
impl = api_to_routers[api.value](routing_table)
|
||||
deps = []
|
||||
for dep in api_to_deps.get(api.value, []):
|
||||
if dep not in _deps:
|
||||
raise ValueError(f"Dependency {dep} not found in _deps")
|
||||
deps.append(_deps[dep])
|
||||
|
||||
impl = api_to_routers[api.value](routing_table, *deps)
|
||||
await impl.initialize()
|
||||
return impl
|
||||
|
|
|
@ -4,6 +4,7 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import time
|
||||
from typing import Any, AsyncGenerator, Dict, List, Optional
|
||||
|
||||
from llama_models.llama3.api.chat_format import ChatFormat
|
||||
|
@ -45,7 +46,7 @@ from llama_stack.apis.scoring import (
|
|||
ScoringFnParams,
|
||||
)
|
||||
from llama_stack.apis.shields import Shield
|
||||
from llama_stack.apis.telemetry import TokenUsage
|
||||
from llama_stack.apis.telemetry import MetricEvent, Telemetry, TokenUsage
|
||||
from llama_stack.apis.tools import (
|
||||
RAGDocument,
|
||||
RAGQueryConfig,
|
||||
|
@ -57,6 +58,7 @@ from llama_stack.apis.tools import (
|
|||
from llama_stack.apis.vector_io import Chunk, QueryChunksResponse, VectorIO
|
||||
from llama_stack.providers.datatypes import RoutingTable
|
||||
from llama_stack.providers.utils.inference.prompt_adapter import get_default_tool_prompt_format
|
||||
from llama_stack.providers.utils.telemetry.tracing import get_current_span
|
||||
|
||||
|
||||
class VectorIORouter(VectorIO):
|
||||
|
@ -113,8 +115,10 @@ class InferenceRouter(Inference):
|
|||
def __init__(
|
||||
self,
|
||||
routing_table: RoutingTable,
|
||||
telemetry: Telemetry,
|
||||
) -> None:
|
||||
self.routing_table = routing_table
|
||||
self.telemetry = telemetry
|
||||
self.tokenizer = Tokenizer.get_instance()
|
||||
self.formatter = ChatFormat(self.tokenizer)
|
||||
|
||||
|
@ -217,6 +221,31 @@ class InferenceRouter(Inference):
|
|||
total_tokens=total_tokens,
|
||||
)
|
||||
)
|
||||
# Log token usage metrics
|
||||
metrics = [
|
||||
("prompt_tokens", prompt_tokens),
|
||||
("completion_tokens", completion_tokens),
|
||||
("total_tokens", total_tokens),
|
||||
]
|
||||
|
||||
span = get_current_span()
|
||||
if span:
|
||||
breakpoint()
|
||||
for metric_name, value in metrics:
|
||||
await self.telemetry.log_event(
|
||||
MetricEvent(
|
||||
trace_id=span.trace_id,
|
||||
span_id=span.span_id,
|
||||
metric=metric_name,
|
||||
value=value,
|
||||
timestamp=time.time(),
|
||||
unit="tokens",
|
||||
attributes={
|
||||
"model_id": model_id,
|
||||
"provider_id": model.provider_id,
|
||||
},
|
||||
)
|
||||
)
|
||||
return response
|
||||
|
||||
async def completion(
|
||||
|
|
|
@ -20,7 +20,7 @@ class TelemetrySink(str, Enum):
|
|||
|
||||
class TelemetryConfig(BaseModel):
|
||||
otel_endpoint: str = Field(
|
||||
default="http://localhost:4318/v1/traces",
|
||||
default="http://localhost:4318",
|
||||
description="The OpenTelemetry collector endpoint URL",
|
||||
)
|
||||
service_name: str = Field(
|
||||
|
|
|
@ -6,6 +6,7 @@
|
|||
|
||||
import threading
|
||||
from typing import Any, Dict, List, Optional
|
||||
from urllib.parse import urljoin
|
||||
|
||||
from opentelemetry import metrics, trace
|
||||
from opentelemetry.exporter.otlp.proto.http.metric_exporter import OTLPMetricExporter
|
||||
|
@ -92,13 +93,13 @@ class TelemetryAdapter(TelemetryDatasetMixin, Telemetry):
|
|||
_TRACER_PROVIDER = provider
|
||||
if TelemetrySink.OTEL in self.config.sinks:
|
||||
otlp_exporter = OTLPSpanExporter(
|
||||
endpoint=self.config.otel_endpoint,
|
||||
endpoint=urljoin(self.config.otel_endpoint, "v1/traces"),
|
||||
)
|
||||
span_processor = BatchSpanProcessor(otlp_exporter)
|
||||
trace.get_tracer_provider().add_span_processor(span_processor)
|
||||
metric_reader = PeriodicExportingMetricReader(
|
||||
OTLPMetricExporter(
|
||||
endpoint=self.config.otel_endpoint,
|
||||
endpoint=urljoin(self.config.otel_endpoint, "v1/metrics"),
|
||||
)
|
||||
)
|
||||
metric_provider = MeterProvider(resource=resource, metric_readers=[metric_reader])
|
||||
|
@ -161,31 +162,9 @@ class TelemetryAdapter(TelemetryDatasetMixin, Telemetry):
|
|||
)
|
||||
return _GLOBAL_STORAGE["counters"][name]
|
||||
|
||||
def _get_or_create_gauge(self, name: str, unit: str) -> metrics.ObservableGauge:
|
||||
if name not in _GLOBAL_STORAGE["gauges"]:
|
||||
_GLOBAL_STORAGE["gauges"][name] = self.meter.create_gauge(
|
||||
name=name,
|
||||
unit=unit,
|
||||
description=f"Gauge for {name}",
|
||||
)
|
||||
return _GLOBAL_STORAGE["gauges"][name]
|
||||
|
||||
def _log_metric(self, event: MetricEvent) -> None:
|
||||
if isinstance(event.value, int):
|
||||
counter = self._get_or_create_counter(event.metric, event.unit)
|
||||
counter.add(event.value, attributes=event.attributes)
|
||||
elif isinstance(event.value, float):
|
||||
up_down_counter = self._get_or_create_up_down_counter(event.metric, event.unit)
|
||||
up_down_counter.add(event.value, attributes=event.attributes)
|
||||
|
||||
def _get_or_create_up_down_counter(self, name: str, unit: str) -> metrics.UpDownCounter:
|
||||
if name not in _GLOBAL_STORAGE["up_down_counters"]:
|
||||
_GLOBAL_STORAGE["up_down_counters"][name] = self.meter.create_up_down_counter(
|
||||
name=name,
|
||||
unit=unit,
|
||||
description=f"UpDownCounter for {name}",
|
||||
)
|
||||
return _GLOBAL_STORAGE["up_down_counters"][name]
|
||||
|
||||
def _log_structured(self, event: StructuredLogEvent, ttl_seconds: int) -> None:
|
||||
with self._lock:
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue