Merge branch 'main' into fix/divide-by-zero-exception-faiss-query-vector

This commit is contained in:
Ibrahim Haroon 2025-06-06 11:29:14 -04:00 committed by GitHub
commit b05a3db358
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
19 changed files with 254 additions and 389 deletions

View file

@ -8,7 +8,7 @@ import json
import time
import uuid
from collections.abc import AsyncIterator
from typing import Any, cast
from typing import Any
from openai.types.chat import ChatCompletionToolParam
from pydantic import BaseModel
@ -200,7 +200,6 @@ class ChatCompletionContext(BaseModel):
messages: list[OpenAIMessageParam]
tools: list[ChatCompletionToolParam] | None = None
mcp_tool_to_server: dict[str, OpenAIResponseInputToolMCP]
stream: bool
temperature: float | None
response_format: OpenAIResponseFormatParam
@ -281,49 +280,6 @@ class OpenAIResponsesImpl:
"""
return await self.responses_store.list_response_input_items(response_id, after, before, include, limit, order)
def _is_function_tool_call(
self,
tool_call: OpenAIChatCompletionToolCall,
tools: list[OpenAIResponseInputTool],
) -> bool:
if not tool_call.function:
return False
for t in tools:
if t.type == "function" and t.name == tool_call.function.name:
return True
return False
async def _process_response_choices(
self,
chat_response: OpenAIChatCompletion,
ctx: ChatCompletionContext,
tools: list[OpenAIResponseInputTool] | None,
) -> list[OpenAIResponseOutput]:
"""Handle tool execution and response message creation."""
output_messages: list[OpenAIResponseOutput] = []
# Execute tool calls if any
for choice in chat_response.choices:
if choice.message.tool_calls and tools:
# Assume if the first tool is a function, all tools are functions
if self._is_function_tool_call(choice.message.tool_calls[0], tools):
for tool_call in choice.message.tool_calls:
output_messages.append(
OpenAIResponseOutputMessageFunctionToolCall(
arguments=tool_call.function.arguments or "",
call_id=tool_call.id,
name=tool_call.function.name or "",
id=f"fc_{uuid.uuid4()}",
status="completed",
)
)
else:
tool_messages = await self._execute_tool_and_return_final_output(choice, ctx)
output_messages.extend(tool_messages)
else:
output_messages.append(await _convert_chat_choice_to_response_message(choice))
return output_messages
async def _store_response(
self,
response: OpenAIResponseObject,
@ -370,9 +326,48 @@ class OpenAIResponsesImpl:
tools: list[OpenAIResponseInputTool] | None = None,
max_infer_iters: int | None = 10,
):
stream = False if stream is None else stream
stream = bool(stream)
text = OpenAIResponseText(format=OpenAIResponseTextFormat(type="text")) if text is None else text
stream_gen = self._create_streaming_response(
input=input,
model=model,
instructions=instructions,
previous_response_id=previous_response_id,
store=store,
temperature=temperature,
text=text,
tools=tools,
max_infer_iters=max_infer_iters,
)
if stream:
return stream_gen
else:
response = None
async for stream_chunk in stream_gen:
if stream_chunk.type == "response.completed":
if response is not None:
raise ValueError("The response stream completed multiple times! Earlier response: {response}")
response = stream_chunk.response
# don't leave the generator half complete!
if response is None:
raise ValueError("The response stream never completed")
return response
async def _create_streaming_response(
self,
input: str | list[OpenAIResponseInput],
model: str,
instructions: str | None = None,
previous_response_id: str | None = None,
store: bool | None = True,
temperature: float | None = None,
text: OpenAIResponseText | None = None,
tools: list[OpenAIResponseInputTool] | None = None,
max_infer_iters: int | None = 10,
) -> AsyncIterator[OpenAIResponseObjectStream]:
output_messages: list[OpenAIResponseOutput] = []
# Input preprocessing
@ -383,7 +378,7 @@ class OpenAIResponsesImpl:
# Structured outputs
response_format = await _convert_response_text_to_chat_response_format(text)
# Tool setup
# Tool setup, TODO: refactor this slightly since this can also yield events
chat_tools, mcp_tool_to_server, mcp_list_message = (
await self._convert_response_tools_to_chat_tools(tools) if tools else (None, {}, None)
)
@ -395,136 +390,10 @@ class OpenAIResponsesImpl:
messages=messages,
tools=chat_tools,
mcp_tool_to_server=mcp_tool_to_server,
stream=stream,
temperature=temperature,
response_format=response_format,
)
# Fork to streaming vs non-streaming - let each handle ALL inference rounds
if stream:
return self._create_streaming_response(
ctx=ctx,
output_messages=output_messages,
input=input,
model=model,
store=store,
text=text,
tools=tools,
max_infer_iters=max_infer_iters,
)
else:
return await self._create_non_streaming_response(
ctx=ctx,
output_messages=output_messages,
input=input,
model=model,
store=store,
text=text,
tools=tools,
max_infer_iters=max_infer_iters,
)
async def _create_non_streaming_response(
self,
ctx: ChatCompletionContext,
output_messages: list[OpenAIResponseOutput],
input: str | list[OpenAIResponseInput],
model: str,
store: bool | None,
text: OpenAIResponseText,
tools: list[OpenAIResponseInputTool] | None,
max_infer_iters: int,
) -> OpenAIResponseObject:
n_iter = 0
messages = ctx.messages.copy()
while True:
# Do inference (including the first one)
inference_result = await self.inference_api.openai_chat_completion(
model=ctx.model,
messages=messages,
tools=ctx.tools,
stream=False,
temperature=ctx.temperature,
response_format=ctx.response_format,
)
completion = OpenAIChatCompletion(**inference_result.model_dump())
# Separate function vs non-function tool calls
function_tool_calls = []
non_function_tool_calls = []
for choice in completion.choices:
if choice.message.tool_calls and tools:
for tool_call in choice.message.tool_calls:
if self._is_function_tool_call(tool_call, tools):
function_tool_calls.append(tool_call)
else:
non_function_tool_calls.append(tool_call)
# Process response choices based on tool call types
if function_tool_calls:
# For function tool calls, use existing logic and return immediately
current_output_messages = await self._process_response_choices(
chat_response=completion,
ctx=ctx,
tools=tools,
)
output_messages.extend(current_output_messages)
break
elif non_function_tool_calls:
# For non-function tool calls, execute them and continue loop
for choice in completion.choices:
tool_outputs, tool_response_messages = await self._execute_tool_calls_only(choice, ctx)
output_messages.extend(tool_outputs)
# Add assistant message and tool responses to messages for next iteration
messages.append(choice.message)
messages.extend(tool_response_messages)
n_iter += 1
if n_iter >= max_infer_iters:
break
# Continue with next iteration of the loop
continue
else:
# No tool calls - convert response to message and we're done
for choice in completion.choices:
output_messages.append(await _convert_chat_choice_to_response_message(choice))
break
response = OpenAIResponseObject(
created_at=completion.created,
id=f"resp-{uuid.uuid4()}",
model=model,
object="response",
status="completed",
output=output_messages,
text=text,
)
logger.debug(f"OpenAI Responses response: {response}")
# Store response if requested
if store:
await self._store_response(
response=response,
input=input,
)
return response
async def _create_streaming_response(
self,
ctx: ChatCompletionContext,
output_messages: list[OpenAIResponseOutput],
input: str | list[OpenAIResponseInput],
model: str,
store: bool | None,
text: OpenAIResponseText,
tools: list[OpenAIResponseInputTool] | None,
max_infer_iters: int | None,
) -> AsyncIterator[OpenAIResponseObjectStream]:
# Create initial response and emit response.created immediately
response_id = f"resp-{uuid.uuid4()}"
created_at = int(time.time())
@ -539,15 +408,13 @@ class OpenAIResponsesImpl:
text=text,
)
# Emit response.created immediately
yield OpenAIResponseObjectStreamResponseCreated(response=initial_response)
# Implement tool execution loop for streaming - handle ALL inference rounds including the first
n_iter = 0
messages = ctx.messages.copy()
while True:
current_inference_result = await self.inference_api.openai_chat_completion(
completion_result = await self.inference_api.openai_chat_completion(
model=ctx.model,
messages=messages,
tools=ctx.tools,
@ -568,7 +435,7 @@ class OpenAIResponsesImpl:
# Create a placeholder message item for delta events
message_item_id = f"msg_{uuid.uuid4()}"
async for chunk in current_inference_result:
async for chunk in completion_result:
chat_response_id = chunk.id
chunk_created = chunk.created
chunk_model = chunk.model
@ -628,50 +495,55 @@ class OpenAIResponsesImpl:
model=chunk_model,
)
# Separate function vs non-function tool calls
function_tool_calls = []
non_function_tool_calls = []
next_turn_messages = messages.copy()
for choice in current_response.choices:
next_turn_messages.append(choice.message)
if choice.message.tool_calls and tools:
for tool_call in choice.message.tool_calls:
if self._is_function_tool_call(tool_call, tools):
if _is_function_tool_call(tool_call, tools):
function_tool_calls.append(tool_call)
else:
non_function_tool_calls.append(tool_call)
# Process response choices based on tool call types
if function_tool_calls:
# For function tool calls, use existing logic and break
current_output_messages = await self._process_response_choices(
chat_response=current_response,
ctx=ctx,
tools=tools,
)
output_messages.extend(current_output_messages)
break
elif non_function_tool_calls:
# For non-function tool calls, execute them and continue loop
for choice in current_response.choices:
tool_outputs, tool_response_messages = await self._execute_tool_calls_only(choice, ctx)
output_messages.extend(tool_outputs)
# Add assistant message and tool responses to messages for next iteration
messages.append(choice.message)
messages.extend(tool_response_messages)
n_iter += 1
if n_iter >= (max_infer_iters or 10):
break
# Continue with next iteration of the loop
continue
else:
# No tool calls - convert response to message and we're done
for choice in current_response.choices:
else:
output_messages.append(await _convert_chat_choice_to_response_message(choice))
# execute non-function tool calls
for tool_call in non_function_tool_calls:
tool_call_log, tool_response_message = await self._execute_tool_call(tool_call, ctx)
if tool_call_log:
output_messages.append(tool_call_log)
if tool_response_message:
next_turn_messages.append(tool_response_message)
for tool_call in function_tool_calls:
output_messages.append(
OpenAIResponseOutputMessageFunctionToolCall(
arguments=tool_call.function.arguments or "",
call_id=tool_call.id,
name=tool_call.function.name or "",
id=f"fc_{uuid.uuid4()}",
status="completed",
)
)
if not function_tool_calls and not non_function_tool_calls:
break
if function_tool_calls:
logger.info("Exiting inference loop since there is a function (client-side) tool call")
break
n_iter += 1
if n_iter >= max_infer_iters:
logger.info(f"Exiting inference loop since iteration count({n_iter}) exceeds {max_infer_iters=}")
break
messages = next_turn_messages
# Create final response
final_response = OpenAIResponseObject(
created_at=created_at,
@ -683,15 +555,15 @@ class OpenAIResponsesImpl:
output=output_messages,
)
# Emit response.completed
yield OpenAIResponseObjectStreamResponseCompleted(response=final_response)
if store:
await self._store_response(
response=final_response,
input=input,
)
# Emit response.completed
yield OpenAIResponseObjectStreamResponseCompleted(response=final_response)
async def _convert_response_tools_to_chat_tools(
self, tools: list[OpenAIResponseInputTool]
) -> tuple[
@ -784,73 +656,6 @@ class OpenAIResponsesImpl:
raise ValueError(f"Llama Stack OpenAI Responses does not yet support tool type: {input_tool.type}")
return chat_tools, mcp_tool_to_server, mcp_list_message
async def _execute_tool_calls_only(
self,
choice: OpenAIChoice,
ctx: ChatCompletionContext,
) -> tuple[list[OpenAIResponseOutput], list[OpenAIMessageParam]]:
"""Execute tool calls and return output messages and tool response messages for next inference."""
output_messages: list[OpenAIResponseOutput] = []
tool_response_messages: list[OpenAIMessageParam] = []
if not isinstance(choice.message, OpenAIAssistantMessageParam):
return output_messages, tool_response_messages
if not choice.message.tool_calls:
return output_messages, tool_response_messages
for tool_call in choice.message.tool_calls:
tool_call_log, further_input = await self._execute_tool_call(tool_call, ctx)
if tool_call_log:
output_messages.append(tool_call_log)
if further_input:
tool_response_messages.append(further_input)
return output_messages, tool_response_messages
async def _execute_tool_and_return_final_output(
self,
choice: OpenAIChoice,
ctx: ChatCompletionContext,
) -> list[OpenAIResponseOutput]:
output_messages: list[OpenAIResponseOutput] = []
if not isinstance(choice.message, OpenAIAssistantMessageParam):
return output_messages
if not choice.message.tool_calls:
return output_messages
next_turn_messages = ctx.messages.copy()
# Add the assistant message with tool_calls response to the messages list
next_turn_messages.append(choice.message)
for tool_call in choice.message.tool_calls:
# TODO: telemetry spans for tool calls
tool_call_log, further_input = await self._execute_tool_call(tool_call, ctx)
if tool_call_log:
output_messages.append(tool_call_log)
if further_input:
next_turn_messages.append(further_input)
tool_results_chat_response = await self.inference_api.openai_chat_completion(
model=ctx.model,
messages=next_turn_messages,
stream=ctx.stream,
temperature=ctx.temperature,
)
# type cast to appease mypy: this is needed because we don't handle streaming properly :)
tool_results_chat_response = cast(OpenAIChatCompletion, tool_results_chat_response)
# Huge TODO: these are NOT the final outputs, we must keep the loop going
tool_final_outputs = [
await _convert_chat_choice_to_response_message(choice) for choice in tool_results_chat_response.choices
]
# TODO: Wire in annotations with URLs, titles, etc to these output messages
output_messages.extend(tool_final_outputs)
return output_messages
async def _execute_tool_call(
self,
tool_call: OpenAIChatCompletionToolCall,
@ -939,3 +744,15 @@ class OpenAIResponsesImpl:
input_message = OpenAIToolMessageParam(content=text, tool_call_id=tool_call_id)
return message, input_message
def _is_function_tool_call(
tool_call: OpenAIChatCompletionToolCall,
tools: list[OpenAIResponseInputTool],
) -> bool:
if not tool_call.function:
return False
for t in tools:
if t.type == "function" and t.name == tool_call.function.name:
return True
return False