mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-08-12 04:50:39 +00:00
refactor
This commit is contained in:
parent
d2c1162021
commit
306a2d2bff
1 changed files with 11 additions and 16 deletions
|
@ -190,23 +190,18 @@ class InferenceRouter(Inference):
|
||||||
)
|
)
|
||||||
return metric_events
|
return metric_events
|
||||||
|
|
||||||
async def _add_token_metrics(
|
async def _compute_and_log_token_usage(
|
||||||
self,
|
self,
|
||||||
prompt_tokens: int,
|
prompt_tokens: int,
|
||||||
completion_tokens: int,
|
completion_tokens: int,
|
||||||
total_tokens: int,
|
total_tokens: int,
|
||||||
model: Model,
|
model: Model,
|
||||||
target: Any,
|
) -> List[MetricEvent]:
|
||||||
) -> None:
|
|
||||||
metrics = getattr(target, "metrics", None)
|
|
||||||
if metrics is None:
|
|
||||||
target.metrics = []
|
|
||||||
|
|
||||||
metrics = self._construct_metrics(prompt_tokens, completion_tokens, total_tokens, model)
|
metrics = self._construct_metrics(prompt_tokens, completion_tokens, total_tokens, model)
|
||||||
target.metrics.extend(metrics)
|
|
||||||
if self.telemetry:
|
if self.telemetry:
|
||||||
for metric in metrics:
|
for metric in metrics:
|
||||||
await self.telemetry.log_event(metric)
|
await self.telemetry.log_event(metric)
|
||||||
|
return metrics
|
||||||
|
|
||||||
async def _count_tokens(
|
async def _count_tokens(
|
||||||
self,
|
self,
|
||||||
|
@ -296,13 +291,13 @@ class InferenceRouter(Inference):
|
||||||
tool_config.tool_prompt_format,
|
tool_config.tool_prompt_format,
|
||||||
)
|
)
|
||||||
total_tokens = (prompt_tokens or 0) + (completion_tokens or 0)
|
total_tokens = (prompt_tokens or 0) + (completion_tokens or 0)
|
||||||
await self._add_token_metrics(
|
metrics = await self._compute_and_log_token_usage(
|
||||||
prompt_tokens or 0,
|
prompt_tokens or 0,
|
||||||
completion_tokens or 0,
|
completion_tokens or 0,
|
||||||
total_tokens,
|
total_tokens,
|
||||||
model,
|
model,
|
||||||
chunk,
|
|
||||||
)
|
)
|
||||||
|
chunk.metrics = metrics if chunk.metrics is None else chunk.metrics + metrics
|
||||||
yield chunk
|
yield chunk
|
||||||
|
|
||||||
return stream_generator()
|
return stream_generator()
|
||||||
|
@ -313,13 +308,13 @@ class InferenceRouter(Inference):
|
||||||
tool_config.tool_prompt_format,
|
tool_config.tool_prompt_format,
|
||||||
)
|
)
|
||||||
total_tokens = (prompt_tokens or 0) + (completion_tokens or 0)
|
total_tokens = (prompt_tokens or 0) + (completion_tokens or 0)
|
||||||
await self._add_token_metrics(
|
metrics = await self._compute_and_log_token_usage(
|
||||||
prompt_tokens or 0,
|
prompt_tokens or 0,
|
||||||
completion_tokens or 0,
|
completion_tokens or 0,
|
||||||
total_tokens,
|
total_tokens,
|
||||||
model,
|
model,
|
||||||
response,
|
|
||||||
)
|
)
|
||||||
|
response.metrics = metrics if response.metrics is None else response.metrics + metrics
|
||||||
return response
|
return response
|
||||||
|
|
||||||
async def completion(
|
async def completion(
|
||||||
|
@ -362,13 +357,13 @@ class InferenceRouter(Inference):
|
||||||
if hasattr(chunk, "stop_reason") and chunk.stop_reason and self.telemetry:
|
if hasattr(chunk, "stop_reason") and chunk.stop_reason and self.telemetry:
|
||||||
completion_tokens = await self._count_tokens(completion_text)
|
completion_tokens = await self._count_tokens(completion_text)
|
||||||
total_tokens = (prompt_tokens or 0) + (completion_tokens or 0)
|
total_tokens = (prompt_tokens or 0) + (completion_tokens or 0)
|
||||||
await self._add_token_metrics(
|
metrics = await self._compute_and_log_token_usage(
|
||||||
prompt_tokens or 0,
|
prompt_tokens or 0,
|
||||||
completion_tokens or 0,
|
completion_tokens or 0,
|
||||||
total_tokens,
|
total_tokens,
|
||||||
model,
|
model,
|
||||||
chunk,
|
|
||||||
)
|
)
|
||||||
|
chunk.metrics = metrics if chunk.metrics is None else chunk.metrics + metrics
|
||||||
yield chunk
|
yield chunk
|
||||||
|
|
||||||
return stream_generator()
|
return stream_generator()
|
||||||
|
@ -376,13 +371,13 @@ class InferenceRouter(Inference):
|
||||||
response = await provider.completion(**params)
|
response = await provider.completion(**params)
|
||||||
completion_tokens = await self._count_tokens(response.content)
|
completion_tokens = await self._count_tokens(response.content)
|
||||||
total_tokens = (prompt_tokens or 0) + (completion_tokens or 0)
|
total_tokens = (prompt_tokens or 0) + (completion_tokens or 0)
|
||||||
await self._add_token_metrics(
|
metrics = await self._compute_and_log_token_usage(
|
||||||
prompt_tokens or 0,
|
prompt_tokens or 0,
|
||||||
completion_tokens or 0,
|
completion_tokens or 0,
|
||||||
total_tokens,
|
total_tokens,
|
||||||
model,
|
model,
|
||||||
response,
|
|
||||||
)
|
)
|
||||||
|
response.metrics = metrics if response.metrics is None else response.metrics + metrics
|
||||||
return response
|
return response
|
||||||
|
|
||||||
async def embeddings(
|
async def embeddings(
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue