diff --git a/llama_stack/apis/inference/inference.py b/llama_stack/apis/inference/inference.py index 871f1f633..7154ad746 100644 --- a/llama_stack/apis/inference/inference.py +++ b/llama_stack/apis/inference/inference.py @@ -186,6 +186,13 @@ ResponseFormat = register_schema( ) +@json_schema_type +class UsageStatistics(BaseModel): + prompt_tokens: int + completion_tokens: int + total_tokens: int + + @json_schema_type class CompletionRequest(BaseModel): model: str @@ -204,6 +211,7 @@ class CompletionResponse(BaseModel): content: str stop_reason: StopReason logprobs: Optional[List[TokenLogProbs]] = None + usage: Optional[UsageStatistics] = None @json_schema_type @@ -213,6 +221,7 @@ class CompletionResponseStreamChunk(BaseModel): delta: str stop_reason: Optional[StopReason] = None logprobs: Optional[List[TokenLogProbs]] = None + usage: Optional[UsageStatistics] = None @json_schema_type @@ -252,6 +261,7 @@ class ChatCompletionResponseStreamChunk(BaseModel): """SSE-stream of these events.""" event: ChatCompletionResponseEvent + usage: Optional[UsageStatistics] = None @json_schema_type @@ -260,6 +270,7 @@ class ChatCompletionResponse(BaseModel): completion_message: CompletionMessage logprobs: Optional[List[TokenLogProbs]] = None + usage: Optional[UsageStatistics] = None @json_schema_type diff --git a/llama_stack/providers/inline/inference/meta_reference/generation.py b/llama_stack/providers/inline/inference/meta_reference/generation.py index a96409cab..ac5bcc388 100644 --- a/llama_stack/providers/inline/inference/meta_reference/generation.py +++ b/llama_stack/providers/inline/inference/meta_reference/generation.py @@ -37,7 +37,6 @@ from llama_models.llama3.reference_impl.multimodal.model import ( CrossAttentionTransformer, ) from llama_models.sku_list import resolve_model - from lmformatenforcer import JsonSchemaParser, TokenEnforcer, TokenEnforcerTokenizerData from pydantic import BaseModel @@ -47,7 +46,6 @@ from llama_stack.apis.inference import ( ResponseFormat, ResponseFormatType, ) - from llama_stack.distribution.utils.model_utils import model_local_dir from llama_stack.providers.utils.inference.prompt_adapter import ( ChatCompletionRequestWithRawContent, @@ -78,6 +76,7 @@ class TokenResult(BaseModel): token: int text: str logprobs: Optional[List[float]] = None + input_token_count: Optional[int] = None class Llama: @@ -348,6 +347,7 @@ class Llama: if logprobs else None ), + input_token_count=len(model_input.tokens), ) prev_pos = cur_pos diff --git a/llama_stack/providers/inline/inference/meta_reference/inference.py b/llama_stack/providers/inline/inference/meta_reference/inference.py index 73962ca7f..47d878b0c 100644 --- a/llama_stack/providers/inline/inference/meta_reference/inference.py +++ b/llama_stack/providers/inline/inference/meta_reference/inference.py @@ -38,6 +38,7 @@ from llama_stack.apis.inference import ( ResponseFormat, TokenLogProbs, ToolChoice, + UsageStatistics, ) from llama_stack.apis.models import Model, ModelType from llama_stack.providers.datatypes import ModelsProtocolPrivate @@ -168,8 +169,14 @@ class MetaReferenceInferenceImpl( async def _stream_completion(self, request: CompletionRequest) -> AsyncGenerator: def impl(): stop_reason = None + input_token_count = 0 + output_token_count = 0 + usage_statistics = None for token_result in self.generator.completion(request): + if input_token_count == 0: + input_token_count = token_result.input_token_count + output_token_count += len(token_result.token) if token_result.text == "<|eot_id|>": stop_reason = StopReason.end_of_turn text = "" @@ -191,17 +198,29 @@ class MetaReferenceInferenceImpl( } ) ] + else: + usage_statistics = UsageStatistics( + prompt_tokens=input_token_count, + completion_tokens=output_token_count, + total_tokens=input_token_count + output_token_count, + ) yield CompletionResponseStreamChunk( delta=text, stop_reason=stop_reason, logprobs=logprobs if request.logprobs else None, + usage=usage_statistics, ) if stop_reason is None: yield CompletionResponseStreamChunk( delta="", stop_reason=StopReason.out_of_tokens, + usage=UsageStatistics( + prompt_tokens=input_token_count, + completion_tokens=output_token_count, + total_tokens=input_token_count + output_token_count, + ), ) if self.config.create_distributed_process_group: @@ -221,7 +240,10 @@ class MetaReferenceInferenceImpl( stop_reason = None tokenizer = self.generator.formatter.tokenizer + input_token_count = 0 for token_result in self.generator.completion(request): + if input_token_count == 0: + input_token_count = token_result.input_token_count tokens.append(token_result.token) if token_result.text == "<|eot_id|>": stop_reason = StopReason.end_of_turn @@ -242,7 +264,7 @@ class MetaReferenceInferenceImpl( if stop_reason is None: stop_reason = StopReason.out_of_tokens - content = self.generator.formatter.tokenizer.decode(tokens) + content = tokenizer.decode(tokens) if content.endswith("<|eot_id|>"): content = content[: -len("<|eot_id|>")] elif content.endswith("<|eom_id|>"): @@ -251,6 +273,11 @@ class MetaReferenceInferenceImpl( content=content, stop_reason=stop_reason, logprobs=logprobs if request.logprobs else None, + usage_statistics=UsageStatistics( + prompt_tokens=input_token_count, + completion_tokens=len(tokens), + total_tokens=input_token_count + len(tokens), + ), ) if self.config.create_distributed_process_group: diff --git a/llama_stack/providers/utils/inference/openai_compat.py b/llama_stack/providers/utils/inference/openai_compat.py index 6c93f49c0..feea8adb0 100644 --- a/llama_stack/providers/utils/inference/openai_compat.py +++ b/llama_stack/providers/utils/inference/openai_compat.py @@ -12,7 +12,6 @@ from llama_models.datatypes import ( TopKSamplingStrategy, TopPSamplingStrategy, ) - from llama_models.llama3.api.chat_format import ChatFormat from llama_models.llama3.api.datatypes import StopReason from pydantic import BaseModel @@ -24,7 +23,6 @@ from llama_stack.apis.common.content_types import ( ToolCallDelta, ToolCallParseStatus, ) - from llama_stack.apis.inference import ( ChatCompletionResponse, ChatCompletionResponseEvent, @@ -35,8 +33,8 @@ from llama_stack.apis.inference import ( CompletionResponseStreamChunk, Message, TokenLogProbs, + UsageStatistics, ) - from llama_stack.providers.utils.inference.prompt_adapter import ( convert_image_content_to_url, ) @@ -63,8 +61,15 @@ class OpenAICompatCompletionChoice(BaseModel): logprobs: Optional[OpenAICompatLogprobs] = None +class OpenAICompatCompletionUsage(BaseModel): + prompt_tokens: int + completion_tokens: int + total_tokens: int + + class OpenAICompatCompletionResponse(BaseModel): choices: List[OpenAICompatCompletionChoice] + usage: Optional[OpenAICompatCompletionUsage] = None def get_sampling_strategy_options(params: SamplingParams) -> dict: @@ -124,16 +129,31 @@ def convert_openai_completion_logprobs( return [TokenLogProbs(logprobs_by_token=x) for x in logprobs.top_logprobs] +def get_usage_statistics( + response: OpenAICompatCompletionResponse, +) -> Optional[UsageStatistics]: + if response.usage: + return UsageStatistics( + prompt_tokens=response.usage.prompt_tokens, + completion_tokens=response.usage.completion_tokens, + total_tokens=response.usage.total_tokens, + ) + return None + + def process_completion_response( response: OpenAICompatCompletionResponse, formatter: ChatFormat ) -> CompletionResponse: choice = response.choices[0] + usage_statistics = get_usage_statistics(response) + # drop suffix if present and return stop reason as end of turn if choice.text.endswith("<|eot_id|>"): return CompletionResponse( stop_reason=StopReason.end_of_turn, content=choice.text[: -len("<|eot_id|>")], logprobs=convert_openai_completion_logprobs(choice.logprobs), + usage=usage_statistics, ) # drop suffix if present and return stop reason as end of message if choice.text.endswith("<|eom_id|>"): @@ -141,11 +161,13 @@ def process_completion_response( stop_reason=StopReason.end_of_message, content=choice.text[: -len("<|eom_id|>")], logprobs=convert_openai_completion_logprobs(choice.logprobs), + usage=usage_statistics, ) return CompletionResponse( stop_reason=get_stop_reason(choice.finish_reason), content=choice.text, logprobs=convert_openai_completion_logprobs(choice.logprobs), + usage=usage_statistics, ) @@ -164,6 +186,7 @@ def process_chat_completion_response( tool_calls=raw_message.tool_calls, ), logprobs=None, + usage=get_usage_statistics(response), ) @@ -171,10 +194,13 @@ async def process_completion_stream_response( stream: AsyncGenerator[OpenAICompatCompletionResponse, None], formatter: ChatFormat ) -> AsyncGenerator: stop_reason = None + usage_statistics = None async for chunk in stream: choice = chunk.choices[0] finish_reason = choice.finish_reason + # usage statistics are only available in the final chunk + usage_statistics = get_usage_statistics(chunk) text = text_from_choice(choice) if text == "<|eot_id|>": @@ -200,6 +226,7 @@ async def process_completion_stream_response( yield CompletionResponseStreamChunk( delta="", stop_reason=stop_reason, + usage=usage_statistics, ) @@ -216,10 +243,11 @@ async def process_chat_completion_stream_response( buffer = "" ipython = False stop_reason = None - + usage_statistics = None async for chunk in stream: choice = chunk.choices[0] finish_reason = choice.finish_reason + usage_statistics = get_usage_statistics(chunk) if finish_reason: if stop_reason is None and finish_reason in ["stop", "eos", "eos_token"]: @@ -313,7 +341,8 @@ async def process_chat_completion_stream_response( event_type=ChatCompletionResponseEventType.complete, delta=TextDelta(text=""), stop_reason=stop_reason, - ) + ), + usage=usage_statistics, )