diff --git a/llama_stack/distribution/resolver.py b/llama_stack/distribution/resolver.py index c24df384d..624a4f2c2 100644 --- a/llama_stack/distribution/resolver.py +++ b/llama_stack/distribution/resolver.py @@ -163,7 +163,9 @@ def specs_for_autorouted_apis(apis_to_serve: List[str] | Set[str]) -> Dict[str, module="llama_stack.distribution.routers", routing_table_api=info.routing_table_api, api_dependencies=[info.routing_table_api], - deps__=[info.routing_table_api.value], + # Add telemetry as an optional dependency to all auto-routed providers + optional_api_dependencies=[Api.telemetry], + deps__=([info.routing_table_api.value, Api.telemetry.value]), ), ) } diff --git a/llama_stack/distribution/routers/__init__.py b/llama_stack/distribution/routers/__init__.py index a54f57fb3..d0fca8771 100644 --- a/llama_stack/distribution/routers/__init__.py +++ b/llama_stack/distribution/routers/__init__.py @@ -45,7 +45,7 @@ async def get_routing_table_impl( return impl -async def get_auto_router_impl(api: Api, routing_table: RoutingTable, _deps) -> Any: +async def get_auto_router_impl(api: Api, routing_table: RoutingTable, deps: Dict[str, Any]) -> Any: from .routers import ( DatasetIORouter, EvalRouter, @@ -65,9 +65,17 @@ async def get_auto_router_impl(api: Api, routing_table: RoutingTable, _deps) -> "eval": EvalRouter, "tool_runtime": ToolRuntimeRouter, } + api_to_deps = { + "inference": {"telemetry": Api.telemetry}, + } if api.value not in api_to_routers: raise ValueError(f"API {api.value} not found in router map") - impl = api_to_routers[api.value](routing_table) + api_to_dep_impl = {} + for dep_name, dep_api in api_to_deps.get(api.value, {}).items(): + if dep_api in deps: + api_to_dep_impl[dep_name] = deps[dep_api] + + impl = api_to_routers[api.value](routing_table, **api_to_dep_impl) await impl.initialize() return impl diff --git a/llama_stack/distribution/routers/routers.py b/llama_stack/distribution/routers/routers.py index 691df1988..a921df929 100644 --- a/llama_stack/distribution/routers/routers.py +++ b/llama_stack/distribution/routers/routers.py @@ -4,7 +4,11 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -from typing import Any, AsyncGenerator, Dict, List, Optional +import time +from typing import Any, AsyncGenerator, AsyncIterator, Dict, List, Optional, Union + +from llama_models.llama3.api.chat_format import ChatFormat +from llama_models.llama3.api.tokenizer import Tokenizer from llama_stack import logcat from llama_stack.apis.common.content_types import ( @@ -21,6 +25,10 @@ from llama_stack.apis.eval import ( JobStatus, ) from llama_stack.apis.inference import ( + ChatCompletionResponse, + ChatCompletionResponseEventType, + ChatCompletionResponseStreamChunk, + CompletionMessage, EmbeddingsResponse, EmbeddingTaskType, Inference, @@ -28,13 +36,15 @@ from llama_stack.apis.inference import ( Message, ResponseFormat, SamplingParams, + StopReason, TextTruncation, ToolChoice, ToolConfig, ToolDefinition, ToolPromptFormat, + UserMessage, ) -from llama_stack.apis.models import ModelType +from llama_stack.apis.models import Model, ModelType from llama_stack.apis.safety import RunShieldResponse, Safety from llama_stack.apis.scoring import ( ScoreBatchResponse, @@ -43,6 +53,7 @@ from llama_stack.apis.scoring import ( ScoringFnParams, ) from llama_stack.apis.shields import Shield +from llama_stack.apis.telemetry import MetricEvent, Telemetry from llama_stack.apis.tools import ( RAGDocument, RAGQueryConfig, @@ -53,6 +64,7 @@ from llama_stack.apis.tools import ( ) from llama_stack.apis.vector_io import Chunk, QueryChunksResponse, VectorIO from llama_stack.providers.datatypes import RoutingTable +from llama_stack.providers.utils.telemetry.tracing import get_current_span class VectorIORouter(VectorIO): @@ -121,9 +133,14 @@ class InferenceRouter(Inference): def __init__( self, routing_table: RoutingTable, + telemetry: Optional[Telemetry] = None, ) -> None: logcat.debug("core", "Initializing InferenceRouter") self.routing_table = routing_table + self.telemetry = telemetry + if self.telemetry: + self.tokenizer = Tokenizer.get_instance() + self.formatter = ChatFormat(self.tokenizer) async def initialize(self) -> None: logcat.debug("core", "InferenceRouter.initialize") @@ -147,6 +164,59 @@ class InferenceRouter(Inference): ) await self.routing_table.register_model(model_id, provider_model_id, provider_id, metadata, model_type) + def _construct_metrics( + self, prompt_tokens: int, completion_tokens: int, total_tokens: int, model: Model + ) -> List[MetricEvent]: + span = get_current_span() + metrics = [ + ("prompt_tokens", prompt_tokens), + ("completion_tokens", completion_tokens), + ("total_tokens", total_tokens), + ] + metric_events = [] + for metric_name, value in metrics: + metric_events.append( + MetricEvent( + trace_id=span.trace_id, + span_id=span.span_id, + metric=metric_name, + value=value, + timestamp=time.time(), + unit="tokens", + attributes={ + "model_id": model.model_id, + "provider_id": model.provider_id, + }, + ) + ) + return metric_events + + async def _add_token_metrics( + self, + prompt_tokens: int, + completion_tokens: int, + total_tokens: int, + model: Model, + target: Any, + ) -> None: + metrics = getattr(target, "metrics", None) + if metrics is None: + target.metrics = [] + + metrics = self._construct_metrics(prompt_tokens, completion_tokens, total_tokens, model) + target.metrics.extend(metrics) + if self.telemetry: + for metric in metrics: + await self.telemetry.log_event(metric) + + async def _count_tokens( + self, + messages: List[Message], + tool_prompt_format: Optional[ToolPromptFormat] = None, + ) -> Optional[int]: + encoded = self.formatter.encode_dialog_prompt(messages, tool_prompt_format) + return len(encoded.tokens) if encoded and encoded.tokens else 0 + async def chat_completion( self, model_id: str, @@ -159,7 +229,7 @@ class InferenceRouter(Inference): stream: Optional[bool] = False, logprobs: Optional[LogProbConfig] = None, tool_config: Optional[ToolConfig] = None, - ) -> AsyncGenerator: + ) -> Union[ChatCompletionResponse, AsyncIterator[ChatCompletionResponseStreamChunk]]: logcat.debug( "core", f"InferenceRouter.chat_completion: {model_id=}, {stream=}, {messages=}, {tools=}, {tool_config=}, {response_format=}", @@ -208,10 +278,47 @@ class InferenceRouter(Inference): tool_config=tool_config, ) provider = self.routing_table.get_provider_impl(model_id) + prompt_tokens = await self._count_tokens(messages, tool_config.tool_prompt_format) + if stream: - return (chunk async for chunk in await provider.chat_completion(**params)) + + async def stream_generator(): + completion_text = "" + async for chunk in await provider.chat_completion(**params): + if chunk.event.event_type == ChatCompletionResponseEventType.progress: + if chunk.event.delta.type == "text": + completion_text += chunk.event.delta.text + if chunk.event.event_type == ChatCompletionResponseEventType.complete: + completion_tokens = await self._count_tokens( + [CompletionMessage(content=completion_text, stop_reason=StopReason.end_of_turn)], + tool_config.tool_prompt_format, + ) + total_tokens = (prompt_tokens or 0) + (completion_tokens or 0) + await self._add_token_metrics( + prompt_tokens or 0, + completion_tokens or 0, + total_tokens, + model, + chunk, + ) + yield chunk + + return stream_generator() else: - return await provider.chat_completion(**params) + response = await provider.chat_completion(**params) + completion_tokens = await self._count_tokens( + [response.completion_message], + tool_config.tool_prompt_format, + ) + total_tokens = (prompt_tokens or 0) + (completion_tokens or 0) + await self._add_token_metrics( + prompt_tokens or 0, + completion_tokens or 0, + total_tokens, + model, + response, + ) + return response async def completion( self, @@ -240,10 +347,45 @@ class InferenceRouter(Inference): stream=stream, logprobs=logprobs, ) + + prompt_tokens = await self._count_tokens([UserMessage(role="user", content=str(content))]) + if stream: - return (chunk async for chunk in await provider.completion(**params)) + + async def stream_generator(): + completion_text = "" + async for chunk in await provider.completion(**params): + if hasattr(chunk, "delta"): + completion_text += chunk.delta + if hasattr(chunk, "stop_reason") and chunk.stop_reason and self.telemetry: + completion_tokens = await self._count_tokens( + [CompletionMessage(content=completion_text, stop_reason=chunk.stop_reason)] + ) + total_tokens = (prompt_tokens or 0) + (completion_tokens or 0) + await self._add_token_metrics( + prompt_tokens or 0, + completion_tokens or 0, + total_tokens, + model, + chunk, + ) + yield chunk + + return stream_generator() else: - return await provider.completion(**params) + response = await provider.completion(**params) + completion_tokens = await self._count_tokens( + [CompletionMessage(content=str(response.content), stop_reason=StopReason.end_of_turn)] + ) + total_tokens = (prompt_tokens or 0) + (completion_tokens or 0) + await self._add_token_metrics( + prompt_tokens or 0, + completion_tokens or 0, + total_tokens, + model, + response, + ) + return response async def embeddings( self,