add metrics for streaming

This commit is contained in:
Dinesh Yeduguru 2025-02-04 16:20:59 -08:00
parent 38f1337afa
commit 37b7390079
2 changed files with 65 additions and 30 deletions

View file

@ -230,7 +230,7 @@ Metric = register_schema(
@json_schema_type
class MetricsMixin(BaseModel):
metrics: List[Metric] = Field(default_factory=list)
metrics: Optional[List[Metric]] = None
@runtime_checkable

View file

@ -5,9 +5,10 @@
# the root directory of this source tree.
import time
from typing import Any, AsyncGenerator, Dict, List, Optional
from typing import Any, AsyncGenerator, AsyncIterator, Dict, List, Optional, Union
from llama_models.llama3.api.chat_format import ChatFormat
from llama_models.llama3.api.datatypes import RawMessage
from llama_models.llama3.api.tokenizer import Tokenizer
from llama_stack.apis.common.content_types import (
@ -24,6 +25,9 @@ from llama_stack.apis.eval import (
JobStatus,
)
from llama_stack.apis.inference import (
ChatCompletionResponse,
ChatCompletionResponseEventType,
ChatCompletionResponseStreamChunk,
EmbeddingsResponse,
EmbeddingTaskType,
Inference,
@ -37,7 +41,7 @@ from llama_stack.apis.inference import (
ToolDefinition,
ToolPromptFormat,
)
from llama_stack.apis.models import ModelType
from llama_stack.apis.models import Model, ModelType
from llama_stack.apis.safety import RunShieldResponse, Safety
from llama_stack.apis.scoring import (
ScoreBatchResponse,
@ -138,6 +142,31 @@ class InferenceRouter(Inference):
) -> None:
await self.routing_table.register_model(model_id, provider_model_id, provider_id, metadata, model_type)
async def _log_token_usage(
self, prompt_tokens: int, completion_tokens: int, total_tokens: int, model: Model
) -> None:
span = get_current_span()
metrics = [
("prompt_tokens", prompt_tokens),
("completion_tokens", completion_tokens),
("total_tokens", total_tokens),
]
for metric_name, value in metrics:
await self.telemetry.log_event(
MetricEvent(
trace_id=span.trace_id,
span_id=span.span_id,
metric=metric_name,
value=value,
timestamp=time.time(),
unit="tokens",
attributes={
"model_id": model.model_id,
"provider_id": model.provider_id,
},
)
)
async def chat_completion(
self,
model_id: str,
@ -150,7 +179,7 @@ class InferenceRouter(Inference):
stream: Optional[bool] = False,
logprobs: Optional[LogProbConfig] = None,
tool_config: Optional[ToolConfig] = None,
) -> AsyncGenerator:
) -> Union[ChatCompletionResponse, AsyncIterator[ChatCompletionResponseStreamChunk]]:
model = await self.routing_table.get_model(model_id)
if model is None:
raise ValueError(f"Model '{model_id}' not found")
@ -198,7 +227,37 @@ class InferenceRouter(Inference):
)
provider = self.routing_table.get_provider_impl(model_id)
if stream:
return (chunk async for chunk in await provider.chat_completion(**params))
async def stream_generator():
model_input = self.formatter.encode_dialog_prompt(
messages,
tool_config.tool_prompt_format,
)
prompt_tokens = len(model_input.tokens) if model_input.tokens else 0
completion_text = ""
async for chunk in await provider.chat_completion(**params):
if chunk.event.event_type == ChatCompletionResponseEventType.progress:
completion_text += chunk.event.delta.text
if chunk.event.event_type == ChatCompletionResponseEventType.complete:
model_output = self.formatter.encode_dialog_prompt(
[RawMessage(role="assistant", content=completion_text)],
tool_config.tool_prompt_format,
)
completion_tokens = len(model_output.tokens) if model_output.tokens else 0
total_tokens = prompt_tokens + completion_tokens
if chunk.metrics is None:
chunk.metrics = []
chunk.metrics.append(
TokenUsage(
prompt_tokens=prompt_tokens,
completion_tokens=completion_tokens,
total_tokens=total_tokens,
)
)
await self._log_token_usage(prompt_tokens, completion_tokens, total_tokens, model)
yield chunk
return stream_generator()
else:
response = await provider.chat_completion(**params)
model_input = self.formatter.encode_dialog_prompt(
@ -221,31 +280,7 @@ class InferenceRouter(Inference):
total_tokens=total_tokens,
)
)
# Log token usage metrics
metrics = [
("prompt_tokens", prompt_tokens),
("completion_tokens", completion_tokens),
("total_tokens", total_tokens),
]
span = get_current_span()
if span:
breakpoint()
for metric_name, value in metrics:
await self.telemetry.log_event(
MetricEvent(
trace_id=span.trace_id,
span_id=span.span_id,
metric=metric_name,
value=value,
timestamp=time.time(),
unit="tokens",
attributes={
"model_id": model_id,
"provider_id": model.provider_id,
},
)
)
await self._log_token_usage(prompt_tokens, completion_tokens, total_tokens, model)
return response
async def completion(