diff --git a/llama_stack/apis/inference/inference.py b/llama_stack/apis/inference/inference.py index e517d9c3c..08ceace4f 100644 --- a/llama_stack/apis/inference/inference.py +++ b/llama_stack/apis/inference/inference.py @@ -285,7 +285,7 @@ class CompletionRequest(BaseModel): @json_schema_type -class CompletionResponse(BaseModel): +class CompletionResponse(MetricResponseMixin): """Response from a completion request. :param content: The generated completion text @@ -299,7 +299,7 @@ class CompletionResponse(BaseModel): @json_schema_type -class CompletionResponseStreamChunk(BaseModel): +class CompletionResponseStreamChunk(MetricResponseMixin): """A chunk of a streamed completion response. :param delta: New content generated since last chunk. This can be one or more tokens. @@ -368,7 +368,7 @@ class ChatCompletionRequest(BaseModel): @json_schema_type -class ChatCompletionResponseStreamChunk(MetricResponseMixin, BaseModel): +class ChatCompletionResponseStreamChunk(MetricResponseMixin): """A chunk of a streamed chat completion response. :param event: The event containing the new content @@ -378,7 +378,7 @@ class ChatCompletionResponseStreamChunk(MetricResponseMixin, BaseModel): @json_schema_type -class ChatCompletionResponse(MetricResponseMixin, BaseModel): +class ChatCompletionResponse(MetricResponseMixin): """Response from a chat completion request. :param completion_message: The complete response message diff --git a/llama_stack/distribution/routers/routers.py b/llama_stack/distribution/routers/routers.py index a921df929..c637a5b23 100644 --- a/llama_stack/distribution/routers/routers.py +++ b/llama_stack/distribution/routers/routers.py @@ -42,7 +42,6 @@ from llama_stack.apis.inference import ( ToolConfig, ToolDefinition, ToolPromptFormat, - UserMessage, ) from llama_stack.apis.models import Model, ModelType from llama_stack.apis.safety import RunShieldResponse, Safety @@ -211,10 +210,13 @@ class InferenceRouter(Inference): async def _count_tokens( self, - messages: List[Message], + messages: List[Message] | InterleavedContent, tool_prompt_format: Optional[ToolPromptFormat] = None, ) -> Optional[int]: - encoded = self.formatter.encode_dialog_prompt(messages, tool_prompt_format) + if isinstance(messages, list): + encoded = self.formatter.encode_dialog_prompt(messages, tool_prompt_format) + else: + encoded = self.formatter.encode_content(messages) return len(encoded.tokens) if encoded and encoded.tokens else 0 async def chat_completion( @@ -348,7 +350,7 @@ class InferenceRouter(Inference): logprobs=logprobs, ) - prompt_tokens = await self._count_tokens([UserMessage(role="user", content=str(content))]) + prompt_tokens = await self._count_tokens(content) if stream: @@ -358,9 +360,7 @@ class InferenceRouter(Inference): if hasattr(chunk, "delta"): completion_text += chunk.delta if hasattr(chunk, "stop_reason") and chunk.stop_reason and self.telemetry: - completion_tokens = await self._count_tokens( - [CompletionMessage(content=completion_text, stop_reason=chunk.stop_reason)] - ) + completion_tokens = await self._count_tokens(completion_text) total_tokens = (prompt_tokens or 0) + (completion_tokens or 0) await self._add_token_metrics( prompt_tokens or 0, @@ -374,9 +374,7 @@ class InferenceRouter(Inference): return stream_generator() else: response = await provider.completion(**params) - completion_tokens = await self._count_tokens( - [CompletionMessage(content=str(response.content), stop_reason=StopReason.end_of_turn)] - ) + completion_tokens = await self._count_tokens(response.content) total_tokens = (prompt_tokens or 0) + (completion_tokens or 0) await self._add_token_metrics( prompt_tokens or 0,