From 37b73900798792bd22b7b15ecf703cc54caf4ba1 Mon Sep 17 00:00:00 2001 From: Dinesh Yeduguru Date: Tue, 4 Feb 2025 16:20:59 -0800 Subject: [PATCH] add metrics for streaming --- llama_stack/apis/telemetry/telemetry.py | 2 +- llama_stack/distribution/routers/routers.py | 93 ++++++++++++++------- 2 files changed, 65 insertions(+), 30 deletions(-) diff --git a/llama_stack/apis/telemetry/telemetry.py b/llama_stack/apis/telemetry/telemetry.py index 6160aba52..e6073c429 100644 --- a/llama_stack/apis/telemetry/telemetry.py +++ b/llama_stack/apis/telemetry/telemetry.py @@ -230,7 +230,7 @@ Metric = register_schema( @json_schema_type class MetricsMixin(BaseModel): - metrics: List[Metric] = Field(default_factory=list) + metrics: Optional[List[Metric]] = None @runtime_checkable diff --git a/llama_stack/distribution/routers/routers.py b/llama_stack/distribution/routers/routers.py index 1af5d43f9..51560ad5c 100644 --- a/llama_stack/distribution/routers/routers.py +++ b/llama_stack/distribution/routers/routers.py @@ -5,9 +5,10 @@ # the root directory of this source tree. import time -from typing import Any, AsyncGenerator, Dict, List, Optional +from typing import Any, AsyncGenerator, AsyncIterator, Dict, List, Optional, Union from llama_models.llama3.api.chat_format import ChatFormat +from llama_models.llama3.api.datatypes import RawMessage from llama_models.llama3.api.tokenizer import Tokenizer from llama_stack.apis.common.content_types import ( @@ -24,6 +25,9 @@ from llama_stack.apis.eval import ( JobStatus, ) from llama_stack.apis.inference import ( + ChatCompletionResponse, + ChatCompletionResponseEventType, + ChatCompletionResponseStreamChunk, EmbeddingsResponse, EmbeddingTaskType, Inference, @@ -37,7 +41,7 @@ from llama_stack.apis.inference import ( ToolDefinition, ToolPromptFormat, ) -from llama_stack.apis.models import ModelType +from llama_stack.apis.models import Model, ModelType from llama_stack.apis.safety import RunShieldResponse, Safety from llama_stack.apis.scoring import ( ScoreBatchResponse, @@ -138,6 +142,31 @@ class InferenceRouter(Inference): ) -> None: await self.routing_table.register_model(model_id, provider_model_id, provider_id, metadata, model_type) + async def _log_token_usage( + self, prompt_tokens: int, completion_tokens: int, total_tokens: int, model: Model + ) -> None: + span = get_current_span() + metrics = [ + ("prompt_tokens", prompt_tokens), + ("completion_tokens", completion_tokens), + ("total_tokens", total_tokens), + ] + for metric_name, value in metrics: + await self.telemetry.log_event( + MetricEvent( + trace_id=span.trace_id, + span_id=span.span_id, + metric=metric_name, + value=value, + timestamp=time.time(), + unit="tokens", + attributes={ + "model_id": model.model_id, + "provider_id": model.provider_id, + }, + ) + ) + async def chat_completion( self, model_id: str, @@ -150,7 +179,7 @@ class InferenceRouter(Inference): stream: Optional[bool] = False, logprobs: Optional[LogProbConfig] = None, tool_config: Optional[ToolConfig] = None, - ) -> AsyncGenerator: + ) -> Union[ChatCompletionResponse, AsyncIterator[ChatCompletionResponseStreamChunk]]: model = await self.routing_table.get_model(model_id) if model is None: raise ValueError(f"Model '{model_id}' not found") @@ -198,7 +227,37 @@ class InferenceRouter(Inference): ) provider = self.routing_table.get_provider_impl(model_id) if stream: - return (chunk async for chunk in await provider.chat_completion(**params)) + + async def stream_generator(): + model_input = self.formatter.encode_dialog_prompt( + messages, + tool_config.tool_prompt_format, + ) + prompt_tokens = len(model_input.tokens) if model_input.tokens else 0 + completion_text = "" + async for chunk in await provider.chat_completion(**params): + if chunk.event.event_type == ChatCompletionResponseEventType.progress: + completion_text += chunk.event.delta.text + if chunk.event.event_type == ChatCompletionResponseEventType.complete: + model_output = self.formatter.encode_dialog_prompt( + [RawMessage(role="assistant", content=completion_text)], + tool_config.tool_prompt_format, + ) + completion_tokens = len(model_output.tokens) if model_output.tokens else 0 + total_tokens = prompt_tokens + completion_tokens + if chunk.metrics is None: + chunk.metrics = [] + chunk.metrics.append( + TokenUsage( + prompt_tokens=prompt_tokens, + completion_tokens=completion_tokens, + total_tokens=total_tokens, + ) + ) + await self._log_token_usage(prompt_tokens, completion_tokens, total_tokens, model) + yield chunk + + return stream_generator() else: response = await provider.chat_completion(**params) model_input = self.formatter.encode_dialog_prompt( @@ -221,31 +280,7 @@ class InferenceRouter(Inference): total_tokens=total_tokens, ) ) - # Log token usage metrics - metrics = [ - ("prompt_tokens", prompt_tokens), - ("completion_tokens", completion_tokens), - ("total_tokens", total_tokens), - ] - - span = get_current_span() - if span: - breakpoint() - for metric_name, value in metrics: - await self.telemetry.log_event( - MetricEvent( - trace_id=span.trace_id, - span_id=span.span_id, - metric=metric_name, - value=value, - timestamp=time.time(), - unit="tokens", - attributes={ - "model_id": model_id, - "provider_id": model.provider_id, - }, - ) - ) + await self._log_token_usage(prompt_tokens, completion_tokens, total_tokens, model) return response async def completion(