diff --git a/llama_stack/distribution/routers/routers.py b/llama_stack/distribution/routers/routers.py index c637a5b23..1a95ad45b 100644 --- a/llama_stack/distribution/routers/routers.py +++ b/llama_stack/distribution/routers/routers.py @@ -190,23 +190,18 @@ class InferenceRouter(Inference): ) return metric_events - async def _add_token_metrics( + async def _compute_and_log_token_usage( self, prompt_tokens: int, completion_tokens: int, total_tokens: int, model: Model, - target: Any, - ) -> None: - metrics = getattr(target, "metrics", None) - if metrics is None: - target.metrics = [] - + ) -> List[MetricEvent]: metrics = self._construct_metrics(prompt_tokens, completion_tokens, total_tokens, model) - target.metrics.extend(metrics) if self.telemetry: for metric in metrics: await self.telemetry.log_event(metric) + return metrics async def _count_tokens( self, @@ -296,13 +291,13 @@ class InferenceRouter(Inference): tool_config.tool_prompt_format, ) total_tokens = (prompt_tokens or 0) + (completion_tokens or 0) - await self._add_token_metrics( + metrics = await self._compute_and_log_token_usage( prompt_tokens or 0, completion_tokens or 0, total_tokens, model, - chunk, ) + chunk.metrics = metrics if chunk.metrics is None else chunk.metrics + metrics yield chunk return stream_generator() @@ -313,13 +308,13 @@ class InferenceRouter(Inference): tool_config.tool_prompt_format, ) total_tokens = (prompt_tokens or 0) + (completion_tokens or 0) - await self._add_token_metrics( + metrics = await self._compute_and_log_token_usage( prompt_tokens or 0, completion_tokens or 0, total_tokens, model, - response, ) + response.metrics = metrics if response.metrics is None else response.metrics + metrics return response async def completion( @@ -362,13 +357,13 @@ class InferenceRouter(Inference): if hasattr(chunk, "stop_reason") and chunk.stop_reason and self.telemetry: completion_tokens = await self._count_tokens(completion_text) total_tokens = (prompt_tokens or 0) + (completion_tokens or 0) - await self._add_token_metrics( + metrics = await self._compute_and_log_token_usage( prompt_tokens or 0, completion_tokens or 0, total_tokens, model, - chunk, ) + chunk.metrics = metrics if chunk.metrics is None else chunk.metrics + metrics yield chunk return stream_generator() @@ -376,13 +371,13 @@ class InferenceRouter(Inference): response = await provider.completion(**params) completion_tokens = await self._count_tokens(response.content) total_tokens = (prompt_tokens or 0) + (completion_tokens or 0) - await self._add_token_metrics( + metrics = await self._compute_and_log_token_usage( prompt_tokens or 0, completion_tokens or 0, total_tokens, model, - response, ) + response.metrics = metrics if response.metrics is None else response.metrics + metrics return response async def embeddings(