fix: telemetry fixes (inference and core telemetry) (#2733)

# What does this PR do?

I found a few issues while adding new metrics for various APIs:

currently metrics are only propagated in `chat_completion` and
`completion`

since most providers use the `openai_..` routes as the default in
`llama-stack-client inference chat-completion`, metrics are currently
not working as expected.

in order to get them working the following had to be done:

1. get the completion as usual
2. use new `openai_` versions of the metric gathering functions which
use `.usage` from the `OpenAI..` response types to gather the metrics
which are already populated.
3. define a `stream_generator` which counts the tokens and computes the
metrics (only for stream=True)
5. add metrics to response


NOTE: I could not add metrics to `openai_completion` where stream=True
because that ONLY returns an `OpenAICompletion` not an AsyncGenerator
that we can manipulate.


acquire the lock, and add event to the span as the other `_log_...`
methods do

some new output:

`llama-stack-client inference chat-completion --message hi`

<img width="2416" height="425" alt="Screenshot 2025-07-16 at 8 28 20 AM"
src="https://github.com/user-attachments/assets/ccdf1643-a184-4ddd-9641-d426c4d51326"
/>


and in the client:

<img width="763" height="319" alt="Screenshot 2025-07-16 at 8 28 32 AM"
src="https://github.com/user-attachments/assets/6bceb811-5201-47e9-9e16-8130f0d60007"
/>

these were not previously being recorded nor were they being printed to
the server due to the improper console sink handling

---------

Signed-off-by: Charlie Doern <cdoern@redhat.com>
This commit is contained in:
Charlie Doern 2025-08-06 16:37:40 -04:00 committed by GitHub
parent c252dfa3ef
commit 0caef40e0d
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
26 changed files with 1595 additions and 246 deletions

View file

@ -7,6 +7,7 @@
import asyncio
import time
from collections.abc import AsyncGenerator, AsyncIterator
from datetime import UTC, datetime
from typing import Annotated, Any
from openai.types.chat import ChatCompletionToolChoiceOptionParam as OpenAIChatCompletionToolChoiceOptionParam
@ -25,14 +26,21 @@ from llama_stack.apis.inference import (
ChatCompletionResponseEventType,
ChatCompletionResponseStreamChunk,
CompletionMessage,
CompletionResponse,
CompletionResponseStreamChunk,
EmbeddingsResponse,
EmbeddingTaskType,
Inference,
ListOpenAIChatCompletionResponse,
LogProbConfig,
Message,
OpenAIAssistantMessageParam,
OpenAIChatCompletion,
OpenAIChatCompletionChunk,
OpenAIChatCompletionToolCall,
OpenAIChatCompletionToolCallFunction,
OpenAIChoice,
OpenAIChoiceLogprobs,
OpenAICompletion,
OpenAICompletionWithInputMessages,
OpenAIEmbeddingsResponse,
@ -55,7 +63,6 @@ from llama_stack.models.llama.llama3.chat_format import ChatFormat
from llama_stack.models.llama.llama3.tokenizer import Tokenizer
from llama_stack.providers.datatypes import HealthResponse, HealthStatus, RoutingTable
from llama_stack.providers.utils.inference.inference_store import InferenceStore
from llama_stack.providers.utils.inference.stream_utils import stream_and_store_openai_completion
from llama_stack.providers.utils.telemetry.tracing import get_current_span
logger = get_logger(name=__name__, category="core")
@ -119,6 +126,7 @@ class InferenceRouter(Inference):
if span is None:
logger.warning("No span found for token usage metrics")
return []
metrics = [
("prompt_tokens", prompt_tokens),
("completion_tokens", completion_tokens),
@ -132,7 +140,7 @@ class InferenceRouter(Inference):
span_id=span.span_id,
metric=metric_name,
value=value,
timestamp=time.time(),
timestamp=datetime.now(UTC),
unit="tokens",
attributes={
"model_id": model.model_id,
@ -234,49 +242,26 @@ class InferenceRouter(Inference):
prompt_tokens = await self._count_tokens(messages, tool_config.tool_prompt_format)
if stream:
async def stream_generator():
completion_text = ""
async for chunk in await provider.chat_completion(**params):
if chunk.event.event_type == ChatCompletionResponseEventType.progress:
if chunk.event.delta.type == "text":
completion_text += chunk.event.delta.text
if chunk.event.event_type == ChatCompletionResponseEventType.complete:
completion_tokens = await self._count_tokens(
[
CompletionMessage(
content=completion_text,
stop_reason=StopReason.end_of_turn,
)
],
tool_config.tool_prompt_format,
)
total_tokens = (prompt_tokens or 0) + (completion_tokens or 0)
metrics = await self._compute_and_log_token_usage(
prompt_tokens or 0,
completion_tokens or 0,
total_tokens,
model,
)
chunk.metrics = metrics if chunk.metrics is None else chunk.metrics + metrics
yield chunk
return stream_generator()
else:
response = await provider.chat_completion(**params)
completion_tokens = await self._count_tokens(
[response.completion_message],
tool_config.tool_prompt_format,
response_stream = await provider.chat_completion(**params)
return self.stream_tokens_and_compute_metrics(
response=response_stream,
prompt_tokens=prompt_tokens,
model=model,
tool_prompt_format=tool_config.tool_prompt_format,
)
total_tokens = (prompt_tokens or 0) + (completion_tokens or 0)
metrics = await self._compute_and_log_token_usage(
prompt_tokens or 0,
completion_tokens or 0,
total_tokens,
model,
)
response.metrics = metrics if response.metrics is None else response.metrics + metrics
return response
response = await provider.chat_completion(**params)
metrics = await self.count_tokens_and_compute_metrics(
response=response,
prompt_tokens=prompt_tokens,
model=model,
tool_prompt_format=tool_config.tool_prompt_format,
)
# these metrics will show up in the client response.
response.metrics = (
metrics if not hasattr(response, "metrics") or response.metrics is None else response.metrics + metrics
)
return response
async def batch_chat_completion(
self,
@ -332,39 +317,20 @@ class InferenceRouter(Inference):
)
prompt_tokens = await self._count_tokens(content)
response = await provider.completion(**params)
if stream:
async def stream_generator():
completion_text = ""
async for chunk in await provider.completion(**params):
if hasattr(chunk, "delta"):
completion_text += chunk.delta
if hasattr(chunk, "stop_reason") and chunk.stop_reason and self.telemetry:
completion_tokens = await self._count_tokens(completion_text)
total_tokens = (prompt_tokens or 0) + (completion_tokens or 0)
metrics = await self._compute_and_log_token_usage(
prompt_tokens or 0,
completion_tokens or 0,
total_tokens,
model,
)
chunk.metrics = metrics if chunk.metrics is None else chunk.metrics + metrics
yield chunk
return stream_generator()
else:
response = await provider.completion(**params)
completion_tokens = await self._count_tokens(response.content)
total_tokens = (prompt_tokens or 0) + (completion_tokens or 0)
metrics = await self._compute_and_log_token_usage(
prompt_tokens or 0,
completion_tokens or 0,
total_tokens,
model,
return self.stream_tokens_and_compute_metrics(
response=response,
prompt_tokens=prompt_tokens,
model=model,
)
response.metrics = metrics if response.metrics is None else response.metrics + metrics
return response
metrics = await self.count_tokens_and_compute_metrics(
response=response, prompt_tokens=prompt_tokens, model=model
)
response.metrics = metrics if response.metrics is None else response.metrics + metrics
return response
async def batch_completion(
self,
@ -457,9 +423,29 @@ class InferenceRouter(Inference):
prompt_logprobs=prompt_logprobs,
suffix=suffix,
)
provider = await self.routing_table.get_provider_impl(model_obj.identifier)
return await provider.openai_completion(**params)
if stream:
return await provider.openai_completion(**params)
# TODO: Metrics do NOT work with openai_completion stream=True due to the fact
# that we do not return an AsyncIterator, our tests expect a stream of chunks we cannot intercept currently.
# response_stream = await provider.openai_completion(**params)
response = await provider.openai_completion(**params)
if self.telemetry:
metrics = self._construct_metrics(
prompt_tokens=response.usage.prompt_tokens,
completion_tokens=response.usage.completion_tokens,
total_tokens=response.usage.total_tokens,
model=model_obj,
)
for metric in metrics:
await self.telemetry.log_event(metric)
# these metrics will show up in the client response.
response.metrics = (
metrics if not hasattr(response, "metrics") or response.metrics is None else response.metrics + metrics
)
return response
async def openai_chat_completion(
self,
@ -537,18 +523,38 @@ class InferenceRouter(Inference):
top_p=top_p,
user=user,
)
provider = await self.routing_table.get_provider_impl(model_obj.identifier)
if stream:
response_stream = await provider.openai_chat_completion(**params)
if self.store:
return stream_and_store_openai_completion(response_stream, model, self.store, messages)
return response_stream
else:
response = await self._nonstream_openai_chat_completion(provider, params)
if self.store:
await self.store.store_chat_completion(response, messages)
return response
# For streaming, the provider returns AsyncIterator[OpenAIChatCompletionChunk]
# We need to add metrics to each chunk and store the final completion
return self.stream_tokens_and_compute_metrics_openai_chat(
response=response_stream,
model=model_obj,
messages=messages,
)
response = await self._nonstream_openai_chat_completion(provider, params)
# Store the response with the ID that will be returned to the client
if self.store:
await self.store.store_chat_completion(response, messages)
if self.telemetry:
metrics = self._construct_metrics(
prompt_tokens=response.usage.prompt_tokens,
completion_tokens=response.usage.completion_tokens,
total_tokens=response.usage.total_tokens,
model=model_obj,
)
for metric in metrics:
await self.telemetry.log_event(metric)
# these metrics will show up in the client response.
response.metrics = (
metrics if not hasattr(response, "metrics") or response.metrics is None else response.metrics + metrics
)
return response
async def openai_embeddings(
self,
@ -625,3 +631,244 @@ class InferenceRouter(Inference):
status=HealthStatus.ERROR, message=f"Health check failed: {str(e)}"
)
return health_statuses
async def stream_tokens_and_compute_metrics(
self,
response,
prompt_tokens,
model,
tool_prompt_format: ToolPromptFormat | None = None,
) -> AsyncGenerator[ChatCompletionResponseStreamChunk, None] | AsyncGenerator[CompletionResponseStreamChunk, None]:
completion_text = ""
async for chunk in response:
complete = False
if hasattr(chunk, "event"): # only ChatCompletions have .event
if chunk.event.event_type == ChatCompletionResponseEventType.progress:
if chunk.event.delta.type == "text":
completion_text += chunk.event.delta.text
if chunk.event.event_type == ChatCompletionResponseEventType.complete:
complete = True
completion_tokens = await self._count_tokens(
[
CompletionMessage(
content=completion_text,
stop_reason=StopReason.end_of_turn,
)
],
tool_prompt_format=tool_prompt_format,
)
else:
if hasattr(chunk, "delta"):
completion_text += chunk.delta
if hasattr(chunk, "stop_reason") and chunk.stop_reason and self.telemetry:
complete = True
completion_tokens = await self._count_tokens(completion_text)
# if we are done receiving tokens
if complete:
total_tokens = (prompt_tokens or 0) + (completion_tokens or 0)
# Create a separate span for streaming completion metrics
if self.telemetry:
# Log metrics in the new span context
completion_metrics = self._construct_metrics(
prompt_tokens=prompt_tokens,
completion_tokens=completion_tokens,
total_tokens=total_tokens,
model=model,
)
for metric in completion_metrics:
if metric.metric in [
"completion_tokens",
"total_tokens",
]: # Only log completion and total tokens
await self.telemetry.log_event(metric)
# Return metrics in response
async_metrics = [
MetricInResponse(metric=metric.metric, value=metric.value) for metric in completion_metrics
]
chunk.metrics = async_metrics if chunk.metrics is None else chunk.metrics + async_metrics
else:
# Fallback if no telemetry
completion_metrics = self._construct_metrics(
prompt_tokens or 0,
completion_tokens or 0,
total_tokens,
model,
)
async_metrics = [
MetricInResponse(metric=metric.metric, value=metric.value) for metric in completion_metrics
]
chunk.metrics = async_metrics if chunk.metrics is None else chunk.metrics + async_metrics
yield chunk
async def count_tokens_and_compute_metrics(
self,
response: ChatCompletionResponse | CompletionResponse,
prompt_tokens,
model,
tool_prompt_format: ToolPromptFormat | None = None,
):
if isinstance(response, ChatCompletionResponse):
content = [response.completion_message]
else:
content = response.content
completion_tokens = await self._count_tokens(messages=content, tool_prompt_format=tool_prompt_format)
total_tokens = (prompt_tokens or 0) + (completion_tokens or 0)
# Create a separate span for completion metrics
if self.telemetry:
# Log metrics in the new span context
completion_metrics = self._construct_metrics(
prompt_tokens=prompt_tokens,
completion_tokens=completion_tokens,
total_tokens=total_tokens,
model=model,
)
for metric in completion_metrics:
if metric.metric in ["completion_tokens", "total_tokens"]: # Only log completion and total tokens
await self.telemetry.log_event(metric)
# Return metrics in response
return [MetricInResponse(metric=metric.metric, value=metric.value) for metric in completion_metrics]
# Fallback if no telemetry
metrics = self._construct_metrics(
prompt_tokens or 0,
completion_tokens or 0,
total_tokens,
model,
)
return [MetricInResponse(metric=metric.metric, value=metric.value) for metric in metrics]
async def stream_tokens_and_compute_metrics_openai_chat(
self,
response: AsyncIterator[OpenAIChatCompletionChunk],
model: Model,
messages: list[OpenAIMessageParam] | None = None,
) -> AsyncIterator[OpenAIChatCompletionChunk]:
"""Stream OpenAI chat completion chunks, compute metrics, and store the final completion."""
id = None
created = None
choices_data: dict[int, dict[str, Any]] = {}
try:
async for chunk in response:
# Skip None chunks
if chunk is None:
continue
# Capture ID and created timestamp from first chunk
if id is None and chunk.id:
id = chunk.id
if created is None and chunk.created:
created = chunk.created
# Accumulate choice data for final assembly
if chunk.choices:
for choice_delta in chunk.choices:
idx = choice_delta.index
if idx not in choices_data:
choices_data[idx] = {
"content_parts": [],
"tool_calls_builder": {},
"finish_reason": None,
"logprobs_content_parts": [],
}
current_choice_data = choices_data[idx]
if choice_delta.delta:
delta = choice_delta.delta
if delta.content:
current_choice_data["content_parts"].append(delta.content)
if delta.tool_calls:
for tool_call_delta in delta.tool_calls:
tc_idx = tool_call_delta.index
if tc_idx not in current_choice_data["tool_calls_builder"]:
current_choice_data["tool_calls_builder"][tc_idx] = {
"id": None,
"type": "function",
"function_name_parts": [],
"function_arguments_parts": [],
}
builder = current_choice_data["tool_calls_builder"][tc_idx]
if tool_call_delta.id:
builder["id"] = tool_call_delta.id
if tool_call_delta.type:
builder["type"] = tool_call_delta.type
if tool_call_delta.function:
if tool_call_delta.function.name:
builder["function_name_parts"].append(tool_call_delta.function.name)
if tool_call_delta.function.arguments:
builder["function_arguments_parts"].append(
tool_call_delta.function.arguments
)
if choice_delta.finish_reason:
current_choice_data["finish_reason"] = choice_delta.finish_reason
if choice_delta.logprobs and choice_delta.logprobs.content:
current_choice_data["logprobs_content_parts"].extend(choice_delta.logprobs.content)
# Compute metrics on final chunk
if chunk.choices and chunk.choices[0].finish_reason:
completion_text = ""
for choice_data in choices_data.values():
completion_text += "".join(choice_data["content_parts"])
# Add metrics to the chunk
if self.telemetry and chunk.usage:
metrics = self._construct_metrics(
prompt_tokens=chunk.usage.prompt_tokens,
completion_tokens=chunk.usage.completion_tokens,
total_tokens=chunk.usage.total_tokens,
model=model,
)
for metric in metrics:
await self.telemetry.log_event(metric)
yield chunk
finally:
# Store the final assembled completion
if id and self.store and messages:
assembled_choices: list[OpenAIChoice] = []
for choice_idx, choice_data in choices_data.items():
content_str = "".join(choice_data["content_parts"])
assembled_tool_calls: list[OpenAIChatCompletionToolCall] = []
if choice_data["tool_calls_builder"]:
for tc_build_data in choice_data["tool_calls_builder"].values():
if tc_build_data["id"]:
func_name = "".join(tc_build_data["function_name_parts"])
func_args = "".join(tc_build_data["function_arguments_parts"])
assembled_tool_calls.append(
OpenAIChatCompletionToolCall(
id=tc_build_data["id"],
type=tc_build_data["type"],
function=OpenAIChatCompletionToolCallFunction(
name=func_name, arguments=func_args
),
)
)
message = OpenAIAssistantMessageParam(
role="assistant",
content=content_str if content_str else None,
tool_calls=assembled_tool_calls if assembled_tool_calls else None,
)
logprobs_content = choice_data["logprobs_content_parts"]
final_logprobs = OpenAIChoiceLogprobs(content=logprobs_content) if logprobs_content else None
assembled_choices.append(
OpenAIChoice(
finish_reason=choice_data["finish_reason"],
index=choice_idx,
message=message,
logprobs=final_logprobs,
)
)
final_response = OpenAIChatCompletion(
id=id,
choices=assembled_choices,
created=created or int(time.time()),
model=model.identifier,
object="chat.completion",
)
await self.store.store_chat_completion(final_response, messages)

View file

@ -28,9 +28,6 @@ class ConsoleSpanProcessor(SpanProcessor):
logger.info(f"[dim]{timestamp}[/dim] [bold magenta][START][/bold magenta] [dim]{span.name}[/dim]")
def on_end(self, span: ReadableSpan) -> None:
if span.attributes and span.attributes.get("__autotraced__"):
return
timestamp = datetime.fromtimestamp(span.end_time / 1e9, tz=UTC).strftime("%H:%M:%S.%f")[:-3]
span_context = f"[dim]{timestamp}[/dim] [bold magenta][END][/bold magenta] [dim]{span.name}[/dim]"
if span.status.status_code == StatusCode.ERROR:
@ -67,7 +64,7 @@ class ConsoleSpanProcessor(SpanProcessor):
for key, value in event.attributes.items():
if key.startswith("__") or key in ["message", "severity"]:
continue
logger.info(f"/r[dim]{key}[/dim]: {value}")
logger.info(f"[dim]{key}[/dim]: {value}")
def shutdown(self) -> None:
"""Shutdown the processor."""

View file

@ -4,10 +4,13 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import logging
import threading
from typing import Any
from opentelemetry import metrics, trace
logger = logging.getLogger(__name__)
from opentelemetry.exporter.otlp.proto.http.metric_exporter import OTLPMetricExporter
from opentelemetry.exporter.otlp.proto.http.trace_exporter import OTLPSpanExporter
from opentelemetry.sdk.metrics import MeterProvider
@ -110,7 +113,7 @@ class TelemetryAdapter(TelemetryDatasetMixin, Telemetry):
if TelemetrySink.SQLITE in self.config.sinks:
trace.get_tracer_provider().add_span_processor(SQLiteSpanProcessor(self.config.sqlite_db_path))
if TelemetrySink.CONSOLE in self.config.sinks:
trace.get_tracer_provider().add_span_processor(ConsoleSpanProcessor())
trace.get_tracer_provider().add_span_processor(ConsoleSpanProcessor(print_attributes=True))
if TelemetrySink.OTEL_METRIC in self.config.sinks:
self.meter = metrics.get_meter(__name__)
@ -126,9 +129,11 @@ class TelemetryAdapter(TelemetryDatasetMixin, Telemetry):
trace.get_tracer_provider().force_flush()
async def log_event(self, event: Event, ttl_seconds: int = 604800) -> None:
logger.debug(f"DEBUG: log_event called with event type: {type(event).__name__}")
if isinstance(event, UnstructuredLogEvent):
self._log_unstructured(event, ttl_seconds)
elif isinstance(event, MetricEvent):
logger.debug("DEBUG: Routing MetricEvent to _log_metric")
self._log_metric(event)
elif isinstance(event, StructuredLogEvent):
self._log_structured(event, ttl_seconds)
@ -188,6 +193,38 @@ class TelemetryAdapter(TelemetryDatasetMixin, Telemetry):
return _GLOBAL_STORAGE["gauges"][name]
def _log_metric(self, event: MetricEvent) -> None:
# Always log to console if console sink is enabled (debug)
if TelemetrySink.CONSOLE in self.config.sinks:
logger.debug(f"METRIC: {event.metric}={event.value} {event.unit} {event.attributes}")
# Add metric as an event to the current span
try:
with self._lock:
# Only try to add to span if we have a valid span_id
if event.span_id:
try:
span_id = int(event.span_id, 16)
span = _GLOBAL_STORAGE["active_spans"].get(span_id)
if span:
timestamp_ns = int(event.timestamp.timestamp() * 1e9)
span.add_event(
name=f"metric.{event.metric}",
attributes={
"value": event.value,
"unit": event.unit,
**(event.attributes or {}),
},
timestamp=timestamp_ns,
)
except (ValueError, KeyError):
# Invalid span_id or span not found, but we already logged to console above
pass
except Exception:
# Lock acquisition failed
logger.debug("Failed to acquire lock to add metric to span")
# Log to OpenTelemetry meter if available
if self.meter is None:
return
if isinstance(event.value, int):

View file

@ -1,129 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from collections.abc import AsyncIterator
from datetime import UTC, datetime
from typing import Any
from llama_stack.apis.inference import (
OpenAIAssistantMessageParam,
OpenAIChatCompletion,
OpenAIChatCompletionChunk,
OpenAIChatCompletionToolCall,
OpenAIChatCompletionToolCallFunction,
OpenAIChoice,
OpenAIChoiceLogprobs,
OpenAIMessageParam,
)
from llama_stack.providers.utils.inference.inference_store import InferenceStore
async def stream_and_store_openai_completion(
provider_stream: AsyncIterator[OpenAIChatCompletionChunk],
model: str,
store: InferenceStore,
input_messages: list[OpenAIMessageParam],
) -> AsyncIterator[OpenAIChatCompletionChunk]:
"""
Wraps a provider's stream, yields chunks, and stores the full completion at the end.
"""
id = None
created = None
choices_data: dict[int, dict[str, Any]] = {}
try:
async for chunk in provider_stream:
if id is None and chunk.id:
id = chunk.id
if created is None and chunk.created:
created = chunk.created
if chunk.choices:
for choice_delta in chunk.choices:
idx = choice_delta.index
if idx not in choices_data:
choices_data[idx] = {
"content_parts": [],
"tool_calls_builder": {},
"finish_reason": None,
"logprobs_content_parts": [],
}
current_choice_data = choices_data[idx]
if choice_delta.delta:
delta = choice_delta.delta
if delta.content:
current_choice_data["content_parts"].append(delta.content)
if delta.tool_calls:
for tool_call_delta in delta.tool_calls:
tc_idx = tool_call_delta.index
if tc_idx not in current_choice_data["tool_calls_builder"]:
# Initialize with correct structure for _ToolCallBuilderData
current_choice_data["tool_calls_builder"][tc_idx] = {
"id": None,
"type": "function",
"function_name_parts": [],
"function_arguments_parts": [],
}
builder = current_choice_data["tool_calls_builder"][tc_idx]
if tool_call_delta.id:
builder["id"] = tool_call_delta.id
if tool_call_delta.type:
builder["type"] = tool_call_delta.type
if tool_call_delta.function:
if tool_call_delta.function.name:
builder["function_name_parts"].append(tool_call_delta.function.name)
if tool_call_delta.function.arguments:
builder["function_arguments_parts"].append(tool_call_delta.function.arguments)
if choice_delta.finish_reason:
current_choice_data["finish_reason"] = choice_delta.finish_reason
if choice_delta.logprobs and choice_delta.logprobs.content:
# Ensure that we are extending with the correct type
current_choice_data["logprobs_content_parts"].extend(choice_delta.logprobs.content)
yield chunk
finally:
if id:
assembled_choices: list[OpenAIChoice] = []
for choice_idx, choice_data in choices_data.items():
content_str = "".join(choice_data["content_parts"])
assembled_tool_calls: list[OpenAIChatCompletionToolCall] = []
if choice_data["tool_calls_builder"]:
for tc_build_data in choice_data["tool_calls_builder"].values():
if tc_build_data["id"]:
func_name = "".join(tc_build_data["function_name_parts"])
func_args = "".join(tc_build_data["function_arguments_parts"])
assembled_tool_calls.append(
OpenAIChatCompletionToolCall(
id=tc_build_data["id"],
type=tc_build_data["type"], # No or "function" needed, already set
function=OpenAIChatCompletionToolCallFunction(name=func_name, arguments=func_args),
)
)
message = OpenAIAssistantMessageParam(
role="assistant",
content=content_str if content_str else None,
tool_calls=assembled_tool_calls if assembled_tool_calls else None,
)
logprobs_content = choice_data["logprobs_content_parts"]
final_logprobs = OpenAIChoiceLogprobs(content=logprobs_content) if logprobs_content else None
assembled_choices.append(
OpenAIChoice(
finish_reason=choice_data["finish_reason"],
index=choice_idx,
message=message,
logprobs=final_logprobs,
)
)
final_response = OpenAIChatCompletion(
id=id,
choices=assembled_choices,
created=created or int(datetime.now(UTC).timestamp()),
model=model,
object="chat.completion",
)
await store.store_chat_completion(final_response, input_messages)

View file

@ -81,7 +81,7 @@ BACKGROUND_LOGGER = None
class BackgroundLogger:
def __init__(self, api: Telemetry, capacity: int = 1000):
def __init__(self, api: Telemetry, capacity: int = 100000):
self.api = api
self.log_queue = queue.Queue(maxsize=capacity)
self.worker_thread = threading.Thread(target=self._process_logs, daemon=True)