mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-08-07 02:58:21 +00:00
add usage statistics for inference API
This commit is contained in:
parent
9f709387e2
commit
6609362d26
4 changed files with 75 additions and 8 deletions
|
@ -186,6 +186,13 @@ ResponseFormat = register_schema(
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@json_schema_type
|
||||||
|
class UsageStatistics(BaseModel):
|
||||||
|
prompt_tokens: int
|
||||||
|
completion_tokens: int
|
||||||
|
total_tokens: int
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
class CompletionRequest(BaseModel):
|
class CompletionRequest(BaseModel):
|
||||||
model: str
|
model: str
|
||||||
|
@ -204,6 +211,7 @@ class CompletionResponse(BaseModel):
|
||||||
content: str
|
content: str
|
||||||
stop_reason: StopReason
|
stop_reason: StopReason
|
||||||
logprobs: Optional[List[TokenLogProbs]] = None
|
logprobs: Optional[List[TokenLogProbs]] = None
|
||||||
|
usage: Optional[UsageStatistics] = None
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
|
@ -213,6 +221,7 @@ class CompletionResponseStreamChunk(BaseModel):
|
||||||
delta: str
|
delta: str
|
||||||
stop_reason: Optional[StopReason] = None
|
stop_reason: Optional[StopReason] = None
|
||||||
logprobs: Optional[List[TokenLogProbs]] = None
|
logprobs: Optional[List[TokenLogProbs]] = None
|
||||||
|
usage: Optional[UsageStatistics] = None
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
|
@ -252,6 +261,7 @@ class ChatCompletionResponseStreamChunk(BaseModel):
|
||||||
"""SSE-stream of these events."""
|
"""SSE-stream of these events."""
|
||||||
|
|
||||||
event: ChatCompletionResponseEvent
|
event: ChatCompletionResponseEvent
|
||||||
|
usage: Optional[UsageStatistics] = None
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
|
@ -260,6 +270,7 @@ class ChatCompletionResponse(BaseModel):
|
||||||
|
|
||||||
completion_message: CompletionMessage
|
completion_message: CompletionMessage
|
||||||
logprobs: Optional[List[TokenLogProbs]] = None
|
logprobs: Optional[List[TokenLogProbs]] = None
|
||||||
|
usage: Optional[UsageStatistics] = None
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
|
|
|
@ -37,7 +37,6 @@ from llama_models.llama3.reference_impl.multimodal.model import (
|
||||||
CrossAttentionTransformer,
|
CrossAttentionTransformer,
|
||||||
)
|
)
|
||||||
from llama_models.sku_list import resolve_model
|
from llama_models.sku_list import resolve_model
|
||||||
|
|
||||||
from lmformatenforcer import JsonSchemaParser, TokenEnforcer, TokenEnforcerTokenizerData
|
from lmformatenforcer import JsonSchemaParser, TokenEnforcer, TokenEnforcerTokenizerData
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
@ -47,7 +46,6 @@ from llama_stack.apis.inference import (
|
||||||
ResponseFormat,
|
ResponseFormat,
|
||||||
ResponseFormatType,
|
ResponseFormatType,
|
||||||
)
|
)
|
||||||
|
|
||||||
from llama_stack.distribution.utils.model_utils import model_local_dir
|
from llama_stack.distribution.utils.model_utils import model_local_dir
|
||||||
from llama_stack.providers.utils.inference.prompt_adapter import (
|
from llama_stack.providers.utils.inference.prompt_adapter import (
|
||||||
ChatCompletionRequestWithRawContent,
|
ChatCompletionRequestWithRawContent,
|
||||||
|
@ -78,6 +76,7 @@ class TokenResult(BaseModel):
|
||||||
token: int
|
token: int
|
||||||
text: str
|
text: str
|
||||||
logprobs: Optional[List[float]] = None
|
logprobs: Optional[List[float]] = None
|
||||||
|
input_token_count: Optional[int] = None
|
||||||
|
|
||||||
|
|
||||||
class Llama:
|
class Llama:
|
||||||
|
@ -348,6 +347,7 @@ class Llama:
|
||||||
if logprobs
|
if logprobs
|
||||||
else None
|
else None
|
||||||
),
|
),
|
||||||
|
input_token_count=len(model_input.tokens),
|
||||||
)
|
)
|
||||||
|
|
||||||
prev_pos = cur_pos
|
prev_pos = cur_pos
|
||||||
|
|
|
@ -38,6 +38,7 @@ from llama_stack.apis.inference import (
|
||||||
ResponseFormat,
|
ResponseFormat,
|
||||||
TokenLogProbs,
|
TokenLogProbs,
|
||||||
ToolChoice,
|
ToolChoice,
|
||||||
|
UsageStatistics,
|
||||||
)
|
)
|
||||||
from llama_stack.apis.models import Model, ModelType
|
from llama_stack.apis.models import Model, ModelType
|
||||||
from llama_stack.providers.datatypes import ModelsProtocolPrivate
|
from llama_stack.providers.datatypes import ModelsProtocolPrivate
|
||||||
|
@ -168,8 +169,14 @@ class MetaReferenceInferenceImpl(
|
||||||
async def _stream_completion(self, request: CompletionRequest) -> AsyncGenerator:
|
async def _stream_completion(self, request: CompletionRequest) -> AsyncGenerator:
|
||||||
def impl():
|
def impl():
|
||||||
stop_reason = None
|
stop_reason = None
|
||||||
|
input_token_count = 0
|
||||||
|
output_token_count = 0
|
||||||
|
usage_statistics = None
|
||||||
|
|
||||||
for token_result in self.generator.completion(request):
|
for token_result in self.generator.completion(request):
|
||||||
|
if input_token_count == 0:
|
||||||
|
input_token_count = token_result.input_token_count
|
||||||
|
output_token_count += len(token_result.token)
|
||||||
if token_result.text == "<|eot_id|>":
|
if token_result.text == "<|eot_id|>":
|
||||||
stop_reason = StopReason.end_of_turn
|
stop_reason = StopReason.end_of_turn
|
||||||
text = ""
|
text = ""
|
||||||
|
@ -191,17 +198,29 @@ class MetaReferenceInferenceImpl(
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
]
|
]
|
||||||
|
else:
|
||||||
|
usage_statistics = UsageStatistics(
|
||||||
|
prompt_tokens=input_token_count,
|
||||||
|
completion_tokens=output_token_count,
|
||||||
|
total_tokens=input_token_count + output_token_count,
|
||||||
|
)
|
||||||
|
|
||||||
yield CompletionResponseStreamChunk(
|
yield CompletionResponseStreamChunk(
|
||||||
delta=text,
|
delta=text,
|
||||||
stop_reason=stop_reason,
|
stop_reason=stop_reason,
|
||||||
logprobs=logprobs if request.logprobs else None,
|
logprobs=logprobs if request.logprobs else None,
|
||||||
|
usage=usage_statistics,
|
||||||
)
|
)
|
||||||
|
|
||||||
if stop_reason is None:
|
if stop_reason is None:
|
||||||
yield CompletionResponseStreamChunk(
|
yield CompletionResponseStreamChunk(
|
||||||
delta="",
|
delta="",
|
||||||
stop_reason=StopReason.out_of_tokens,
|
stop_reason=StopReason.out_of_tokens,
|
||||||
|
usage=UsageStatistics(
|
||||||
|
prompt_tokens=input_token_count,
|
||||||
|
completion_tokens=output_token_count,
|
||||||
|
total_tokens=input_token_count + output_token_count,
|
||||||
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
if self.config.create_distributed_process_group:
|
if self.config.create_distributed_process_group:
|
||||||
|
@ -221,7 +240,10 @@ class MetaReferenceInferenceImpl(
|
||||||
stop_reason = None
|
stop_reason = None
|
||||||
|
|
||||||
tokenizer = self.generator.formatter.tokenizer
|
tokenizer = self.generator.formatter.tokenizer
|
||||||
|
input_token_count = 0
|
||||||
for token_result in self.generator.completion(request):
|
for token_result in self.generator.completion(request):
|
||||||
|
if input_token_count == 0:
|
||||||
|
input_token_count = token_result.input_token_count
|
||||||
tokens.append(token_result.token)
|
tokens.append(token_result.token)
|
||||||
if token_result.text == "<|eot_id|>":
|
if token_result.text == "<|eot_id|>":
|
||||||
stop_reason = StopReason.end_of_turn
|
stop_reason = StopReason.end_of_turn
|
||||||
|
@ -242,7 +264,7 @@ class MetaReferenceInferenceImpl(
|
||||||
if stop_reason is None:
|
if stop_reason is None:
|
||||||
stop_reason = StopReason.out_of_tokens
|
stop_reason = StopReason.out_of_tokens
|
||||||
|
|
||||||
content = self.generator.formatter.tokenizer.decode(tokens)
|
content = tokenizer.decode(tokens)
|
||||||
if content.endswith("<|eot_id|>"):
|
if content.endswith("<|eot_id|>"):
|
||||||
content = content[: -len("<|eot_id|>")]
|
content = content[: -len("<|eot_id|>")]
|
||||||
elif content.endswith("<|eom_id|>"):
|
elif content.endswith("<|eom_id|>"):
|
||||||
|
@ -251,6 +273,11 @@ class MetaReferenceInferenceImpl(
|
||||||
content=content,
|
content=content,
|
||||||
stop_reason=stop_reason,
|
stop_reason=stop_reason,
|
||||||
logprobs=logprobs if request.logprobs else None,
|
logprobs=logprobs if request.logprobs else None,
|
||||||
|
usage_statistics=UsageStatistics(
|
||||||
|
prompt_tokens=input_token_count,
|
||||||
|
completion_tokens=len(tokens),
|
||||||
|
total_tokens=input_token_count + len(tokens),
|
||||||
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
if self.config.create_distributed_process_group:
|
if self.config.create_distributed_process_group:
|
||||||
|
|
|
@ -12,7 +12,6 @@ from llama_models.datatypes import (
|
||||||
TopKSamplingStrategy,
|
TopKSamplingStrategy,
|
||||||
TopPSamplingStrategy,
|
TopPSamplingStrategy,
|
||||||
)
|
)
|
||||||
|
|
||||||
from llama_models.llama3.api.chat_format import ChatFormat
|
from llama_models.llama3.api.chat_format import ChatFormat
|
||||||
from llama_models.llama3.api.datatypes import StopReason
|
from llama_models.llama3.api.datatypes import StopReason
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
@ -24,7 +23,6 @@ from llama_stack.apis.common.content_types import (
|
||||||
ToolCallDelta,
|
ToolCallDelta,
|
||||||
ToolCallParseStatus,
|
ToolCallParseStatus,
|
||||||
)
|
)
|
||||||
|
|
||||||
from llama_stack.apis.inference import (
|
from llama_stack.apis.inference import (
|
||||||
ChatCompletionResponse,
|
ChatCompletionResponse,
|
||||||
ChatCompletionResponseEvent,
|
ChatCompletionResponseEvent,
|
||||||
|
@ -35,8 +33,8 @@ from llama_stack.apis.inference import (
|
||||||
CompletionResponseStreamChunk,
|
CompletionResponseStreamChunk,
|
||||||
Message,
|
Message,
|
||||||
TokenLogProbs,
|
TokenLogProbs,
|
||||||
|
UsageStatistics,
|
||||||
)
|
)
|
||||||
|
|
||||||
from llama_stack.providers.utils.inference.prompt_adapter import (
|
from llama_stack.providers.utils.inference.prompt_adapter import (
|
||||||
convert_image_content_to_url,
|
convert_image_content_to_url,
|
||||||
)
|
)
|
||||||
|
@ -63,8 +61,15 @@ class OpenAICompatCompletionChoice(BaseModel):
|
||||||
logprobs: Optional[OpenAICompatLogprobs] = None
|
logprobs: Optional[OpenAICompatLogprobs] = None
|
||||||
|
|
||||||
|
|
||||||
|
class OpenAICompatCompletionUsage(BaseModel):
|
||||||
|
prompt_tokens: int
|
||||||
|
completion_tokens: int
|
||||||
|
total_tokens: int
|
||||||
|
|
||||||
|
|
||||||
class OpenAICompatCompletionResponse(BaseModel):
|
class OpenAICompatCompletionResponse(BaseModel):
|
||||||
choices: List[OpenAICompatCompletionChoice]
|
choices: List[OpenAICompatCompletionChoice]
|
||||||
|
usage: Optional[OpenAICompatCompletionUsage] = None
|
||||||
|
|
||||||
|
|
||||||
def get_sampling_strategy_options(params: SamplingParams) -> dict:
|
def get_sampling_strategy_options(params: SamplingParams) -> dict:
|
||||||
|
@ -124,16 +129,31 @@ def convert_openai_completion_logprobs(
|
||||||
return [TokenLogProbs(logprobs_by_token=x) for x in logprobs.top_logprobs]
|
return [TokenLogProbs(logprobs_by_token=x) for x in logprobs.top_logprobs]
|
||||||
|
|
||||||
|
|
||||||
|
def get_usage_statistics(
|
||||||
|
response: OpenAICompatCompletionResponse,
|
||||||
|
) -> Optional[UsageStatistics]:
|
||||||
|
if response.usage:
|
||||||
|
return UsageStatistics(
|
||||||
|
prompt_tokens=response.usage.prompt_tokens,
|
||||||
|
completion_tokens=response.usage.completion_tokens,
|
||||||
|
total_tokens=response.usage.total_tokens,
|
||||||
|
)
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
def process_completion_response(
|
def process_completion_response(
|
||||||
response: OpenAICompatCompletionResponse, formatter: ChatFormat
|
response: OpenAICompatCompletionResponse, formatter: ChatFormat
|
||||||
) -> CompletionResponse:
|
) -> CompletionResponse:
|
||||||
choice = response.choices[0]
|
choice = response.choices[0]
|
||||||
|
usage_statistics = get_usage_statistics(response)
|
||||||
|
|
||||||
# drop suffix <eot_id> if present and return stop reason as end of turn
|
# drop suffix <eot_id> if present and return stop reason as end of turn
|
||||||
if choice.text.endswith("<|eot_id|>"):
|
if choice.text.endswith("<|eot_id|>"):
|
||||||
return CompletionResponse(
|
return CompletionResponse(
|
||||||
stop_reason=StopReason.end_of_turn,
|
stop_reason=StopReason.end_of_turn,
|
||||||
content=choice.text[: -len("<|eot_id|>")],
|
content=choice.text[: -len("<|eot_id|>")],
|
||||||
logprobs=convert_openai_completion_logprobs(choice.logprobs),
|
logprobs=convert_openai_completion_logprobs(choice.logprobs),
|
||||||
|
usage=usage_statistics,
|
||||||
)
|
)
|
||||||
# drop suffix <eom_id> if present and return stop reason as end of message
|
# drop suffix <eom_id> if present and return stop reason as end of message
|
||||||
if choice.text.endswith("<|eom_id|>"):
|
if choice.text.endswith("<|eom_id|>"):
|
||||||
|
@ -141,11 +161,13 @@ def process_completion_response(
|
||||||
stop_reason=StopReason.end_of_message,
|
stop_reason=StopReason.end_of_message,
|
||||||
content=choice.text[: -len("<|eom_id|>")],
|
content=choice.text[: -len("<|eom_id|>")],
|
||||||
logprobs=convert_openai_completion_logprobs(choice.logprobs),
|
logprobs=convert_openai_completion_logprobs(choice.logprobs),
|
||||||
|
usage=usage_statistics,
|
||||||
)
|
)
|
||||||
return CompletionResponse(
|
return CompletionResponse(
|
||||||
stop_reason=get_stop_reason(choice.finish_reason),
|
stop_reason=get_stop_reason(choice.finish_reason),
|
||||||
content=choice.text,
|
content=choice.text,
|
||||||
logprobs=convert_openai_completion_logprobs(choice.logprobs),
|
logprobs=convert_openai_completion_logprobs(choice.logprobs),
|
||||||
|
usage=usage_statistics,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@ -164,6 +186,7 @@ def process_chat_completion_response(
|
||||||
tool_calls=raw_message.tool_calls,
|
tool_calls=raw_message.tool_calls,
|
||||||
),
|
),
|
||||||
logprobs=None,
|
logprobs=None,
|
||||||
|
usage=get_usage_statistics(response),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@ -171,10 +194,13 @@ async def process_completion_stream_response(
|
||||||
stream: AsyncGenerator[OpenAICompatCompletionResponse, None], formatter: ChatFormat
|
stream: AsyncGenerator[OpenAICompatCompletionResponse, None], formatter: ChatFormat
|
||||||
) -> AsyncGenerator:
|
) -> AsyncGenerator:
|
||||||
stop_reason = None
|
stop_reason = None
|
||||||
|
usage_statistics = None
|
||||||
|
|
||||||
async for chunk in stream:
|
async for chunk in stream:
|
||||||
choice = chunk.choices[0]
|
choice = chunk.choices[0]
|
||||||
finish_reason = choice.finish_reason
|
finish_reason = choice.finish_reason
|
||||||
|
# usage statistics are only available in the final chunk
|
||||||
|
usage_statistics = get_usage_statistics(chunk)
|
||||||
|
|
||||||
text = text_from_choice(choice)
|
text = text_from_choice(choice)
|
||||||
if text == "<|eot_id|>":
|
if text == "<|eot_id|>":
|
||||||
|
@ -200,6 +226,7 @@ async def process_completion_stream_response(
|
||||||
yield CompletionResponseStreamChunk(
|
yield CompletionResponseStreamChunk(
|
||||||
delta="",
|
delta="",
|
||||||
stop_reason=stop_reason,
|
stop_reason=stop_reason,
|
||||||
|
usage=usage_statistics,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@ -216,10 +243,11 @@ async def process_chat_completion_stream_response(
|
||||||
buffer = ""
|
buffer = ""
|
||||||
ipython = False
|
ipython = False
|
||||||
stop_reason = None
|
stop_reason = None
|
||||||
|
usage_statistics = None
|
||||||
async for chunk in stream:
|
async for chunk in stream:
|
||||||
choice = chunk.choices[0]
|
choice = chunk.choices[0]
|
||||||
finish_reason = choice.finish_reason
|
finish_reason = choice.finish_reason
|
||||||
|
usage_statistics = get_usage_statistics(chunk)
|
||||||
|
|
||||||
if finish_reason:
|
if finish_reason:
|
||||||
if stop_reason is None and finish_reason in ["stop", "eos", "eos_token"]:
|
if stop_reason is None and finish_reason in ["stop", "eos", "eos_token"]:
|
||||||
|
@ -313,7 +341,8 @@ async def process_chat_completion_stream_response(
|
||||||
event_type=ChatCompletionResponseEventType.complete,
|
event_type=ChatCompletionResponseEventType.complete,
|
||||||
delta=TextDelta(text=""),
|
delta=TextDelta(text=""),
|
||||||
stop_reason=stop_reason,
|
stop_reason=stop_reason,
|
||||||
)
|
),
|
||||||
|
usage=usage_statistics,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue