From b8535417e0f9986b096c24d6811689b11c17d7ae Mon Sep 17 00:00:00 2001 From: Dinesh Yeduguru Date: Wed, 5 Mar 2025 12:41:45 -0800 Subject: [PATCH] feat: record token usage for inference API (#1300) # What does this PR do? Inference router computes the token usage related metrics for all providers and returns the metrics as part of response and also logs to telemetry. ## Test Plan LLAMA_STACK_DISABLE_VERSION_CHECK=true llama stack run ~/.llama/distributions/fireworks/fireworks-run.yaml ``` curl --request POST \ --url http://localhost:8321/v1/inference/chat-completion \ --header 'content-type: application/json' \ --data '{ "model_id": "meta-llama/Llama-3.1-70B-Instruct", "messages": [ { "role": "user", "content": { "type": "text", "text": "where do humans live" } } ], "stream": false }' | jq . { "metrics": [ { "trace_id": "yjv1tf0jS1evOyPm", "span_id": "WqYKvg0_", "timestamp": "2025-02-27T18:55:10.770903Z", "attributes": { "model_id": "meta-llama/Llama-3.1-70B-Instruct", "provider_id": "fireworks" }, "type": "metric", "metric": "prompt_tokens", "value": 10, "unit": "tokens" }, { "trace_id": "yjv1tf0jS1evOyPm", "span_id": "WqYKvg0_", "timestamp": "2025-02-27T18:55:10.770916Z", "attributes": { "model_id": "meta-llama/Llama-3.1-70B-Instruct", "provider_id": "fireworks" }, "type": "metric", "metric": "completion_tokens", "value": 411, "unit": "tokens" }, { "trace_id": "yjv1tf0jS1evOyPm", "span_id": "WqYKvg0_", "timestamp": "2025-02-27T18:55:10.770919Z", "attributes": { "model_id": "meta-llama/Llama-3.1-70B-Instruct", "provider_id": "fireworks" }, "type": "metric", "metric": "total_tokens", "value": 421, "unit": "tokens" } ], "completion_message": { "role": "assistant", "content": "Humans live in various parts of the world, inhabiting almost every continent, country, and region. Here's a breakdown of where humans live:\n\n1. **Continents:** Humans inhabit all seven continents:\n\t* Africa\n\t* Antarctica (research stations only)\n\t* Asia\n\t* Australia\n\t* Europe\n\t* North America\n\t* South America\n2. **Countries:** There are 196 countries recognized by the United Nations, and humans live in almost all of them.\n3. **Regions:** Humans live in diverse regions, including:\n\t* Deserts (e.g., Sahara, Mojave)\n\t* Forests (e.g., Amazon, Congo)\n\t* Grasslands (e.g., Prairies, Steppes)\n\t* Mountains (e.g., Himalayas, Andes)\n\t* Oceans (e.g., coastal areas, islands)\n\t* Tundras (e.g., Arctic, sub-Arctic)\n4. **Cities and towns:** Many humans live in urban areas, such as cities and towns, which are often located near:\n\t* Coastlines\n\t* Rivers\n\t* Lakes\n\t* Mountains\n5. **Rural areas:** Some humans live in rural areas, such as:\n\t* Villages\n\t* Farms\n\t* Countryside\n6. **Islands:** Humans inhabit many islands, including:\n\t* Tropical islands (e.g., Hawaii, Maldives)\n\t* Arctic islands (e.g., Greenland, Iceland)\n\t* Continental islands (e.g., Great Britain, Ireland)\n7. **Extreme environments:** Humans also live in extreme environments, such as:\n\t* High-altitude areas (e.g., Tibet, Andes)\n\t* Low-altitude areas (e.g., Death Valley, Dead Sea)\n\t* Areas with extreme temperatures (e.g., Arctic, Sahara)\n\nOverall, humans have adapted to live in a wide range of environments and ecosystems around the world.", "stop_reason": "end_of_turn", "tool_calls": [] }, "logprobs": null } ``` ``` LLAMA_STACK_CONFIG=fireworks pytest -s -v tests/integration/inference ======================================================================== short test summary info ========================================================================= FAILED tests/integration/inference/test_text_inference.py::test_text_chat_completion_tool_calling_tools_not_in_request[txt=8B:vis=11B-inference:chat_completion:tool_calling_tools_absent-True] - ValueError: Unsupported tool prompt format: ToolPromptFormat.json FAILED tests/integration/inference/test_text_inference.py::test_text_chat_completion_tool_calling_tools_not_in_request[txt=8B:vis=11B-inference:chat_completion:tool_calling_tools_absent-False] - ValueError: Unsupported tool prompt format: ToolPromptFormat.json FAILED tests/integration/inference/test_vision_inference.py::test_image_chat_completion_non_streaming[txt=8B:vis=11B] - fireworks.client.error.InvalidRequestError: {'error': {'object': 'error', 'type': 'invalid_request_error', 'message': 'Failed to decode image cannot identify image f... FAILED tests/integration/inference/test_vision_inference.py::test_image_chat_completion_streaming[txt=8B:vis=11B] - fireworks.client.error.InvalidRequestError: {'error': {'object': 'error', 'type': 'invalid_request_error', 'message': 'Failed to decode image cannot identify image f... ========================================================= 4 failed, 16 passed, 23 xfailed, 17 warnings in 44.36s ========================================================= ``` --- llama_stack/apis/inference/inference.py | 8 +- llama_stack/distribution/resolver.py | 4 +- llama_stack/distribution/routers/__init__.py | 12 +- llama_stack/distribution/routers/routers.py | 149 +++++++++++++++++- .../telemetry/meta_reference/telemetry.py | 3 + 5 files changed, 162 insertions(+), 14 deletions(-) diff --git a/llama_stack/apis/inference/inference.py b/llama_stack/apis/inference/inference.py index e517d9c3c..08ceace4f 100644 --- a/llama_stack/apis/inference/inference.py +++ b/llama_stack/apis/inference/inference.py @@ -285,7 +285,7 @@ class CompletionRequest(BaseModel): @json_schema_type -class CompletionResponse(BaseModel): +class CompletionResponse(MetricResponseMixin): """Response from a completion request. :param content: The generated completion text @@ -299,7 +299,7 @@ class CompletionResponse(BaseModel): @json_schema_type -class CompletionResponseStreamChunk(BaseModel): +class CompletionResponseStreamChunk(MetricResponseMixin): """A chunk of a streamed completion response. :param delta: New content generated since last chunk. This can be one or more tokens. @@ -368,7 +368,7 @@ class ChatCompletionRequest(BaseModel): @json_schema_type -class ChatCompletionResponseStreamChunk(MetricResponseMixin, BaseModel): +class ChatCompletionResponseStreamChunk(MetricResponseMixin): """A chunk of a streamed chat completion response. :param event: The event containing the new content @@ -378,7 +378,7 @@ class ChatCompletionResponseStreamChunk(MetricResponseMixin, BaseModel): @json_schema_type -class ChatCompletionResponse(MetricResponseMixin, BaseModel): +class ChatCompletionResponse(MetricResponseMixin): """Response from a chat completion request. :param completion_message: The complete response message diff --git a/llama_stack/distribution/resolver.py b/llama_stack/distribution/resolver.py index c24df384d..624a4f2c2 100644 --- a/llama_stack/distribution/resolver.py +++ b/llama_stack/distribution/resolver.py @@ -163,7 +163,9 @@ def specs_for_autorouted_apis(apis_to_serve: List[str] | Set[str]) -> Dict[str, module="llama_stack.distribution.routers", routing_table_api=info.routing_table_api, api_dependencies=[info.routing_table_api], - deps__=[info.routing_table_api.value], + # Add telemetry as an optional dependency to all auto-routed providers + optional_api_dependencies=[Api.telemetry], + deps__=([info.routing_table_api.value, Api.telemetry.value]), ), ) } diff --git a/llama_stack/distribution/routers/__init__.py b/llama_stack/distribution/routers/__init__.py index a54f57fb3..d0fca8771 100644 --- a/llama_stack/distribution/routers/__init__.py +++ b/llama_stack/distribution/routers/__init__.py @@ -45,7 +45,7 @@ async def get_routing_table_impl( return impl -async def get_auto_router_impl(api: Api, routing_table: RoutingTable, _deps) -> Any: +async def get_auto_router_impl(api: Api, routing_table: RoutingTable, deps: Dict[str, Any]) -> Any: from .routers import ( DatasetIORouter, EvalRouter, @@ -65,9 +65,17 @@ async def get_auto_router_impl(api: Api, routing_table: RoutingTable, _deps) -> "eval": EvalRouter, "tool_runtime": ToolRuntimeRouter, } + api_to_deps = { + "inference": {"telemetry": Api.telemetry}, + } if api.value not in api_to_routers: raise ValueError(f"API {api.value} not found in router map") - impl = api_to_routers[api.value](routing_table) + api_to_dep_impl = {} + for dep_name, dep_api in api_to_deps.get(api.value, {}).items(): + if dep_api in deps: + api_to_dep_impl[dep_name] = deps[dep_api] + + impl = api_to_routers[api.value](routing_table, **api_to_dep_impl) await impl.initialize() return impl diff --git a/llama_stack/distribution/routers/routers.py b/llama_stack/distribution/routers/routers.py index 691df1988..1a95ad45b 100644 --- a/llama_stack/distribution/routers/routers.py +++ b/llama_stack/distribution/routers/routers.py @@ -4,7 +4,11 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -from typing import Any, AsyncGenerator, Dict, List, Optional +import time +from typing import Any, AsyncGenerator, AsyncIterator, Dict, List, Optional, Union + +from llama_models.llama3.api.chat_format import ChatFormat +from llama_models.llama3.api.tokenizer import Tokenizer from llama_stack import logcat from llama_stack.apis.common.content_types import ( @@ -21,6 +25,10 @@ from llama_stack.apis.eval import ( JobStatus, ) from llama_stack.apis.inference import ( + ChatCompletionResponse, + ChatCompletionResponseEventType, + ChatCompletionResponseStreamChunk, + CompletionMessage, EmbeddingsResponse, EmbeddingTaskType, Inference, @@ -28,13 +36,14 @@ from llama_stack.apis.inference import ( Message, ResponseFormat, SamplingParams, + StopReason, TextTruncation, ToolChoice, ToolConfig, ToolDefinition, ToolPromptFormat, ) -from llama_stack.apis.models import ModelType +from llama_stack.apis.models import Model, ModelType from llama_stack.apis.safety import RunShieldResponse, Safety from llama_stack.apis.scoring import ( ScoreBatchResponse, @@ -43,6 +52,7 @@ from llama_stack.apis.scoring import ( ScoringFnParams, ) from llama_stack.apis.shields import Shield +from llama_stack.apis.telemetry import MetricEvent, Telemetry from llama_stack.apis.tools import ( RAGDocument, RAGQueryConfig, @@ -53,6 +63,7 @@ from llama_stack.apis.tools import ( ) from llama_stack.apis.vector_io import Chunk, QueryChunksResponse, VectorIO from llama_stack.providers.datatypes import RoutingTable +from llama_stack.providers.utils.telemetry.tracing import get_current_span class VectorIORouter(VectorIO): @@ -121,9 +132,14 @@ class InferenceRouter(Inference): def __init__( self, routing_table: RoutingTable, + telemetry: Optional[Telemetry] = None, ) -> None: logcat.debug("core", "Initializing InferenceRouter") self.routing_table = routing_table + self.telemetry = telemetry + if self.telemetry: + self.tokenizer = Tokenizer.get_instance() + self.formatter = ChatFormat(self.tokenizer) async def initialize(self) -> None: logcat.debug("core", "InferenceRouter.initialize") @@ -147,6 +163,57 @@ class InferenceRouter(Inference): ) await self.routing_table.register_model(model_id, provider_model_id, provider_id, metadata, model_type) + def _construct_metrics( + self, prompt_tokens: int, completion_tokens: int, total_tokens: int, model: Model + ) -> List[MetricEvent]: + span = get_current_span() + metrics = [ + ("prompt_tokens", prompt_tokens), + ("completion_tokens", completion_tokens), + ("total_tokens", total_tokens), + ] + metric_events = [] + for metric_name, value in metrics: + metric_events.append( + MetricEvent( + trace_id=span.trace_id, + span_id=span.span_id, + metric=metric_name, + value=value, + timestamp=time.time(), + unit="tokens", + attributes={ + "model_id": model.model_id, + "provider_id": model.provider_id, + }, + ) + ) + return metric_events + + async def _compute_and_log_token_usage( + self, + prompt_tokens: int, + completion_tokens: int, + total_tokens: int, + model: Model, + ) -> List[MetricEvent]: + metrics = self._construct_metrics(prompt_tokens, completion_tokens, total_tokens, model) + if self.telemetry: + for metric in metrics: + await self.telemetry.log_event(metric) + return metrics + + async def _count_tokens( + self, + messages: List[Message] | InterleavedContent, + tool_prompt_format: Optional[ToolPromptFormat] = None, + ) -> Optional[int]: + if isinstance(messages, list): + encoded = self.formatter.encode_dialog_prompt(messages, tool_prompt_format) + else: + encoded = self.formatter.encode_content(messages) + return len(encoded.tokens) if encoded and encoded.tokens else 0 + async def chat_completion( self, model_id: str, @@ -159,7 +226,7 @@ class InferenceRouter(Inference): stream: Optional[bool] = False, logprobs: Optional[LogProbConfig] = None, tool_config: Optional[ToolConfig] = None, - ) -> AsyncGenerator: + ) -> Union[ChatCompletionResponse, AsyncIterator[ChatCompletionResponseStreamChunk]]: logcat.debug( "core", f"InferenceRouter.chat_completion: {model_id=}, {stream=}, {messages=}, {tools=}, {tool_config=}, {response_format=}", @@ -208,10 +275,47 @@ class InferenceRouter(Inference): tool_config=tool_config, ) provider = self.routing_table.get_provider_impl(model_id) + prompt_tokens = await self._count_tokens(messages, tool_config.tool_prompt_format) + if stream: - return (chunk async for chunk in await provider.chat_completion(**params)) + + async def stream_generator(): + completion_text = "" + async for chunk in await provider.chat_completion(**params): + if chunk.event.event_type == ChatCompletionResponseEventType.progress: + if chunk.event.delta.type == "text": + completion_text += chunk.event.delta.text + if chunk.event.event_type == ChatCompletionResponseEventType.complete: + completion_tokens = await self._count_tokens( + [CompletionMessage(content=completion_text, stop_reason=StopReason.end_of_turn)], + tool_config.tool_prompt_format, + ) + total_tokens = (prompt_tokens or 0) + (completion_tokens or 0) + metrics = await self._compute_and_log_token_usage( + prompt_tokens or 0, + completion_tokens or 0, + total_tokens, + model, + ) + chunk.metrics = metrics if chunk.metrics is None else chunk.metrics + metrics + yield chunk + + return stream_generator() else: - return await provider.chat_completion(**params) + response = await provider.chat_completion(**params) + completion_tokens = await self._count_tokens( + [response.completion_message], + tool_config.tool_prompt_format, + ) + total_tokens = (prompt_tokens or 0) + (completion_tokens or 0) + metrics = await self._compute_and_log_token_usage( + prompt_tokens or 0, + completion_tokens or 0, + total_tokens, + model, + ) + response.metrics = metrics if response.metrics is None else response.metrics + metrics + return response async def completion( self, @@ -240,10 +344,41 @@ class InferenceRouter(Inference): stream=stream, logprobs=logprobs, ) + + prompt_tokens = await self._count_tokens(content) + if stream: - return (chunk async for chunk in await provider.completion(**params)) + + async def stream_generator(): + completion_text = "" + async for chunk in await provider.completion(**params): + if hasattr(chunk, "delta"): + completion_text += chunk.delta + if hasattr(chunk, "stop_reason") and chunk.stop_reason and self.telemetry: + completion_tokens = await self._count_tokens(completion_text) + total_tokens = (prompt_tokens or 0) + (completion_tokens or 0) + metrics = await self._compute_and_log_token_usage( + prompt_tokens or 0, + completion_tokens or 0, + total_tokens, + model, + ) + chunk.metrics = metrics if chunk.metrics is None else chunk.metrics + metrics + yield chunk + + return stream_generator() else: - return await provider.completion(**params) + response = await provider.completion(**params) + completion_tokens = await self._count_tokens(response.content) + total_tokens = (prompt_tokens or 0) + (completion_tokens or 0) + metrics = await self._compute_and_log_token_usage( + prompt_tokens or 0, + completion_tokens or 0, + total_tokens, + model, + ) + response.metrics = metrics if response.metrics is None else response.metrics + metrics + return response async def embeddings( self, diff --git a/llama_stack/providers/inline/telemetry/meta_reference/telemetry.py b/llama_stack/providers/inline/telemetry/meta_reference/telemetry.py index e713a057f..4cdb420b2 100644 --- a/llama_stack/providers/inline/telemetry/meta_reference/telemetry.py +++ b/llama_stack/providers/inline/telemetry/meta_reference/telemetry.py @@ -73,6 +73,7 @@ class TelemetryAdapter(TelemetryDatasetMixin, Telemetry): def __init__(self, config: TelemetryConfig, deps: Dict[str, Any]) -> None: self.config = config self.datasetio_api = deps.get(Api.datasetio) + self.meter = None resource = Resource.create( { @@ -171,6 +172,8 @@ class TelemetryAdapter(TelemetryDatasetMixin, Telemetry): return _GLOBAL_STORAGE["gauges"][name] def _log_metric(self, event: MetricEvent) -> None: + if self.meter is None: + return if isinstance(event.value, int): counter = self._get_or_create_counter(event.metric, event.unit) counter.add(event.value, attributes=event.attributes)