mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-08-12 04:50:39 +00:00
use encode_content for completion token usage
This commit is contained in:
parent
1952ffa410
commit
d2c1162021
2 changed files with 12 additions and 14 deletions
|
@ -285,7 +285,7 @@ class CompletionRequest(BaseModel):
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
class CompletionResponse(BaseModel):
|
class CompletionResponse(MetricResponseMixin):
|
||||||
"""Response from a completion request.
|
"""Response from a completion request.
|
||||||
|
|
||||||
:param content: The generated completion text
|
:param content: The generated completion text
|
||||||
|
@ -299,7 +299,7 @@ class CompletionResponse(BaseModel):
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
class CompletionResponseStreamChunk(BaseModel):
|
class CompletionResponseStreamChunk(MetricResponseMixin):
|
||||||
"""A chunk of a streamed completion response.
|
"""A chunk of a streamed completion response.
|
||||||
|
|
||||||
:param delta: New content generated since last chunk. This can be one or more tokens.
|
:param delta: New content generated since last chunk. This can be one or more tokens.
|
||||||
|
@ -368,7 +368,7 @@ class ChatCompletionRequest(BaseModel):
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
class ChatCompletionResponseStreamChunk(MetricResponseMixin, BaseModel):
|
class ChatCompletionResponseStreamChunk(MetricResponseMixin):
|
||||||
"""A chunk of a streamed chat completion response.
|
"""A chunk of a streamed chat completion response.
|
||||||
|
|
||||||
:param event: The event containing the new content
|
:param event: The event containing the new content
|
||||||
|
@ -378,7 +378,7 @@ class ChatCompletionResponseStreamChunk(MetricResponseMixin, BaseModel):
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
class ChatCompletionResponse(MetricResponseMixin, BaseModel):
|
class ChatCompletionResponse(MetricResponseMixin):
|
||||||
"""Response from a chat completion request.
|
"""Response from a chat completion request.
|
||||||
|
|
||||||
:param completion_message: The complete response message
|
:param completion_message: The complete response message
|
||||||
|
|
|
@ -42,7 +42,6 @@ from llama_stack.apis.inference import (
|
||||||
ToolConfig,
|
ToolConfig,
|
||||||
ToolDefinition,
|
ToolDefinition,
|
||||||
ToolPromptFormat,
|
ToolPromptFormat,
|
||||||
UserMessage,
|
|
||||||
)
|
)
|
||||||
from llama_stack.apis.models import Model, ModelType
|
from llama_stack.apis.models import Model, ModelType
|
||||||
from llama_stack.apis.safety import RunShieldResponse, Safety
|
from llama_stack.apis.safety import RunShieldResponse, Safety
|
||||||
|
@ -211,10 +210,13 @@ class InferenceRouter(Inference):
|
||||||
|
|
||||||
async def _count_tokens(
|
async def _count_tokens(
|
||||||
self,
|
self,
|
||||||
messages: List[Message],
|
messages: List[Message] | InterleavedContent,
|
||||||
tool_prompt_format: Optional[ToolPromptFormat] = None,
|
tool_prompt_format: Optional[ToolPromptFormat] = None,
|
||||||
) -> Optional[int]:
|
) -> Optional[int]:
|
||||||
encoded = self.formatter.encode_dialog_prompt(messages, tool_prompt_format)
|
if isinstance(messages, list):
|
||||||
|
encoded = self.formatter.encode_dialog_prompt(messages, tool_prompt_format)
|
||||||
|
else:
|
||||||
|
encoded = self.formatter.encode_content(messages)
|
||||||
return len(encoded.tokens) if encoded and encoded.tokens else 0
|
return len(encoded.tokens) if encoded and encoded.tokens else 0
|
||||||
|
|
||||||
async def chat_completion(
|
async def chat_completion(
|
||||||
|
@ -348,7 +350,7 @@ class InferenceRouter(Inference):
|
||||||
logprobs=logprobs,
|
logprobs=logprobs,
|
||||||
)
|
)
|
||||||
|
|
||||||
prompt_tokens = await self._count_tokens([UserMessage(role="user", content=str(content))])
|
prompt_tokens = await self._count_tokens(content)
|
||||||
|
|
||||||
if stream:
|
if stream:
|
||||||
|
|
||||||
|
@ -358,9 +360,7 @@ class InferenceRouter(Inference):
|
||||||
if hasattr(chunk, "delta"):
|
if hasattr(chunk, "delta"):
|
||||||
completion_text += chunk.delta
|
completion_text += chunk.delta
|
||||||
if hasattr(chunk, "stop_reason") and chunk.stop_reason and self.telemetry:
|
if hasattr(chunk, "stop_reason") and chunk.stop_reason and self.telemetry:
|
||||||
completion_tokens = await self._count_tokens(
|
completion_tokens = await self._count_tokens(completion_text)
|
||||||
[CompletionMessage(content=completion_text, stop_reason=chunk.stop_reason)]
|
|
||||||
)
|
|
||||||
total_tokens = (prompt_tokens or 0) + (completion_tokens or 0)
|
total_tokens = (prompt_tokens or 0) + (completion_tokens or 0)
|
||||||
await self._add_token_metrics(
|
await self._add_token_metrics(
|
||||||
prompt_tokens or 0,
|
prompt_tokens or 0,
|
||||||
|
@ -374,9 +374,7 @@ class InferenceRouter(Inference):
|
||||||
return stream_generator()
|
return stream_generator()
|
||||||
else:
|
else:
|
||||||
response = await provider.completion(**params)
|
response = await provider.completion(**params)
|
||||||
completion_tokens = await self._count_tokens(
|
completion_tokens = await self._count_tokens(response.content)
|
||||||
[CompletionMessage(content=str(response.content), stop_reason=StopReason.end_of_turn)]
|
|
||||||
)
|
|
||||||
total_tokens = (prompt_tokens or 0) + (completion_tokens or 0)
|
total_tokens = (prompt_tokens or 0) + (completion_tokens or 0)
|
||||||
await self._add_token_metrics(
|
await self._add_token_metrics(
|
||||||
prompt_tokens or 0,
|
prompt_tokens or 0,
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue