diff --git a/llama_stack/apis/inference/inference.py b/llama_stack/apis/inference/inference.py index fa917ac22..d0f5d15c5 100644 --- a/llama_stack/apis/inference/inference.py +++ b/llama_stack/apis/inference/inference.py @@ -285,7 +285,7 @@ class CompletionRequest(BaseModel): @json_schema_type -class CompletionResponse(MetricResponseMixin): +class CompletionResponse(BaseModel): """Response from a completion request. :param content: The generated completion text @@ -299,7 +299,7 @@ class CompletionResponse(MetricResponseMixin): @json_schema_type -class CompletionResponseStreamChunk(MetricResponseMixin): +class CompletionResponseStreamChunk(BaseModel): """A chunk of a streamed completion response. :param delta: New content generated since last chunk. This can be one or more tokens. @@ -368,7 +368,7 @@ class ChatCompletionRequest(BaseModel): @json_schema_type -class ChatCompletionResponseStreamChunk(MetricResponseMixin): +class ChatCompletionResponseStreamChunk(MetricResponseMixin, BaseModel): """A chunk of a streamed chat completion response. :param event: The event containing the new content @@ -378,7 +378,7 @@ class ChatCompletionResponseStreamChunk(MetricResponseMixin): @json_schema_type -class ChatCompletionResponse(MetricResponseMixin): +class ChatCompletionResponse(MetricResponseMixin, BaseModel): """Response from a chat completion request. :param completion_message: The complete response message diff --git a/llama_stack/distribution/resolver.py b/llama_stack/distribution/resolver.py index 624a4f2c2..c24df384d 100644 --- a/llama_stack/distribution/resolver.py +++ b/llama_stack/distribution/resolver.py @@ -163,9 +163,7 @@ def specs_for_autorouted_apis(apis_to_serve: List[str] | Set[str]) -> Dict[str, module="llama_stack.distribution.routers", routing_table_api=info.routing_table_api, api_dependencies=[info.routing_table_api], - # Add telemetry as an optional dependency to all auto-routed providers - optional_api_dependencies=[Api.telemetry], - deps__=([info.routing_table_api.value, Api.telemetry.value]), + deps__=[info.routing_table_api.value], ), ) } diff --git a/llama_stack/distribution/routers/__init__.py b/llama_stack/distribution/routers/__init__.py index d0fca8771..a54f57fb3 100644 --- a/llama_stack/distribution/routers/__init__.py +++ b/llama_stack/distribution/routers/__init__.py @@ -45,7 +45,7 @@ async def get_routing_table_impl( return impl -async def get_auto_router_impl(api: Api, routing_table: RoutingTable, deps: Dict[str, Any]) -> Any: +async def get_auto_router_impl(api: Api, routing_table: RoutingTable, _deps) -> Any: from .routers import ( DatasetIORouter, EvalRouter, @@ -65,17 +65,9 @@ async def get_auto_router_impl(api: Api, routing_table: RoutingTable, deps: Dict "eval": EvalRouter, "tool_runtime": ToolRuntimeRouter, } - api_to_deps = { - "inference": {"telemetry": Api.telemetry}, - } if api.value not in api_to_routers: raise ValueError(f"API {api.value} not found in router map") - api_to_dep_impl = {} - for dep_name, dep_api in api_to_deps.get(api.value, {}).items(): - if dep_api in deps: - api_to_dep_impl[dep_name] = deps[dep_api] - - impl = api_to_routers[api.value](routing_table, **api_to_dep_impl) + impl = api_to_routers[api.value](routing_table) await impl.initialize() return impl diff --git a/llama_stack/distribution/routers/routers.py b/llama_stack/distribution/routers/routers.py index 3cfc2b119..f2c70e66f 100644 --- a/llama_stack/distribution/routers/routers.py +++ b/llama_stack/distribution/routers/routers.py @@ -4,8 +4,7 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -import time -from typing import Any, AsyncGenerator, AsyncIterator, Dict, List, Optional, Union +from typing import Any, AsyncGenerator, Dict, List, Optional from llama_stack import logcat from llama_stack.apis.common.content_types import ( @@ -22,10 +21,6 @@ from llama_stack.apis.eval import ( JobStatus, ) from llama_stack.apis.inference import ( - ChatCompletionResponse, - ChatCompletionResponseEventType, - ChatCompletionResponseStreamChunk, - CompletionMessage, EmbeddingsResponse, EmbeddingTaskType, Inference, @@ -33,14 +28,13 @@ from llama_stack.apis.inference import ( Message, ResponseFormat, SamplingParams, - StopReason, TextTruncation, ToolChoice, ToolConfig, ToolDefinition, ToolPromptFormat, ) -from llama_stack.apis.models import Model, ModelType +from llama_stack.apis.models import ModelType from llama_stack.apis.safety import RunShieldResponse, Safety from llama_stack.apis.scoring import ( ScoreBatchResponse, @@ -49,7 +43,6 @@ from llama_stack.apis.scoring import ( ScoringFnParams, ) from llama_stack.apis.shields import Shield -from llama_stack.apis.telemetry import MetricEvent, Telemetry from llama_stack.apis.tools import ( RAGDocument, RAGQueryConfig, @@ -59,10 +52,7 @@ from llama_stack.apis.tools import ( ToolRuntime, ) from llama_stack.apis.vector_io import Chunk, QueryChunksResponse, VectorIO -from llama_stack.models.llama.llama3.chat_format import ChatFormat -from llama_stack.models.llama.llama3.tokenizer import Tokenizer from llama_stack.providers.datatypes import RoutingTable -from llama_stack.providers.utils.telemetry.tracing import get_current_span class VectorIORouter(VectorIO): @@ -131,14 +121,9 @@ class InferenceRouter(Inference): def __init__( self, routing_table: RoutingTable, - telemetry: Optional[Telemetry] = None, ) -> None: logcat.debug("core", "Initializing InferenceRouter") self.routing_table = routing_table - self.telemetry = telemetry - if self.telemetry: - self.tokenizer = Tokenizer.get_instance() - self.formatter = ChatFormat(self.tokenizer) async def initialize(self) -> None: logcat.debug("core", "InferenceRouter.initialize") @@ -162,57 +147,6 @@ class InferenceRouter(Inference): ) await self.routing_table.register_model(model_id, provider_model_id, provider_id, metadata, model_type) - def _construct_metrics( - self, prompt_tokens: int, completion_tokens: int, total_tokens: int, model: Model - ) -> List[MetricEvent]: - span = get_current_span() - metrics = [ - ("prompt_tokens", prompt_tokens), - ("completion_tokens", completion_tokens), - ("total_tokens", total_tokens), - ] - metric_events = [] - for metric_name, value in metrics: - metric_events.append( - MetricEvent( - trace_id=span.trace_id, - span_id=span.span_id, - metric=metric_name, - value=value, - timestamp=time.time(), - unit="tokens", - attributes={ - "model_id": model.model_id, - "provider_id": model.provider_id, - }, - ) - ) - return metric_events - - async def _compute_and_log_token_usage( - self, - prompt_tokens: int, - completion_tokens: int, - total_tokens: int, - model: Model, - ) -> List[MetricEvent]: - metrics = self._construct_metrics(prompt_tokens, completion_tokens, total_tokens, model) - if self.telemetry: - for metric in metrics: - await self.telemetry.log_event(metric) - return metrics - - async def _count_tokens( - self, - messages: List[Message] | InterleavedContent, - tool_prompt_format: Optional[ToolPromptFormat] = None, - ) -> Optional[int]: - if isinstance(messages, list): - encoded = self.formatter.encode_dialog_prompt(messages, tool_prompt_format) - else: - encoded = self.formatter.encode_content(messages) - return len(encoded.tokens) if encoded and encoded.tokens else 0 - async def chat_completion( self, model_id: str, @@ -225,7 +159,7 @@ class InferenceRouter(Inference): stream: Optional[bool] = False, logprobs: Optional[LogProbConfig] = None, tool_config: Optional[ToolConfig] = None, - ) -> Union[ChatCompletionResponse, AsyncIterator[ChatCompletionResponseStreamChunk]]: + ) -> AsyncGenerator: logcat.debug( "core", f"InferenceRouter.chat_completion: {model_id=}, {stream=}, {messages=}, {tools=}, {tool_config=}, {response_format=}", @@ -276,47 +210,10 @@ class InferenceRouter(Inference): tool_config=tool_config, ) provider = self.routing_table.get_provider_impl(model_id) - prompt_tokens = await self._count_tokens(messages, tool_config.tool_prompt_format) - if stream: - - async def stream_generator(): - completion_text = "" - async for chunk in await provider.chat_completion(**params): - if chunk.event.event_type == ChatCompletionResponseEventType.progress: - if chunk.event.delta.type == "text": - completion_text += chunk.event.delta.text - if chunk.event.event_type == ChatCompletionResponseEventType.complete: - completion_tokens = await self._count_tokens( - [CompletionMessage(content=completion_text, stop_reason=StopReason.end_of_turn)], - tool_config.tool_prompt_format, - ) - total_tokens = (prompt_tokens or 0) + (completion_tokens or 0) - metrics = await self._compute_and_log_token_usage( - prompt_tokens or 0, - completion_tokens or 0, - total_tokens, - model, - ) - chunk.metrics = metrics if chunk.metrics is None else chunk.metrics + metrics - yield chunk - - return stream_generator() + return (chunk async for chunk in await provider.chat_completion(**params)) else: - response = await provider.chat_completion(**params) - completion_tokens = await self._count_tokens( - [response.completion_message], - tool_config.tool_prompt_format, - ) - total_tokens = (prompt_tokens or 0) + (completion_tokens or 0) - metrics = await self._compute_and_log_token_usage( - prompt_tokens or 0, - completion_tokens or 0, - total_tokens, - model, - ) - response.metrics = metrics if response.metrics is None else response.metrics + metrics - return response + return await provider.chat_completion(**params) async def completion( self, @@ -347,41 +244,10 @@ class InferenceRouter(Inference): stream=stream, logprobs=logprobs, ) - - prompt_tokens = await self._count_tokens(content) - if stream: - - async def stream_generator(): - completion_text = "" - async for chunk in await provider.completion(**params): - if hasattr(chunk, "delta"): - completion_text += chunk.delta - if hasattr(chunk, "stop_reason") and chunk.stop_reason and self.telemetry: - completion_tokens = await self._count_tokens(completion_text) - total_tokens = (prompt_tokens or 0) + (completion_tokens or 0) - metrics = await self._compute_and_log_token_usage( - prompt_tokens or 0, - completion_tokens or 0, - total_tokens, - model, - ) - chunk.metrics = metrics if chunk.metrics is None else chunk.metrics + metrics - yield chunk - - return stream_generator() + return (chunk async for chunk in await provider.completion(**params)) else: - response = await provider.completion(**params) - completion_tokens = await self._count_tokens(response.content) - total_tokens = (prompt_tokens or 0) + (completion_tokens or 0) - metrics = await self._compute_and_log_token_usage( - prompt_tokens or 0, - completion_tokens or 0, - total_tokens, - model, - ) - response.metrics = metrics if response.metrics is None else response.metrics + metrics - return response + return await provider.completion(**params) async def embeddings( self, diff --git a/llama_stack/providers/inline/telemetry/meta_reference/telemetry.py b/llama_stack/providers/inline/telemetry/meta_reference/telemetry.py index 4cdb420b2..e713a057f 100644 --- a/llama_stack/providers/inline/telemetry/meta_reference/telemetry.py +++ b/llama_stack/providers/inline/telemetry/meta_reference/telemetry.py @@ -73,7 +73,6 @@ class TelemetryAdapter(TelemetryDatasetMixin, Telemetry): def __init__(self, config: TelemetryConfig, deps: Dict[str, Any]) -> None: self.config = config self.datasetio_api = deps.get(Api.datasetio) - self.meter = None resource = Resource.create( { @@ -172,8 +171,6 @@ class TelemetryAdapter(TelemetryDatasetMixin, Telemetry): return _GLOBAL_STORAGE["gauges"][name] def _log_metric(self, event: MetricEvent) -> None: - if self.meter is None: - return if isinstance(event.value, int): counter = self._get_or_create_counter(event.metric, event.unit) counter.add(event.value, attributes=event.attributes)