From a72cdafac0f6d6acdf009f8406c8fd169c2c68a4 Mon Sep 17 00:00:00 2001 From: Dinesh Yeduguru Date: Tue, 4 Feb 2025 10:45:16 -0800 Subject: [PATCH] Add inference token usage metrics --- llama_stack/apis/inference/inference.py | 13 ++++----- llama_stack/apis/telemetry/telemetry.py | 22 ++++++++++++++++ llama_stack/distribution/routers/routers.py | 29 ++++++++++++++++++++- 3 files changed, 57 insertions(+), 7 deletions(-) diff --git a/llama_stack/apis/inference/inference.py b/llama_stack/apis/inference/inference.py index e517d9c3c..c5c61e764 100644 --- a/llama_stack/apis/inference/inference.py +++ b/llama_stack/apis/inference/inference.py @@ -17,12 +17,13 @@ from typing import ( runtime_checkable, ) +from llama_models.schema_utils import json_schema_type, register_schema, webmethod from pydantic import BaseModel, Field, field_validator from typing_extensions import Annotated from llama_stack.apis.common.content_types import ContentDelta, InterleavedContent, InterleavedContentItem from llama_stack.apis.models import Model -from llama_stack.apis.telemetry.telemetry import MetricResponseMixin +from llama_stack.apis.telemetry.telemetry import MetricsMixin from llama_stack.models.llama.datatypes import ( BuiltinTool, SamplingParams, @@ -285,7 +286,7 @@ class CompletionRequest(BaseModel): @json_schema_type -class CompletionResponse(BaseModel): +class CompletionResponse(MetricsMixin, BaseModel): """Response from a completion request. :param content: The generated completion text @@ -299,7 +300,7 @@ class CompletionResponse(BaseModel): @json_schema_type -class CompletionResponseStreamChunk(BaseModel): +class CompletionResponseStreamChunk(MetricsMixin, BaseModel): """A chunk of a streamed completion response. :param delta: New content generated since last chunk. This can be one or more tokens. @@ -368,7 +369,7 @@ class ChatCompletionRequest(BaseModel): @json_schema_type -class ChatCompletionResponseStreamChunk(MetricResponseMixin, BaseModel): +class ChatCompletionResponseStreamChunk(MetricsMixin, BaseModel): """A chunk of a streamed chat completion response. :param event: The event containing the new content @@ -378,7 +379,7 @@ class ChatCompletionResponseStreamChunk(MetricResponseMixin, BaseModel): @json_schema_type -class ChatCompletionResponse(MetricResponseMixin, BaseModel): +class ChatCompletionResponse(MetricsMixin, BaseModel): """Response from a chat completion request. :param completion_message: The complete response message @@ -390,7 +391,7 @@ class ChatCompletionResponse(MetricResponseMixin, BaseModel): @json_schema_type -class EmbeddingsResponse(BaseModel): +class EmbeddingsResponse(MetricsMixin, BaseModel): """Response containing generated embeddings. :param embeddings: List of embedding vectors, one per input content. Each embedding is a list of floats. The dimensionality of the embedding is model-specific; you can check model metadata using /models/{model_id} diff --git a/llama_stack/apis/telemetry/telemetry.py b/llama_stack/apis/telemetry/telemetry.py index fe75677e7..6160aba52 100644 --- a/llama_stack/apis/telemetry/telemetry.py +++ b/llama_stack/apis/telemetry/telemetry.py @@ -211,6 +211,28 @@ class QuerySpanTreeResponse(BaseModel): data: Dict[str, SpanWithStatus] +@json_schema_type +class TokenUsage(BaseModel): + type: Literal["token_usage"] = "token_usage" + prompt_tokens: int + completion_tokens: int + total_tokens: int + + +Metric = register_schema( + Annotated[ + Union[TokenUsage], + Field(discriminator="type"), + ], + name="Metric", +) + + +@json_schema_type +class MetricsMixin(BaseModel): + metrics: List[Metric] = Field(default_factory=list) + + @runtime_checkable class Telemetry(Protocol): @webmethod(route="/telemetry/events", method="POST") diff --git a/llama_stack/distribution/routers/routers.py b/llama_stack/distribution/routers/routers.py index b0cb50e42..9047b1c33 100644 --- a/llama_stack/distribution/routers/routers.py +++ b/llama_stack/distribution/routers/routers.py @@ -6,6 +6,9 @@ from typing import Any, AsyncGenerator, Dict, List, Optional +from llama_models.llama3.api.chat_format import ChatFormat +from llama_models.llama3.api.tokenizer import Tokenizer + from llama_stack.apis.common.content_types import ( URL, InterleavedContent, @@ -42,6 +45,7 @@ from llama_stack.apis.scoring import ( ScoringFnParams, ) from llama_stack.apis.shields import Shield +from llama_stack.apis.telemetry import TokenUsage from llama_stack.apis.tools import ( RAGDocument, RAGQueryConfig, @@ -111,6 +115,8 @@ class InferenceRouter(Inference): routing_table: RoutingTable, ) -> None: self.routing_table = routing_table + self.tokenizer = Tokenizer.get_instance() + self.formatter = ChatFormat(self.tokenizer) async def initialize(self) -> None: pass @@ -190,7 +196,28 @@ class InferenceRouter(Inference): if stream: return (chunk async for chunk in await provider.chat_completion(**params)) else: - return await provider.chat_completion(**params) + response = await provider.chat_completion(**params) + model_input = self.formatter.encode_dialog_prompt( + messages, + tool_config.tool_prompt_format, + ) + model_output = self.formatter.encode_dialog_prompt( + [response.completion_message], + tool_config.tool_prompt_format, + ) + prompt_tokens = len(model_input.tokens) if model_input.tokens else 0 + completion_tokens = len(model_output.tokens) if model_output.tokens else 0 + total_tokens = prompt_tokens + completion_tokens + if response.metrics is None: + response.metrics = [] + response.metrics.append( + TokenUsage( + prompt_tokens=prompt_tokens, + completion_tokens=completion_tokens, + total_tokens=total_tokens, + ) + ) + return response async def completion( self,