use encode_content for completion token usage

This commit is contained in:
Dinesh Yeduguru 2025-03-05 11:00:00 -08:00
parent 1952ffa410
commit d2c1162021
2 changed files with 12 additions and 14 deletions

View file

@ -285,7 +285,7 @@ class CompletionRequest(BaseModel):
@json_schema_type @json_schema_type
class CompletionResponse(BaseModel): class CompletionResponse(MetricResponseMixin):
"""Response from a completion request. """Response from a completion request.
:param content: The generated completion text :param content: The generated completion text
@ -299,7 +299,7 @@ class CompletionResponse(BaseModel):
@json_schema_type @json_schema_type
class CompletionResponseStreamChunk(BaseModel): class CompletionResponseStreamChunk(MetricResponseMixin):
"""A chunk of a streamed completion response. """A chunk of a streamed completion response.
:param delta: New content generated since last chunk. This can be one or more tokens. :param delta: New content generated since last chunk. This can be one or more tokens.
@ -368,7 +368,7 @@ class ChatCompletionRequest(BaseModel):
@json_schema_type @json_schema_type
class ChatCompletionResponseStreamChunk(MetricResponseMixin, BaseModel): class ChatCompletionResponseStreamChunk(MetricResponseMixin):
"""A chunk of a streamed chat completion response. """A chunk of a streamed chat completion response.
:param event: The event containing the new content :param event: The event containing the new content
@ -378,7 +378,7 @@ class ChatCompletionResponseStreamChunk(MetricResponseMixin, BaseModel):
@json_schema_type @json_schema_type
class ChatCompletionResponse(MetricResponseMixin, BaseModel): class ChatCompletionResponse(MetricResponseMixin):
"""Response from a chat completion request. """Response from a chat completion request.
:param completion_message: The complete response message :param completion_message: The complete response message

View file

@ -42,7 +42,6 @@ from llama_stack.apis.inference import (
ToolConfig, ToolConfig,
ToolDefinition, ToolDefinition,
ToolPromptFormat, ToolPromptFormat,
UserMessage,
) )
from llama_stack.apis.models import Model, ModelType from llama_stack.apis.models import Model, ModelType
from llama_stack.apis.safety import RunShieldResponse, Safety from llama_stack.apis.safety import RunShieldResponse, Safety
@ -211,10 +210,13 @@ class InferenceRouter(Inference):
async def _count_tokens( async def _count_tokens(
self, self,
messages: List[Message], messages: List[Message] | InterleavedContent,
tool_prompt_format: Optional[ToolPromptFormat] = None, tool_prompt_format: Optional[ToolPromptFormat] = None,
) -> Optional[int]: ) -> Optional[int]:
if isinstance(messages, list):
encoded = self.formatter.encode_dialog_prompt(messages, tool_prompt_format) encoded = self.formatter.encode_dialog_prompt(messages, tool_prompt_format)
else:
encoded = self.formatter.encode_content(messages)
return len(encoded.tokens) if encoded and encoded.tokens else 0 return len(encoded.tokens) if encoded and encoded.tokens else 0
async def chat_completion( async def chat_completion(
@ -348,7 +350,7 @@ class InferenceRouter(Inference):
logprobs=logprobs, logprobs=logprobs,
) )
prompt_tokens = await self._count_tokens([UserMessage(role="user", content=str(content))]) prompt_tokens = await self._count_tokens(content)
if stream: if stream:
@ -358,9 +360,7 @@ class InferenceRouter(Inference):
if hasattr(chunk, "delta"): if hasattr(chunk, "delta"):
completion_text += chunk.delta completion_text += chunk.delta
if hasattr(chunk, "stop_reason") and chunk.stop_reason and self.telemetry: if hasattr(chunk, "stop_reason") and chunk.stop_reason and self.telemetry:
completion_tokens = await self._count_tokens( completion_tokens = await self._count_tokens(completion_text)
[CompletionMessage(content=completion_text, stop_reason=chunk.stop_reason)]
)
total_tokens = (prompt_tokens or 0) + (completion_tokens or 0) total_tokens = (prompt_tokens or 0) + (completion_tokens or 0)
await self._add_token_metrics( await self._add_token_metrics(
prompt_tokens or 0, prompt_tokens or 0,
@ -374,9 +374,7 @@ class InferenceRouter(Inference):
return stream_generator() return stream_generator()
else: else:
response = await provider.completion(**params) response = await provider.completion(**params)
completion_tokens = await self._count_tokens( completion_tokens = await self._count_tokens(response.content)
[CompletionMessage(content=str(response.content), stop_reason=StopReason.end_of_turn)]
)
total_tokens = (prompt_tokens or 0) + (completion_tokens or 0) total_tokens = (prompt_tokens or 0) + (completion_tokens or 0)
await self._add_token_metrics( await self._add_token_metrics(
prompt_tokens or 0, prompt_tokens or 0,