mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-08-12 04:50:39 +00:00
feat: record token usage for inference API
This commit is contained in:
parent
00570fde31
commit
1952ffa410
3 changed files with 162 additions and 10 deletions
|
@ -163,7 +163,9 @@ def specs_for_autorouted_apis(apis_to_serve: List[str] | Set[str]) -> Dict[str,
|
||||||
module="llama_stack.distribution.routers",
|
module="llama_stack.distribution.routers",
|
||||||
routing_table_api=info.routing_table_api,
|
routing_table_api=info.routing_table_api,
|
||||||
api_dependencies=[info.routing_table_api],
|
api_dependencies=[info.routing_table_api],
|
||||||
deps__=[info.routing_table_api.value],
|
# Add telemetry as an optional dependency to all auto-routed providers
|
||||||
|
optional_api_dependencies=[Api.telemetry],
|
||||||
|
deps__=([info.routing_table_api.value, Api.telemetry.value]),
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|
|
@ -45,7 +45,7 @@ async def get_routing_table_impl(
|
||||||
return impl
|
return impl
|
||||||
|
|
||||||
|
|
||||||
async def get_auto_router_impl(api: Api, routing_table: RoutingTable, _deps) -> Any:
|
async def get_auto_router_impl(api: Api, routing_table: RoutingTable, deps: Dict[str, Any]) -> Any:
|
||||||
from .routers import (
|
from .routers import (
|
||||||
DatasetIORouter,
|
DatasetIORouter,
|
||||||
EvalRouter,
|
EvalRouter,
|
||||||
|
@ -65,9 +65,17 @@ async def get_auto_router_impl(api: Api, routing_table: RoutingTable, _deps) ->
|
||||||
"eval": EvalRouter,
|
"eval": EvalRouter,
|
||||||
"tool_runtime": ToolRuntimeRouter,
|
"tool_runtime": ToolRuntimeRouter,
|
||||||
}
|
}
|
||||||
|
api_to_deps = {
|
||||||
|
"inference": {"telemetry": Api.telemetry},
|
||||||
|
}
|
||||||
if api.value not in api_to_routers:
|
if api.value not in api_to_routers:
|
||||||
raise ValueError(f"API {api.value} not found in router map")
|
raise ValueError(f"API {api.value} not found in router map")
|
||||||
|
|
||||||
impl = api_to_routers[api.value](routing_table)
|
api_to_dep_impl = {}
|
||||||
|
for dep_name, dep_api in api_to_deps.get(api.value, {}).items():
|
||||||
|
if dep_api in deps:
|
||||||
|
api_to_dep_impl[dep_name] = deps[dep_api]
|
||||||
|
|
||||||
|
impl = api_to_routers[api.value](routing_table, **api_to_dep_impl)
|
||||||
await impl.initialize()
|
await impl.initialize()
|
||||||
return impl
|
return impl
|
||||||
|
|
|
@ -4,7 +4,11 @@
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
from typing import Any, AsyncGenerator, Dict, List, Optional
|
import time
|
||||||
|
from typing import Any, AsyncGenerator, AsyncIterator, Dict, List, Optional, Union
|
||||||
|
|
||||||
|
from llama_models.llama3.api.chat_format import ChatFormat
|
||||||
|
from llama_models.llama3.api.tokenizer import Tokenizer
|
||||||
|
|
||||||
from llama_stack import logcat
|
from llama_stack import logcat
|
||||||
from llama_stack.apis.common.content_types import (
|
from llama_stack.apis.common.content_types import (
|
||||||
|
@ -21,6 +25,10 @@ from llama_stack.apis.eval import (
|
||||||
JobStatus,
|
JobStatus,
|
||||||
)
|
)
|
||||||
from llama_stack.apis.inference import (
|
from llama_stack.apis.inference import (
|
||||||
|
ChatCompletionResponse,
|
||||||
|
ChatCompletionResponseEventType,
|
||||||
|
ChatCompletionResponseStreamChunk,
|
||||||
|
CompletionMessage,
|
||||||
EmbeddingsResponse,
|
EmbeddingsResponse,
|
||||||
EmbeddingTaskType,
|
EmbeddingTaskType,
|
||||||
Inference,
|
Inference,
|
||||||
|
@ -28,13 +36,15 @@ from llama_stack.apis.inference import (
|
||||||
Message,
|
Message,
|
||||||
ResponseFormat,
|
ResponseFormat,
|
||||||
SamplingParams,
|
SamplingParams,
|
||||||
|
StopReason,
|
||||||
TextTruncation,
|
TextTruncation,
|
||||||
ToolChoice,
|
ToolChoice,
|
||||||
ToolConfig,
|
ToolConfig,
|
||||||
ToolDefinition,
|
ToolDefinition,
|
||||||
ToolPromptFormat,
|
ToolPromptFormat,
|
||||||
|
UserMessage,
|
||||||
)
|
)
|
||||||
from llama_stack.apis.models import ModelType
|
from llama_stack.apis.models import Model, ModelType
|
||||||
from llama_stack.apis.safety import RunShieldResponse, Safety
|
from llama_stack.apis.safety import RunShieldResponse, Safety
|
||||||
from llama_stack.apis.scoring import (
|
from llama_stack.apis.scoring import (
|
||||||
ScoreBatchResponse,
|
ScoreBatchResponse,
|
||||||
|
@ -43,6 +53,7 @@ from llama_stack.apis.scoring import (
|
||||||
ScoringFnParams,
|
ScoringFnParams,
|
||||||
)
|
)
|
||||||
from llama_stack.apis.shields import Shield
|
from llama_stack.apis.shields import Shield
|
||||||
|
from llama_stack.apis.telemetry import MetricEvent, Telemetry
|
||||||
from llama_stack.apis.tools import (
|
from llama_stack.apis.tools import (
|
||||||
RAGDocument,
|
RAGDocument,
|
||||||
RAGQueryConfig,
|
RAGQueryConfig,
|
||||||
|
@ -53,6 +64,7 @@ from llama_stack.apis.tools import (
|
||||||
)
|
)
|
||||||
from llama_stack.apis.vector_io import Chunk, QueryChunksResponse, VectorIO
|
from llama_stack.apis.vector_io import Chunk, QueryChunksResponse, VectorIO
|
||||||
from llama_stack.providers.datatypes import RoutingTable
|
from llama_stack.providers.datatypes import RoutingTable
|
||||||
|
from llama_stack.providers.utils.telemetry.tracing import get_current_span
|
||||||
|
|
||||||
|
|
||||||
class VectorIORouter(VectorIO):
|
class VectorIORouter(VectorIO):
|
||||||
|
@ -121,9 +133,14 @@ class InferenceRouter(Inference):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
routing_table: RoutingTable,
|
routing_table: RoutingTable,
|
||||||
|
telemetry: Optional[Telemetry] = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
logcat.debug("core", "Initializing InferenceRouter")
|
logcat.debug("core", "Initializing InferenceRouter")
|
||||||
self.routing_table = routing_table
|
self.routing_table = routing_table
|
||||||
|
self.telemetry = telemetry
|
||||||
|
if self.telemetry:
|
||||||
|
self.tokenizer = Tokenizer.get_instance()
|
||||||
|
self.formatter = ChatFormat(self.tokenizer)
|
||||||
|
|
||||||
async def initialize(self) -> None:
|
async def initialize(self) -> None:
|
||||||
logcat.debug("core", "InferenceRouter.initialize")
|
logcat.debug("core", "InferenceRouter.initialize")
|
||||||
|
@ -147,6 +164,59 @@ class InferenceRouter(Inference):
|
||||||
)
|
)
|
||||||
await self.routing_table.register_model(model_id, provider_model_id, provider_id, metadata, model_type)
|
await self.routing_table.register_model(model_id, provider_model_id, provider_id, metadata, model_type)
|
||||||
|
|
||||||
|
def _construct_metrics(
|
||||||
|
self, prompt_tokens: int, completion_tokens: int, total_tokens: int, model: Model
|
||||||
|
) -> List[MetricEvent]:
|
||||||
|
span = get_current_span()
|
||||||
|
metrics = [
|
||||||
|
("prompt_tokens", prompt_tokens),
|
||||||
|
("completion_tokens", completion_tokens),
|
||||||
|
("total_tokens", total_tokens),
|
||||||
|
]
|
||||||
|
metric_events = []
|
||||||
|
for metric_name, value in metrics:
|
||||||
|
metric_events.append(
|
||||||
|
MetricEvent(
|
||||||
|
trace_id=span.trace_id,
|
||||||
|
span_id=span.span_id,
|
||||||
|
metric=metric_name,
|
||||||
|
value=value,
|
||||||
|
timestamp=time.time(),
|
||||||
|
unit="tokens",
|
||||||
|
attributes={
|
||||||
|
"model_id": model.model_id,
|
||||||
|
"provider_id": model.provider_id,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
)
|
||||||
|
return metric_events
|
||||||
|
|
||||||
|
async def _add_token_metrics(
|
||||||
|
self,
|
||||||
|
prompt_tokens: int,
|
||||||
|
completion_tokens: int,
|
||||||
|
total_tokens: int,
|
||||||
|
model: Model,
|
||||||
|
target: Any,
|
||||||
|
) -> None:
|
||||||
|
metrics = getattr(target, "metrics", None)
|
||||||
|
if metrics is None:
|
||||||
|
target.metrics = []
|
||||||
|
|
||||||
|
metrics = self._construct_metrics(prompt_tokens, completion_tokens, total_tokens, model)
|
||||||
|
target.metrics.extend(metrics)
|
||||||
|
if self.telemetry:
|
||||||
|
for metric in metrics:
|
||||||
|
await self.telemetry.log_event(metric)
|
||||||
|
|
||||||
|
async def _count_tokens(
|
||||||
|
self,
|
||||||
|
messages: List[Message],
|
||||||
|
tool_prompt_format: Optional[ToolPromptFormat] = None,
|
||||||
|
) -> Optional[int]:
|
||||||
|
encoded = self.formatter.encode_dialog_prompt(messages, tool_prompt_format)
|
||||||
|
return len(encoded.tokens) if encoded and encoded.tokens else 0
|
||||||
|
|
||||||
async def chat_completion(
|
async def chat_completion(
|
||||||
self,
|
self,
|
||||||
model_id: str,
|
model_id: str,
|
||||||
|
@ -159,7 +229,7 @@ class InferenceRouter(Inference):
|
||||||
stream: Optional[bool] = False,
|
stream: Optional[bool] = False,
|
||||||
logprobs: Optional[LogProbConfig] = None,
|
logprobs: Optional[LogProbConfig] = None,
|
||||||
tool_config: Optional[ToolConfig] = None,
|
tool_config: Optional[ToolConfig] = None,
|
||||||
) -> AsyncGenerator:
|
) -> Union[ChatCompletionResponse, AsyncIterator[ChatCompletionResponseStreamChunk]]:
|
||||||
logcat.debug(
|
logcat.debug(
|
||||||
"core",
|
"core",
|
||||||
f"InferenceRouter.chat_completion: {model_id=}, {stream=}, {messages=}, {tools=}, {tool_config=}, {response_format=}",
|
f"InferenceRouter.chat_completion: {model_id=}, {stream=}, {messages=}, {tools=}, {tool_config=}, {response_format=}",
|
||||||
|
@ -208,10 +278,47 @@ class InferenceRouter(Inference):
|
||||||
tool_config=tool_config,
|
tool_config=tool_config,
|
||||||
)
|
)
|
||||||
provider = self.routing_table.get_provider_impl(model_id)
|
provider = self.routing_table.get_provider_impl(model_id)
|
||||||
|
prompt_tokens = await self._count_tokens(messages, tool_config.tool_prompt_format)
|
||||||
|
|
||||||
if stream:
|
if stream:
|
||||||
return (chunk async for chunk in await provider.chat_completion(**params))
|
|
||||||
|
async def stream_generator():
|
||||||
|
completion_text = ""
|
||||||
|
async for chunk in await provider.chat_completion(**params):
|
||||||
|
if chunk.event.event_type == ChatCompletionResponseEventType.progress:
|
||||||
|
if chunk.event.delta.type == "text":
|
||||||
|
completion_text += chunk.event.delta.text
|
||||||
|
if chunk.event.event_type == ChatCompletionResponseEventType.complete:
|
||||||
|
completion_tokens = await self._count_tokens(
|
||||||
|
[CompletionMessage(content=completion_text, stop_reason=StopReason.end_of_turn)],
|
||||||
|
tool_config.tool_prompt_format,
|
||||||
|
)
|
||||||
|
total_tokens = (prompt_tokens or 0) + (completion_tokens or 0)
|
||||||
|
await self._add_token_metrics(
|
||||||
|
prompt_tokens or 0,
|
||||||
|
completion_tokens or 0,
|
||||||
|
total_tokens,
|
||||||
|
model,
|
||||||
|
chunk,
|
||||||
|
)
|
||||||
|
yield chunk
|
||||||
|
|
||||||
|
return stream_generator()
|
||||||
else:
|
else:
|
||||||
return await provider.chat_completion(**params)
|
response = await provider.chat_completion(**params)
|
||||||
|
completion_tokens = await self._count_tokens(
|
||||||
|
[response.completion_message],
|
||||||
|
tool_config.tool_prompt_format,
|
||||||
|
)
|
||||||
|
total_tokens = (prompt_tokens or 0) + (completion_tokens or 0)
|
||||||
|
await self._add_token_metrics(
|
||||||
|
prompt_tokens or 0,
|
||||||
|
completion_tokens or 0,
|
||||||
|
total_tokens,
|
||||||
|
model,
|
||||||
|
response,
|
||||||
|
)
|
||||||
|
return response
|
||||||
|
|
||||||
async def completion(
|
async def completion(
|
||||||
self,
|
self,
|
||||||
|
@ -240,10 +347,45 @@ class InferenceRouter(Inference):
|
||||||
stream=stream,
|
stream=stream,
|
||||||
logprobs=logprobs,
|
logprobs=logprobs,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
prompt_tokens = await self._count_tokens([UserMessage(role="user", content=str(content))])
|
||||||
|
|
||||||
if stream:
|
if stream:
|
||||||
return (chunk async for chunk in await provider.completion(**params))
|
|
||||||
|
async def stream_generator():
|
||||||
|
completion_text = ""
|
||||||
|
async for chunk in await provider.completion(**params):
|
||||||
|
if hasattr(chunk, "delta"):
|
||||||
|
completion_text += chunk.delta
|
||||||
|
if hasattr(chunk, "stop_reason") and chunk.stop_reason and self.telemetry:
|
||||||
|
completion_tokens = await self._count_tokens(
|
||||||
|
[CompletionMessage(content=completion_text, stop_reason=chunk.stop_reason)]
|
||||||
|
)
|
||||||
|
total_tokens = (prompt_tokens or 0) + (completion_tokens or 0)
|
||||||
|
await self._add_token_metrics(
|
||||||
|
prompt_tokens or 0,
|
||||||
|
completion_tokens or 0,
|
||||||
|
total_tokens,
|
||||||
|
model,
|
||||||
|
chunk,
|
||||||
|
)
|
||||||
|
yield chunk
|
||||||
|
|
||||||
|
return stream_generator()
|
||||||
else:
|
else:
|
||||||
return await provider.completion(**params)
|
response = await provider.completion(**params)
|
||||||
|
completion_tokens = await self._count_tokens(
|
||||||
|
[CompletionMessage(content=str(response.content), stop_reason=StopReason.end_of_turn)]
|
||||||
|
)
|
||||||
|
total_tokens = (prompt_tokens or 0) + (completion_tokens or 0)
|
||||||
|
await self._add_token_metrics(
|
||||||
|
prompt_tokens or 0,
|
||||||
|
completion_tokens or 0,
|
||||||
|
total_tokens,
|
||||||
|
model,
|
||||||
|
response,
|
||||||
|
)
|
||||||
|
return response
|
||||||
|
|
||||||
async def embeddings(
|
async def embeddings(
|
||||||
self,
|
self,
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue