mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-10-24 08:47:26 +00:00
Some checks failed
Integration Auth Tests / test-matrix (oauth2_token) (push) Failing after 0s
Python Package Build Test / build (3.12) (push) Failing after 1s
Unit Tests / unit-tests (3.13) (push) Failing after 4s
SqlStore Integration Tests / test-postgres (3.12) (push) Failing after 0s
SqlStore Integration Tests / test-postgres (3.13) (push) Failing after 0s
Test External Providers Installed via Module / test-external-providers-from-module (venv) (push) Has been skipped
Python Package Build Test / build (3.13) (push) Failing after 1s
Integration Tests (Replay) / Integration Tests (, , , client=, ) (push) Failing after 3s
Vector IO Integration Tests / test-matrix (push) Failing after 5s
Test External API and Providers / test-external (venv) (push) Failing after 5s
Unit Tests / unit-tests (3.12) (push) Failing after 4s
API Conformance Tests / check-schema-compatibility (push) Successful in 10s
UI Tests / ui-tests (22) (push) Successful in 40s
Pre-commit / pre-commit (push) Successful in 1m23s
Applies the same pattern from https://github.com/llamastack/llama-stack/pull/3777 to embeddings and vector_stores.create() endpoints. This should _not_ be a breaking change since (a) our tests were already using the `extra_body` parameter when passing in to the backend (b) but the backend probably wasn't extracting the parameters correctly. This PR will fix that. Updated APIs: `openai_embeddings(), openai_create_vector_store(), openai_create_vector_store_file_batch()`
586 lines
26 KiB
Python
586 lines
26 KiB
Python
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
# All rights reserved.
|
|
#
|
|
# This source code is licensed under the terms described in the LICENSE file in
|
|
# the root directory of this source tree.
|
|
|
|
import asyncio
|
|
import time
|
|
from collections.abc import AsyncGenerator, AsyncIterator
|
|
from datetime import UTC, datetime
|
|
from typing import Annotated, Any
|
|
|
|
from fastapi import Body
|
|
from openai.types.chat import ChatCompletionToolChoiceOptionParam as OpenAIChatCompletionToolChoiceOptionParam
|
|
from openai.types.chat import ChatCompletionToolParam as OpenAIChatCompletionToolParam
|
|
from pydantic import TypeAdapter
|
|
|
|
from llama_stack.apis.common.content_types import (
|
|
InterleavedContent,
|
|
)
|
|
from llama_stack.apis.common.errors import ModelNotFoundError, ModelTypeError
|
|
from llama_stack.apis.inference import (
|
|
ChatCompletionResponse,
|
|
ChatCompletionResponseEventType,
|
|
ChatCompletionResponseStreamChunk,
|
|
CompletionMessage,
|
|
CompletionResponse,
|
|
CompletionResponseStreamChunk,
|
|
Inference,
|
|
ListOpenAIChatCompletionResponse,
|
|
Message,
|
|
OpenAIAssistantMessageParam,
|
|
OpenAIChatCompletion,
|
|
OpenAIChatCompletionChunk,
|
|
OpenAIChatCompletionRequestWithExtraBody,
|
|
OpenAIChatCompletionToolCall,
|
|
OpenAIChatCompletionToolCallFunction,
|
|
OpenAIChoice,
|
|
OpenAIChoiceLogprobs,
|
|
OpenAICompletion,
|
|
OpenAICompletionRequestWithExtraBody,
|
|
OpenAICompletionWithInputMessages,
|
|
OpenAIEmbeddingsRequestWithExtraBody,
|
|
OpenAIEmbeddingsResponse,
|
|
OpenAIMessageParam,
|
|
Order,
|
|
StopReason,
|
|
ToolPromptFormat,
|
|
)
|
|
from llama_stack.apis.models import Model, ModelType
|
|
from llama_stack.apis.telemetry import MetricEvent, MetricInResponse, Telemetry
|
|
from llama_stack.log import get_logger
|
|
from llama_stack.models.llama.llama3.chat_format import ChatFormat
|
|
from llama_stack.models.llama.llama3.tokenizer import Tokenizer
|
|
from llama_stack.providers.datatypes import HealthResponse, HealthStatus, RoutingTable
|
|
from llama_stack.providers.utils.inference.inference_store import InferenceStore
|
|
from llama_stack.providers.utils.telemetry.tracing import enqueue_event, get_current_span
|
|
|
|
logger = get_logger(name=__name__, category="core::routers")
|
|
|
|
|
|
class InferenceRouter(Inference):
|
|
"""Routes to an provider based on the model"""
|
|
|
|
def __init__(
|
|
self,
|
|
routing_table: RoutingTable,
|
|
telemetry: Telemetry | None = None,
|
|
store: InferenceStore | None = None,
|
|
) -> None:
|
|
logger.debug("Initializing InferenceRouter")
|
|
self.routing_table = routing_table
|
|
self.telemetry = telemetry
|
|
self.store = store
|
|
if self.telemetry:
|
|
self.tokenizer = Tokenizer.get_instance()
|
|
self.formatter = ChatFormat(self.tokenizer)
|
|
|
|
async def initialize(self) -> None:
|
|
logger.debug("InferenceRouter.initialize")
|
|
|
|
async def shutdown(self) -> None:
|
|
logger.debug("InferenceRouter.shutdown")
|
|
if self.store:
|
|
try:
|
|
await self.store.shutdown()
|
|
except Exception as e:
|
|
logger.warning(f"Error during InferenceStore shutdown: {e}")
|
|
|
|
async def register_model(
|
|
self,
|
|
model_id: str,
|
|
provider_model_id: str | None = None,
|
|
provider_id: str | None = None,
|
|
metadata: dict[str, Any] | None = None,
|
|
model_type: ModelType | None = None,
|
|
) -> None:
|
|
logger.debug(
|
|
f"InferenceRouter.register_model: {model_id=} {provider_model_id=} {provider_id=} {metadata=} {model_type=}",
|
|
)
|
|
await self.routing_table.register_model(model_id, provider_model_id, provider_id, metadata, model_type)
|
|
|
|
def _construct_metrics(
|
|
self,
|
|
prompt_tokens: int,
|
|
completion_tokens: int,
|
|
total_tokens: int,
|
|
model: Model,
|
|
) -> list[MetricEvent]:
|
|
"""Constructs a list of MetricEvent objects containing token usage metrics.
|
|
|
|
Args:
|
|
prompt_tokens: Number of tokens in the prompt
|
|
completion_tokens: Number of tokens in the completion
|
|
total_tokens: Total number of tokens used
|
|
model: Model object containing model_id and provider_id
|
|
|
|
Returns:
|
|
List of MetricEvent objects with token usage metrics
|
|
"""
|
|
span = get_current_span()
|
|
if span is None:
|
|
logger.warning("No span found for token usage metrics")
|
|
return []
|
|
|
|
metrics = [
|
|
("prompt_tokens", prompt_tokens),
|
|
("completion_tokens", completion_tokens),
|
|
("total_tokens", total_tokens),
|
|
]
|
|
metric_events = []
|
|
for metric_name, value in metrics:
|
|
metric_events.append(
|
|
MetricEvent(
|
|
trace_id=span.trace_id,
|
|
span_id=span.span_id,
|
|
metric=metric_name,
|
|
value=value,
|
|
timestamp=datetime.now(UTC),
|
|
unit="tokens",
|
|
attributes={
|
|
"model_id": model.model_id,
|
|
"provider_id": model.provider_id,
|
|
},
|
|
)
|
|
)
|
|
return metric_events
|
|
|
|
async def _compute_and_log_token_usage(
|
|
self,
|
|
prompt_tokens: int,
|
|
completion_tokens: int,
|
|
total_tokens: int,
|
|
model: Model,
|
|
) -> list[MetricInResponse]:
|
|
metrics = self._construct_metrics(prompt_tokens, completion_tokens, total_tokens, model)
|
|
if self.telemetry:
|
|
for metric in metrics:
|
|
enqueue_event(metric)
|
|
return [MetricInResponse(metric=metric.metric, value=metric.value) for metric in metrics]
|
|
|
|
async def _count_tokens(
|
|
self,
|
|
messages: list[Message] | InterleavedContent,
|
|
tool_prompt_format: ToolPromptFormat | None = None,
|
|
) -> int | None:
|
|
if not hasattr(self, "formatter") or self.formatter is None:
|
|
return None
|
|
|
|
if isinstance(messages, list):
|
|
encoded = self.formatter.encode_dialog_prompt(messages, tool_prompt_format)
|
|
else:
|
|
encoded = self.formatter.encode_content(messages)
|
|
return len(encoded.tokens) if encoded and encoded.tokens else 0
|
|
|
|
async def _get_model(self, model_id: str, expected_model_type: str) -> Model:
|
|
"""takes a model id and gets model after ensuring that it is accessible and of the correct type"""
|
|
model = await self.routing_table.get_model(model_id)
|
|
if model is None:
|
|
raise ModelNotFoundError(model_id)
|
|
if model.model_type != expected_model_type:
|
|
raise ModelTypeError(model_id, model.model_type, expected_model_type)
|
|
return model
|
|
|
|
async def openai_completion(
|
|
self,
|
|
params: Annotated[OpenAICompletionRequestWithExtraBody, Body(...)],
|
|
) -> OpenAICompletion:
|
|
logger.debug(
|
|
f"InferenceRouter.openai_completion: model={params.model}, stream={params.stream}, prompt={params.prompt}",
|
|
)
|
|
model_obj = await self._get_model(params.model, ModelType.llm)
|
|
|
|
# Update params with the resolved model identifier
|
|
params.model = model_obj.identifier
|
|
|
|
provider = await self.routing_table.get_provider_impl(model_obj.identifier)
|
|
if params.stream:
|
|
return await provider.openai_completion(params)
|
|
# TODO: Metrics do NOT work with openai_completion stream=True due to the fact
|
|
# that we do not return an AsyncIterator, our tests expect a stream of chunks we cannot intercept currently.
|
|
|
|
response = await provider.openai_completion(params)
|
|
if self.telemetry:
|
|
metrics = self._construct_metrics(
|
|
prompt_tokens=response.usage.prompt_tokens,
|
|
completion_tokens=response.usage.completion_tokens,
|
|
total_tokens=response.usage.total_tokens,
|
|
model=model_obj,
|
|
)
|
|
for metric in metrics:
|
|
enqueue_event(metric)
|
|
|
|
# these metrics will show up in the client response.
|
|
response.metrics = (
|
|
metrics if not hasattr(response, "metrics") or response.metrics is None else response.metrics + metrics
|
|
)
|
|
return response
|
|
|
|
async def openai_chat_completion(
|
|
self,
|
|
params: Annotated[OpenAIChatCompletionRequestWithExtraBody, Body(...)],
|
|
) -> OpenAIChatCompletion | AsyncIterator[OpenAIChatCompletionChunk]:
|
|
logger.debug(
|
|
f"InferenceRouter.openai_chat_completion: model={params.model}, stream={params.stream}, messages={params.messages}",
|
|
)
|
|
model_obj = await self._get_model(params.model, ModelType.llm)
|
|
|
|
# Use the OpenAI client for a bit of extra input validation without
|
|
# exposing the OpenAI client itself as part of our API surface
|
|
if params.tool_choice:
|
|
TypeAdapter(OpenAIChatCompletionToolChoiceOptionParam).validate_python(params.tool_choice)
|
|
if params.tools is None:
|
|
raise ValueError("'tool_choice' is only allowed when 'tools' is also provided")
|
|
if params.tools:
|
|
for tool in params.tools:
|
|
TypeAdapter(OpenAIChatCompletionToolParam).validate_python(tool)
|
|
|
|
# Some providers make tool calls even when tool_choice is "none"
|
|
# so just clear them both out to avoid unexpected tool calls
|
|
if params.tool_choice == "none" and params.tools is not None:
|
|
params.tool_choice = None
|
|
params.tools = None
|
|
|
|
# Update params with the resolved model identifier
|
|
params.model = model_obj.identifier
|
|
|
|
provider = await self.routing_table.get_provider_impl(model_obj.identifier)
|
|
if params.stream:
|
|
response_stream = await provider.openai_chat_completion(params)
|
|
|
|
# For streaming, the provider returns AsyncIterator[OpenAIChatCompletionChunk]
|
|
# We need to add metrics to each chunk and store the final completion
|
|
return self.stream_tokens_and_compute_metrics_openai_chat(
|
|
response=response_stream,
|
|
model=model_obj,
|
|
messages=params.messages,
|
|
)
|
|
|
|
response = await self._nonstream_openai_chat_completion(provider, params)
|
|
|
|
# Store the response with the ID that will be returned to the client
|
|
if self.store:
|
|
asyncio.create_task(self.store.store_chat_completion(response, params.messages))
|
|
|
|
if self.telemetry:
|
|
metrics = self._construct_metrics(
|
|
prompt_tokens=response.usage.prompt_tokens,
|
|
completion_tokens=response.usage.completion_tokens,
|
|
total_tokens=response.usage.total_tokens,
|
|
model=model_obj,
|
|
)
|
|
for metric in metrics:
|
|
enqueue_event(metric)
|
|
# these metrics will show up in the client response.
|
|
response.metrics = (
|
|
metrics if not hasattr(response, "metrics") or response.metrics is None else response.metrics + metrics
|
|
)
|
|
return response
|
|
|
|
async def openai_embeddings(
|
|
self,
|
|
params: Annotated[OpenAIEmbeddingsRequestWithExtraBody, Body(...)],
|
|
) -> OpenAIEmbeddingsResponse:
|
|
logger.debug(
|
|
f"InferenceRouter.openai_embeddings: model={params.model}, input_type={type(params.input)}, encoding_format={params.encoding_format}, dimensions={params.dimensions}",
|
|
)
|
|
model_obj = await self._get_model(params.model, ModelType.embedding)
|
|
|
|
# Update model to use resolved identifier
|
|
params.model = model_obj.identifier
|
|
|
|
provider = await self.routing_table.get_provider_impl(model_obj.identifier)
|
|
return await provider.openai_embeddings(params)
|
|
|
|
async def list_chat_completions(
|
|
self,
|
|
after: str | None = None,
|
|
limit: int | None = 20,
|
|
model: str | None = None,
|
|
order: Order | None = Order.desc,
|
|
) -> ListOpenAIChatCompletionResponse:
|
|
if self.store:
|
|
return await self.store.list_chat_completions(after, limit, model, order)
|
|
raise NotImplementedError("List chat completions is not supported: inference store is not configured.")
|
|
|
|
async def get_chat_completion(self, completion_id: str) -> OpenAICompletionWithInputMessages:
|
|
if self.store:
|
|
return await self.store.get_chat_completion(completion_id)
|
|
raise NotImplementedError("Get chat completion is not supported: inference store is not configured.")
|
|
|
|
async def _nonstream_openai_chat_completion(
|
|
self, provider: Inference, params: OpenAIChatCompletionRequestWithExtraBody
|
|
) -> OpenAIChatCompletion:
|
|
response = await provider.openai_chat_completion(params)
|
|
for choice in response.choices:
|
|
# some providers return an empty list for no tool calls in non-streaming responses
|
|
# but the OpenAI API returns None. So, set tool_calls to None if it's empty
|
|
if choice.message and choice.message.tool_calls is not None and len(choice.message.tool_calls) == 0:
|
|
choice.message.tool_calls = None
|
|
return response
|
|
|
|
async def health(self) -> dict[str, HealthResponse]:
|
|
health_statuses = {}
|
|
timeout = 1 # increasing the timeout to 1 second for health checks
|
|
for provider_id, impl in self.routing_table.impls_by_provider_id.items():
|
|
try:
|
|
# check if the provider has a health method
|
|
if not hasattr(impl, "health"):
|
|
continue
|
|
health = await asyncio.wait_for(impl.health(), timeout=timeout)
|
|
health_statuses[provider_id] = health
|
|
except TimeoutError:
|
|
health_statuses[provider_id] = HealthResponse(
|
|
status=HealthStatus.ERROR,
|
|
message=f"Health check timed out after {timeout} seconds",
|
|
)
|
|
except NotImplementedError:
|
|
health_statuses[provider_id] = HealthResponse(status=HealthStatus.NOT_IMPLEMENTED)
|
|
except Exception as e:
|
|
health_statuses[provider_id] = HealthResponse(
|
|
status=HealthStatus.ERROR, message=f"Health check failed: {str(e)}"
|
|
)
|
|
return health_statuses
|
|
|
|
async def stream_tokens_and_compute_metrics(
|
|
self,
|
|
response,
|
|
prompt_tokens,
|
|
model,
|
|
tool_prompt_format: ToolPromptFormat | None = None,
|
|
) -> AsyncGenerator[ChatCompletionResponseStreamChunk, None] | AsyncGenerator[CompletionResponseStreamChunk, None]:
|
|
completion_text = ""
|
|
async for chunk in response:
|
|
complete = False
|
|
if hasattr(chunk, "event"): # only ChatCompletions have .event
|
|
if chunk.event.event_type == ChatCompletionResponseEventType.progress:
|
|
if chunk.event.delta.type == "text":
|
|
completion_text += chunk.event.delta.text
|
|
if chunk.event.event_type == ChatCompletionResponseEventType.complete:
|
|
complete = True
|
|
completion_tokens = await self._count_tokens(
|
|
[
|
|
CompletionMessage(
|
|
content=completion_text,
|
|
stop_reason=StopReason.end_of_turn,
|
|
)
|
|
],
|
|
tool_prompt_format=tool_prompt_format,
|
|
)
|
|
else:
|
|
if hasattr(chunk, "delta"):
|
|
completion_text += chunk.delta
|
|
if hasattr(chunk, "stop_reason") and chunk.stop_reason and self.telemetry:
|
|
complete = True
|
|
completion_tokens = await self._count_tokens(completion_text)
|
|
# if we are done receiving tokens
|
|
if complete:
|
|
total_tokens = (prompt_tokens or 0) + (completion_tokens or 0)
|
|
|
|
# Create a separate span for streaming completion metrics
|
|
if self.telemetry:
|
|
# Log metrics in the new span context
|
|
completion_metrics = self._construct_metrics(
|
|
prompt_tokens=prompt_tokens,
|
|
completion_tokens=completion_tokens,
|
|
total_tokens=total_tokens,
|
|
model=model,
|
|
)
|
|
for metric in completion_metrics:
|
|
if metric.metric in [
|
|
"completion_tokens",
|
|
"total_tokens",
|
|
]: # Only log completion and total tokens
|
|
enqueue_event(metric)
|
|
|
|
# Return metrics in response
|
|
async_metrics = [
|
|
MetricInResponse(metric=metric.metric, value=metric.value) for metric in completion_metrics
|
|
]
|
|
chunk.metrics = async_metrics if chunk.metrics is None else chunk.metrics + async_metrics
|
|
else:
|
|
# Fallback if no telemetry
|
|
completion_metrics = self._construct_metrics(
|
|
prompt_tokens or 0,
|
|
completion_tokens or 0,
|
|
total_tokens,
|
|
model,
|
|
)
|
|
async_metrics = [
|
|
MetricInResponse(metric=metric.metric, value=metric.value) for metric in completion_metrics
|
|
]
|
|
chunk.metrics = async_metrics if chunk.metrics is None else chunk.metrics + async_metrics
|
|
yield chunk
|
|
|
|
async def count_tokens_and_compute_metrics(
|
|
self,
|
|
response: ChatCompletionResponse | CompletionResponse,
|
|
prompt_tokens,
|
|
model,
|
|
tool_prompt_format: ToolPromptFormat | None = None,
|
|
):
|
|
if isinstance(response, ChatCompletionResponse):
|
|
content = [response.completion_message]
|
|
else:
|
|
content = response.content
|
|
completion_tokens = await self._count_tokens(messages=content, tool_prompt_format=tool_prompt_format)
|
|
total_tokens = (prompt_tokens or 0) + (completion_tokens or 0)
|
|
|
|
# Create a separate span for completion metrics
|
|
if self.telemetry:
|
|
# Log metrics in the new span context
|
|
completion_metrics = self._construct_metrics(
|
|
prompt_tokens=prompt_tokens,
|
|
completion_tokens=completion_tokens,
|
|
total_tokens=total_tokens,
|
|
model=model,
|
|
)
|
|
for metric in completion_metrics:
|
|
if metric.metric in ["completion_tokens", "total_tokens"]: # Only log completion and total tokens
|
|
enqueue_event(metric)
|
|
|
|
# Return metrics in response
|
|
return [MetricInResponse(metric=metric.metric, value=metric.value) for metric in completion_metrics]
|
|
|
|
# Fallback if no telemetry
|
|
metrics = self._construct_metrics(
|
|
prompt_tokens or 0,
|
|
completion_tokens or 0,
|
|
total_tokens,
|
|
model,
|
|
)
|
|
return [MetricInResponse(metric=metric.metric, value=metric.value) for metric in metrics]
|
|
|
|
async def stream_tokens_and_compute_metrics_openai_chat(
|
|
self,
|
|
response: AsyncIterator[OpenAIChatCompletionChunk],
|
|
model: Model,
|
|
messages: list[OpenAIMessageParam] | None = None,
|
|
) -> AsyncIterator[OpenAIChatCompletionChunk]:
|
|
"""Stream OpenAI chat completion chunks, compute metrics, and store the final completion."""
|
|
id = None
|
|
created = None
|
|
choices_data: dict[int, dict[str, Any]] = {}
|
|
|
|
try:
|
|
async for chunk in response:
|
|
# Skip None chunks
|
|
if chunk is None:
|
|
continue
|
|
|
|
# Capture ID and created timestamp from first chunk
|
|
if id is None and chunk.id:
|
|
id = chunk.id
|
|
if created is None and chunk.created:
|
|
created = chunk.created
|
|
|
|
# Accumulate choice data for final assembly
|
|
if chunk.choices:
|
|
for choice_delta in chunk.choices:
|
|
idx = choice_delta.index
|
|
if idx not in choices_data:
|
|
choices_data[idx] = {
|
|
"content_parts": [],
|
|
"tool_calls_builder": {},
|
|
"finish_reason": "stop",
|
|
"logprobs_content_parts": [],
|
|
}
|
|
current_choice_data = choices_data[idx]
|
|
|
|
if choice_delta.delta:
|
|
delta = choice_delta.delta
|
|
if delta.content:
|
|
current_choice_data["content_parts"].append(delta.content)
|
|
if delta.tool_calls:
|
|
for tool_call_delta in delta.tool_calls:
|
|
tc_idx = tool_call_delta.index
|
|
if tc_idx not in current_choice_data["tool_calls_builder"]:
|
|
current_choice_data["tool_calls_builder"][tc_idx] = {
|
|
"id": None,
|
|
"type": "function",
|
|
"function_name_parts": [],
|
|
"function_arguments_parts": [],
|
|
}
|
|
builder = current_choice_data["tool_calls_builder"][tc_idx]
|
|
if tool_call_delta.id:
|
|
builder["id"] = tool_call_delta.id
|
|
if tool_call_delta.type:
|
|
builder["type"] = tool_call_delta.type
|
|
if tool_call_delta.function:
|
|
if tool_call_delta.function.name:
|
|
builder["function_name_parts"].append(tool_call_delta.function.name)
|
|
if tool_call_delta.function.arguments:
|
|
builder["function_arguments_parts"].append(
|
|
tool_call_delta.function.arguments
|
|
)
|
|
if choice_delta.finish_reason:
|
|
current_choice_data["finish_reason"] = choice_delta.finish_reason
|
|
if choice_delta.logprobs and choice_delta.logprobs.content:
|
|
current_choice_data["logprobs_content_parts"].extend(choice_delta.logprobs.content)
|
|
|
|
# Compute metrics on final chunk
|
|
if chunk.choices and chunk.choices[0].finish_reason:
|
|
completion_text = ""
|
|
for choice_data in choices_data.values():
|
|
completion_text += "".join(choice_data["content_parts"])
|
|
|
|
# Add metrics to the chunk
|
|
if self.telemetry and hasattr(chunk, "usage") and chunk.usage:
|
|
metrics = self._construct_metrics(
|
|
prompt_tokens=chunk.usage.prompt_tokens,
|
|
completion_tokens=chunk.usage.completion_tokens,
|
|
total_tokens=chunk.usage.total_tokens,
|
|
model=model,
|
|
)
|
|
for metric in metrics:
|
|
enqueue_event(metric)
|
|
|
|
yield chunk
|
|
finally:
|
|
# Store the final assembled completion
|
|
if id and self.store and messages:
|
|
assembled_choices: list[OpenAIChoice] = []
|
|
for choice_idx, choice_data in choices_data.items():
|
|
content_str = "".join(choice_data["content_parts"])
|
|
assembled_tool_calls: list[OpenAIChatCompletionToolCall] = []
|
|
if choice_data["tool_calls_builder"]:
|
|
for tc_build_data in choice_data["tool_calls_builder"].values():
|
|
if tc_build_data["id"]:
|
|
func_name = "".join(tc_build_data["function_name_parts"])
|
|
func_args = "".join(tc_build_data["function_arguments_parts"])
|
|
assembled_tool_calls.append(
|
|
OpenAIChatCompletionToolCall(
|
|
id=tc_build_data["id"],
|
|
type=tc_build_data["type"],
|
|
function=OpenAIChatCompletionToolCallFunction(
|
|
name=func_name, arguments=func_args
|
|
),
|
|
)
|
|
)
|
|
message = OpenAIAssistantMessageParam(
|
|
role="assistant",
|
|
content=content_str if content_str else None,
|
|
tool_calls=assembled_tool_calls if assembled_tool_calls else None,
|
|
)
|
|
logprobs_content = choice_data["logprobs_content_parts"]
|
|
final_logprobs = OpenAIChoiceLogprobs(content=logprobs_content) if logprobs_content else None
|
|
|
|
assembled_choices.append(
|
|
OpenAIChoice(
|
|
finish_reason=choice_data["finish_reason"],
|
|
index=choice_idx,
|
|
message=message,
|
|
logprobs=final_logprobs,
|
|
)
|
|
)
|
|
|
|
final_response = OpenAIChatCompletion(
|
|
id=id,
|
|
choices=assembled_choices,
|
|
created=created or int(time.time()),
|
|
model=model.identifier,
|
|
object="chat.completion",
|
|
)
|
|
logger.debug(f"InferenceRouter.completion_response: {final_response}")
|
|
asyncio.create_task(self.store.store_chat_completion(final_response, messages))
|