add metrics for streaming

This commit is contained in:
Dinesh Yeduguru 2025-02-04 16:20:59 -08:00
parent 38f1337afa
commit 37b7390079
2 changed files with 65 additions and 30 deletions

View file

@ -230,7 +230,7 @@ Metric = register_schema(
@json_schema_type @json_schema_type
class MetricsMixin(BaseModel): class MetricsMixin(BaseModel):
metrics: List[Metric] = Field(default_factory=list) metrics: Optional[List[Metric]] = None
@runtime_checkable @runtime_checkable

View file

@ -5,9 +5,10 @@
# the root directory of this source tree. # the root directory of this source tree.
import time import time
from typing import Any, AsyncGenerator, Dict, List, Optional from typing import Any, AsyncGenerator, AsyncIterator, Dict, List, Optional, Union
from llama_models.llama3.api.chat_format import ChatFormat from llama_models.llama3.api.chat_format import ChatFormat
from llama_models.llama3.api.datatypes import RawMessage
from llama_models.llama3.api.tokenizer import Tokenizer from llama_models.llama3.api.tokenizer import Tokenizer
from llama_stack.apis.common.content_types import ( from llama_stack.apis.common.content_types import (
@ -24,6 +25,9 @@ from llama_stack.apis.eval import (
JobStatus, JobStatus,
) )
from llama_stack.apis.inference import ( from llama_stack.apis.inference import (
ChatCompletionResponse,
ChatCompletionResponseEventType,
ChatCompletionResponseStreamChunk,
EmbeddingsResponse, EmbeddingsResponse,
EmbeddingTaskType, EmbeddingTaskType,
Inference, Inference,
@ -37,7 +41,7 @@ from llama_stack.apis.inference import (
ToolDefinition, ToolDefinition,
ToolPromptFormat, ToolPromptFormat,
) )
from llama_stack.apis.models import ModelType from llama_stack.apis.models import Model, ModelType
from llama_stack.apis.safety import RunShieldResponse, Safety from llama_stack.apis.safety import RunShieldResponse, Safety
from llama_stack.apis.scoring import ( from llama_stack.apis.scoring import (
ScoreBatchResponse, ScoreBatchResponse,
@ -138,6 +142,31 @@ class InferenceRouter(Inference):
) -> None: ) -> None:
await self.routing_table.register_model(model_id, provider_model_id, provider_id, metadata, model_type) await self.routing_table.register_model(model_id, provider_model_id, provider_id, metadata, model_type)
async def _log_token_usage(
self, prompt_tokens: int, completion_tokens: int, total_tokens: int, model: Model
) -> None:
span = get_current_span()
metrics = [
("prompt_tokens", prompt_tokens),
("completion_tokens", completion_tokens),
("total_tokens", total_tokens),
]
for metric_name, value in metrics:
await self.telemetry.log_event(
MetricEvent(
trace_id=span.trace_id,
span_id=span.span_id,
metric=metric_name,
value=value,
timestamp=time.time(),
unit="tokens",
attributes={
"model_id": model.model_id,
"provider_id": model.provider_id,
},
)
)
async def chat_completion( async def chat_completion(
self, self,
model_id: str, model_id: str,
@ -150,7 +179,7 @@ class InferenceRouter(Inference):
stream: Optional[bool] = False, stream: Optional[bool] = False,
logprobs: Optional[LogProbConfig] = None, logprobs: Optional[LogProbConfig] = None,
tool_config: Optional[ToolConfig] = None, tool_config: Optional[ToolConfig] = None,
) -> AsyncGenerator: ) -> Union[ChatCompletionResponse, AsyncIterator[ChatCompletionResponseStreamChunk]]:
model = await self.routing_table.get_model(model_id) model = await self.routing_table.get_model(model_id)
if model is None: if model is None:
raise ValueError(f"Model '{model_id}' not found") raise ValueError(f"Model '{model_id}' not found")
@ -198,7 +227,37 @@ class InferenceRouter(Inference):
) )
provider = self.routing_table.get_provider_impl(model_id) provider = self.routing_table.get_provider_impl(model_id)
if stream: if stream:
return (chunk async for chunk in await provider.chat_completion(**params))
async def stream_generator():
model_input = self.formatter.encode_dialog_prompt(
messages,
tool_config.tool_prompt_format,
)
prompt_tokens = len(model_input.tokens) if model_input.tokens else 0
completion_text = ""
async for chunk in await provider.chat_completion(**params):
if chunk.event.event_type == ChatCompletionResponseEventType.progress:
completion_text += chunk.event.delta.text
if chunk.event.event_type == ChatCompletionResponseEventType.complete:
model_output = self.formatter.encode_dialog_prompt(
[RawMessage(role="assistant", content=completion_text)],
tool_config.tool_prompt_format,
)
completion_tokens = len(model_output.tokens) if model_output.tokens else 0
total_tokens = prompt_tokens + completion_tokens
if chunk.metrics is None:
chunk.metrics = []
chunk.metrics.append(
TokenUsage(
prompt_tokens=prompt_tokens,
completion_tokens=completion_tokens,
total_tokens=total_tokens,
)
)
await self._log_token_usage(prompt_tokens, completion_tokens, total_tokens, model)
yield chunk
return stream_generator()
else: else:
response = await provider.chat_completion(**params) response = await provider.chat_completion(**params)
model_input = self.formatter.encode_dialog_prompt( model_input = self.formatter.encode_dialog_prompt(
@ -221,31 +280,7 @@ class InferenceRouter(Inference):
total_tokens=total_tokens, total_tokens=total_tokens,
) )
) )
# Log token usage metrics await self._log_token_usage(prompt_tokens, completion_tokens, total_tokens, model)
metrics = [
("prompt_tokens", prompt_tokens),
("completion_tokens", completion_tokens),
("total_tokens", total_tokens),
]
span = get_current_span()
if span:
breakpoint()
for metric_name, value in metrics:
await self.telemetry.log_event(
MetricEvent(
trace_id=span.trace_id,
span_id=span.span_id,
metric=metric_name,
value=value,
timestamp=time.time(),
unit="tokens",
attributes={
"model_id": model_id,
"provider_id": model.provider_id,
},
)
)
return response return response
async def completion( async def completion(