This commit is contained in:
Dinesh Yeduguru 2025-03-05 11:26:29 -08:00
parent d2c1162021
commit 306a2d2bff

View file

@ -190,23 +190,18 @@ class InferenceRouter(Inference):
) )
return metric_events return metric_events
async def _add_token_metrics( async def _compute_and_log_token_usage(
self, self,
prompt_tokens: int, prompt_tokens: int,
completion_tokens: int, completion_tokens: int,
total_tokens: int, total_tokens: int,
model: Model, model: Model,
target: Any, ) -> List[MetricEvent]:
) -> None:
metrics = getattr(target, "metrics", None)
if metrics is None:
target.metrics = []
metrics = self._construct_metrics(prompt_tokens, completion_tokens, total_tokens, model) metrics = self._construct_metrics(prompt_tokens, completion_tokens, total_tokens, model)
target.metrics.extend(metrics)
if self.telemetry: if self.telemetry:
for metric in metrics: for metric in metrics:
await self.telemetry.log_event(metric) await self.telemetry.log_event(metric)
return metrics
async def _count_tokens( async def _count_tokens(
self, self,
@ -296,13 +291,13 @@ class InferenceRouter(Inference):
tool_config.tool_prompt_format, tool_config.tool_prompt_format,
) )
total_tokens = (prompt_tokens or 0) + (completion_tokens or 0) total_tokens = (prompt_tokens or 0) + (completion_tokens or 0)
await self._add_token_metrics( metrics = await self._compute_and_log_token_usage(
prompt_tokens or 0, prompt_tokens or 0,
completion_tokens or 0, completion_tokens or 0,
total_tokens, total_tokens,
model, model,
chunk,
) )
chunk.metrics = metrics if chunk.metrics is None else chunk.metrics + metrics
yield chunk yield chunk
return stream_generator() return stream_generator()
@ -313,13 +308,13 @@ class InferenceRouter(Inference):
tool_config.tool_prompt_format, tool_config.tool_prompt_format,
) )
total_tokens = (prompt_tokens or 0) + (completion_tokens or 0) total_tokens = (prompt_tokens or 0) + (completion_tokens or 0)
await self._add_token_metrics( metrics = await self._compute_and_log_token_usage(
prompt_tokens or 0, prompt_tokens or 0,
completion_tokens or 0, completion_tokens or 0,
total_tokens, total_tokens,
model, model,
response,
) )
response.metrics = metrics if response.metrics is None else response.metrics + metrics
return response return response
async def completion( async def completion(
@ -362,13 +357,13 @@ class InferenceRouter(Inference):
if hasattr(chunk, "stop_reason") and chunk.stop_reason and self.telemetry: if hasattr(chunk, "stop_reason") and chunk.stop_reason and self.telemetry:
completion_tokens = await self._count_tokens(completion_text) completion_tokens = await self._count_tokens(completion_text)
total_tokens = (prompt_tokens or 0) + (completion_tokens or 0) total_tokens = (prompt_tokens or 0) + (completion_tokens or 0)
await self._add_token_metrics( metrics = await self._compute_and_log_token_usage(
prompt_tokens or 0, prompt_tokens or 0,
completion_tokens or 0, completion_tokens or 0,
total_tokens, total_tokens,
model, model,
chunk,
) )
chunk.metrics = metrics if chunk.metrics is None else chunk.metrics + metrics
yield chunk yield chunk
return stream_generator() return stream_generator()
@ -376,13 +371,13 @@ class InferenceRouter(Inference):
response = await provider.completion(**params) response = await provider.completion(**params)
completion_tokens = await self._count_tokens(response.content) completion_tokens = await self._count_tokens(response.content)
total_tokens = (prompt_tokens or 0) + (completion_tokens or 0) total_tokens = (prompt_tokens or 0) + (completion_tokens or 0)
await self._add_token_metrics( metrics = await self._compute_and_log_token_usage(
prompt_tokens or 0, prompt_tokens or 0,
completion_tokens or 0, completion_tokens or 0,
total_tokens, total_tokens,
model, model,
response,
) )
response.metrics = metrics if response.metrics is None else response.metrics + metrics
return response return response
async def embeddings( async def embeddings(