mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-08-12 04:50:39 +00:00
make router call telemetry
This commit is contained in:
parent
a72cdafac0
commit
38f1337afa
5 changed files with 50 additions and 31 deletions
|
@ -182,7 +182,9 @@ async def resolve_impls(
|
||||||
module="llama_stack.distribution.routers",
|
module="llama_stack.distribution.routers",
|
||||||
routing_table_api=info.routing_table_api,
|
routing_table_api=info.routing_table_api,
|
||||||
api_dependencies=[info.routing_table_api],
|
api_dependencies=[info.routing_table_api],
|
||||||
deps__=([info.routing_table_api.value]),
|
# Add telemetry as an optional dependency to all auto-routed providers
|
||||||
|
optional_api_dependencies=[Api.telemetry],
|
||||||
|
deps__=([info.routing_table_api.value, Api.telemetry.value]),
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|
|
@ -45,7 +45,7 @@ async def get_routing_table_impl(
|
||||||
return impl
|
return impl
|
||||||
|
|
||||||
|
|
||||||
async def get_auto_router_impl(api: Api, routing_table: RoutingTable, _deps) -> Any:
|
async def get_auto_router_impl(api: Api, routing_table: RoutingTable, _deps: Dict[str, Any]) -> Any:
|
||||||
from .routers import (
|
from .routers import (
|
||||||
DatasetIORouter,
|
DatasetIORouter,
|
||||||
EvalRouter,
|
EvalRouter,
|
||||||
|
@ -65,9 +65,18 @@ async def get_auto_router_impl(api: Api, routing_table: RoutingTable, _deps) ->
|
||||||
"eval": EvalRouter,
|
"eval": EvalRouter,
|
||||||
"tool_runtime": ToolRuntimeRouter,
|
"tool_runtime": ToolRuntimeRouter,
|
||||||
}
|
}
|
||||||
|
api_to_deps = {
|
||||||
|
"inference": [Api.telemetry],
|
||||||
|
}
|
||||||
if api.value not in api_to_routers:
|
if api.value not in api_to_routers:
|
||||||
raise ValueError(f"API {api.value} not found in router map")
|
raise ValueError(f"API {api.value} not found in router map")
|
||||||
|
|
||||||
impl = api_to_routers[api.value](routing_table)
|
deps = []
|
||||||
|
for dep in api_to_deps.get(api.value, []):
|
||||||
|
if dep not in _deps:
|
||||||
|
raise ValueError(f"Dependency {dep} not found in _deps")
|
||||||
|
deps.append(_deps[dep])
|
||||||
|
|
||||||
|
impl = api_to_routers[api.value](routing_table, *deps)
|
||||||
await impl.initialize()
|
await impl.initialize()
|
||||||
return impl
|
return impl
|
||||||
|
|
|
@ -4,6 +4,7 @@
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
import time
|
||||||
from typing import Any, AsyncGenerator, Dict, List, Optional
|
from typing import Any, AsyncGenerator, Dict, List, Optional
|
||||||
|
|
||||||
from llama_models.llama3.api.chat_format import ChatFormat
|
from llama_models.llama3.api.chat_format import ChatFormat
|
||||||
|
@ -45,7 +46,7 @@ from llama_stack.apis.scoring import (
|
||||||
ScoringFnParams,
|
ScoringFnParams,
|
||||||
)
|
)
|
||||||
from llama_stack.apis.shields import Shield
|
from llama_stack.apis.shields import Shield
|
||||||
from llama_stack.apis.telemetry import TokenUsage
|
from llama_stack.apis.telemetry import MetricEvent, Telemetry, TokenUsage
|
||||||
from llama_stack.apis.tools import (
|
from llama_stack.apis.tools import (
|
||||||
RAGDocument,
|
RAGDocument,
|
||||||
RAGQueryConfig,
|
RAGQueryConfig,
|
||||||
|
@ -57,6 +58,7 @@ from llama_stack.apis.tools import (
|
||||||
from llama_stack.apis.vector_io import Chunk, QueryChunksResponse, VectorIO
|
from llama_stack.apis.vector_io import Chunk, QueryChunksResponse, VectorIO
|
||||||
from llama_stack.providers.datatypes import RoutingTable
|
from llama_stack.providers.datatypes import RoutingTable
|
||||||
from llama_stack.providers.utils.inference.prompt_adapter import get_default_tool_prompt_format
|
from llama_stack.providers.utils.inference.prompt_adapter import get_default_tool_prompt_format
|
||||||
|
from llama_stack.providers.utils.telemetry.tracing import get_current_span
|
||||||
|
|
||||||
|
|
||||||
class VectorIORouter(VectorIO):
|
class VectorIORouter(VectorIO):
|
||||||
|
@ -113,8 +115,10 @@ class InferenceRouter(Inference):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
routing_table: RoutingTable,
|
routing_table: RoutingTable,
|
||||||
|
telemetry: Telemetry,
|
||||||
) -> None:
|
) -> None:
|
||||||
self.routing_table = routing_table
|
self.routing_table = routing_table
|
||||||
|
self.telemetry = telemetry
|
||||||
self.tokenizer = Tokenizer.get_instance()
|
self.tokenizer = Tokenizer.get_instance()
|
||||||
self.formatter = ChatFormat(self.tokenizer)
|
self.formatter = ChatFormat(self.tokenizer)
|
||||||
|
|
||||||
|
@ -217,6 +221,31 @@ class InferenceRouter(Inference):
|
||||||
total_tokens=total_tokens,
|
total_tokens=total_tokens,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
# Log token usage metrics
|
||||||
|
metrics = [
|
||||||
|
("prompt_tokens", prompt_tokens),
|
||||||
|
("completion_tokens", completion_tokens),
|
||||||
|
("total_tokens", total_tokens),
|
||||||
|
]
|
||||||
|
|
||||||
|
span = get_current_span()
|
||||||
|
if span:
|
||||||
|
breakpoint()
|
||||||
|
for metric_name, value in metrics:
|
||||||
|
await self.telemetry.log_event(
|
||||||
|
MetricEvent(
|
||||||
|
trace_id=span.trace_id,
|
||||||
|
span_id=span.span_id,
|
||||||
|
metric=metric_name,
|
||||||
|
value=value,
|
||||||
|
timestamp=time.time(),
|
||||||
|
unit="tokens",
|
||||||
|
attributes={
|
||||||
|
"model_id": model_id,
|
||||||
|
"provider_id": model.provider_id,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
)
|
||||||
return response
|
return response
|
||||||
|
|
||||||
async def completion(
|
async def completion(
|
||||||
|
|
|
@ -20,7 +20,7 @@ class TelemetrySink(str, Enum):
|
||||||
|
|
||||||
class TelemetryConfig(BaseModel):
|
class TelemetryConfig(BaseModel):
|
||||||
otel_endpoint: str = Field(
|
otel_endpoint: str = Field(
|
||||||
default="http://localhost:4318/v1/traces",
|
default="http://localhost:4318",
|
||||||
description="The OpenTelemetry collector endpoint URL",
|
description="The OpenTelemetry collector endpoint URL",
|
||||||
)
|
)
|
||||||
service_name: str = Field(
|
service_name: str = Field(
|
||||||
|
|
|
@ -6,6 +6,7 @@
|
||||||
|
|
||||||
import threading
|
import threading
|
||||||
from typing import Any, Dict, List, Optional
|
from typing import Any, Dict, List, Optional
|
||||||
|
from urllib.parse import urljoin
|
||||||
|
|
||||||
from opentelemetry import metrics, trace
|
from opentelemetry import metrics, trace
|
||||||
from opentelemetry.exporter.otlp.proto.http.metric_exporter import OTLPMetricExporter
|
from opentelemetry.exporter.otlp.proto.http.metric_exporter import OTLPMetricExporter
|
||||||
|
@ -92,13 +93,13 @@ class TelemetryAdapter(TelemetryDatasetMixin, Telemetry):
|
||||||
_TRACER_PROVIDER = provider
|
_TRACER_PROVIDER = provider
|
||||||
if TelemetrySink.OTEL in self.config.sinks:
|
if TelemetrySink.OTEL in self.config.sinks:
|
||||||
otlp_exporter = OTLPSpanExporter(
|
otlp_exporter = OTLPSpanExporter(
|
||||||
endpoint=self.config.otel_endpoint,
|
endpoint=urljoin(self.config.otel_endpoint, "v1/traces"),
|
||||||
)
|
)
|
||||||
span_processor = BatchSpanProcessor(otlp_exporter)
|
span_processor = BatchSpanProcessor(otlp_exporter)
|
||||||
trace.get_tracer_provider().add_span_processor(span_processor)
|
trace.get_tracer_provider().add_span_processor(span_processor)
|
||||||
metric_reader = PeriodicExportingMetricReader(
|
metric_reader = PeriodicExportingMetricReader(
|
||||||
OTLPMetricExporter(
|
OTLPMetricExporter(
|
||||||
endpoint=self.config.otel_endpoint,
|
endpoint=urljoin(self.config.otel_endpoint, "v1/metrics"),
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
metric_provider = MeterProvider(resource=resource, metric_readers=[metric_reader])
|
metric_provider = MeterProvider(resource=resource, metric_readers=[metric_reader])
|
||||||
|
@ -161,31 +162,9 @@ class TelemetryAdapter(TelemetryDatasetMixin, Telemetry):
|
||||||
)
|
)
|
||||||
return _GLOBAL_STORAGE["counters"][name]
|
return _GLOBAL_STORAGE["counters"][name]
|
||||||
|
|
||||||
def _get_or_create_gauge(self, name: str, unit: str) -> metrics.ObservableGauge:
|
|
||||||
if name not in _GLOBAL_STORAGE["gauges"]:
|
|
||||||
_GLOBAL_STORAGE["gauges"][name] = self.meter.create_gauge(
|
|
||||||
name=name,
|
|
||||||
unit=unit,
|
|
||||||
description=f"Gauge for {name}",
|
|
||||||
)
|
|
||||||
return _GLOBAL_STORAGE["gauges"][name]
|
|
||||||
|
|
||||||
def _log_metric(self, event: MetricEvent) -> None:
|
def _log_metric(self, event: MetricEvent) -> None:
|
||||||
if isinstance(event.value, int):
|
counter = self._get_or_create_counter(event.metric, event.unit)
|
||||||
counter = self._get_or_create_counter(event.metric, event.unit)
|
counter.add(event.value, attributes=event.attributes)
|
||||||
counter.add(event.value, attributes=event.attributes)
|
|
||||||
elif isinstance(event.value, float):
|
|
||||||
up_down_counter = self._get_or_create_up_down_counter(event.metric, event.unit)
|
|
||||||
up_down_counter.add(event.value, attributes=event.attributes)
|
|
||||||
|
|
||||||
def _get_or_create_up_down_counter(self, name: str, unit: str) -> metrics.UpDownCounter:
|
|
||||||
if name not in _GLOBAL_STORAGE["up_down_counters"]:
|
|
||||||
_GLOBAL_STORAGE["up_down_counters"][name] = self.meter.create_up_down_counter(
|
|
||||||
name=name,
|
|
||||||
unit=unit,
|
|
||||||
description=f"UpDownCounter for {name}",
|
|
||||||
)
|
|
||||||
return _GLOBAL_STORAGE["up_down_counters"][name]
|
|
||||||
|
|
||||||
def _log_structured(self, event: StructuredLogEvent, ttl_seconds: int) -> None:
|
def _log_structured(self, event: StructuredLogEvent, ttl_seconds: int) -> None:
|
||||||
with self._lock:
|
with self._lock:
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue