mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-08-12 04:50:39 +00:00
Add inference token usage metrics
This commit is contained in:
parent
0762c61402
commit
a72cdafac0
3 changed files with 57 additions and 7 deletions
|
@ -17,12 +17,13 @@ from typing import (
|
|||
runtime_checkable,
|
||||
)
|
||||
|
||||
from llama_models.schema_utils import json_schema_type, register_schema, webmethod
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
from typing_extensions import Annotated
|
||||
|
||||
from llama_stack.apis.common.content_types import ContentDelta, InterleavedContent, InterleavedContentItem
|
||||
from llama_stack.apis.models import Model
|
||||
from llama_stack.apis.telemetry.telemetry import MetricResponseMixin
|
||||
from llama_stack.apis.telemetry.telemetry import MetricsMixin
|
||||
from llama_stack.models.llama.datatypes import (
|
||||
BuiltinTool,
|
||||
SamplingParams,
|
||||
|
@ -285,7 +286,7 @@ class CompletionRequest(BaseModel):
|
|||
|
||||
|
||||
@json_schema_type
|
||||
class CompletionResponse(BaseModel):
|
||||
class CompletionResponse(MetricsMixin, BaseModel):
|
||||
"""Response from a completion request.
|
||||
|
||||
:param content: The generated completion text
|
||||
|
@ -299,7 +300,7 @@ class CompletionResponse(BaseModel):
|
|||
|
||||
|
||||
@json_schema_type
|
||||
class CompletionResponseStreamChunk(BaseModel):
|
||||
class CompletionResponseStreamChunk(MetricsMixin, BaseModel):
|
||||
"""A chunk of a streamed completion response.
|
||||
|
||||
:param delta: New content generated since last chunk. This can be one or more tokens.
|
||||
|
@ -368,7 +369,7 @@ class ChatCompletionRequest(BaseModel):
|
|||
|
||||
|
||||
@json_schema_type
|
||||
class ChatCompletionResponseStreamChunk(MetricResponseMixin, BaseModel):
|
||||
class ChatCompletionResponseStreamChunk(MetricsMixin, BaseModel):
|
||||
"""A chunk of a streamed chat completion response.
|
||||
|
||||
:param event: The event containing the new content
|
||||
|
@ -378,7 +379,7 @@ class ChatCompletionResponseStreamChunk(MetricResponseMixin, BaseModel):
|
|||
|
||||
|
||||
@json_schema_type
|
||||
class ChatCompletionResponse(MetricResponseMixin, BaseModel):
|
||||
class ChatCompletionResponse(MetricsMixin, BaseModel):
|
||||
"""Response from a chat completion request.
|
||||
|
||||
:param completion_message: The complete response message
|
||||
|
@ -390,7 +391,7 @@ class ChatCompletionResponse(MetricResponseMixin, BaseModel):
|
|||
|
||||
|
||||
@json_schema_type
|
||||
class EmbeddingsResponse(BaseModel):
|
||||
class EmbeddingsResponse(MetricsMixin, BaseModel):
|
||||
"""Response containing generated embeddings.
|
||||
|
||||
:param embeddings: List of embedding vectors, one per input content. Each embedding is a list of floats. The dimensionality of the embedding is model-specific; you can check model metadata using /models/{model_id}
|
||||
|
|
|
@ -211,6 +211,28 @@ class QuerySpanTreeResponse(BaseModel):
|
|||
data: Dict[str, SpanWithStatus]
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class TokenUsage(BaseModel):
|
||||
type: Literal["token_usage"] = "token_usage"
|
||||
prompt_tokens: int
|
||||
completion_tokens: int
|
||||
total_tokens: int
|
||||
|
||||
|
||||
Metric = register_schema(
|
||||
Annotated[
|
||||
Union[TokenUsage],
|
||||
Field(discriminator="type"),
|
||||
],
|
||||
name="Metric",
|
||||
)
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class MetricsMixin(BaseModel):
|
||||
metrics: List[Metric] = Field(default_factory=list)
|
||||
|
||||
|
||||
@runtime_checkable
|
||||
class Telemetry(Protocol):
|
||||
@webmethod(route="/telemetry/events", method="POST")
|
||||
|
|
|
@ -6,6 +6,9 @@
|
|||
|
||||
from typing import Any, AsyncGenerator, Dict, List, Optional
|
||||
|
||||
from llama_models.llama3.api.chat_format import ChatFormat
|
||||
from llama_models.llama3.api.tokenizer import Tokenizer
|
||||
|
||||
from llama_stack.apis.common.content_types import (
|
||||
URL,
|
||||
InterleavedContent,
|
||||
|
@ -42,6 +45,7 @@ from llama_stack.apis.scoring import (
|
|||
ScoringFnParams,
|
||||
)
|
||||
from llama_stack.apis.shields import Shield
|
||||
from llama_stack.apis.telemetry import TokenUsage
|
||||
from llama_stack.apis.tools import (
|
||||
RAGDocument,
|
||||
RAGQueryConfig,
|
||||
|
@ -111,6 +115,8 @@ class InferenceRouter(Inference):
|
|||
routing_table: RoutingTable,
|
||||
) -> None:
|
||||
self.routing_table = routing_table
|
||||
self.tokenizer = Tokenizer.get_instance()
|
||||
self.formatter = ChatFormat(self.tokenizer)
|
||||
|
||||
async def initialize(self) -> None:
|
||||
pass
|
||||
|
@ -190,7 +196,28 @@ class InferenceRouter(Inference):
|
|||
if stream:
|
||||
return (chunk async for chunk in await provider.chat_completion(**params))
|
||||
else:
|
||||
return await provider.chat_completion(**params)
|
||||
response = await provider.chat_completion(**params)
|
||||
model_input = self.formatter.encode_dialog_prompt(
|
||||
messages,
|
||||
tool_config.tool_prompt_format,
|
||||
)
|
||||
model_output = self.formatter.encode_dialog_prompt(
|
||||
[response.completion_message],
|
||||
tool_config.tool_prompt_format,
|
||||
)
|
||||
prompt_tokens = len(model_input.tokens) if model_input.tokens else 0
|
||||
completion_tokens = len(model_output.tokens) if model_output.tokens else 0
|
||||
total_tokens = prompt_tokens + completion_tokens
|
||||
if response.metrics is None:
|
||||
response.metrics = []
|
||||
response.metrics.append(
|
||||
TokenUsage(
|
||||
prompt_tokens=prompt_tokens,
|
||||
completion_tokens=completion_tokens,
|
||||
total_tokens=total_tokens,
|
||||
)
|
||||
)
|
||||
return response
|
||||
|
||||
async def completion(
|
||||
self,
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue