mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-08-12 04:50:39 +00:00
use encode_content for completion token usage
This commit is contained in:
parent
1952ffa410
commit
d2c1162021
2 changed files with 12 additions and 14 deletions
|
@ -285,7 +285,7 @@ class CompletionRequest(BaseModel):
|
|||
|
||||
|
||||
@json_schema_type
|
||||
class CompletionResponse(BaseModel):
|
||||
class CompletionResponse(MetricResponseMixin):
|
||||
"""Response from a completion request.
|
||||
|
||||
:param content: The generated completion text
|
||||
|
@ -299,7 +299,7 @@ class CompletionResponse(BaseModel):
|
|||
|
||||
|
||||
@json_schema_type
|
||||
class CompletionResponseStreamChunk(BaseModel):
|
||||
class CompletionResponseStreamChunk(MetricResponseMixin):
|
||||
"""A chunk of a streamed completion response.
|
||||
|
||||
:param delta: New content generated since last chunk. This can be one or more tokens.
|
||||
|
@ -368,7 +368,7 @@ class ChatCompletionRequest(BaseModel):
|
|||
|
||||
|
||||
@json_schema_type
|
||||
class ChatCompletionResponseStreamChunk(MetricResponseMixin, BaseModel):
|
||||
class ChatCompletionResponseStreamChunk(MetricResponseMixin):
|
||||
"""A chunk of a streamed chat completion response.
|
||||
|
||||
:param event: The event containing the new content
|
||||
|
@ -378,7 +378,7 @@ class ChatCompletionResponseStreamChunk(MetricResponseMixin, BaseModel):
|
|||
|
||||
|
||||
@json_schema_type
|
||||
class ChatCompletionResponse(MetricResponseMixin, BaseModel):
|
||||
class ChatCompletionResponse(MetricResponseMixin):
|
||||
"""Response from a chat completion request.
|
||||
|
||||
:param completion_message: The complete response message
|
||||
|
|
|
@ -42,7 +42,6 @@ from llama_stack.apis.inference import (
|
|||
ToolConfig,
|
||||
ToolDefinition,
|
||||
ToolPromptFormat,
|
||||
UserMessage,
|
||||
)
|
||||
from llama_stack.apis.models import Model, ModelType
|
||||
from llama_stack.apis.safety import RunShieldResponse, Safety
|
||||
|
@ -211,10 +210,13 @@ class InferenceRouter(Inference):
|
|||
|
||||
async def _count_tokens(
|
||||
self,
|
||||
messages: List[Message],
|
||||
messages: List[Message] | InterleavedContent,
|
||||
tool_prompt_format: Optional[ToolPromptFormat] = None,
|
||||
) -> Optional[int]:
|
||||
if isinstance(messages, list):
|
||||
encoded = self.formatter.encode_dialog_prompt(messages, tool_prompt_format)
|
||||
else:
|
||||
encoded = self.formatter.encode_content(messages)
|
||||
return len(encoded.tokens) if encoded and encoded.tokens else 0
|
||||
|
||||
async def chat_completion(
|
||||
|
@ -348,7 +350,7 @@ class InferenceRouter(Inference):
|
|||
logprobs=logprobs,
|
||||
)
|
||||
|
||||
prompt_tokens = await self._count_tokens([UserMessage(role="user", content=str(content))])
|
||||
prompt_tokens = await self._count_tokens(content)
|
||||
|
||||
if stream:
|
||||
|
||||
|
@ -358,9 +360,7 @@ class InferenceRouter(Inference):
|
|||
if hasattr(chunk, "delta"):
|
||||
completion_text += chunk.delta
|
||||
if hasattr(chunk, "stop_reason") and chunk.stop_reason and self.telemetry:
|
||||
completion_tokens = await self._count_tokens(
|
||||
[CompletionMessage(content=completion_text, stop_reason=chunk.stop_reason)]
|
||||
)
|
||||
completion_tokens = await self._count_tokens(completion_text)
|
||||
total_tokens = (prompt_tokens or 0) + (completion_tokens or 0)
|
||||
await self._add_token_metrics(
|
||||
prompt_tokens or 0,
|
||||
|
@ -374,9 +374,7 @@ class InferenceRouter(Inference):
|
|||
return stream_generator()
|
||||
else:
|
||||
response = await provider.completion(**params)
|
||||
completion_tokens = await self._count_tokens(
|
||||
[CompletionMessage(content=str(response.content), stop_reason=StopReason.end_of_turn)]
|
||||
)
|
||||
completion_tokens = await self._count_tokens(response.content)
|
||||
total_tokens = (prompt_tokens or 0) + (completion_tokens or 0)
|
||||
await self._add_token_metrics(
|
||||
prompt_tokens or 0,
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue