mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-08-12 13:00:39 +00:00
metrics for completion API
This commit is contained in:
parent
e9bb96334b
commit
77c2418a9c
1 changed files with 84 additions and 34 deletions
|
@ -168,6 +168,37 @@ class InferenceRouter(Inference):
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
async def _add_token_metrics(
|
||||||
|
self,
|
||||||
|
prompt_tokens: int,
|
||||||
|
completion_tokens: int,
|
||||||
|
total_tokens: int,
|
||||||
|
model: Model,
|
||||||
|
target: Any,
|
||||||
|
) -> None:
|
||||||
|
metrics = getattr(target, "metrics", None)
|
||||||
|
if metrics is None:
|
||||||
|
target.metrics = []
|
||||||
|
target.metrics.append(
|
||||||
|
TokenUsage(
|
||||||
|
prompt_tokens=prompt_tokens,
|
||||||
|
completion_tokens=completion_tokens,
|
||||||
|
total_tokens=total_tokens,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
if self.telemetry:
|
||||||
|
await self._log_token_usage(prompt_tokens, completion_tokens, total_tokens, model)
|
||||||
|
|
||||||
|
async def _count_tokens(
|
||||||
|
self,
|
||||||
|
messages: Union[List[Message], List[RawMessage]],
|
||||||
|
tool_prompt_format: Optional[ToolPromptFormat] = None,
|
||||||
|
) -> Optional[int]:
|
||||||
|
if not self.telemetry:
|
||||||
|
return None
|
||||||
|
encoded = self.formatter.encode_dialog_prompt(messages, tool_prompt_format)
|
||||||
|
return len(encoded.tokens) if encoded and encoded.tokens else 0
|
||||||
|
|
||||||
async def chat_completion(
|
async def chat_completion(
|
||||||
self,
|
self,
|
||||||
model_id: str,
|
model_id: str,
|
||||||
|
@ -227,60 +258,46 @@ class InferenceRouter(Inference):
|
||||||
tool_config=tool_config,
|
tool_config=tool_config,
|
||||||
)
|
)
|
||||||
provider = self.routing_table.get_provider_impl(model_id)
|
provider = self.routing_table.get_provider_impl(model_id)
|
||||||
model_input = self.formatter.encode_dialog_prompt(
|
prompt_tokens = await self._count_tokens(messages, tool_config.tool_prompt_format)
|
||||||
messages,
|
|
||||||
tool_config.tool_prompt_format,
|
|
||||||
)
|
|
||||||
if stream:
|
if stream:
|
||||||
|
|
||||||
async def stream_generator():
|
async def stream_generator():
|
||||||
prompt_tokens = len(model_input.tokens) if model_input.tokens else 0
|
|
||||||
completion_text = ""
|
completion_text = ""
|
||||||
async for chunk in await provider.chat_completion(**params):
|
async for chunk in await provider.chat_completion(**params):
|
||||||
if chunk.event.event_type == ChatCompletionResponseEventType.progress:
|
if chunk.event.event_type == ChatCompletionResponseEventType.progress:
|
||||||
if chunk.event.delta.type == "text":
|
if chunk.event.delta.type == "text":
|
||||||
completion_text += chunk.event.delta.text
|
completion_text += chunk.event.delta.text
|
||||||
if chunk.event.event_type == ChatCompletionResponseEventType.complete:
|
if chunk.event.event_type == ChatCompletionResponseEventType.complete:
|
||||||
model_output = self.formatter.encode_dialog_prompt(
|
completion_tokens = await self._count_tokens(
|
||||||
[RawMessage(role="assistant", content=completion_text)],
|
[RawMessage(role="assistant", content=completion_text)],
|
||||||
tool_config.tool_prompt_format,
|
tool_config.tool_prompt_format,
|
||||||
)
|
)
|
||||||
completion_tokens = len(model_output.tokens) if model_output.tokens else 0
|
total_tokens = (prompt_tokens or 0) + (completion_tokens or 0)
|
||||||
total_tokens = prompt_tokens + completion_tokens
|
await self._add_token_metrics(
|
||||||
if chunk.metrics is None:
|
prompt_tokens or 0,
|
||||||
chunk.metrics = []
|
completion_tokens or 0,
|
||||||
chunk.metrics.append(
|
total_tokens,
|
||||||
TokenUsage(
|
model,
|
||||||
prompt_tokens=prompt_tokens,
|
chunk,
|
||||||
completion_tokens=completion_tokens,
|
|
||||||
total_tokens=total_tokens,
|
|
||||||
)
|
)
|
||||||
)
|
|
||||||
if self.telemetry:
|
|
||||||
await self._log_token_usage(prompt_tokens, completion_tokens, total_tokens, model)
|
|
||||||
yield chunk
|
yield chunk
|
||||||
|
|
||||||
return stream_generator()
|
return stream_generator()
|
||||||
else:
|
else:
|
||||||
response = await provider.chat_completion(**params)
|
response = await provider.chat_completion(**params)
|
||||||
model_output = self.formatter.encode_dialog_prompt(
|
completion_tokens = await self._count_tokens(
|
||||||
[response.completion_message],
|
[response.completion_message],
|
||||||
tool_config.tool_prompt_format,
|
tool_config.tool_prompt_format,
|
||||||
)
|
)
|
||||||
prompt_tokens = len(model_input.tokens) if model_input.tokens else 0
|
total_tokens = (prompt_tokens or 0) + (completion_tokens or 0)
|
||||||
completion_tokens = len(model_output.tokens) if model_output.tokens else 0
|
await self._add_token_metrics(
|
||||||
total_tokens = prompt_tokens + completion_tokens
|
prompt_tokens or 0,
|
||||||
if response.metrics is None:
|
completion_tokens or 0,
|
||||||
response.metrics = []
|
total_tokens,
|
||||||
response.metrics.append(
|
model,
|
||||||
TokenUsage(
|
response,
|
||||||
prompt_tokens=prompt_tokens,
|
|
||||||
completion_tokens=completion_tokens,
|
|
||||||
total_tokens=total_tokens,
|
|
||||||
)
|
)
|
||||||
)
|
|
||||||
if self.telemetry:
|
|
||||||
await self._log_token_usage(prompt_tokens, completion_tokens, total_tokens, model)
|
|
||||||
return response
|
return response
|
||||||
|
|
||||||
async def completion(
|
async def completion(
|
||||||
|
@ -306,10 +323,43 @@ class InferenceRouter(Inference):
|
||||||
stream=stream,
|
stream=stream,
|
||||||
logprobs=logprobs,
|
logprobs=logprobs,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
prompt_tokens = await self._count_tokens([RawMessage(role="user", content=str(content))])
|
||||||
|
|
||||||
if stream:
|
if stream:
|
||||||
return (chunk async for chunk in await provider.completion(**params))
|
|
||||||
|
async def stream_generator():
|
||||||
|
completion_text = ""
|
||||||
|
async for chunk in await provider.completion(**params):
|
||||||
|
if hasattr(chunk, "delta"):
|
||||||
|
completion_text += chunk.delta
|
||||||
|
if hasattr(chunk, "stop_reason") and chunk.stop_reason and self.telemetry:
|
||||||
|
completion_tokens = await self._count_tokens(
|
||||||
|
[RawMessage(role="assistant", content=completion_text)]
|
||||||
|
)
|
||||||
|
total_tokens = (prompt_tokens or 0) + (completion_tokens or 0)
|
||||||
|
await self._add_token_metrics(
|
||||||
|
prompt_tokens or 0,
|
||||||
|
completion_tokens or 0,
|
||||||
|
total_tokens,
|
||||||
|
model,
|
||||||
|
chunk,
|
||||||
|
)
|
||||||
|
yield chunk
|
||||||
|
|
||||||
|
return stream_generator()
|
||||||
else:
|
else:
|
||||||
return await provider.completion(**params)
|
response = await provider.completion(**params)
|
||||||
|
completion_tokens = await self._count_tokens([RawMessage(role="assistant", content=str(response.content))])
|
||||||
|
total_tokens = (prompt_tokens or 0) + (completion_tokens or 0)
|
||||||
|
await self._add_token_metrics(
|
||||||
|
prompt_tokens or 0,
|
||||||
|
completion_tokens or 0,
|
||||||
|
total_tokens,
|
||||||
|
model,
|
||||||
|
response,
|
||||||
|
)
|
||||||
|
return response
|
||||||
|
|
||||||
async def embeddings(
|
async def embeddings(
|
||||||
self,
|
self,
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue