feat: record token usage for inference API

This commit is contained in:
Dinesh Yeduguru 2025-02-27 10:57:08 -08:00
parent 00570fde31
commit 1952ffa410
3 changed files with 162 additions and 10 deletions

View file

@ -163,7 +163,9 @@ def specs_for_autorouted_apis(apis_to_serve: List[str] | Set[str]) -> Dict[str,
module="llama_stack.distribution.routers", module="llama_stack.distribution.routers",
routing_table_api=info.routing_table_api, routing_table_api=info.routing_table_api,
api_dependencies=[info.routing_table_api], api_dependencies=[info.routing_table_api],
deps__=[info.routing_table_api.value], # Add telemetry as an optional dependency to all auto-routed providers
optional_api_dependencies=[Api.telemetry],
deps__=([info.routing_table_api.value, Api.telemetry.value]),
), ),
) )
} }

View file

@ -45,7 +45,7 @@ async def get_routing_table_impl(
return impl return impl
async def get_auto_router_impl(api: Api, routing_table: RoutingTable, _deps) -> Any: async def get_auto_router_impl(api: Api, routing_table: RoutingTable, deps: Dict[str, Any]) -> Any:
from .routers import ( from .routers import (
DatasetIORouter, DatasetIORouter,
EvalRouter, EvalRouter,
@ -65,9 +65,17 @@ async def get_auto_router_impl(api: Api, routing_table: RoutingTable, _deps) ->
"eval": EvalRouter, "eval": EvalRouter,
"tool_runtime": ToolRuntimeRouter, "tool_runtime": ToolRuntimeRouter,
} }
api_to_deps = {
"inference": {"telemetry": Api.telemetry},
}
if api.value not in api_to_routers: if api.value not in api_to_routers:
raise ValueError(f"API {api.value} not found in router map") raise ValueError(f"API {api.value} not found in router map")
impl = api_to_routers[api.value](routing_table) api_to_dep_impl = {}
for dep_name, dep_api in api_to_deps.get(api.value, {}).items():
if dep_api in deps:
api_to_dep_impl[dep_name] = deps[dep_api]
impl = api_to_routers[api.value](routing_table, **api_to_dep_impl)
await impl.initialize() await impl.initialize()
return impl return impl

View file

@ -4,7 +4,11 @@
# This source code is licensed under the terms described in the LICENSE file in # This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree. # the root directory of this source tree.
from typing import Any, AsyncGenerator, Dict, List, Optional import time
from typing import Any, AsyncGenerator, AsyncIterator, Dict, List, Optional, Union
from llama_models.llama3.api.chat_format import ChatFormat
from llama_models.llama3.api.tokenizer import Tokenizer
from llama_stack import logcat from llama_stack import logcat
from llama_stack.apis.common.content_types import ( from llama_stack.apis.common.content_types import (
@ -21,6 +25,10 @@ from llama_stack.apis.eval import (
JobStatus, JobStatus,
) )
from llama_stack.apis.inference import ( from llama_stack.apis.inference import (
ChatCompletionResponse,
ChatCompletionResponseEventType,
ChatCompletionResponseStreamChunk,
CompletionMessage,
EmbeddingsResponse, EmbeddingsResponse,
EmbeddingTaskType, EmbeddingTaskType,
Inference, Inference,
@ -28,13 +36,15 @@ from llama_stack.apis.inference import (
Message, Message,
ResponseFormat, ResponseFormat,
SamplingParams, SamplingParams,
StopReason,
TextTruncation, TextTruncation,
ToolChoice, ToolChoice,
ToolConfig, ToolConfig,
ToolDefinition, ToolDefinition,
ToolPromptFormat, ToolPromptFormat,
UserMessage,
) )
from llama_stack.apis.models import ModelType from llama_stack.apis.models import Model, ModelType
from llama_stack.apis.safety import RunShieldResponse, Safety from llama_stack.apis.safety import RunShieldResponse, Safety
from llama_stack.apis.scoring import ( from llama_stack.apis.scoring import (
ScoreBatchResponse, ScoreBatchResponse,
@ -43,6 +53,7 @@ from llama_stack.apis.scoring import (
ScoringFnParams, ScoringFnParams,
) )
from llama_stack.apis.shields import Shield from llama_stack.apis.shields import Shield
from llama_stack.apis.telemetry import MetricEvent, Telemetry
from llama_stack.apis.tools import ( from llama_stack.apis.tools import (
RAGDocument, RAGDocument,
RAGQueryConfig, RAGQueryConfig,
@ -53,6 +64,7 @@ from llama_stack.apis.tools import (
) )
from llama_stack.apis.vector_io import Chunk, QueryChunksResponse, VectorIO from llama_stack.apis.vector_io import Chunk, QueryChunksResponse, VectorIO
from llama_stack.providers.datatypes import RoutingTable from llama_stack.providers.datatypes import RoutingTable
from llama_stack.providers.utils.telemetry.tracing import get_current_span
class VectorIORouter(VectorIO): class VectorIORouter(VectorIO):
@ -121,9 +133,14 @@ class InferenceRouter(Inference):
def __init__( def __init__(
self, self,
routing_table: RoutingTable, routing_table: RoutingTable,
telemetry: Optional[Telemetry] = None,
) -> None: ) -> None:
logcat.debug("core", "Initializing InferenceRouter") logcat.debug("core", "Initializing InferenceRouter")
self.routing_table = routing_table self.routing_table = routing_table
self.telemetry = telemetry
if self.telemetry:
self.tokenizer = Tokenizer.get_instance()
self.formatter = ChatFormat(self.tokenizer)
async def initialize(self) -> None: async def initialize(self) -> None:
logcat.debug("core", "InferenceRouter.initialize") logcat.debug("core", "InferenceRouter.initialize")
@ -147,6 +164,59 @@ class InferenceRouter(Inference):
) )
await self.routing_table.register_model(model_id, provider_model_id, provider_id, metadata, model_type) await self.routing_table.register_model(model_id, provider_model_id, provider_id, metadata, model_type)
def _construct_metrics(
self, prompt_tokens: int, completion_tokens: int, total_tokens: int, model: Model
) -> List[MetricEvent]:
span = get_current_span()
metrics = [
("prompt_tokens", prompt_tokens),
("completion_tokens", completion_tokens),
("total_tokens", total_tokens),
]
metric_events = []
for metric_name, value in metrics:
metric_events.append(
MetricEvent(
trace_id=span.trace_id,
span_id=span.span_id,
metric=metric_name,
value=value,
timestamp=time.time(),
unit="tokens",
attributes={
"model_id": model.model_id,
"provider_id": model.provider_id,
},
)
)
return metric_events
async def _add_token_metrics(
self,
prompt_tokens: int,
completion_tokens: int,
total_tokens: int,
model: Model,
target: Any,
) -> None:
metrics = getattr(target, "metrics", None)
if metrics is None:
target.metrics = []
metrics = self._construct_metrics(prompt_tokens, completion_tokens, total_tokens, model)
target.metrics.extend(metrics)
if self.telemetry:
for metric in metrics:
await self.telemetry.log_event(metric)
async def _count_tokens(
self,
messages: List[Message],
tool_prompt_format: Optional[ToolPromptFormat] = None,
) -> Optional[int]:
encoded = self.formatter.encode_dialog_prompt(messages, tool_prompt_format)
return len(encoded.tokens) if encoded and encoded.tokens else 0
async def chat_completion( async def chat_completion(
self, self,
model_id: str, model_id: str,
@ -159,7 +229,7 @@ class InferenceRouter(Inference):
stream: Optional[bool] = False, stream: Optional[bool] = False,
logprobs: Optional[LogProbConfig] = None, logprobs: Optional[LogProbConfig] = None,
tool_config: Optional[ToolConfig] = None, tool_config: Optional[ToolConfig] = None,
) -> AsyncGenerator: ) -> Union[ChatCompletionResponse, AsyncIterator[ChatCompletionResponseStreamChunk]]:
logcat.debug( logcat.debug(
"core", "core",
f"InferenceRouter.chat_completion: {model_id=}, {stream=}, {messages=}, {tools=}, {tool_config=}, {response_format=}", f"InferenceRouter.chat_completion: {model_id=}, {stream=}, {messages=}, {tools=}, {tool_config=}, {response_format=}",
@ -208,10 +278,47 @@ class InferenceRouter(Inference):
tool_config=tool_config, tool_config=tool_config,
) )
provider = self.routing_table.get_provider_impl(model_id) provider = self.routing_table.get_provider_impl(model_id)
prompt_tokens = await self._count_tokens(messages, tool_config.tool_prompt_format)
if stream: if stream:
return (chunk async for chunk in await provider.chat_completion(**params))
async def stream_generator():
completion_text = ""
async for chunk in await provider.chat_completion(**params):
if chunk.event.event_type == ChatCompletionResponseEventType.progress:
if chunk.event.delta.type == "text":
completion_text += chunk.event.delta.text
if chunk.event.event_type == ChatCompletionResponseEventType.complete:
completion_tokens = await self._count_tokens(
[CompletionMessage(content=completion_text, stop_reason=StopReason.end_of_turn)],
tool_config.tool_prompt_format,
)
total_tokens = (prompt_tokens or 0) + (completion_tokens or 0)
await self._add_token_metrics(
prompt_tokens or 0,
completion_tokens or 0,
total_tokens,
model,
chunk,
)
yield chunk
return stream_generator()
else: else:
return await provider.chat_completion(**params) response = await provider.chat_completion(**params)
completion_tokens = await self._count_tokens(
[response.completion_message],
tool_config.tool_prompt_format,
)
total_tokens = (prompt_tokens or 0) + (completion_tokens or 0)
await self._add_token_metrics(
prompt_tokens or 0,
completion_tokens or 0,
total_tokens,
model,
response,
)
return response
async def completion( async def completion(
self, self,
@ -240,10 +347,45 @@ class InferenceRouter(Inference):
stream=stream, stream=stream,
logprobs=logprobs, logprobs=logprobs,
) )
prompt_tokens = await self._count_tokens([UserMessage(role="user", content=str(content))])
if stream: if stream:
return (chunk async for chunk in await provider.completion(**params))
async def stream_generator():
completion_text = ""
async for chunk in await provider.completion(**params):
if hasattr(chunk, "delta"):
completion_text += chunk.delta
if hasattr(chunk, "stop_reason") and chunk.stop_reason and self.telemetry:
completion_tokens = await self._count_tokens(
[CompletionMessage(content=completion_text, stop_reason=chunk.stop_reason)]
)
total_tokens = (prompt_tokens or 0) + (completion_tokens or 0)
await self._add_token_metrics(
prompt_tokens or 0,
completion_tokens or 0,
total_tokens,
model,
chunk,
)
yield chunk
return stream_generator()
else: else:
return await provider.completion(**params) response = await provider.completion(**params)
completion_tokens = await self._count_tokens(
[CompletionMessage(content=str(response.content), stop_reason=StopReason.end_of_turn)]
)
total_tokens = (prompt_tokens or 0) + (completion_tokens or 0)
await self._add_token_metrics(
prompt_tokens or 0,
completion_tokens or 0,
total_tokens,
model,
response,
)
return response
async def embeddings( async def embeddings(
self, self,