add usage statistics for inference API

This commit is contained in:
Dinesh Yeduguru 2025-01-28 15:26:19 -08:00
parent 9f709387e2
commit 6609362d26
4 changed files with 75 additions and 8 deletions

View file

@ -186,6 +186,13 @@ ResponseFormat = register_schema(
) )
@json_schema_type
class UsageStatistics(BaseModel):
prompt_tokens: int
completion_tokens: int
total_tokens: int
@json_schema_type @json_schema_type
class CompletionRequest(BaseModel): class CompletionRequest(BaseModel):
model: str model: str
@ -204,6 +211,7 @@ class CompletionResponse(BaseModel):
content: str content: str
stop_reason: StopReason stop_reason: StopReason
logprobs: Optional[List[TokenLogProbs]] = None logprobs: Optional[List[TokenLogProbs]] = None
usage: Optional[UsageStatistics] = None
@json_schema_type @json_schema_type
@ -213,6 +221,7 @@ class CompletionResponseStreamChunk(BaseModel):
delta: str delta: str
stop_reason: Optional[StopReason] = None stop_reason: Optional[StopReason] = None
logprobs: Optional[List[TokenLogProbs]] = None logprobs: Optional[List[TokenLogProbs]] = None
usage: Optional[UsageStatistics] = None
@json_schema_type @json_schema_type
@ -252,6 +261,7 @@ class ChatCompletionResponseStreamChunk(BaseModel):
"""SSE-stream of these events.""" """SSE-stream of these events."""
event: ChatCompletionResponseEvent event: ChatCompletionResponseEvent
usage: Optional[UsageStatistics] = None
@json_schema_type @json_schema_type
@ -260,6 +270,7 @@ class ChatCompletionResponse(BaseModel):
completion_message: CompletionMessage completion_message: CompletionMessage
logprobs: Optional[List[TokenLogProbs]] = None logprobs: Optional[List[TokenLogProbs]] = None
usage: Optional[UsageStatistics] = None
@json_schema_type @json_schema_type

View file

@ -37,7 +37,6 @@ from llama_models.llama3.reference_impl.multimodal.model import (
CrossAttentionTransformer, CrossAttentionTransformer,
) )
from llama_models.sku_list import resolve_model from llama_models.sku_list import resolve_model
from lmformatenforcer import JsonSchemaParser, TokenEnforcer, TokenEnforcerTokenizerData from lmformatenforcer import JsonSchemaParser, TokenEnforcer, TokenEnforcerTokenizerData
from pydantic import BaseModel from pydantic import BaseModel
@ -47,7 +46,6 @@ from llama_stack.apis.inference import (
ResponseFormat, ResponseFormat,
ResponseFormatType, ResponseFormatType,
) )
from llama_stack.distribution.utils.model_utils import model_local_dir from llama_stack.distribution.utils.model_utils import model_local_dir
from llama_stack.providers.utils.inference.prompt_adapter import ( from llama_stack.providers.utils.inference.prompt_adapter import (
ChatCompletionRequestWithRawContent, ChatCompletionRequestWithRawContent,
@ -78,6 +76,7 @@ class TokenResult(BaseModel):
token: int token: int
text: str text: str
logprobs: Optional[List[float]] = None logprobs: Optional[List[float]] = None
input_token_count: Optional[int] = None
class Llama: class Llama:
@ -348,6 +347,7 @@ class Llama:
if logprobs if logprobs
else None else None
), ),
input_token_count=len(model_input.tokens),
) )
prev_pos = cur_pos prev_pos = cur_pos

View file

@ -38,6 +38,7 @@ from llama_stack.apis.inference import (
ResponseFormat, ResponseFormat,
TokenLogProbs, TokenLogProbs,
ToolChoice, ToolChoice,
UsageStatistics,
) )
from llama_stack.apis.models import Model, ModelType from llama_stack.apis.models import Model, ModelType
from llama_stack.providers.datatypes import ModelsProtocolPrivate from llama_stack.providers.datatypes import ModelsProtocolPrivate
@ -168,8 +169,14 @@ class MetaReferenceInferenceImpl(
async def _stream_completion(self, request: CompletionRequest) -> AsyncGenerator: async def _stream_completion(self, request: CompletionRequest) -> AsyncGenerator:
def impl(): def impl():
stop_reason = None stop_reason = None
input_token_count = 0
output_token_count = 0
usage_statistics = None
for token_result in self.generator.completion(request): for token_result in self.generator.completion(request):
if input_token_count == 0:
input_token_count = token_result.input_token_count
output_token_count += len(token_result.token)
if token_result.text == "<|eot_id|>": if token_result.text == "<|eot_id|>":
stop_reason = StopReason.end_of_turn stop_reason = StopReason.end_of_turn
text = "" text = ""
@ -191,17 +198,29 @@ class MetaReferenceInferenceImpl(
} }
) )
] ]
else:
usage_statistics = UsageStatistics(
prompt_tokens=input_token_count,
completion_tokens=output_token_count,
total_tokens=input_token_count + output_token_count,
)
yield CompletionResponseStreamChunk( yield CompletionResponseStreamChunk(
delta=text, delta=text,
stop_reason=stop_reason, stop_reason=stop_reason,
logprobs=logprobs if request.logprobs else None, logprobs=logprobs if request.logprobs else None,
usage=usage_statistics,
) )
if stop_reason is None: if stop_reason is None:
yield CompletionResponseStreamChunk( yield CompletionResponseStreamChunk(
delta="", delta="",
stop_reason=StopReason.out_of_tokens, stop_reason=StopReason.out_of_tokens,
usage=UsageStatistics(
prompt_tokens=input_token_count,
completion_tokens=output_token_count,
total_tokens=input_token_count + output_token_count,
),
) )
if self.config.create_distributed_process_group: if self.config.create_distributed_process_group:
@ -221,7 +240,10 @@ class MetaReferenceInferenceImpl(
stop_reason = None stop_reason = None
tokenizer = self.generator.formatter.tokenizer tokenizer = self.generator.formatter.tokenizer
input_token_count = 0
for token_result in self.generator.completion(request): for token_result in self.generator.completion(request):
if input_token_count == 0:
input_token_count = token_result.input_token_count
tokens.append(token_result.token) tokens.append(token_result.token)
if token_result.text == "<|eot_id|>": if token_result.text == "<|eot_id|>":
stop_reason = StopReason.end_of_turn stop_reason = StopReason.end_of_turn
@ -242,7 +264,7 @@ class MetaReferenceInferenceImpl(
if stop_reason is None: if stop_reason is None:
stop_reason = StopReason.out_of_tokens stop_reason = StopReason.out_of_tokens
content = self.generator.formatter.tokenizer.decode(tokens) content = tokenizer.decode(tokens)
if content.endswith("<|eot_id|>"): if content.endswith("<|eot_id|>"):
content = content[: -len("<|eot_id|>")] content = content[: -len("<|eot_id|>")]
elif content.endswith("<|eom_id|>"): elif content.endswith("<|eom_id|>"):
@ -251,6 +273,11 @@ class MetaReferenceInferenceImpl(
content=content, content=content,
stop_reason=stop_reason, stop_reason=stop_reason,
logprobs=logprobs if request.logprobs else None, logprobs=logprobs if request.logprobs else None,
usage_statistics=UsageStatistics(
prompt_tokens=input_token_count,
completion_tokens=len(tokens),
total_tokens=input_token_count + len(tokens),
),
) )
if self.config.create_distributed_process_group: if self.config.create_distributed_process_group:

View file

@ -12,7 +12,6 @@ from llama_models.datatypes import (
TopKSamplingStrategy, TopKSamplingStrategy,
TopPSamplingStrategy, TopPSamplingStrategy,
) )
from llama_models.llama3.api.chat_format import ChatFormat from llama_models.llama3.api.chat_format import ChatFormat
from llama_models.llama3.api.datatypes import StopReason from llama_models.llama3.api.datatypes import StopReason
from pydantic import BaseModel from pydantic import BaseModel
@ -24,7 +23,6 @@ from llama_stack.apis.common.content_types import (
ToolCallDelta, ToolCallDelta,
ToolCallParseStatus, ToolCallParseStatus,
) )
from llama_stack.apis.inference import ( from llama_stack.apis.inference import (
ChatCompletionResponse, ChatCompletionResponse,
ChatCompletionResponseEvent, ChatCompletionResponseEvent,
@ -35,8 +33,8 @@ from llama_stack.apis.inference import (
CompletionResponseStreamChunk, CompletionResponseStreamChunk,
Message, Message,
TokenLogProbs, TokenLogProbs,
UsageStatistics,
) )
from llama_stack.providers.utils.inference.prompt_adapter import ( from llama_stack.providers.utils.inference.prompt_adapter import (
convert_image_content_to_url, convert_image_content_to_url,
) )
@ -63,8 +61,15 @@ class OpenAICompatCompletionChoice(BaseModel):
logprobs: Optional[OpenAICompatLogprobs] = None logprobs: Optional[OpenAICompatLogprobs] = None
class OpenAICompatCompletionUsage(BaseModel):
prompt_tokens: int
completion_tokens: int
total_tokens: int
class OpenAICompatCompletionResponse(BaseModel): class OpenAICompatCompletionResponse(BaseModel):
choices: List[OpenAICompatCompletionChoice] choices: List[OpenAICompatCompletionChoice]
usage: Optional[OpenAICompatCompletionUsage] = None
def get_sampling_strategy_options(params: SamplingParams) -> dict: def get_sampling_strategy_options(params: SamplingParams) -> dict:
@ -124,16 +129,31 @@ def convert_openai_completion_logprobs(
return [TokenLogProbs(logprobs_by_token=x) for x in logprobs.top_logprobs] return [TokenLogProbs(logprobs_by_token=x) for x in logprobs.top_logprobs]
def get_usage_statistics(
response: OpenAICompatCompletionResponse,
) -> Optional[UsageStatistics]:
if response.usage:
return UsageStatistics(
prompt_tokens=response.usage.prompt_tokens,
completion_tokens=response.usage.completion_tokens,
total_tokens=response.usage.total_tokens,
)
return None
def process_completion_response( def process_completion_response(
response: OpenAICompatCompletionResponse, formatter: ChatFormat response: OpenAICompatCompletionResponse, formatter: ChatFormat
) -> CompletionResponse: ) -> CompletionResponse:
choice = response.choices[0] choice = response.choices[0]
usage_statistics = get_usage_statistics(response)
# drop suffix <eot_id> if present and return stop reason as end of turn # drop suffix <eot_id> if present and return stop reason as end of turn
if choice.text.endswith("<|eot_id|>"): if choice.text.endswith("<|eot_id|>"):
return CompletionResponse( return CompletionResponse(
stop_reason=StopReason.end_of_turn, stop_reason=StopReason.end_of_turn,
content=choice.text[: -len("<|eot_id|>")], content=choice.text[: -len("<|eot_id|>")],
logprobs=convert_openai_completion_logprobs(choice.logprobs), logprobs=convert_openai_completion_logprobs(choice.logprobs),
usage=usage_statistics,
) )
# drop suffix <eom_id> if present and return stop reason as end of message # drop suffix <eom_id> if present and return stop reason as end of message
if choice.text.endswith("<|eom_id|>"): if choice.text.endswith("<|eom_id|>"):
@ -141,11 +161,13 @@ def process_completion_response(
stop_reason=StopReason.end_of_message, stop_reason=StopReason.end_of_message,
content=choice.text[: -len("<|eom_id|>")], content=choice.text[: -len("<|eom_id|>")],
logprobs=convert_openai_completion_logprobs(choice.logprobs), logprobs=convert_openai_completion_logprobs(choice.logprobs),
usage=usage_statistics,
) )
return CompletionResponse( return CompletionResponse(
stop_reason=get_stop_reason(choice.finish_reason), stop_reason=get_stop_reason(choice.finish_reason),
content=choice.text, content=choice.text,
logprobs=convert_openai_completion_logprobs(choice.logprobs), logprobs=convert_openai_completion_logprobs(choice.logprobs),
usage=usage_statistics,
) )
@ -164,6 +186,7 @@ def process_chat_completion_response(
tool_calls=raw_message.tool_calls, tool_calls=raw_message.tool_calls,
), ),
logprobs=None, logprobs=None,
usage=get_usage_statistics(response),
) )
@ -171,10 +194,13 @@ async def process_completion_stream_response(
stream: AsyncGenerator[OpenAICompatCompletionResponse, None], formatter: ChatFormat stream: AsyncGenerator[OpenAICompatCompletionResponse, None], formatter: ChatFormat
) -> AsyncGenerator: ) -> AsyncGenerator:
stop_reason = None stop_reason = None
usage_statistics = None
async for chunk in stream: async for chunk in stream:
choice = chunk.choices[0] choice = chunk.choices[0]
finish_reason = choice.finish_reason finish_reason = choice.finish_reason
# usage statistics are only available in the final chunk
usage_statistics = get_usage_statistics(chunk)
text = text_from_choice(choice) text = text_from_choice(choice)
if text == "<|eot_id|>": if text == "<|eot_id|>":
@ -200,6 +226,7 @@ async def process_completion_stream_response(
yield CompletionResponseStreamChunk( yield CompletionResponseStreamChunk(
delta="", delta="",
stop_reason=stop_reason, stop_reason=stop_reason,
usage=usage_statistics,
) )
@ -216,10 +243,11 @@ async def process_chat_completion_stream_response(
buffer = "" buffer = ""
ipython = False ipython = False
stop_reason = None stop_reason = None
usage_statistics = None
async for chunk in stream: async for chunk in stream:
choice = chunk.choices[0] choice = chunk.choices[0]
finish_reason = choice.finish_reason finish_reason = choice.finish_reason
usage_statistics = get_usage_statistics(chunk)
if finish_reason: if finish_reason:
if stop_reason is None and finish_reason in ["stop", "eos", "eos_token"]: if stop_reason is None and finish_reason in ["stop", "eos", "eos_token"]:
@ -313,7 +341,8 @@ async def process_chat_completion_stream_response(
event_type=ChatCompletionResponseEventType.complete, event_type=ChatCompletionResponseEventType.complete,
delta=TextDelta(text=""), delta=TextDelta(text=""),
stop_reason=stop_reason, stop_reason=stop_reason,
) ),
usage=usage_statistics,
) )