mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-12-03 09:53:45 +00:00
fix(metrics): capture token metrics using auto instrumentation
This commit is contained in:
parent
aa2a7dae07
commit
f5455a64d1
6 changed files with 6 additions and 118 deletions
|
|
@ -392,8 +392,6 @@ async def instantiate_provider(
|
||||||
args = [config, deps]
|
args = [config, deps]
|
||||||
if "policy" in inspect.signature(getattr(module, method)).parameters:
|
if "policy" in inspect.signature(getattr(module, method)).parameters:
|
||||||
args.append(policy)
|
args.append(policy)
|
||||||
if "telemetry_enabled" in inspect.signature(getattr(module, method)).parameters and run_config.telemetry:
|
|
||||||
args.append(run_config.telemetry.enabled)
|
|
||||||
|
|
||||||
fn = getattr(module, method)
|
fn = getattr(module, method)
|
||||||
impl = await fn(*args)
|
impl = await fn(*args)
|
||||||
|
|
|
||||||
|
|
@ -85,7 +85,6 @@ async def get_auto_router_impl(
|
||||||
)
|
)
|
||||||
await inference_store.initialize()
|
await inference_store.initialize()
|
||||||
api_to_dep_impl["store"] = inference_store
|
api_to_dep_impl["store"] = inference_store
|
||||||
api_to_dep_impl["telemetry_enabled"] = run_config.telemetry.enabled
|
|
||||||
|
|
||||||
elif api == Api.vector_io:
|
elif api == Api.vector_io:
|
||||||
api_to_dep_impl["vector_stores_config"] = run_config.vector_stores
|
api_to_dep_impl["vector_stores_config"] = run_config.vector_stores
|
||||||
|
|
|
||||||
|
|
@ -7,7 +7,6 @@
|
||||||
import asyncio
|
import asyncio
|
||||||
import time
|
import time
|
||||||
from collections.abc import AsyncIterator
|
from collections.abc import AsyncIterator
|
||||||
from datetime import UTC, datetime
|
|
||||||
from typing import Annotated, Any
|
from typing import Annotated, Any
|
||||||
|
|
||||||
from fastapi import Body
|
from fastapi import Body
|
||||||
|
|
@ -15,11 +14,7 @@ from openai.types.chat import ChatCompletionToolChoiceOptionParam as OpenAIChatC
|
||||||
from openai.types.chat import ChatCompletionToolParam as OpenAIChatCompletionToolParam
|
from openai.types.chat import ChatCompletionToolParam as OpenAIChatCompletionToolParam
|
||||||
from pydantic import TypeAdapter
|
from pydantic import TypeAdapter
|
||||||
|
|
||||||
from llama_stack.core.telemetry.telemetry import MetricEvent
|
|
||||||
from llama_stack.core.telemetry.tracing import enqueue_event, get_current_span
|
|
||||||
from llama_stack.log import get_logger
|
from llama_stack.log import get_logger
|
||||||
from llama_stack.models.llama.llama3.chat_format import ChatFormat
|
|
||||||
from llama_stack.models.llama.llama3.tokenizer import Tokenizer
|
|
||||||
from llama_stack.providers.utils.inference.inference_store import InferenceStore
|
from llama_stack.providers.utils.inference.inference_store import InferenceStore
|
||||||
from llama_stack_api import (
|
from llama_stack_api import (
|
||||||
HealthResponse,
|
HealthResponse,
|
||||||
|
|
@ -49,6 +44,12 @@ from llama_stack_api import (
|
||||||
RerankResponse,
|
RerankResponse,
|
||||||
RoutingTable,
|
RoutingTable,
|
||||||
)
|
)
|
||||||
|
from openai.types.chat import ChatCompletionToolChoiceOptionParam as OpenAIChatCompletionToolChoiceOptionParam
|
||||||
|
from openai.types.chat import ChatCompletionToolParam as OpenAIChatCompletionToolParam
|
||||||
|
from pydantic import TypeAdapter
|
||||||
|
|
||||||
|
from llama_stack.log import get_logger
|
||||||
|
from llama_stack.providers.utils.inference.inference_store import InferenceStore
|
||||||
|
|
||||||
logger = get_logger(name=__name__, category="core::routers")
|
logger = get_logger(name=__name__, category="core::routers")
|
||||||
|
|
||||||
|
|
@ -60,15 +61,10 @@ class InferenceRouter(Inference):
|
||||||
self,
|
self,
|
||||||
routing_table: RoutingTable,
|
routing_table: RoutingTable,
|
||||||
store: InferenceStore | None = None,
|
store: InferenceStore | None = None,
|
||||||
telemetry_enabled: bool = False,
|
|
||||||
) -> None:
|
) -> None:
|
||||||
logger.debug("Initializing InferenceRouter")
|
logger.debug("Initializing InferenceRouter")
|
||||||
self.routing_table = routing_table
|
self.routing_table = routing_table
|
||||||
self.telemetry_enabled = telemetry_enabled
|
|
||||||
self.store = store
|
self.store = store
|
||||||
if self.telemetry_enabled:
|
|
||||||
self.tokenizer = Tokenizer.get_instance()
|
|
||||||
self.formatter = ChatFormat(self.tokenizer)
|
|
||||||
|
|
||||||
async def initialize(self) -> None:
|
async def initialize(self) -> None:
|
||||||
logger.debug("InferenceRouter.initialize")
|
logger.debug("InferenceRouter.initialize")
|
||||||
|
|
@ -94,54 +90,6 @@ class InferenceRouter(Inference):
|
||||||
)
|
)
|
||||||
await self.routing_table.register_model(model_id, provider_model_id, provider_id, metadata, model_type)
|
await self.routing_table.register_model(model_id, provider_model_id, provider_id, metadata, model_type)
|
||||||
|
|
||||||
def _construct_metrics(
|
|
||||||
self,
|
|
||||||
prompt_tokens: int,
|
|
||||||
completion_tokens: int,
|
|
||||||
total_tokens: int,
|
|
||||||
fully_qualified_model_id: str,
|
|
||||||
provider_id: str,
|
|
||||||
) -> list[MetricEvent]:
|
|
||||||
"""Constructs a list of MetricEvent objects containing token usage metrics.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
prompt_tokens: Number of tokens in the prompt
|
|
||||||
completion_tokens: Number of tokens in the completion
|
|
||||||
total_tokens: Total number of tokens used
|
|
||||||
fully_qualified_model_id:
|
|
||||||
provider_id: The provider identifier
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
List of MetricEvent objects with token usage metrics
|
|
||||||
"""
|
|
||||||
span = get_current_span()
|
|
||||||
if span is None:
|
|
||||||
logger.warning("No span found for token usage metrics")
|
|
||||||
return []
|
|
||||||
|
|
||||||
metrics = [
|
|
||||||
("prompt_tokens", prompt_tokens),
|
|
||||||
("completion_tokens", completion_tokens),
|
|
||||||
("total_tokens", total_tokens),
|
|
||||||
]
|
|
||||||
metric_events = []
|
|
||||||
for metric_name, value in metrics:
|
|
||||||
metric_events.append(
|
|
||||||
MetricEvent(
|
|
||||||
trace_id=span.trace_id,
|
|
||||||
span_id=span.span_id,
|
|
||||||
metric=metric_name,
|
|
||||||
value=value,
|
|
||||||
timestamp=datetime.now(UTC),
|
|
||||||
unit="tokens",
|
|
||||||
attributes={
|
|
||||||
"model_id": fully_qualified_model_id,
|
|
||||||
"provider_id": provider_id,
|
|
||||||
},
|
|
||||||
)
|
|
||||||
)
|
|
||||||
return metric_events
|
|
||||||
|
|
||||||
async def _get_model_provider(self, model_id: str, expected_model_type: str) -> tuple[Inference, str]:
|
async def _get_model_provider(self, model_id: str, expected_model_type: str) -> tuple[Inference, str]:
|
||||||
model = await self.routing_table.get_object_by_identifier("model", model_id)
|
model = await self.routing_table.get_object_by_identifier("model", model_id)
|
||||||
if model:
|
if model:
|
||||||
|
|
@ -186,26 +134,9 @@ class InferenceRouter(Inference):
|
||||||
|
|
||||||
if params.stream:
|
if params.stream:
|
||||||
return await provider.openai_completion(params)
|
return await provider.openai_completion(params)
|
||||||
# TODO: Metrics do NOT work with openai_completion stream=True due to the fact
|
|
||||||
# that we do not return an AsyncIterator, our tests expect a stream of chunks we cannot intercept currently.
|
|
||||||
|
|
||||||
response = await provider.openai_completion(params)
|
response = await provider.openai_completion(params)
|
||||||
response.model = request_model_id
|
response.model = request_model_id
|
||||||
if self.telemetry_enabled and response.usage is not None:
|
|
||||||
metrics = self._construct_metrics(
|
|
||||||
prompt_tokens=response.usage.prompt_tokens,
|
|
||||||
completion_tokens=response.usage.completion_tokens,
|
|
||||||
total_tokens=response.usage.total_tokens,
|
|
||||||
fully_qualified_model_id=request_model_id,
|
|
||||||
provider_id=provider.__provider_id__,
|
|
||||||
)
|
|
||||||
for metric in metrics:
|
|
||||||
enqueue_event(metric)
|
|
||||||
|
|
||||||
# these metrics will show up in the client response.
|
|
||||||
response.metrics = (
|
|
||||||
metrics if not hasattr(response, "metrics") or response.metrics is None else response.metrics + metrics
|
|
||||||
)
|
|
||||||
return response
|
return response
|
||||||
|
|
||||||
async def openai_chat_completion(
|
async def openai_chat_completion(
|
||||||
|
|
@ -254,20 +185,6 @@ class InferenceRouter(Inference):
|
||||||
if self.store:
|
if self.store:
|
||||||
asyncio.create_task(self.store.store_chat_completion(response, params.messages))
|
asyncio.create_task(self.store.store_chat_completion(response, params.messages))
|
||||||
|
|
||||||
if self.telemetry_enabled and response.usage is not None:
|
|
||||||
metrics = self._construct_metrics(
|
|
||||||
prompt_tokens=response.usage.prompt_tokens,
|
|
||||||
completion_tokens=response.usage.completion_tokens,
|
|
||||||
total_tokens=response.usage.total_tokens,
|
|
||||||
fully_qualified_model_id=request_model_id,
|
|
||||||
provider_id=provider.__provider_id__,
|
|
||||||
)
|
|
||||||
for metric in metrics:
|
|
||||||
enqueue_event(metric)
|
|
||||||
# these metrics will show up in the client response.
|
|
||||||
response.metrics = (
|
|
||||||
metrics if not hasattr(response, "metrics") or response.metrics is None else response.metrics + metrics
|
|
||||||
)
|
|
||||||
return response
|
return response
|
||||||
|
|
||||||
async def openai_embeddings(
|
async def openai_embeddings(
|
||||||
|
|
@ -411,18 +328,6 @@ class InferenceRouter(Inference):
|
||||||
for choice_data in choices_data.values():
|
for choice_data in choices_data.values():
|
||||||
completion_text += "".join(choice_data["content_parts"])
|
completion_text += "".join(choice_data["content_parts"])
|
||||||
|
|
||||||
# Add metrics to the chunk
|
|
||||||
if self.telemetry_enabled and hasattr(chunk, "usage") and chunk.usage:
|
|
||||||
metrics = self._construct_metrics(
|
|
||||||
prompt_tokens=chunk.usage.prompt_tokens,
|
|
||||||
completion_tokens=chunk.usage.completion_tokens,
|
|
||||||
total_tokens=chunk.usage.total_tokens,
|
|
||||||
fully_qualified_model_id=fully_qualified_model_id,
|
|
||||||
provider_id=provider_id,
|
|
||||||
)
|
|
||||||
for metric in metrics:
|
|
||||||
enqueue_event(metric)
|
|
||||||
|
|
||||||
yield chunk
|
yield chunk
|
||||||
finally:
|
finally:
|
||||||
# Store the final assembled completion
|
# Store the final assembled completion
|
||||||
|
|
|
||||||
|
|
@ -490,16 +490,6 @@ class Telemetry:
|
||||||
)
|
)
|
||||||
return cast(metrics.Counter, _GLOBAL_STORAGE["counters"][name])
|
return cast(metrics.Counter, _GLOBAL_STORAGE["counters"][name])
|
||||||
|
|
||||||
def _get_or_create_gauge(self, name: str, unit: str) -> metrics.ObservableGauge:
|
|
||||||
assert self.meter is not None
|
|
||||||
if name not in _GLOBAL_STORAGE["gauges"]:
|
|
||||||
_GLOBAL_STORAGE["gauges"][name] = self.meter.create_gauge(
|
|
||||||
name=name,
|
|
||||||
unit=unit,
|
|
||||||
description=f"Gauge for {name}",
|
|
||||||
)
|
|
||||||
return cast(metrics.ObservableGauge, _GLOBAL_STORAGE["gauges"][name])
|
|
||||||
|
|
||||||
def _get_or_create_histogram(self, name: str, unit: str) -> metrics.Histogram:
|
def _get_or_create_histogram(self, name: str, unit: str) -> metrics.Histogram:
|
||||||
assert self.meter is not None
|
assert self.meter is not None
|
||||||
if name not in _GLOBAL_STORAGE["histograms"]:
|
if name not in _GLOBAL_STORAGE["histograms"]:
|
||||||
|
|
|
||||||
|
|
@ -15,7 +15,6 @@ async def get_provider_impl(
|
||||||
config: MetaReferenceAgentsImplConfig,
|
config: MetaReferenceAgentsImplConfig,
|
||||||
deps: dict[Api, Any],
|
deps: dict[Api, Any],
|
||||||
policy: list[AccessRule],
|
policy: list[AccessRule],
|
||||||
telemetry_enabled: bool = False,
|
|
||||||
):
|
):
|
||||||
from .agents import MetaReferenceAgentsImpl
|
from .agents import MetaReferenceAgentsImpl
|
||||||
|
|
||||||
|
|
@ -29,7 +28,6 @@ async def get_provider_impl(
|
||||||
deps[Api.conversations],
|
deps[Api.conversations],
|
||||||
deps[Api.prompts],
|
deps[Api.prompts],
|
||||||
deps[Api.files],
|
deps[Api.files],
|
||||||
telemetry_enabled,
|
|
||||||
policy,
|
policy,
|
||||||
)
|
)
|
||||||
await impl.initialize()
|
await impl.initialize()
|
||||||
|
|
|
||||||
|
|
@ -50,7 +50,6 @@ class MetaReferenceAgentsImpl(Agents):
|
||||||
prompts_api: Prompts,
|
prompts_api: Prompts,
|
||||||
files_api: Files,
|
files_api: Files,
|
||||||
policy: list[AccessRule],
|
policy: list[AccessRule],
|
||||||
telemetry_enabled: bool = False,
|
|
||||||
):
|
):
|
||||||
self.config = config
|
self.config = config
|
||||||
self.inference_api = inference_api
|
self.inference_api = inference_api
|
||||||
|
|
@ -59,7 +58,6 @@ class MetaReferenceAgentsImpl(Agents):
|
||||||
self.tool_runtime_api = tool_runtime_api
|
self.tool_runtime_api = tool_runtime_api
|
||||||
self.tool_groups_api = tool_groups_api
|
self.tool_groups_api = tool_groups_api
|
||||||
self.conversations_api = conversations_api
|
self.conversations_api = conversations_api
|
||||||
self.telemetry_enabled = telemetry_enabled
|
|
||||||
self.prompts_api = prompts_api
|
self.prompts_api = prompts_api
|
||||||
self.files_api = files_api
|
self.files_api = files_api
|
||||||
self.in_memory_store = InmemoryKVStoreImpl()
|
self.in_memory_store = InmemoryKVStoreImpl()
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue