mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-08-12 04:50:39 +00:00
Add inference token usage metrics
This commit is contained in:
parent
0762c61402
commit
a72cdafac0
3 changed files with 57 additions and 7 deletions
|
@ -17,12 +17,13 @@ from typing import (
|
||||||
runtime_checkable,
|
runtime_checkable,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
from llama_models.schema_utils import json_schema_type, register_schema, webmethod
|
||||||
from pydantic import BaseModel, Field, field_validator
|
from pydantic import BaseModel, Field, field_validator
|
||||||
from typing_extensions import Annotated
|
from typing_extensions import Annotated
|
||||||
|
|
||||||
from llama_stack.apis.common.content_types import ContentDelta, InterleavedContent, InterleavedContentItem
|
from llama_stack.apis.common.content_types import ContentDelta, InterleavedContent, InterleavedContentItem
|
||||||
from llama_stack.apis.models import Model
|
from llama_stack.apis.models import Model
|
||||||
from llama_stack.apis.telemetry.telemetry import MetricResponseMixin
|
from llama_stack.apis.telemetry.telemetry import MetricsMixin
|
||||||
from llama_stack.models.llama.datatypes import (
|
from llama_stack.models.llama.datatypes import (
|
||||||
BuiltinTool,
|
BuiltinTool,
|
||||||
SamplingParams,
|
SamplingParams,
|
||||||
|
@ -285,7 +286,7 @@ class CompletionRequest(BaseModel):
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
class CompletionResponse(BaseModel):
|
class CompletionResponse(MetricsMixin, BaseModel):
|
||||||
"""Response from a completion request.
|
"""Response from a completion request.
|
||||||
|
|
||||||
:param content: The generated completion text
|
:param content: The generated completion text
|
||||||
|
@ -299,7 +300,7 @@ class CompletionResponse(BaseModel):
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
class CompletionResponseStreamChunk(BaseModel):
|
class CompletionResponseStreamChunk(MetricsMixin, BaseModel):
|
||||||
"""A chunk of a streamed completion response.
|
"""A chunk of a streamed completion response.
|
||||||
|
|
||||||
:param delta: New content generated since last chunk. This can be one or more tokens.
|
:param delta: New content generated since last chunk. This can be one or more tokens.
|
||||||
|
@ -368,7 +369,7 @@ class ChatCompletionRequest(BaseModel):
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
class ChatCompletionResponseStreamChunk(MetricResponseMixin, BaseModel):
|
class ChatCompletionResponseStreamChunk(MetricsMixin, BaseModel):
|
||||||
"""A chunk of a streamed chat completion response.
|
"""A chunk of a streamed chat completion response.
|
||||||
|
|
||||||
:param event: The event containing the new content
|
:param event: The event containing the new content
|
||||||
|
@ -378,7 +379,7 @@ class ChatCompletionResponseStreamChunk(MetricResponseMixin, BaseModel):
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
class ChatCompletionResponse(MetricResponseMixin, BaseModel):
|
class ChatCompletionResponse(MetricsMixin, BaseModel):
|
||||||
"""Response from a chat completion request.
|
"""Response from a chat completion request.
|
||||||
|
|
||||||
:param completion_message: The complete response message
|
:param completion_message: The complete response message
|
||||||
|
@ -390,7 +391,7 @@ class ChatCompletionResponse(MetricResponseMixin, BaseModel):
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
class EmbeddingsResponse(BaseModel):
|
class EmbeddingsResponse(MetricsMixin, BaseModel):
|
||||||
"""Response containing generated embeddings.
|
"""Response containing generated embeddings.
|
||||||
|
|
||||||
:param embeddings: List of embedding vectors, one per input content. Each embedding is a list of floats. The dimensionality of the embedding is model-specific; you can check model metadata using /models/{model_id}
|
:param embeddings: List of embedding vectors, one per input content. Each embedding is a list of floats. The dimensionality of the embedding is model-specific; you can check model metadata using /models/{model_id}
|
||||||
|
|
|
@ -211,6 +211,28 @@ class QuerySpanTreeResponse(BaseModel):
|
||||||
data: Dict[str, SpanWithStatus]
|
data: Dict[str, SpanWithStatus]
|
||||||
|
|
||||||
|
|
||||||
|
@json_schema_type
|
||||||
|
class TokenUsage(BaseModel):
|
||||||
|
type: Literal["token_usage"] = "token_usage"
|
||||||
|
prompt_tokens: int
|
||||||
|
completion_tokens: int
|
||||||
|
total_tokens: int
|
||||||
|
|
||||||
|
|
||||||
|
Metric = register_schema(
|
||||||
|
Annotated[
|
||||||
|
Union[TokenUsage],
|
||||||
|
Field(discriminator="type"),
|
||||||
|
],
|
||||||
|
name="Metric",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@json_schema_type
|
||||||
|
class MetricsMixin(BaseModel):
|
||||||
|
metrics: List[Metric] = Field(default_factory=list)
|
||||||
|
|
||||||
|
|
||||||
@runtime_checkable
|
@runtime_checkable
|
||||||
class Telemetry(Protocol):
|
class Telemetry(Protocol):
|
||||||
@webmethod(route="/telemetry/events", method="POST")
|
@webmethod(route="/telemetry/events", method="POST")
|
||||||
|
|
|
@ -6,6 +6,9 @@
|
||||||
|
|
||||||
from typing import Any, AsyncGenerator, Dict, List, Optional
|
from typing import Any, AsyncGenerator, Dict, List, Optional
|
||||||
|
|
||||||
|
from llama_models.llama3.api.chat_format import ChatFormat
|
||||||
|
from llama_models.llama3.api.tokenizer import Tokenizer
|
||||||
|
|
||||||
from llama_stack.apis.common.content_types import (
|
from llama_stack.apis.common.content_types import (
|
||||||
URL,
|
URL,
|
||||||
InterleavedContent,
|
InterleavedContent,
|
||||||
|
@ -42,6 +45,7 @@ from llama_stack.apis.scoring import (
|
||||||
ScoringFnParams,
|
ScoringFnParams,
|
||||||
)
|
)
|
||||||
from llama_stack.apis.shields import Shield
|
from llama_stack.apis.shields import Shield
|
||||||
|
from llama_stack.apis.telemetry import TokenUsage
|
||||||
from llama_stack.apis.tools import (
|
from llama_stack.apis.tools import (
|
||||||
RAGDocument,
|
RAGDocument,
|
||||||
RAGQueryConfig,
|
RAGQueryConfig,
|
||||||
|
@ -111,6 +115,8 @@ class InferenceRouter(Inference):
|
||||||
routing_table: RoutingTable,
|
routing_table: RoutingTable,
|
||||||
) -> None:
|
) -> None:
|
||||||
self.routing_table = routing_table
|
self.routing_table = routing_table
|
||||||
|
self.tokenizer = Tokenizer.get_instance()
|
||||||
|
self.formatter = ChatFormat(self.tokenizer)
|
||||||
|
|
||||||
async def initialize(self) -> None:
|
async def initialize(self) -> None:
|
||||||
pass
|
pass
|
||||||
|
@ -190,7 +196,28 @@ class InferenceRouter(Inference):
|
||||||
if stream:
|
if stream:
|
||||||
return (chunk async for chunk in await provider.chat_completion(**params))
|
return (chunk async for chunk in await provider.chat_completion(**params))
|
||||||
else:
|
else:
|
||||||
return await provider.chat_completion(**params)
|
response = await provider.chat_completion(**params)
|
||||||
|
model_input = self.formatter.encode_dialog_prompt(
|
||||||
|
messages,
|
||||||
|
tool_config.tool_prompt_format,
|
||||||
|
)
|
||||||
|
model_output = self.formatter.encode_dialog_prompt(
|
||||||
|
[response.completion_message],
|
||||||
|
tool_config.tool_prompt_format,
|
||||||
|
)
|
||||||
|
prompt_tokens = len(model_input.tokens) if model_input.tokens else 0
|
||||||
|
completion_tokens = len(model_output.tokens) if model_output.tokens else 0
|
||||||
|
total_tokens = prompt_tokens + completion_tokens
|
||||||
|
if response.metrics is None:
|
||||||
|
response.metrics = []
|
||||||
|
response.metrics.append(
|
||||||
|
TokenUsage(
|
||||||
|
prompt_tokens=prompt_tokens,
|
||||||
|
completion_tokens=completion_tokens,
|
||||||
|
total_tokens=total_tokens,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
return response
|
||||||
|
|
||||||
async def completion(
|
async def completion(
|
||||||
self,
|
self,
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue