fix: Revert "feat: record token usage for inference API (#1300)" (#1476)

This reverts commit b8535417e0.

Test plan:
LLAMA_STACK_DISABLE_VERSION_CHECK=true llama stack run
~/.llama/distributions/together/together-run.yaml
python -m examples.agents.e2e_loop_with_client_tools localhost 8321
This commit is contained in:
Dinesh Yeduguru 2025-03-07 10:16:47 -08:00 committed by GitHub
parent df4fbae35c
commit 60e7f3d705
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
5 changed files with 14 additions and 161 deletions

View file

@ -4,8 +4,7 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import time
from typing import Any, AsyncGenerator, AsyncIterator, Dict, List, Optional, Union
from typing import Any, AsyncGenerator, Dict, List, Optional
from llama_stack import logcat
from llama_stack.apis.common.content_types import (
@ -22,10 +21,6 @@ from llama_stack.apis.eval import (
JobStatus,
)
from llama_stack.apis.inference import (
ChatCompletionResponse,
ChatCompletionResponseEventType,
ChatCompletionResponseStreamChunk,
CompletionMessage,
EmbeddingsResponse,
EmbeddingTaskType,
Inference,
@ -33,14 +28,13 @@ from llama_stack.apis.inference import (
Message,
ResponseFormat,
SamplingParams,
StopReason,
TextTruncation,
ToolChoice,
ToolConfig,
ToolDefinition,
ToolPromptFormat,
)
from llama_stack.apis.models import Model, ModelType
from llama_stack.apis.models import ModelType
from llama_stack.apis.safety import RunShieldResponse, Safety
from llama_stack.apis.scoring import (
ScoreBatchResponse,
@ -49,7 +43,6 @@ from llama_stack.apis.scoring import (
ScoringFnParams,
)
from llama_stack.apis.shields import Shield
from llama_stack.apis.telemetry import MetricEvent, Telemetry
from llama_stack.apis.tools import (
RAGDocument,
RAGQueryConfig,
@ -59,10 +52,7 @@ from llama_stack.apis.tools import (
ToolRuntime,
)
from llama_stack.apis.vector_io import Chunk, QueryChunksResponse, VectorIO
from llama_stack.models.llama.llama3.chat_format import ChatFormat
from llama_stack.models.llama.llama3.tokenizer import Tokenizer
from llama_stack.providers.datatypes import RoutingTable
from llama_stack.providers.utils.telemetry.tracing import get_current_span
class VectorIORouter(VectorIO):
@ -131,14 +121,9 @@ class InferenceRouter(Inference):
def __init__(
self,
routing_table: RoutingTable,
telemetry: Optional[Telemetry] = None,
) -> None:
logcat.debug("core", "Initializing InferenceRouter")
self.routing_table = routing_table
self.telemetry = telemetry
if self.telemetry:
self.tokenizer = Tokenizer.get_instance()
self.formatter = ChatFormat(self.tokenizer)
async def initialize(self) -> None:
logcat.debug("core", "InferenceRouter.initialize")
@ -162,57 +147,6 @@ class InferenceRouter(Inference):
)
await self.routing_table.register_model(model_id, provider_model_id, provider_id, metadata, model_type)
def _construct_metrics(
self, prompt_tokens: int, completion_tokens: int, total_tokens: int, model: Model
) -> List[MetricEvent]:
span = get_current_span()
metrics = [
("prompt_tokens", prompt_tokens),
("completion_tokens", completion_tokens),
("total_tokens", total_tokens),
]
metric_events = []
for metric_name, value in metrics:
metric_events.append(
MetricEvent(
trace_id=span.trace_id,
span_id=span.span_id,
metric=metric_name,
value=value,
timestamp=time.time(),
unit="tokens",
attributes={
"model_id": model.model_id,
"provider_id": model.provider_id,
},
)
)
return metric_events
async def _compute_and_log_token_usage(
self,
prompt_tokens: int,
completion_tokens: int,
total_tokens: int,
model: Model,
) -> List[MetricEvent]:
metrics = self._construct_metrics(prompt_tokens, completion_tokens, total_tokens, model)
if self.telemetry:
for metric in metrics:
await self.telemetry.log_event(metric)
return metrics
async def _count_tokens(
self,
messages: List[Message] | InterleavedContent,
tool_prompt_format: Optional[ToolPromptFormat] = None,
) -> Optional[int]:
if isinstance(messages, list):
encoded = self.formatter.encode_dialog_prompt(messages, tool_prompt_format)
else:
encoded = self.formatter.encode_content(messages)
return len(encoded.tokens) if encoded and encoded.tokens else 0
async def chat_completion(
self,
model_id: str,
@ -225,7 +159,7 @@ class InferenceRouter(Inference):
stream: Optional[bool] = False,
logprobs: Optional[LogProbConfig] = None,
tool_config: Optional[ToolConfig] = None,
) -> Union[ChatCompletionResponse, AsyncIterator[ChatCompletionResponseStreamChunk]]:
) -> AsyncGenerator:
logcat.debug(
"core",
f"InferenceRouter.chat_completion: {model_id=}, {stream=}, {messages=}, {tools=}, {tool_config=}, {response_format=}",
@ -276,47 +210,10 @@ class InferenceRouter(Inference):
tool_config=tool_config,
)
provider = self.routing_table.get_provider_impl(model_id)
prompt_tokens = await self._count_tokens(messages, tool_config.tool_prompt_format)
if stream:
async def stream_generator():
completion_text = ""
async for chunk in await provider.chat_completion(**params):
if chunk.event.event_type == ChatCompletionResponseEventType.progress:
if chunk.event.delta.type == "text":
completion_text += chunk.event.delta.text
if chunk.event.event_type == ChatCompletionResponseEventType.complete:
completion_tokens = await self._count_tokens(
[CompletionMessage(content=completion_text, stop_reason=StopReason.end_of_turn)],
tool_config.tool_prompt_format,
)
total_tokens = (prompt_tokens or 0) + (completion_tokens or 0)
metrics = await self._compute_and_log_token_usage(
prompt_tokens or 0,
completion_tokens or 0,
total_tokens,
model,
)
chunk.metrics = metrics if chunk.metrics is None else chunk.metrics + metrics
yield chunk
return stream_generator()
return (chunk async for chunk in await provider.chat_completion(**params))
else:
response = await provider.chat_completion(**params)
completion_tokens = await self._count_tokens(
[response.completion_message],
tool_config.tool_prompt_format,
)
total_tokens = (prompt_tokens or 0) + (completion_tokens or 0)
metrics = await self._compute_and_log_token_usage(
prompt_tokens or 0,
completion_tokens or 0,
total_tokens,
model,
)
response.metrics = metrics if response.metrics is None else response.metrics + metrics
return response
return await provider.chat_completion(**params)
async def completion(
self,
@ -347,41 +244,10 @@ class InferenceRouter(Inference):
stream=stream,
logprobs=logprobs,
)
prompt_tokens = await self._count_tokens(content)
if stream:
async def stream_generator():
completion_text = ""
async for chunk in await provider.completion(**params):
if hasattr(chunk, "delta"):
completion_text += chunk.delta
if hasattr(chunk, "stop_reason") and chunk.stop_reason and self.telemetry:
completion_tokens = await self._count_tokens(completion_text)
total_tokens = (prompt_tokens or 0) + (completion_tokens or 0)
metrics = await self._compute_and_log_token_usage(
prompt_tokens or 0,
completion_tokens or 0,
total_tokens,
model,
)
chunk.metrics = metrics if chunk.metrics is None else chunk.metrics + metrics
yield chunk
return stream_generator()
return (chunk async for chunk in await provider.completion(**params))
else:
response = await provider.completion(**params)
completion_tokens = await self._count_tokens(response.content)
total_tokens = (prompt_tokens or 0) + (completion_tokens or 0)
metrics = await self._compute_and_log_token_usage(
prompt_tokens or 0,
completion_tokens or 0,
total_tokens,
model,
)
response.metrics = metrics if response.metrics is None else response.metrics + metrics
return response
return await provider.completion(**params)
async def embeddings(
self,