mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-08-12 13:00:39 +00:00
add metrics for streaming
This commit is contained in:
parent
38f1337afa
commit
37b7390079
2 changed files with 65 additions and 30 deletions
|
@ -230,7 +230,7 @@ Metric = register_schema(
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
class MetricsMixin(BaseModel):
|
class MetricsMixin(BaseModel):
|
||||||
metrics: List[Metric] = Field(default_factory=list)
|
metrics: Optional[List[Metric]] = None
|
||||||
|
|
||||||
|
|
||||||
@runtime_checkable
|
@runtime_checkable
|
||||||
|
|
|
@ -5,9 +5,10 @@
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
import time
|
import time
|
||||||
from typing import Any, AsyncGenerator, Dict, List, Optional
|
from typing import Any, AsyncGenerator, AsyncIterator, Dict, List, Optional, Union
|
||||||
|
|
||||||
from llama_models.llama3.api.chat_format import ChatFormat
|
from llama_models.llama3.api.chat_format import ChatFormat
|
||||||
|
from llama_models.llama3.api.datatypes import RawMessage
|
||||||
from llama_models.llama3.api.tokenizer import Tokenizer
|
from llama_models.llama3.api.tokenizer import Tokenizer
|
||||||
|
|
||||||
from llama_stack.apis.common.content_types import (
|
from llama_stack.apis.common.content_types import (
|
||||||
|
@ -24,6 +25,9 @@ from llama_stack.apis.eval import (
|
||||||
JobStatus,
|
JobStatus,
|
||||||
)
|
)
|
||||||
from llama_stack.apis.inference import (
|
from llama_stack.apis.inference import (
|
||||||
|
ChatCompletionResponse,
|
||||||
|
ChatCompletionResponseEventType,
|
||||||
|
ChatCompletionResponseStreamChunk,
|
||||||
EmbeddingsResponse,
|
EmbeddingsResponse,
|
||||||
EmbeddingTaskType,
|
EmbeddingTaskType,
|
||||||
Inference,
|
Inference,
|
||||||
|
@ -37,7 +41,7 @@ from llama_stack.apis.inference import (
|
||||||
ToolDefinition,
|
ToolDefinition,
|
||||||
ToolPromptFormat,
|
ToolPromptFormat,
|
||||||
)
|
)
|
||||||
from llama_stack.apis.models import ModelType
|
from llama_stack.apis.models import Model, ModelType
|
||||||
from llama_stack.apis.safety import RunShieldResponse, Safety
|
from llama_stack.apis.safety import RunShieldResponse, Safety
|
||||||
from llama_stack.apis.scoring import (
|
from llama_stack.apis.scoring import (
|
||||||
ScoreBatchResponse,
|
ScoreBatchResponse,
|
||||||
|
@ -138,6 +142,31 @@ class InferenceRouter(Inference):
|
||||||
) -> None:
|
) -> None:
|
||||||
await self.routing_table.register_model(model_id, provider_model_id, provider_id, metadata, model_type)
|
await self.routing_table.register_model(model_id, provider_model_id, provider_id, metadata, model_type)
|
||||||
|
|
||||||
|
async def _log_token_usage(
|
||||||
|
self, prompt_tokens: int, completion_tokens: int, total_tokens: int, model: Model
|
||||||
|
) -> None:
|
||||||
|
span = get_current_span()
|
||||||
|
metrics = [
|
||||||
|
("prompt_tokens", prompt_tokens),
|
||||||
|
("completion_tokens", completion_tokens),
|
||||||
|
("total_tokens", total_tokens),
|
||||||
|
]
|
||||||
|
for metric_name, value in metrics:
|
||||||
|
await self.telemetry.log_event(
|
||||||
|
MetricEvent(
|
||||||
|
trace_id=span.trace_id,
|
||||||
|
span_id=span.span_id,
|
||||||
|
metric=metric_name,
|
||||||
|
value=value,
|
||||||
|
timestamp=time.time(),
|
||||||
|
unit="tokens",
|
||||||
|
attributes={
|
||||||
|
"model_id": model.model_id,
|
||||||
|
"provider_id": model.provider_id,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
async def chat_completion(
|
async def chat_completion(
|
||||||
self,
|
self,
|
||||||
model_id: str,
|
model_id: str,
|
||||||
|
@ -150,7 +179,7 @@ class InferenceRouter(Inference):
|
||||||
stream: Optional[bool] = False,
|
stream: Optional[bool] = False,
|
||||||
logprobs: Optional[LogProbConfig] = None,
|
logprobs: Optional[LogProbConfig] = None,
|
||||||
tool_config: Optional[ToolConfig] = None,
|
tool_config: Optional[ToolConfig] = None,
|
||||||
) -> AsyncGenerator:
|
) -> Union[ChatCompletionResponse, AsyncIterator[ChatCompletionResponseStreamChunk]]:
|
||||||
model = await self.routing_table.get_model(model_id)
|
model = await self.routing_table.get_model(model_id)
|
||||||
if model is None:
|
if model is None:
|
||||||
raise ValueError(f"Model '{model_id}' not found")
|
raise ValueError(f"Model '{model_id}' not found")
|
||||||
|
@ -198,7 +227,37 @@ class InferenceRouter(Inference):
|
||||||
)
|
)
|
||||||
provider = self.routing_table.get_provider_impl(model_id)
|
provider = self.routing_table.get_provider_impl(model_id)
|
||||||
if stream:
|
if stream:
|
||||||
return (chunk async for chunk in await provider.chat_completion(**params))
|
|
||||||
|
async def stream_generator():
|
||||||
|
model_input = self.formatter.encode_dialog_prompt(
|
||||||
|
messages,
|
||||||
|
tool_config.tool_prompt_format,
|
||||||
|
)
|
||||||
|
prompt_tokens = len(model_input.tokens) if model_input.tokens else 0
|
||||||
|
completion_text = ""
|
||||||
|
async for chunk in await provider.chat_completion(**params):
|
||||||
|
if chunk.event.event_type == ChatCompletionResponseEventType.progress:
|
||||||
|
completion_text += chunk.event.delta.text
|
||||||
|
if chunk.event.event_type == ChatCompletionResponseEventType.complete:
|
||||||
|
model_output = self.formatter.encode_dialog_prompt(
|
||||||
|
[RawMessage(role="assistant", content=completion_text)],
|
||||||
|
tool_config.tool_prompt_format,
|
||||||
|
)
|
||||||
|
completion_tokens = len(model_output.tokens) if model_output.tokens else 0
|
||||||
|
total_tokens = prompt_tokens + completion_tokens
|
||||||
|
if chunk.metrics is None:
|
||||||
|
chunk.metrics = []
|
||||||
|
chunk.metrics.append(
|
||||||
|
TokenUsage(
|
||||||
|
prompt_tokens=prompt_tokens,
|
||||||
|
completion_tokens=completion_tokens,
|
||||||
|
total_tokens=total_tokens,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
await self._log_token_usage(prompt_tokens, completion_tokens, total_tokens, model)
|
||||||
|
yield chunk
|
||||||
|
|
||||||
|
return stream_generator()
|
||||||
else:
|
else:
|
||||||
response = await provider.chat_completion(**params)
|
response = await provider.chat_completion(**params)
|
||||||
model_input = self.formatter.encode_dialog_prompt(
|
model_input = self.formatter.encode_dialog_prompt(
|
||||||
|
@ -221,31 +280,7 @@ class InferenceRouter(Inference):
|
||||||
total_tokens=total_tokens,
|
total_tokens=total_tokens,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
# Log token usage metrics
|
await self._log_token_usage(prompt_tokens, completion_tokens, total_tokens, model)
|
||||||
metrics = [
|
|
||||||
("prompt_tokens", prompt_tokens),
|
|
||||||
("completion_tokens", completion_tokens),
|
|
||||||
("total_tokens", total_tokens),
|
|
||||||
]
|
|
||||||
|
|
||||||
span = get_current_span()
|
|
||||||
if span:
|
|
||||||
breakpoint()
|
|
||||||
for metric_name, value in metrics:
|
|
||||||
await self.telemetry.log_event(
|
|
||||||
MetricEvent(
|
|
||||||
trace_id=span.trace_id,
|
|
||||||
span_id=span.span_id,
|
|
||||||
metric=metric_name,
|
|
||||||
value=value,
|
|
||||||
timestamp=time.time(),
|
|
||||||
unit="tokens",
|
|
||||||
attributes={
|
|
||||||
"model_id": model_id,
|
|
||||||
"provider_id": model.provider_id,
|
|
||||||
},
|
|
||||||
)
|
|
||||||
)
|
|
||||||
return response
|
return response
|
||||||
|
|
||||||
async def completion(
|
async def completion(
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue