From 77c2418a9c8161b3151d82f59005bfc2ceec7a93 Mon Sep 17 00:00:00 2001 From: Dinesh Yeduguru Date: Wed, 5 Feb 2025 10:54:08 -0800 Subject: [PATCH] metrics for completion API --- llama_stack/distribution/routers/routers.py | 118 ++++++++++++++------ 1 file changed, 84 insertions(+), 34 deletions(-) diff --git a/llama_stack/distribution/routers/routers.py b/llama_stack/distribution/routers/routers.py index a9fc13502..54700d33f 100644 --- a/llama_stack/distribution/routers/routers.py +++ b/llama_stack/distribution/routers/routers.py @@ -168,6 +168,37 @@ class InferenceRouter(Inference): ) ) + async def _add_token_metrics( + self, + prompt_tokens: int, + completion_tokens: int, + total_tokens: int, + model: Model, + target: Any, + ) -> None: + metrics = getattr(target, "metrics", None) + if metrics is None: + target.metrics = [] + target.metrics.append( + TokenUsage( + prompt_tokens=prompt_tokens, + completion_tokens=completion_tokens, + total_tokens=total_tokens, + ) + ) + if self.telemetry: + await self._log_token_usage(prompt_tokens, completion_tokens, total_tokens, model) + + async def _count_tokens( + self, + messages: Union[List[Message], List[RawMessage]], + tool_prompt_format: Optional[ToolPromptFormat] = None, + ) -> Optional[int]: + if not self.telemetry: + return None + encoded = self.formatter.encode_dialog_prompt(messages, tool_prompt_format) + return len(encoded.tokens) if encoded and encoded.tokens else 0 + async def chat_completion( self, model_id: str, @@ -227,60 +258,46 @@ class InferenceRouter(Inference): tool_config=tool_config, ) provider = self.routing_table.get_provider_impl(model_id) - model_input = self.formatter.encode_dialog_prompt( - messages, - tool_config.tool_prompt_format, - ) + prompt_tokens = await self._count_tokens(messages, tool_config.tool_prompt_format) + if stream: async def stream_generator(): - prompt_tokens = len(model_input.tokens) if model_input.tokens else 0 completion_text = "" async for chunk in await provider.chat_completion(**params): if chunk.event.event_type == ChatCompletionResponseEventType.progress: if chunk.event.delta.type == "text": completion_text += chunk.event.delta.text if chunk.event.event_type == ChatCompletionResponseEventType.complete: - model_output = self.formatter.encode_dialog_prompt( + completion_tokens = await self._count_tokens( [RawMessage(role="assistant", content=completion_text)], tool_config.tool_prompt_format, ) - completion_tokens = len(model_output.tokens) if model_output.tokens else 0 - total_tokens = prompt_tokens + completion_tokens - if chunk.metrics is None: - chunk.metrics = [] - chunk.metrics.append( - TokenUsage( - prompt_tokens=prompt_tokens, - completion_tokens=completion_tokens, - total_tokens=total_tokens, - ) + total_tokens = (prompt_tokens or 0) + (completion_tokens or 0) + await self._add_token_metrics( + prompt_tokens or 0, + completion_tokens or 0, + total_tokens, + model, + chunk, ) - if self.telemetry: - await self._log_token_usage(prompt_tokens, completion_tokens, total_tokens, model) yield chunk return stream_generator() else: response = await provider.chat_completion(**params) - model_output = self.formatter.encode_dialog_prompt( + completion_tokens = await self._count_tokens( [response.completion_message], tool_config.tool_prompt_format, ) - prompt_tokens = len(model_input.tokens) if model_input.tokens else 0 - completion_tokens = len(model_output.tokens) if model_output.tokens else 0 - total_tokens = prompt_tokens + completion_tokens - if response.metrics is None: - response.metrics = [] - response.metrics.append( - TokenUsage( - prompt_tokens=prompt_tokens, - completion_tokens=completion_tokens, - total_tokens=total_tokens, - ) + total_tokens = (prompt_tokens or 0) + (completion_tokens or 0) + await self._add_token_metrics( + prompt_tokens or 0, + completion_tokens or 0, + total_tokens, + model, + response, ) - if self.telemetry: - await self._log_token_usage(prompt_tokens, completion_tokens, total_tokens, model) return response async def completion( @@ -306,10 +323,43 @@ class InferenceRouter(Inference): stream=stream, logprobs=logprobs, ) + + prompt_tokens = await self._count_tokens([RawMessage(role="user", content=str(content))]) + if stream: - return (chunk async for chunk in await provider.completion(**params)) + + async def stream_generator(): + completion_text = "" + async for chunk in await provider.completion(**params): + if hasattr(chunk, "delta"): + completion_text += chunk.delta + if hasattr(chunk, "stop_reason") and chunk.stop_reason and self.telemetry: + completion_tokens = await self._count_tokens( + [RawMessage(role="assistant", content=completion_text)] + ) + total_tokens = (prompt_tokens or 0) + (completion_tokens or 0) + await self._add_token_metrics( + prompt_tokens or 0, + completion_tokens or 0, + total_tokens, + model, + chunk, + ) + yield chunk + + return stream_generator() else: - return await provider.completion(**params) + response = await provider.completion(**params) + completion_tokens = await self._count_tokens([RawMessage(role="assistant", content=str(response.content))]) + total_tokens = (prompt_tokens or 0) + (completion_tokens or 0) + await self._add_token_metrics( + prompt_tokens or 0, + completion_tokens or 0, + total_tokens, + model, + response, + ) + return response async def embeddings( self,