This commit is contained in:
Dinesh Yeduguru 2025-03-05 11:26:29 -08:00
parent d2c1162021
commit 306a2d2bff

View file

@ -190,23 +190,18 @@ class InferenceRouter(Inference):
)
return metric_events
async def _add_token_metrics(
async def _compute_and_log_token_usage(
self,
prompt_tokens: int,
completion_tokens: int,
total_tokens: int,
model: Model,
target: Any,
) -> None:
metrics = getattr(target, "metrics", None)
if metrics is None:
target.metrics = []
) -> List[MetricEvent]:
metrics = self._construct_metrics(prompt_tokens, completion_tokens, total_tokens, model)
target.metrics.extend(metrics)
if self.telemetry:
for metric in metrics:
await self.telemetry.log_event(metric)
return metrics
async def _count_tokens(
self,
@ -296,13 +291,13 @@ class InferenceRouter(Inference):
tool_config.tool_prompt_format,
)
total_tokens = (prompt_tokens or 0) + (completion_tokens or 0)
await self._add_token_metrics(
metrics = await self._compute_and_log_token_usage(
prompt_tokens or 0,
completion_tokens or 0,
total_tokens,
model,
chunk,
)
chunk.metrics = metrics if chunk.metrics is None else chunk.metrics + metrics
yield chunk
return stream_generator()
@ -313,13 +308,13 @@ class InferenceRouter(Inference):
tool_config.tool_prompt_format,
)
total_tokens = (prompt_tokens or 0) + (completion_tokens or 0)
await self._add_token_metrics(
metrics = await self._compute_and_log_token_usage(
prompt_tokens or 0,
completion_tokens or 0,
total_tokens,
model,
response,
)
response.metrics = metrics if response.metrics is None else response.metrics + metrics
return response
async def completion(
@ -362,13 +357,13 @@ class InferenceRouter(Inference):
if hasattr(chunk, "stop_reason") and chunk.stop_reason and self.telemetry:
completion_tokens = await self._count_tokens(completion_text)
total_tokens = (prompt_tokens or 0) + (completion_tokens or 0)
await self._add_token_metrics(
metrics = await self._compute_and_log_token_usage(
prompt_tokens or 0,
completion_tokens or 0,
total_tokens,
model,
chunk,
)
chunk.metrics = metrics if chunk.metrics is None else chunk.metrics + metrics
yield chunk
return stream_generator()
@ -376,13 +371,13 @@ class InferenceRouter(Inference):
response = await provider.completion(**params)
completion_tokens = await self._count_tokens(response.content)
total_tokens = (prompt_tokens or 0) + (completion_tokens or 0)
await self._add_token_metrics(
metrics = await self._compute_and_log_token_usage(
prompt_tokens or 0,
completion_tokens or 0,
total_tokens,
model,
response,
)
response.metrics = metrics if response.metrics is None else response.metrics + metrics
return response
async def embeddings(