Add inference token usage metrics

This commit is contained in:
Dinesh Yeduguru 2025-02-04 10:45:16 -08:00
parent 0762c61402
commit a72cdafac0
3 changed files with 57 additions and 7 deletions

View file

@ -17,12 +17,13 @@ from typing import (
runtime_checkable,
)
from llama_models.schema_utils import json_schema_type, register_schema, webmethod
from pydantic import BaseModel, Field, field_validator
from typing_extensions import Annotated
from llama_stack.apis.common.content_types import ContentDelta, InterleavedContent, InterleavedContentItem
from llama_stack.apis.models import Model
from llama_stack.apis.telemetry.telemetry import MetricResponseMixin
from llama_stack.apis.telemetry.telemetry import MetricsMixin
from llama_stack.models.llama.datatypes import (
BuiltinTool,
SamplingParams,
@ -285,7 +286,7 @@ class CompletionRequest(BaseModel):
@json_schema_type
class CompletionResponse(BaseModel):
class CompletionResponse(MetricsMixin, BaseModel):
"""Response from a completion request.
:param content: The generated completion text
@ -299,7 +300,7 @@ class CompletionResponse(BaseModel):
@json_schema_type
class CompletionResponseStreamChunk(BaseModel):
class CompletionResponseStreamChunk(MetricsMixin, BaseModel):
"""A chunk of a streamed completion response.
:param delta: New content generated since last chunk. This can be one or more tokens.
@ -368,7 +369,7 @@ class ChatCompletionRequest(BaseModel):
@json_schema_type
class ChatCompletionResponseStreamChunk(MetricResponseMixin, BaseModel):
class ChatCompletionResponseStreamChunk(MetricsMixin, BaseModel):
"""A chunk of a streamed chat completion response.
:param event: The event containing the new content
@ -378,7 +379,7 @@ class ChatCompletionResponseStreamChunk(MetricResponseMixin, BaseModel):
@json_schema_type
class ChatCompletionResponse(MetricResponseMixin, BaseModel):
class ChatCompletionResponse(MetricsMixin, BaseModel):
"""Response from a chat completion request.
:param completion_message: The complete response message
@ -390,7 +391,7 @@ class ChatCompletionResponse(MetricResponseMixin, BaseModel):
@json_schema_type
class EmbeddingsResponse(BaseModel):
class EmbeddingsResponse(MetricsMixin, BaseModel):
"""Response containing generated embeddings.
:param embeddings: List of embedding vectors, one per input content. Each embedding is a list of floats. The dimensionality of the embedding is model-specific; you can check model metadata using /models/{model_id}

View file

@ -211,6 +211,28 @@ class QuerySpanTreeResponse(BaseModel):
data: Dict[str, SpanWithStatus]
@json_schema_type
class TokenUsage(BaseModel):
type: Literal["token_usage"] = "token_usage"
prompt_tokens: int
completion_tokens: int
total_tokens: int
Metric = register_schema(
Annotated[
Union[TokenUsage],
Field(discriminator="type"),
],
name="Metric",
)
@json_schema_type
class MetricsMixin(BaseModel):
metrics: List[Metric] = Field(default_factory=list)
@runtime_checkable
class Telemetry(Protocol):
@webmethod(route="/telemetry/events", method="POST")

View file

@ -6,6 +6,9 @@
from typing import Any, AsyncGenerator, Dict, List, Optional
from llama_models.llama3.api.chat_format import ChatFormat
from llama_models.llama3.api.tokenizer import Tokenizer
from llama_stack.apis.common.content_types import (
URL,
InterleavedContent,
@ -42,6 +45,7 @@ from llama_stack.apis.scoring import (
ScoringFnParams,
)
from llama_stack.apis.shields import Shield
from llama_stack.apis.telemetry import TokenUsage
from llama_stack.apis.tools import (
RAGDocument,
RAGQueryConfig,
@ -111,6 +115,8 @@ class InferenceRouter(Inference):
routing_table: RoutingTable,
) -> None:
self.routing_table = routing_table
self.tokenizer = Tokenizer.get_instance()
self.formatter = ChatFormat(self.tokenizer)
async def initialize(self) -> None:
pass
@ -190,7 +196,28 @@ class InferenceRouter(Inference):
if stream:
return (chunk async for chunk in await provider.chat_completion(**params))
else:
return await provider.chat_completion(**params)
response = await provider.chat_completion(**params)
model_input = self.formatter.encode_dialog_prompt(
messages,
tool_config.tool_prompt_format,
)
model_output = self.formatter.encode_dialog_prompt(
[response.completion_message],
tool_config.tool_prompt_format,
)
prompt_tokens = len(model_input.tokens) if model_input.tokens else 0
completion_tokens = len(model_output.tokens) if model_output.tokens else 0
total_tokens = prompt_tokens + completion_tokens
if response.metrics is None:
response.metrics = []
response.metrics.append(
TokenUsage(
prompt_tokens=prompt_tokens,
completion_tokens=completion_tokens,
total_tokens=total_tokens,
)
)
return response
async def completion(
self,