mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-08-15 14:08:00 +00:00
fix: actually propagate inference metrics
currently metrics are only propagated in `chat_completion` and `completion` since most providers use the openai_.. routes as the default in llama-stack-client inference chat-completion, metrics are currently not working as expected. in order to get them working the following had to be done: 1. get the completion as usual 2. use new openai_ versions of the metric gathering functions which use .usage from the OpenAI.. response types to gather the metrics which are already populated. 3. define a `stream_generator` which counts the tokens and computes the metrics 4. use a NEW span and log_metrics because the span of the request ends before this processing is complete, leading to no logging unless a custom span is used 5. add metrics to response Signed-off-by: Charlie Doern <cdoern@redhat.com>
This commit is contained in:
parent
c252dfa3ef
commit
d52722b0d1
3 changed files with 335 additions and 217 deletions
|
@ -7,6 +7,7 @@
|
||||||
import asyncio
|
import asyncio
|
||||||
import time
|
import time
|
||||||
from collections.abc import AsyncGenerator, AsyncIterator
|
from collections.abc import AsyncGenerator, AsyncIterator
|
||||||
|
from datetime import UTC, datetime
|
||||||
from typing import Annotated, Any
|
from typing import Annotated, Any
|
||||||
|
|
||||||
from openai.types.chat import ChatCompletionToolChoiceOptionParam as OpenAIChatCompletionToolChoiceOptionParam
|
from openai.types.chat import ChatCompletionToolChoiceOptionParam as OpenAIChatCompletionToolChoiceOptionParam
|
||||||
|
@ -25,14 +26,21 @@ from llama_stack.apis.inference import (
|
||||||
ChatCompletionResponseEventType,
|
ChatCompletionResponseEventType,
|
||||||
ChatCompletionResponseStreamChunk,
|
ChatCompletionResponseStreamChunk,
|
||||||
CompletionMessage,
|
CompletionMessage,
|
||||||
|
CompletionResponse,
|
||||||
|
CompletionResponseStreamChunk,
|
||||||
EmbeddingsResponse,
|
EmbeddingsResponse,
|
||||||
EmbeddingTaskType,
|
EmbeddingTaskType,
|
||||||
Inference,
|
Inference,
|
||||||
ListOpenAIChatCompletionResponse,
|
ListOpenAIChatCompletionResponse,
|
||||||
LogProbConfig,
|
LogProbConfig,
|
||||||
Message,
|
Message,
|
||||||
|
OpenAIAssistantMessageParam,
|
||||||
OpenAIChatCompletion,
|
OpenAIChatCompletion,
|
||||||
OpenAIChatCompletionChunk,
|
OpenAIChatCompletionChunk,
|
||||||
|
OpenAIChatCompletionToolCall,
|
||||||
|
OpenAIChatCompletionToolCallFunction,
|
||||||
|
OpenAIChoice,
|
||||||
|
OpenAIChoiceLogprobs,
|
||||||
OpenAICompletion,
|
OpenAICompletion,
|
||||||
OpenAICompletionWithInputMessages,
|
OpenAICompletionWithInputMessages,
|
||||||
OpenAIEmbeddingsResponse,
|
OpenAIEmbeddingsResponse,
|
||||||
|
@ -55,7 +63,6 @@ from llama_stack.models.llama.llama3.chat_format import ChatFormat
|
||||||
from llama_stack.models.llama.llama3.tokenizer import Tokenizer
|
from llama_stack.models.llama.llama3.tokenizer import Tokenizer
|
||||||
from llama_stack.providers.datatypes import HealthResponse, HealthStatus, RoutingTable
|
from llama_stack.providers.datatypes import HealthResponse, HealthStatus, RoutingTable
|
||||||
from llama_stack.providers.utils.inference.inference_store import InferenceStore
|
from llama_stack.providers.utils.inference.inference_store import InferenceStore
|
||||||
from llama_stack.providers.utils.inference.stream_utils import stream_and_store_openai_completion
|
|
||||||
from llama_stack.providers.utils.telemetry.tracing import get_current_span
|
from llama_stack.providers.utils.telemetry.tracing import get_current_span
|
||||||
|
|
||||||
logger = get_logger(name=__name__, category="core")
|
logger = get_logger(name=__name__, category="core")
|
||||||
|
@ -119,6 +126,7 @@ class InferenceRouter(Inference):
|
||||||
if span is None:
|
if span is None:
|
||||||
logger.warning("No span found for token usage metrics")
|
logger.warning("No span found for token usage metrics")
|
||||||
return []
|
return []
|
||||||
|
|
||||||
metrics = [
|
metrics = [
|
||||||
("prompt_tokens", prompt_tokens),
|
("prompt_tokens", prompt_tokens),
|
||||||
("completion_tokens", completion_tokens),
|
("completion_tokens", completion_tokens),
|
||||||
|
@ -132,7 +140,7 @@ class InferenceRouter(Inference):
|
||||||
span_id=span.span_id,
|
span_id=span.span_id,
|
||||||
metric=metric_name,
|
metric=metric_name,
|
||||||
value=value,
|
value=value,
|
||||||
timestamp=time.time(),
|
timestamp=datetime.now(UTC),
|
||||||
unit="tokens",
|
unit="tokens",
|
||||||
attributes={
|
attributes={
|
||||||
"model_id": model.model_id,
|
"model_id": model.model_id,
|
||||||
|
@ -234,49 +242,26 @@ class InferenceRouter(Inference):
|
||||||
prompt_tokens = await self._count_tokens(messages, tool_config.tool_prompt_format)
|
prompt_tokens = await self._count_tokens(messages, tool_config.tool_prompt_format)
|
||||||
|
|
||||||
if stream:
|
if stream:
|
||||||
|
response_stream = await provider.chat_completion(**params)
|
||||||
async def stream_generator():
|
return self.stream_tokens_and_compute_metrics(
|
||||||
completion_text = ""
|
response=response_stream,
|
||||||
async for chunk in await provider.chat_completion(**params):
|
prompt_tokens=prompt_tokens,
|
||||||
if chunk.event.event_type == ChatCompletionResponseEventType.progress:
|
model=model,
|
||||||
if chunk.event.delta.type == "text":
|
tool_prompt_format=tool_config.tool_prompt_format,
|
||||||
completion_text += chunk.event.delta.text
|
|
||||||
if chunk.event.event_type == ChatCompletionResponseEventType.complete:
|
|
||||||
completion_tokens = await self._count_tokens(
|
|
||||||
[
|
|
||||||
CompletionMessage(
|
|
||||||
content=completion_text,
|
|
||||||
stop_reason=StopReason.end_of_turn,
|
|
||||||
)
|
|
||||||
],
|
|
||||||
tool_config.tool_prompt_format,
|
|
||||||
)
|
|
||||||
total_tokens = (prompt_tokens or 0) + (completion_tokens or 0)
|
|
||||||
metrics = await self._compute_and_log_token_usage(
|
|
||||||
prompt_tokens or 0,
|
|
||||||
completion_tokens or 0,
|
|
||||||
total_tokens,
|
|
||||||
model,
|
|
||||||
)
|
|
||||||
chunk.metrics = metrics if chunk.metrics is None else chunk.metrics + metrics
|
|
||||||
yield chunk
|
|
||||||
|
|
||||||
return stream_generator()
|
|
||||||
else:
|
|
||||||
response = await provider.chat_completion(**params)
|
|
||||||
completion_tokens = await self._count_tokens(
|
|
||||||
[response.completion_message],
|
|
||||||
tool_config.tool_prompt_format,
|
|
||||||
)
|
)
|
||||||
total_tokens = (prompt_tokens or 0) + (completion_tokens or 0)
|
|
||||||
metrics = await self._compute_and_log_token_usage(
|
response = await provider.chat_completion(**params)
|
||||||
prompt_tokens or 0,
|
metrics = await self.count_tokens_and_compute_metrics(
|
||||||
completion_tokens or 0,
|
response=response,
|
||||||
total_tokens,
|
prompt_tokens=prompt_tokens,
|
||||||
model,
|
model=model,
|
||||||
)
|
tool_prompt_format=tool_config.tool_prompt_format,
|
||||||
response.metrics = metrics if response.metrics is None else response.metrics + metrics
|
)
|
||||||
return response
|
# these metrics will show up in the client response.
|
||||||
|
response.metrics = (
|
||||||
|
metrics if not hasattr(response, "metrics") or response.metrics is None else response.metrics + metrics
|
||||||
|
)
|
||||||
|
return response
|
||||||
|
|
||||||
async def batch_chat_completion(
|
async def batch_chat_completion(
|
||||||
self,
|
self,
|
||||||
|
@ -332,39 +317,20 @@ class InferenceRouter(Inference):
|
||||||
)
|
)
|
||||||
|
|
||||||
prompt_tokens = await self._count_tokens(content)
|
prompt_tokens = await self._count_tokens(content)
|
||||||
|
response = await provider.completion(**params)
|
||||||
if stream:
|
if stream:
|
||||||
|
return self.stream_tokens_and_compute_metrics(
|
||||||
async def stream_generator():
|
response=response,
|
||||||
completion_text = ""
|
prompt_tokens=prompt_tokens,
|
||||||
async for chunk in await provider.completion(**params):
|
model=model,
|
||||||
if hasattr(chunk, "delta"):
|
|
||||||
completion_text += chunk.delta
|
|
||||||
if hasattr(chunk, "stop_reason") and chunk.stop_reason and self.telemetry:
|
|
||||||
completion_tokens = await self._count_tokens(completion_text)
|
|
||||||
total_tokens = (prompt_tokens or 0) + (completion_tokens or 0)
|
|
||||||
metrics = await self._compute_and_log_token_usage(
|
|
||||||
prompt_tokens or 0,
|
|
||||||
completion_tokens or 0,
|
|
||||||
total_tokens,
|
|
||||||
model,
|
|
||||||
)
|
|
||||||
chunk.metrics = metrics if chunk.metrics is None else chunk.metrics + metrics
|
|
||||||
yield chunk
|
|
||||||
|
|
||||||
return stream_generator()
|
|
||||||
else:
|
|
||||||
response = await provider.completion(**params)
|
|
||||||
completion_tokens = await self._count_tokens(response.content)
|
|
||||||
total_tokens = (prompt_tokens or 0) + (completion_tokens or 0)
|
|
||||||
metrics = await self._compute_and_log_token_usage(
|
|
||||||
prompt_tokens or 0,
|
|
||||||
completion_tokens or 0,
|
|
||||||
total_tokens,
|
|
||||||
model,
|
|
||||||
)
|
)
|
||||||
response.metrics = metrics if response.metrics is None else response.metrics + metrics
|
|
||||||
return response
|
metrics = await self.count_tokens_and_compute_metrics(
|
||||||
|
response=response, prompt_tokens=prompt_tokens, model=model
|
||||||
|
)
|
||||||
|
response.metrics = metrics if response.metrics is None else response.metrics + metrics
|
||||||
|
|
||||||
|
return response
|
||||||
|
|
||||||
async def batch_completion(
|
async def batch_completion(
|
||||||
self,
|
self,
|
||||||
|
@ -457,9 +423,29 @@ class InferenceRouter(Inference):
|
||||||
prompt_logprobs=prompt_logprobs,
|
prompt_logprobs=prompt_logprobs,
|
||||||
suffix=suffix,
|
suffix=suffix,
|
||||||
)
|
)
|
||||||
|
provider = self.routing_table.get_provider_impl(model_obj.identifier)
|
||||||
|
if stream:
|
||||||
|
return await provider.openai_completion(**params)
|
||||||
|
# TODO: Metrics do NOT work with openai_completion stream=True due to the fact
|
||||||
|
# that we do not return an AsyncIterator, our tests expect a stream of chunks we cannot intercept currently.
|
||||||
|
# response_stream = await provider.openai_completion(**params)
|
||||||
|
|
||||||
provider = await self.routing_table.get_provider_impl(model_obj.identifier)
|
response = await provider.openai_completion(**params)
|
||||||
return await provider.openai_completion(**params)
|
if self.telemetry:
|
||||||
|
metrics = self._construct_metrics(
|
||||||
|
prompt_tokens=response.usage.prompt_tokens,
|
||||||
|
completion_tokens=response.usage.completion_tokens,
|
||||||
|
total_tokens=response.usage.total_tokens,
|
||||||
|
model=model_obj,
|
||||||
|
)
|
||||||
|
for metric in metrics:
|
||||||
|
await self.telemetry.log_event(metric)
|
||||||
|
|
||||||
|
# these metrics will show up in the client response.
|
||||||
|
response.metrics = (
|
||||||
|
metrics if not hasattr(response, "metrics") or response.metrics is None else response.metrics + metrics
|
||||||
|
)
|
||||||
|
return response
|
||||||
|
|
||||||
async def openai_chat_completion(
|
async def openai_chat_completion(
|
||||||
self,
|
self,
|
||||||
|
@ -537,18 +523,38 @@ class InferenceRouter(Inference):
|
||||||
top_p=top_p,
|
top_p=top_p,
|
||||||
user=user,
|
user=user,
|
||||||
)
|
)
|
||||||
|
provider = self.routing_table.get_provider_impl(model_obj.identifier)
|
||||||
provider = await self.routing_table.get_provider_impl(model_obj.identifier)
|
|
||||||
if stream:
|
if stream:
|
||||||
response_stream = await provider.openai_chat_completion(**params)
|
response_stream = await provider.openai_chat_completion(**params)
|
||||||
if self.store:
|
|
||||||
return stream_and_store_openai_completion(response_stream, model, self.store, messages)
|
# For streaming, the provider returns AsyncIterator[OpenAIChatCompletionChunk]
|
||||||
return response_stream
|
# We need to add metrics to each chunk and store the final completion
|
||||||
else:
|
return self.stream_tokens_and_compute_metrics_openai_chat(
|
||||||
response = await self._nonstream_openai_chat_completion(provider, params)
|
response=response_stream,
|
||||||
if self.store:
|
model=model_obj,
|
||||||
await self.store.store_chat_completion(response, messages)
|
messages=messages,
|
||||||
return response
|
)
|
||||||
|
|
||||||
|
response = await self._nonstream_openai_chat_completion(provider, params)
|
||||||
|
|
||||||
|
# Store the response with the ID that will be returned to the client
|
||||||
|
if self.store:
|
||||||
|
await self.store.store_chat_completion(response, messages)
|
||||||
|
|
||||||
|
if self.telemetry:
|
||||||
|
metrics = self._construct_metrics(
|
||||||
|
prompt_tokens=response.usage.prompt_tokens,
|
||||||
|
completion_tokens=response.usage.completion_tokens,
|
||||||
|
total_tokens=response.usage.total_tokens,
|
||||||
|
model=model_obj,
|
||||||
|
)
|
||||||
|
for metric in metrics:
|
||||||
|
await self.telemetry.log_event(metric)
|
||||||
|
# these metrics will show up in the client response.
|
||||||
|
response.metrics = (
|
||||||
|
metrics if not hasattr(response, "metrics") or response.metrics is None else response.metrics + metrics
|
||||||
|
)
|
||||||
|
return response
|
||||||
|
|
||||||
async def openai_embeddings(
|
async def openai_embeddings(
|
||||||
self,
|
self,
|
||||||
|
@ -625,3 +631,244 @@ class InferenceRouter(Inference):
|
||||||
status=HealthStatus.ERROR, message=f"Health check failed: {str(e)}"
|
status=HealthStatus.ERROR, message=f"Health check failed: {str(e)}"
|
||||||
)
|
)
|
||||||
return health_statuses
|
return health_statuses
|
||||||
|
|
||||||
|
async def stream_tokens_and_compute_metrics(
|
||||||
|
self,
|
||||||
|
response,
|
||||||
|
prompt_tokens,
|
||||||
|
model,
|
||||||
|
tool_prompt_format: ToolPromptFormat | None = None,
|
||||||
|
) -> AsyncGenerator[ChatCompletionResponseStreamChunk, None] | AsyncGenerator[CompletionResponseStreamChunk, None]:
|
||||||
|
completion_text = ""
|
||||||
|
async for chunk in response:
|
||||||
|
complete = False
|
||||||
|
if hasattr(chunk, "event"): # only ChatCompletions have .event
|
||||||
|
if chunk.event.event_type == ChatCompletionResponseEventType.progress:
|
||||||
|
if chunk.event.delta.type == "text":
|
||||||
|
completion_text += chunk.event.delta.text
|
||||||
|
if chunk.event.event_type == ChatCompletionResponseEventType.complete:
|
||||||
|
complete = True
|
||||||
|
completion_tokens = await self._count_tokens(
|
||||||
|
[
|
||||||
|
CompletionMessage(
|
||||||
|
content=completion_text,
|
||||||
|
stop_reason=StopReason.end_of_turn,
|
||||||
|
)
|
||||||
|
],
|
||||||
|
tool_prompt_format=tool_prompt_format,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
if hasattr(chunk, "delta"):
|
||||||
|
completion_text += chunk.delta
|
||||||
|
if hasattr(chunk, "stop_reason") and chunk.stop_reason and self.telemetry:
|
||||||
|
complete = True
|
||||||
|
completion_tokens = await self._count_tokens(completion_text)
|
||||||
|
# if we are done receiving tokens
|
||||||
|
if complete:
|
||||||
|
total_tokens = (prompt_tokens or 0) + (completion_tokens or 0)
|
||||||
|
|
||||||
|
# Create a separate span for streaming completion metrics
|
||||||
|
if self.telemetry:
|
||||||
|
# Log metrics in the new span context
|
||||||
|
completion_metrics = self._construct_metrics(
|
||||||
|
prompt_tokens=prompt_tokens,
|
||||||
|
completion_tokens=completion_tokens,
|
||||||
|
total_tokens=total_tokens,
|
||||||
|
model=model,
|
||||||
|
)
|
||||||
|
for metric in completion_metrics:
|
||||||
|
if metric.metric in [
|
||||||
|
"completion_tokens",
|
||||||
|
"total_tokens",
|
||||||
|
]: # Only log completion and total tokens
|
||||||
|
await self.telemetry.log_event(metric)
|
||||||
|
|
||||||
|
# Return metrics in response
|
||||||
|
async_metrics = [
|
||||||
|
MetricInResponse(metric=metric.metric, value=metric.value) for metric in completion_metrics
|
||||||
|
]
|
||||||
|
chunk.metrics = async_metrics if chunk.metrics is None else chunk.metrics + async_metrics
|
||||||
|
else:
|
||||||
|
# Fallback if no telemetry
|
||||||
|
completion_metrics = self._construct_metrics(
|
||||||
|
prompt_tokens or 0,
|
||||||
|
completion_tokens or 0,
|
||||||
|
total_tokens,
|
||||||
|
model,
|
||||||
|
)
|
||||||
|
async_metrics = [
|
||||||
|
MetricInResponse(metric=metric.metric, value=metric.value) for metric in completion_metrics
|
||||||
|
]
|
||||||
|
chunk.metrics = async_metrics if chunk.metrics is None else chunk.metrics + async_metrics
|
||||||
|
yield chunk
|
||||||
|
|
||||||
|
async def count_tokens_and_compute_metrics(
|
||||||
|
self,
|
||||||
|
response: ChatCompletionResponse | CompletionResponse,
|
||||||
|
prompt_tokens,
|
||||||
|
model,
|
||||||
|
tool_prompt_format: ToolPromptFormat | None = None,
|
||||||
|
):
|
||||||
|
if isinstance(response, ChatCompletionResponse):
|
||||||
|
content = [response.completion_message]
|
||||||
|
else:
|
||||||
|
content = response.content
|
||||||
|
completion_tokens = await self._count_tokens(messages=content, tool_prompt_format=tool_prompt_format)
|
||||||
|
total_tokens = (prompt_tokens or 0) + (completion_tokens or 0)
|
||||||
|
|
||||||
|
# Create a separate span for completion metrics
|
||||||
|
if self.telemetry:
|
||||||
|
# Log metrics in the new span context
|
||||||
|
completion_metrics = self._construct_metrics(
|
||||||
|
prompt_tokens=prompt_tokens,
|
||||||
|
completion_tokens=completion_tokens,
|
||||||
|
total_tokens=total_tokens,
|
||||||
|
model=model,
|
||||||
|
)
|
||||||
|
for metric in completion_metrics:
|
||||||
|
if metric.metric in ["completion_tokens", "total_tokens"]: # Only log completion and total tokens
|
||||||
|
await self.telemetry.log_event(metric)
|
||||||
|
|
||||||
|
# Return metrics in response
|
||||||
|
return [MetricInResponse(metric=metric.metric, value=metric.value) for metric in completion_metrics]
|
||||||
|
|
||||||
|
# Fallback if no telemetry
|
||||||
|
metrics = self._construct_metrics(
|
||||||
|
prompt_tokens or 0,
|
||||||
|
completion_tokens or 0,
|
||||||
|
total_tokens,
|
||||||
|
model,
|
||||||
|
)
|
||||||
|
return [MetricInResponse(metric=metric.metric, value=metric.value) for metric in metrics]
|
||||||
|
|
||||||
|
async def stream_tokens_and_compute_metrics_openai_chat(
|
||||||
|
self,
|
||||||
|
response: AsyncIterator[OpenAIChatCompletionChunk],
|
||||||
|
model: Model,
|
||||||
|
messages: list[OpenAIMessageParam] | None = None,
|
||||||
|
) -> AsyncIterator[OpenAIChatCompletionChunk]:
|
||||||
|
"""Stream OpenAI chat completion chunks, compute metrics, and store the final completion."""
|
||||||
|
id = None
|
||||||
|
created = None
|
||||||
|
choices_data: dict[int, dict[str, Any]] = {}
|
||||||
|
|
||||||
|
try:
|
||||||
|
async for chunk in response:
|
||||||
|
# Skip None chunks
|
||||||
|
if chunk is None:
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Capture ID and created timestamp from first chunk
|
||||||
|
if id is None and chunk.id:
|
||||||
|
id = chunk.id
|
||||||
|
if created is None and chunk.created:
|
||||||
|
created = chunk.created
|
||||||
|
|
||||||
|
# Accumulate choice data for final assembly
|
||||||
|
if chunk.choices:
|
||||||
|
for choice_delta in chunk.choices:
|
||||||
|
idx = choice_delta.index
|
||||||
|
if idx not in choices_data:
|
||||||
|
choices_data[idx] = {
|
||||||
|
"content_parts": [],
|
||||||
|
"tool_calls_builder": {},
|
||||||
|
"finish_reason": None,
|
||||||
|
"logprobs_content_parts": [],
|
||||||
|
}
|
||||||
|
current_choice_data = choices_data[idx]
|
||||||
|
|
||||||
|
if choice_delta.delta:
|
||||||
|
delta = choice_delta.delta
|
||||||
|
if delta.content:
|
||||||
|
current_choice_data["content_parts"].append(delta.content)
|
||||||
|
if delta.tool_calls:
|
||||||
|
for tool_call_delta in delta.tool_calls:
|
||||||
|
tc_idx = tool_call_delta.index
|
||||||
|
if tc_idx not in current_choice_data["tool_calls_builder"]:
|
||||||
|
current_choice_data["tool_calls_builder"][tc_idx] = {
|
||||||
|
"id": None,
|
||||||
|
"type": "function",
|
||||||
|
"function_name_parts": [],
|
||||||
|
"function_arguments_parts": [],
|
||||||
|
}
|
||||||
|
builder = current_choice_data["tool_calls_builder"][tc_idx]
|
||||||
|
if tool_call_delta.id:
|
||||||
|
builder["id"] = tool_call_delta.id
|
||||||
|
if tool_call_delta.type:
|
||||||
|
builder["type"] = tool_call_delta.type
|
||||||
|
if tool_call_delta.function:
|
||||||
|
if tool_call_delta.function.name:
|
||||||
|
builder["function_name_parts"].append(tool_call_delta.function.name)
|
||||||
|
if tool_call_delta.function.arguments:
|
||||||
|
builder["function_arguments_parts"].append(
|
||||||
|
tool_call_delta.function.arguments
|
||||||
|
)
|
||||||
|
if choice_delta.finish_reason:
|
||||||
|
current_choice_data["finish_reason"] = choice_delta.finish_reason
|
||||||
|
if choice_delta.logprobs and choice_delta.logprobs.content:
|
||||||
|
current_choice_data["logprobs_content_parts"].extend(choice_delta.logprobs.content)
|
||||||
|
|
||||||
|
# Compute metrics on final chunk
|
||||||
|
if chunk.choices and chunk.choices[0].finish_reason:
|
||||||
|
completion_text = ""
|
||||||
|
for choice_data in choices_data.values():
|
||||||
|
completion_text += "".join(choice_data["content_parts"])
|
||||||
|
|
||||||
|
# Add metrics to the chunk
|
||||||
|
if self.telemetry and chunk.usage:
|
||||||
|
metrics = self._construct_metrics(
|
||||||
|
prompt_tokens=chunk.usage.prompt_tokens,
|
||||||
|
completion_tokens=chunk.usage.completion_tokens,
|
||||||
|
total_tokens=chunk.usage.total_tokens,
|
||||||
|
model=model,
|
||||||
|
)
|
||||||
|
for metric in metrics:
|
||||||
|
await self.telemetry.log_event(metric)
|
||||||
|
|
||||||
|
yield chunk
|
||||||
|
finally:
|
||||||
|
# Store the final assembled completion
|
||||||
|
if id and self.store and messages:
|
||||||
|
assembled_choices: list[OpenAIChoice] = []
|
||||||
|
for choice_idx, choice_data in choices_data.items():
|
||||||
|
content_str = "".join(choice_data["content_parts"])
|
||||||
|
assembled_tool_calls: list[OpenAIChatCompletionToolCall] = []
|
||||||
|
if choice_data["tool_calls_builder"]:
|
||||||
|
for tc_build_data in choice_data["tool_calls_builder"].values():
|
||||||
|
if tc_build_data["id"]:
|
||||||
|
func_name = "".join(tc_build_data["function_name_parts"])
|
||||||
|
func_args = "".join(tc_build_data["function_arguments_parts"])
|
||||||
|
assembled_tool_calls.append(
|
||||||
|
OpenAIChatCompletionToolCall(
|
||||||
|
id=tc_build_data["id"],
|
||||||
|
type=tc_build_data["type"],
|
||||||
|
function=OpenAIChatCompletionToolCallFunction(
|
||||||
|
name=func_name, arguments=func_args
|
||||||
|
),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
message = OpenAIAssistantMessageParam(
|
||||||
|
role="assistant",
|
||||||
|
content=content_str if content_str else None,
|
||||||
|
tool_calls=assembled_tool_calls if assembled_tool_calls else None,
|
||||||
|
)
|
||||||
|
logprobs_content = choice_data["logprobs_content_parts"]
|
||||||
|
final_logprobs = OpenAIChoiceLogprobs(content=logprobs_content) if logprobs_content else None
|
||||||
|
|
||||||
|
assembled_choices.append(
|
||||||
|
OpenAIChoice(
|
||||||
|
finish_reason=choice_data["finish_reason"],
|
||||||
|
index=choice_idx,
|
||||||
|
message=message,
|
||||||
|
logprobs=final_logprobs,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
final_response = OpenAIChatCompletion(
|
||||||
|
id=id,
|
||||||
|
choices=assembled_choices,
|
||||||
|
created=created or int(time.time()),
|
||||||
|
model=model.identifier,
|
||||||
|
object="chat.completion",
|
||||||
|
)
|
||||||
|
await self.store.store_chat_completion(final_response, messages)
|
||||||
|
|
|
@ -1,129 +0,0 @@
|
||||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
||||||
# All rights reserved.
|
|
||||||
#
|
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
|
||||||
# the root directory of this source tree.
|
|
||||||
|
|
||||||
from collections.abc import AsyncIterator
|
|
||||||
from datetime import UTC, datetime
|
|
||||||
from typing import Any
|
|
||||||
|
|
||||||
from llama_stack.apis.inference import (
|
|
||||||
OpenAIAssistantMessageParam,
|
|
||||||
OpenAIChatCompletion,
|
|
||||||
OpenAIChatCompletionChunk,
|
|
||||||
OpenAIChatCompletionToolCall,
|
|
||||||
OpenAIChatCompletionToolCallFunction,
|
|
||||||
OpenAIChoice,
|
|
||||||
OpenAIChoiceLogprobs,
|
|
||||||
OpenAIMessageParam,
|
|
||||||
)
|
|
||||||
from llama_stack.providers.utils.inference.inference_store import InferenceStore
|
|
||||||
|
|
||||||
|
|
||||||
async def stream_and_store_openai_completion(
|
|
||||||
provider_stream: AsyncIterator[OpenAIChatCompletionChunk],
|
|
||||||
model: str,
|
|
||||||
store: InferenceStore,
|
|
||||||
input_messages: list[OpenAIMessageParam],
|
|
||||||
) -> AsyncIterator[OpenAIChatCompletionChunk]:
|
|
||||||
"""
|
|
||||||
Wraps a provider's stream, yields chunks, and stores the full completion at the end.
|
|
||||||
"""
|
|
||||||
id = None
|
|
||||||
created = None
|
|
||||||
choices_data: dict[int, dict[str, Any]] = {}
|
|
||||||
|
|
||||||
try:
|
|
||||||
async for chunk in provider_stream:
|
|
||||||
if id is None and chunk.id:
|
|
||||||
id = chunk.id
|
|
||||||
if created is None and chunk.created:
|
|
||||||
created = chunk.created
|
|
||||||
|
|
||||||
if chunk.choices:
|
|
||||||
for choice_delta in chunk.choices:
|
|
||||||
idx = choice_delta.index
|
|
||||||
if idx not in choices_data:
|
|
||||||
choices_data[idx] = {
|
|
||||||
"content_parts": [],
|
|
||||||
"tool_calls_builder": {},
|
|
||||||
"finish_reason": None,
|
|
||||||
"logprobs_content_parts": [],
|
|
||||||
}
|
|
||||||
current_choice_data = choices_data[idx]
|
|
||||||
|
|
||||||
if choice_delta.delta:
|
|
||||||
delta = choice_delta.delta
|
|
||||||
if delta.content:
|
|
||||||
current_choice_data["content_parts"].append(delta.content)
|
|
||||||
if delta.tool_calls:
|
|
||||||
for tool_call_delta in delta.tool_calls:
|
|
||||||
tc_idx = tool_call_delta.index
|
|
||||||
if tc_idx not in current_choice_data["tool_calls_builder"]:
|
|
||||||
# Initialize with correct structure for _ToolCallBuilderData
|
|
||||||
current_choice_data["tool_calls_builder"][tc_idx] = {
|
|
||||||
"id": None,
|
|
||||||
"type": "function",
|
|
||||||
"function_name_parts": [],
|
|
||||||
"function_arguments_parts": [],
|
|
||||||
}
|
|
||||||
builder = current_choice_data["tool_calls_builder"][tc_idx]
|
|
||||||
if tool_call_delta.id:
|
|
||||||
builder["id"] = tool_call_delta.id
|
|
||||||
if tool_call_delta.type:
|
|
||||||
builder["type"] = tool_call_delta.type
|
|
||||||
if tool_call_delta.function:
|
|
||||||
if tool_call_delta.function.name:
|
|
||||||
builder["function_name_parts"].append(tool_call_delta.function.name)
|
|
||||||
if tool_call_delta.function.arguments:
|
|
||||||
builder["function_arguments_parts"].append(tool_call_delta.function.arguments)
|
|
||||||
if choice_delta.finish_reason:
|
|
||||||
current_choice_data["finish_reason"] = choice_delta.finish_reason
|
|
||||||
if choice_delta.logprobs and choice_delta.logprobs.content:
|
|
||||||
# Ensure that we are extending with the correct type
|
|
||||||
current_choice_data["logprobs_content_parts"].extend(choice_delta.logprobs.content)
|
|
||||||
yield chunk
|
|
||||||
finally:
|
|
||||||
if id:
|
|
||||||
assembled_choices: list[OpenAIChoice] = []
|
|
||||||
for choice_idx, choice_data in choices_data.items():
|
|
||||||
content_str = "".join(choice_data["content_parts"])
|
|
||||||
assembled_tool_calls: list[OpenAIChatCompletionToolCall] = []
|
|
||||||
if choice_data["tool_calls_builder"]:
|
|
||||||
for tc_build_data in choice_data["tool_calls_builder"].values():
|
|
||||||
if tc_build_data["id"]:
|
|
||||||
func_name = "".join(tc_build_data["function_name_parts"])
|
|
||||||
func_args = "".join(tc_build_data["function_arguments_parts"])
|
|
||||||
assembled_tool_calls.append(
|
|
||||||
OpenAIChatCompletionToolCall(
|
|
||||||
id=tc_build_data["id"],
|
|
||||||
type=tc_build_data["type"], # No or "function" needed, already set
|
|
||||||
function=OpenAIChatCompletionToolCallFunction(name=func_name, arguments=func_args),
|
|
||||||
)
|
|
||||||
)
|
|
||||||
message = OpenAIAssistantMessageParam(
|
|
||||||
role="assistant",
|
|
||||||
content=content_str if content_str else None,
|
|
||||||
tool_calls=assembled_tool_calls if assembled_tool_calls else None,
|
|
||||||
)
|
|
||||||
logprobs_content = choice_data["logprobs_content_parts"]
|
|
||||||
final_logprobs = OpenAIChoiceLogprobs(content=logprobs_content) if logprobs_content else None
|
|
||||||
|
|
||||||
assembled_choices.append(
|
|
||||||
OpenAIChoice(
|
|
||||||
finish_reason=choice_data["finish_reason"],
|
|
||||||
index=choice_idx,
|
|
||||||
message=message,
|
|
||||||
logprobs=final_logprobs,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
final_response = OpenAIChatCompletion(
|
|
||||||
id=id,
|
|
||||||
choices=assembled_choices,
|
|
||||||
created=created or int(datetime.now(UTC).timestamp()),
|
|
||||||
model=model,
|
|
||||||
object="chat.completion",
|
|
||||||
)
|
|
||||||
await store.store_chat_completion(final_response, input_messages)
|
|
|
@ -81,7 +81,7 @@ BACKGROUND_LOGGER = None
|
||||||
|
|
||||||
|
|
||||||
class BackgroundLogger:
|
class BackgroundLogger:
|
||||||
def __init__(self, api: Telemetry, capacity: int = 1000):
|
def __init__(self, api: Telemetry, capacity: int = 100000):
|
||||||
self.api = api
|
self.api = api
|
||||||
self.log_queue = queue.Queue(maxsize=capacity)
|
self.log_queue = queue.Queue(maxsize=capacity)
|
||||||
self.worker_thread = threading.Thread(target=self._process_logs, daemon=True)
|
self.worker_thread = threading.Thread(target=self._process_logs, daemon=True)
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue