mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-12-04 18:13:44 +00:00
Closes security gaps where RBAC checks could be bypassed: o Inference router: Added RBAC enforcement in the fallback path to ensure access control is applied consistently. o Model listing: Dynamic models fetched via provider_data were returned without RBAC checks. Added filtering to ensure users only see models they have permission to access. Both fixes create temporary ModelWithOwner objects for RBAC validation, maintaining security through consistent access control enforcement. Closes: #4269 <hr>This is an automatic backport of pull request #4270 done by [Mergify](https://mergify.com). Signed-off-by: Derek Higgins <derekh@redhat.com> Signed-off-by: Charlie Doern <cdoern@redhat.com> Co-authored-by: Derek Higgins <derekh@redhat.com>
645 lines
29 KiB
Python
645 lines
29 KiB
Python
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
# All rights reserved.
|
|
#
|
|
# This source code is licensed under the terms described in the LICENSE file in
|
|
# the root directory of this source tree.
|
|
|
|
import asyncio
|
|
import time
|
|
from collections.abc import AsyncGenerator, AsyncIterator
|
|
from datetime import UTC, datetime
|
|
from typing import Annotated, Any
|
|
|
|
from fastapi import Body
|
|
from openai.types.chat import ChatCompletionToolChoiceOptionParam as OpenAIChatCompletionToolChoiceOptionParam
|
|
from openai.types.chat import ChatCompletionToolParam as OpenAIChatCompletionToolParam
|
|
from pydantic import TypeAdapter
|
|
|
|
from llama_stack.apis.common.content_types import (
|
|
InterleavedContent,
|
|
)
|
|
from llama_stack.apis.common.errors import ModelNotFoundError, ModelTypeError
|
|
from llama_stack.apis.inference import (
|
|
ChatCompletionResponse,
|
|
ChatCompletionResponseEventType,
|
|
ChatCompletionResponseStreamChunk,
|
|
CompletionMessage,
|
|
CompletionResponse,
|
|
CompletionResponseStreamChunk,
|
|
Inference,
|
|
ListOpenAIChatCompletionResponse,
|
|
Message,
|
|
OpenAIAssistantMessageParam,
|
|
OpenAIChatCompletion,
|
|
OpenAIChatCompletionChunk,
|
|
OpenAIChatCompletionRequestWithExtraBody,
|
|
OpenAIChatCompletionToolCall,
|
|
OpenAIChatCompletionToolCallFunction,
|
|
OpenAIChoice,
|
|
OpenAIChoiceLogprobs,
|
|
OpenAICompletion,
|
|
OpenAICompletionRequestWithExtraBody,
|
|
OpenAICompletionWithInputMessages,
|
|
OpenAIEmbeddingsRequestWithExtraBody,
|
|
OpenAIEmbeddingsResponse,
|
|
OpenAIMessageParam,
|
|
Order,
|
|
StopReason,
|
|
ToolPromptFormat,
|
|
)
|
|
from llama_stack.apis.models import Model, ModelType
|
|
from llama_stack.apis.telemetry import MetricEvent, MetricInResponse, Telemetry
|
|
from llama_stack.core.access_control.access_control import is_action_allowed
|
|
from llama_stack.core.datatypes import ModelWithOwner
|
|
from llama_stack.core.request_headers import get_authenticated_user
|
|
from llama_stack.log import get_logger
|
|
from llama_stack.models.llama.llama3.chat_format import ChatFormat
|
|
from llama_stack.models.llama.llama3.tokenizer import Tokenizer
|
|
from llama_stack.providers.datatypes import (
|
|
HealthResponse,
|
|
HealthStatus,
|
|
RoutingTable,
|
|
)
|
|
from llama_stack.providers.utils.inference.inference_store import InferenceStore
|
|
from llama_stack.providers.utils.telemetry.tracing import enqueue_event, get_current_span
|
|
|
|
logger = get_logger(name=__name__, category="core::routers")
|
|
|
|
|
|
class InferenceRouter(Inference):
|
|
"""Routes to an provider based on the model"""
|
|
|
|
def __init__(
|
|
self,
|
|
routing_table: RoutingTable,
|
|
telemetry: Telemetry | None = None,
|
|
store: InferenceStore | None = None,
|
|
) -> None:
|
|
logger.debug("Initializing InferenceRouter")
|
|
self.routing_table = routing_table
|
|
self.telemetry = telemetry
|
|
self.store = store
|
|
if self.telemetry:
|
|
self.tokenizer = Tokenizer.get_instance()
|
|
self.formatter = ChatFormat(self.tokenizer)
|
|
|
|
async def initialize(self) -> None:
|
|
logger.debug("InferenceRouter.initialize")
|
|
|
|
async def shutdown(self) -> None:
|
|
logger.debug("InferenceRouter.shutdown")
|
|
if self.store:
|
|
try:
|
|
await self.store.shutdown()
|
|
except Exception as e:
|
|
logger.warning(f"Error during InferenceStore shutdown: {e}")
|
|
|
|
async def register_model(
|
|
self,
|
|
model_id: str,
|
|
provider_model_id: str | None = None,
|
|
provider_id: str | None = None,
|
|
metadata: dict[str, Any] | None = None,
|
|
model_type: ModelType | None = None,
|
|
) -> None:
|
|
logger.debug(
|
|
f"InferenceRouter.register_model: {model_id=} {provider_model_id=} {provider_id=} {metadata=} {model_type=}",
|
|
)
|
|
await self.routing_table.register_model(model_id, provider_model_id, provider_id, metadata, model_type)
|
|
|
|
def _construct_metrics(
|
|
self,
|
|
prompt_tokens: int,
|
|
completion_tokens: int,
|
|
total_tokens: int,
|
|
fully_qualified_model_id: str,
|
|
provider_id: str,
|
|
) -> list[MetricEvent]:
|
|
"""Constructs a list of MetricEvent objects containing token usage metrics.
|
|
|
|
Args:
|
|
prompt_tokens: Number of tokens in the prompt
|
|
completion_tokens: Number of tokens in the completion
|
|
total_tokens: Total number of tokens used
|
|
fully_qualified_model_id:
|
|
provider_id: The provider identifier
|
|
|
|
Returns:
|
|
List of MetricEvent objects with token usage metrics
|
|
"""
|
|
span = get_current_span()
|
|
if span is None:
|
|
logger.warning("No span found for token usage metrics")
|
|
return []
|
|
|
|
metrics = [
|
|
("prompt_tokens", prompt_tokens),
|
|
("completion_tokens", completion_tokens),
|
|
("total_tokens", total_tokens),
|
|
]
|
|
metric_events = []
|
|
for metric_name, value in metrics:
|
|
metric_events.append(
|
|
MetricEvent(
|
|
trace_id=span.trace_id,
|
|
span_id=span.span_id,
|
|
metric=metric_name,
|
|
value=value,
|
|
timestamp=datetime.now(UTC),
|
|
unit="tokens",
|
|
attributes={
|
|
"model_id": fully_qualified_model_id,
|
|
"provider_id": provider_id,
|
|
},
|
|
)
|
|
)
|
|
return metric_events
|
|
|
|
async def _compute_and_log_token_usage(
|
|
self,
|
|
prompt_tokens: int,
|
|
completion_tokens: int,
|
|
total_tokens: int,
|
|
model: Model,
|
|
) -> list[MetricInResponse]:
|
|
metrics = self._construct_metrics(
|
|
prompt_tokens, completion_tokens, total_tokens, model.model_id, model.provider_id
|
|
)
|
|
if self.telemetry:
|
|
for metric in metrics:
|
|
enqueue_event(metric)
|
|
return [MetricInResponse(metric=metric.metric, value=metric.value) for metric in metrics]
|
|
|
|
async def _count_tokens(
|
|
self,
|
|
messages: list[Message] | InterleavedContent,
|
|
tool_prompt_format: ToolPromptFormat | None = None,
|
|
) -> int | None:
|
|
if not hasattr(self, "formatter") or self.formatter is None:
|
|
return None
|
|
|
|
if isinstance(messages, list):
|
|
encoded = self.formatter.encode_dialog_prompt(messages, tool_prompt_format)
|
|
else:
|
|
encoded = self.formatter.encode_content(messages)
|
|
return len(encoded.tokens) if encoded and encoded.tokens else 0
|
|
|
|
async def _get_model_provider(self, model_id: str, expected_model_type: str) -> tuple[Inference, str]:
|
|
model = await self.routing_table.get_object_by_identifier("model", model_id)
|
|
if model:
|
|
if model.model_type != expected_model_type:
|
|
raise ModelTypeError(model_id, model.model_type, expected_model_type)
|
|
|
|
provider = await self.routing_table.get_provider_impl(model.identifier)
|
|
return provider, model.provider_resource_id
|
|
|
|
# Handles cases where clients use the provider format directly
|
|
return await self._get_provider_by_fallback(model_id, expected_model_type)
|
|
|
|
async def _get_provider_by_fallback(self, model_id: str, expected_model_type: str) -> tuple[Inference, str]:
|
|
"""
|
|
Handle fallback case where model_id is in provider_id/provider_resource_id format.
|
|
"""
|
|
splits = model_id.split("/", maxsplit=1)
|
|
if len(splits) != 2:
|
|
raise ModelNotFoundError(model_id)
|
|
|
|
provider_id, provider_resource_id = splits
|
|
|
|
# Check if provider exists
|
|
if provider_id not in self.routing_table.impls_by_provider_id:
|
|
logger.warning(f"Provider {provider_id} not found for model {model_id}")
|
|
raise ModelNotFoundError(model_id)
|
|
|
|
# Create a temporary model object for RBAC check
|
|
temp_model = ModelWithOwner(
|
|
identifier=model_id,
|
|
provider_id=provider_id,
|
|
provider_resource_id=provider_resource_id,
|
|
model_type=expected_model_type,
|
|
metadata={}, # Empty metadata for temporary object
|
|
)
|
|
|
|
# Perform RBAC check
|
|
user = get_authenticated_user()
|
|
if not is_action_allowed(self.routing_table.policy, "read", temp_model, user):
|
|
logger.debug(
|
|
f"Access denied to model '{model_id}' via fallback path for user {user.principal if user else 'anonymous'}"
|
|
)
|
|
raise ModelNotFoundError(model_id)
|
|
|
|
return self.routing_table.impls_by_provider_id[provider_id], provider_resource_id
|
|
|
|
async def openai_completion(
|
|
self,
|
|
params: Annotated[OpenAICompletionRequestWithExtraBody, Body(...)],
|
|
) -> OpenAICompletion:
|
|
logger.debug(
|
|
f"InferenceRouter.openai_completion: model={params.model}, stream={params.stream}, prompt={params.prompt}",
|
|
)
|
|
request_model_id = params.model
|
|
provider, provider_resource_id = await self._get_model_provider(params.model, ModelType.llm)
|
|
params.model = provider_resource_id
|
|
|
|
if params.stream:
|
|
return await provider.openai_completion(params)
|
|
# TODO: Metrics do NOT work with openai_completion stream=True due to the fact
|
|
# that we do not return an AsyncIterator, our tests expect a stream of chunks we cannot intercept currently.
|
|
|
|
response = await provider.openai_completion(params)
|
|
response.model = request_model_id
|
|
if self.telemetry:
|
|
metrics = self._construct_metrics(
|
|
prompt_tokens=response.usage.prompt_tokens,
|
|
completion_tokens=response.usage.completion_tokens,
|
|
total_tokens=response.usage.total_tokens,
|
|
fully_qualified_model_id=request_model_id,
|
|
provider_id=provider.__provider_id__,
|
|
)
|
|
for metric in metrics:
|
|
enqueue_event(metric)
|
|
|
|
# these metrics will show up in the client response.
|
|
response.metrics = (
|
|
metrics if not hasattr(response, "metrics") or response.metrics is None else response.metrics + metrics
|
|
)
|
|
return response
|
|
|
|
async def openai_chat_completion(
|
|
self,
|
|
params: Annotated[OpenAIChatCompletionRequestWithExtraBody, Body(...)],
|
|
) -> OpenAIChatCompletion | AsyncIterator[OpenAIChatCompletionChunk]:
|
|
logger.debug(
|
|
f"InferenceRouter.openai_chat_completion: model={params.model}, stream={params.stream}, messages={params.messages}",
|
|
)
|
|
request_model_id = params.model
|
|
provider, provider_resource_id = await self._get_model_provider(params.model, ModelType.llm)
|
|
params.model = provider_resource_id
|
|
|
|
# Use the OpenAI client for a bit of extra input validation without
|
|
# exposing the OpenAI client itself as part of our API surface
|
|
if params.tool_choice:
|
|
TypeAdapter(OpenAIChatCompletionToolChoiceOptionParam).validate_python(params.tool_choice)
|
|
if params.tools is None:
|
|
raise ValueError("'tool_choice' is only allowed when 'tools' is also provided")
|
|
if params.tools:
|
|
for tool in params.tools:
|
|
TypeAdapter(OpenAIChatCompletionToolParam).validate_python(tool)
|
|
|
|
# Some providers make tool calls even when tool_choice is "none"
|
|
# so just clear them both out to avoid unexpected tool calls
|
|
if params.tool_choice == "none" and params.tools is not None:
|
|
params.tool_choice = None
|
|
params.tools = None
|
|
|
|
if params.stream:
|
|
response_stream = await provider.openai_chat_completion(params)
|
|
|
|
# For streaming, the provider returns AsyncIterator[OpenAIChatCompletionChunk]
|
|
# We need to add metrics to each chunk and store the final completion
|
|
return self.stream_tokens_and_compute_metrics_openai_chat(
|
|
response=response_stream,
|
|
fully_qualified_model_id=request_model_id,
|
|
provider_id=provider.__provider_id__,
|
|
messages=params.messages,
|
|
)
|
|
|
|
response = await self._nonstream_openai_chat_completion(provider, params)
|
|
response.model = request_model_id
|
|
|
|
# Store the response with the ID that will be returned to the client
|
|
if self.store:
|
|
asyncio.create_task(self.store.store_chat_completion(response, params.messages))
|
|
|
|
if self.telemetry:
|
|
metrics = self._construct_metrics(
|
|
prompt_tokens=response.usage.prompt_tokens,
|
|
completion_tokens=response.usage.completion_tokens,
|
|
total_tokens=response.usage.total_tokens,
|
|
fully_qualified_model_id=request_model_id,
|
|
provider_id=provider.__provider_id__,
|
|
)
|
|
for metric in metrics:
|
|
enqueue_event(metric)
|
|
# these metrics will show up in the client response.
|
|
response.metrics = (
|
|
metrics if not hasattr(response, "metrics") or response.metrics is None else response.metrics + metrics
|
|
)
|
|
return response
|
|
|
|
async def openai_embeddings(
|
|
self,
|
|
params: Annotated[OpenAIEmbeddingsRequestWithExtraBody, Body(...)],
|
|
) -> OpenAIEmbeddingsResponse:
|
|
logger.debug(
|
|
f"InferenceRouter.openai_embeddings: model={params.model}, input_type={type(params.input)}, encoding_format={params.encoding_format}, dimensions={params.dimensions}",
|
|
)
|
|
request_model_id = params.model
|
|
provider, provider_resource_id = await self._get_model_provider(params.model, ModelType.embedding)
|
|
params.model = provider_resource_id
|
|
|
|
response = await provider.openai_embeddings(params)
|
|
response.model = request_model_id
|
|
return response
|
|
|
|
async def list_chat_completions(
|
|
self,
|
|
after: str | None = None,
|
|
limit: int | None = 20,
|
|
model: str | None = None,
|
|
order: Order | None = Order.desc,
|
|
) -> ListOpenAIChatCompletionResponse:
|
|
if self.store:
|
|
return await self.store.list_chat_completions(after, limit, model, order)
|
|
raise NotImplementedError("List chat completions is not supported: inference store is not configured.")
|
|
|
|
async def get_chat_completion(self, completion_id: str) -> OpenAICompletionWithInputMessages:
|
|
if self.store:
|
|
return await self.store.get_chat_completion(completion_id)
|
|
raise NotImplementedError("Get chat completion is not supported: inference store is not configured.")
|
|
|
|
async def _nonstream_openai_chat_completion(
|
|
self, provider: Inference, params: OpenAIChatCompletionRequestWithExtraBody
|
|
) -> OpenAIChatCompletion:
|
|
response = await provider.openai_chat_completion(params)
|
|
for choice in response.choices:
|
|
# some providers return an empty list for no tool calls in non-streaming responses
|
|
# but the OpenAI API returns None. So, set tool_calls to None if it's empty
|
|
if choice.message and choice.message.tool_calls is not None and len(choice.message.tool_calls) == 0:
|
|
choice.message.tool_calls = None
|
|
return response
|
|
|
|
async def health(self) -> dict[str, HealthResponse]:
|
|
health_statuses = {}
|
|
timeout = 1 # increasing the timeout to 1 second for health checks
|
|
for provider_id, impl in self.routing_table.impls_by_provider_id.items():
|
|
try:
|
|
# check if the provider has a health method
|
|
if not hasattr(impl, "health"):
|
|
continue
|
|
health = await asyncio.wait_for(impl.health(), timeout=timeout)
|
|
health_statuses[provider_id] = health
|
|
except TimeoutError:
|
|
health_statuses[provider_id] = HealthResponse(
|
|
status=HealthStatus.ERROR,
|
|
message=f"Health check timed out after {timeout} seconds",
|
|
)
|
|
except NotImplementedError:
|
|
health_statuses[provider_id] = HealthResponse(status=HealthStatus.NOT_IMPLEMENTED)
|
|
except Exception as e:
|
|
health_statuses[provider_id] = HealthResponse(
|
|
status=HealthStatus.ERROR, message=f"Health check failed: {str(e)}"
|
|
)
|
|
return health_statuses
|
|
|
|
async def stream_tokens_and_compute_metrics(
|
|
self,
|
|
response,
|
|
prompt_tokens,
|
|
fully_qualified_model_id: str,
|
|
provider_id: str,
|
|
tool_prompt_format: ToolPromptFormat | None = None,
|
|
) -> AsyncGenerator[ChatCompletionResponseStreamChunk, None] | AsyncGenerator[CompletionResponseStreamChunk, None]:
|
|
completion_text = ""
|
|
async for chunk in response:
|
|
complete = False
|
|
if hasattr(chunk, "event"): # only ChatCompletions have .event
|
|
if chunk.event.event_type == ChatCompletionResponseEventType.progress:
|
|
if chunk.event.delta.type == "text":
|
|
completion_text += chunk.event.delta.text
|
|
if chunk.event.event_type == ChatCompletionResponseEventType.complete:
|
|
complete = True
|
|
completion_tokens = await self._count_tokens(
|
|
[
|
|
CompletionMessage(
|
|
content=completion_text,
|
|
stop_reason=StopReason.end_of_turn,
|
|
)
|
|
],
|
|
tool_prompt_format=tool_prompt_format,
|
|
)
|
|
else:
|
|
if hasattr(chunk, "delta"):
|
|
completion_text += chunk.delta
|
|
if hasattr(chunk, "stop_reason") and chunk.stop_reason and self.telemetry:
|
|
complete = True
|
|
completion_tokens = await self._count_tokens(completion_text)
|
|
# if we are done receiving tokens
|
|
if complete:
|
|
total_tokens = (prompt_tokens or 0) + (completion_tokens or 0)
|
|
|
|
# Create a separate span for streaming completion metrics
|
|
if self.telemetry:
|
|
# Log metrics in the new span context
|
|
completion_metrics = self._construct_metrics(
|
|
prompt_tokens=prompt_tokens,
|
|
completion_tokens=completion_tokens,
|
|
total_tokens=total_tokens,
|
|
fully_qualified_model_id=fully_qualified_model_id,
|
|
provider_id=provider_id,
|
|
)
|
|
for metric in completion_metrics:
|
|
if metric.metric in [
|
|
"completion_tokens",
|
|
"total_tokens",
|
|
]: # Only log completion and total tokens
|
|
enqueue_event(metric)
|
|
|
|
# Return metrics in response
|
|
async_metrics = [
|
|
MetricInResponse(metric=metric.metric, value=metric.value) for metric in completion_metrics
|
|
]
|
|
chunk.metrics = async_metrics if chunk.metrics is None else chunk.metrics + async_metrics
|
|
else:
|
|
# Fallback if no telemetry
|
|
completion_metrics = self._construct_metrics(
|
|
prompt_tokens or 0,
|
|
completion_tokens or 0,
|
|
total_tokens,
|
|
fully_qualified_model_id=fully_qualified_model_id,
|
|
provider_id=provider_id,
|
|
)
|
|
async_metrics = [
|
|
MetricInResponse(metric=metric.metric, value=metric.value) for metric in completion_metrics
|
|
]
|
|
chunk.metrics = async_metrics if chunk.metrics is None else chunk.metrics + async_metrics
|
|
yield chunk
|
|
|
|
async def count_tokens_and_compute_metrics(
|
|
self,
|
|
response: ChatCompletionResponse | CompletionResponse,
|
|
prompt_tokens,
|
|
fully_qualified_model_id: str,
|
|
provider_id: str,
|
|
tool_prompt_format: ToolPromptFormat | None = None,
|
|
):
|
|
if isinstance(response, ChatCompletionResponse):
|
|
content = [response.completion_message]
|
|
else:
|
|
content = response.content
|
|
completion_tokens = await self._count_tokens(messages=content, tool_prompt_format=tool_prompt_format)
|
|
total_tokens = (prompt_tokens or 0) + (completion_tokens or 0)
|
|
|
|
# Create a separate span for completion metrics
|
|
if self.telemetry:
|
|
# Log metrics in the new span context
|
|
completion_metrics = self._construct_metrics(
|
|
prompt_tokens=prompt_tokens,
|
|
completion_tokens=completion_tokens,
|
|
total_tokens=total_tokens,
|
|
fully_qualified_model_id=fully_qualified_model_id,
|
|
provider_id=provider_id,
|
|
)
|
|
for metric in completion_metrics:
|
|
if metric.metric in ["completion_tokens", "total_tokens"]: # Only log completion and total tokens
|
|
enqueue_event(metric)
|
|
|
|
# Return metrics in response
|
|
return [MetricInResponse(metric=metric.metric, value=metric.value) for metric in completion_metrics]
|
|
|
|
# Fallback if no telemetry
|
|
metrics = self._construct_metrics(
|
|
prompt_tokens or 0,
|
|
completion_tokens or 0,
|
|
total_tokens,
|
|
fully_qualified_model_id=fully_qualified_model_id,
|
|
provider_id=provider_id,
|
|
)
|
|
return [MetricInResponse(metric=metric.metric, value=metric.value) for metric in metrics]
|
|
|
|
async def stream_tokens_and_compute_metrics_openai_chat(
|
|
self,
|
|
response: AsyncIterator[OpenAIChatCompletionChunk],
|
|
fully_qualified_model_id: str,
|
|
provider_id: str,
|
|
messages: list[OpenAIMessageParam] | None = None,
|
|
) -> AsyncIterator[OpenAIChatCompletionChunk]:
|
|
"""Stream OpenAI chat completion chunks, compute metrics, and store the final completion."""
|
|
id = None
|
|
created = None
|
|
choices_data: dict[int, dict[str, Any]] = {}
|
|
|
|
try:
|
|
async for chunk in response:
|
|
# Skip None chunks
|
|
if chunk is None:
|
|
continue
|
|
|
|
# Capture ID and created timestamp from first chunk
|
|
if id is None and chunk.id:
|
|
id = chunk.id
|
|
if created is None and chunk.created:
|
|
created = chunk.created
|
|
|
|
chunk.model = fully_qualified_model_id
|
|
|
|
# Accumulate choice data for final assembly
|
|
if chunk.choices:
|
|
for choice_delta in chunk.choices:
|
|
idx = choice_delta.index
|
|
if idx not in choices_data:
|
|
choices_data[idx] = {
|
|
"content_parts": [],
|
|
"tool_calls_builder": {},
|
|
"finish_reason": "stop",
|
|
"logprobs_content_parts": [],
|
|
}
|
|
current_choice_data = choices_data[idx]
|
|
|
|
if choice_delta.delta:
|
|
delta = choice_delta.delta
|
|
if delta.content:
|
|
current_choice_data["content_parts"].append(delta.content)
|
|
if delta.tool_calls:
|
|
for tool_call_delta in delta.tool_calls:
|
|
tc_idx = tool_call_delta.index
|
|
if tc_idx not in current_choice_data["tool_calls_builder"]:
|
|
current_choice_data["tool_calls_builder"][tc_idx] = {
|
|
"id": None,
|
|
"type": "function",
|
|
"function_name_parts": [],
|
|
"function_arguments_parts": [],
|
|
}
|
|
builder = current_choice_data["tool_calls_builder"][tc_idx]
|
|
if tool_call_delta.id:
|
|
builder["id"] = tool_call_delta.id
|
|
if tool_call_delta.type:
|
|
builder["type"] = tool_call_delta.type
|
|
if tool_call_delta.function:
|
|
if tool_call_delta.function.name:
|
|
builder["function_name_parts"].append(tool_call_delta.function.name)
|
|
if tool_call_delta.function.arguments:
|
|
builder["function_arguments_parts"].append(
|
|
tool_call_delta.function.arguments
|
|
)
|
|
if choice_delta.finish_reason:
|
|
current_choice_data["finish_reason"] = choice_delta.finish_reason
|
|
if choice_delta.logprobs and choice_delta.logprobs.content:
|
|
current_choice_data["logprobs_content_parts"].extend(choice_delta.logprobs.content)
|
|
|
|
# Compute metrics on final chunk
|
|
if chunk.choices and chunk.choices[0].finish_reason:
|
|
completion_text = ""
|
|
for choice_data in choices_data.values():
|
|
completion_text += "".join(choice_data["content_parts"])
|
|
|
|
# Add metrics to the chunk
|
|
if self.telemetry and hasattr(chunk, "usage") and chunk.usage:
|
|
metrics = self._construct_metrics(
|
|
prompt_tokens=chunk.usage.prompt_tokens,
|
|
completion_tokens=chunk.usage.completion_tokens,
|
|
total_tokens=chunk.usage.total_tokens,
|
|
model_id=fully_qualified_model_id,
|
|
provider_id=provider_id,
|
|
)
|
|
for metric in metrics:
|
|
enqueue_event(metric)
|
|
|
|
yield chunk
|
|
finally:
|
|
# Store the final assembled completion
|
|
if id and self.store and messages:
|
|
assembled_choices: list[OpenAIChoice] = []
|
|
for choice_idx, choice_data in choices_data.items():
|
|
content_str = "".join(choice_data["content_parts"])
|
|
assembled_tool_calls: list[OpenAIChatCompletionToolCall] = []
|
|
if choice_data["tool_calls_builder"]:
|
|
for tc_build_data in choice_data["tool_calls_builder"].values():
|
|
if tc_build_data["id"]:
|
|
func_name = "".join(tc_build_data["function_name_parts"])
|
|
func_args = "".join(tc_build_data["function_arguments_parts"])
|
|
assembled_tool_calls.append(
|
|
OpenAIChatCompletionToolCall(
|
|
id=tc_build_data["id"],
|
|
type=tc_build_data["type"],
|
|
function=OpenAIChatCompletionToolCallFunction(
|
|
name=func_name, arguments=func_args
|
|
),
|
|
)
|
|
)
|
|
message = OpenAIAssistantMessageParam(
|
|
role="assistant",
|
|
content=content_str if content_str else None,
|
|
tool_calls=assembled_tool_calls if assembled_tool_calls else None,
|
|
)
|
|
logprobs_content = choice_data["logprobs_content_parts"]
|
|
final_logprobs = OpenAIChoiceLogprobs(content=logprobs_content) if logprobs_content else None
|
|
|
|
assembled_choices.append(
|
|
OpenAIChoice(
|
|
finish_reason=choice_data["finish_reason"],
|
|
index=choice_idx,
|
|
message=message,
|
|
logprobs=final_logprobs,
|
|
)
|
|
)
|
|
|
|
final_response = OpenAIChatCompletion(
|
|
id=id,
|
|
choices=assembled_choices,
|
|
created=created or int(time.time()),
|
|
model=fully_qualified_model_id,
|
|
object="chat.completion",
|
|
)
|
|
logger.debug(f"InferenceRouter.completion_response: {final_response}")
|
|
asyncio.create_task(self.store.store_chat_completion(final_response, messages))
|