Merge branch 'main' into milvus/search-modes

This commit is contained in:
Francisco Arceo 2025-08-14 07:36:48 -06:00 committed by GitHub
commit 2d0d13b826
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
7 changed files with 621 additions and 56 deletions

View file

@ -623,6 +623,62 @@ class OpenAIResponseObjectStreamResponseMcpCallCompleted(BaseModel):
type: Literal["response.mcp_call.completed"] = "response.mcp_call.completed"
@json_schema_type
class OpenAIResponseContentPartOutputText(BaseModel):
type: Literal["output_text"] = "output_text"
text: str
# TODO: add annotations, logprobs, etc.
@json_schema_type
class OpenAIResponseContentPartRefusal(BaseModel):
type: Literal["refusal"] = "refusal"
refusal: str
OpenAIResponseContentPart = Annotated[
OpenAIResponseContentPartOutputText | OpenAIResponseContentPartRefusal,
Field(discriminator="type"),
]
register_schema(OpenAIResponseContentPart, name="OpenAIResponseContentPart")
@json_schema_type
class OpenAIResponseObjectStreamResponseContentPartAdded(BaseModel):
"""Streaming event for when a new content part is added to a response item.
:param response_id: Unique identifier of the response containing this content
:param item_id: Unique identifier of the output item containing this content part
:param part: The content part that was added
:param sequence_number: Sequential number for ordering streaming events
:param type: Event type identifier, always "response.content_part.added"
"""
response_id: str
item_id: str
part: OpenAIResponseContentPart
sequence_number: int
type: Literal["response.content_part.added"] = "response.content_part.added"
@json_schema_type
class OpenAIResponseObjectStreamResponseContentPartDone(BaseModel):
"""Streaming event for when a content part is completed.
:param response_id: Unique identifier of the response containing this content
:param item_id: Unique identifier of the output item containing this content part
:param part: The completed content part
:param sequence_number: Sequential number for ordering streaming events
:param type: Event type identifier, always "response.content_part.done"
"""
response_id: str
item_id: str
part: OpenAIResponseContentPart
sequence_number: int
type: Literal["response.content_part.done"] = "response.content_part.done"
OpenAIResponseObjectStream = Annotated[
OpenAIResponseObjectStreamResponseCreated
| OpenAIResponseObjectStreamResponseOutputItemAdded
@ -642,6 +698,8 @@ OpenAIResponseObjectStream = Annotated[
| OpenAIResponseObjectStreamResponseMcpCallInProgress
| OpenAIResponseObjectStreamResponseMcpCallFailed
| OpenAIResponseObjectStreamResponseMcpCallCompleted
| OpenAIResponseObjectStreamResponseContentPartAdded
| OpenAIResponseObjectStreamResponseContentPartDone
| OpenAIResponseObjectStreamResponseCompleted,
Field(discriminator="type"),
]

View file

@ -20,6 +20,7 @@ from llama_stack.apis.agents.openai_responses import (
ListOpenAIResponseInputItem,
ListOpenAIResponseObject,
OpenAIDeleteResponseObject,
OpenAIResponseContentPartOutputText,
OpenAIResponseInput,
OpenAIResponseInputFunctionToolCallOutput,
OpenAIResponseInputMessageContent,
@ -32,12 +33,22 @@ from llama_stack.apis.agents.openai_responses import (
OpenAIResponseObject,
OpenAIResponseObjectStream,
OpenAIResponseObjectStreamResponseCompleted,
OpenAIResponseObjectStreamResponseContentPartAdded,
OpenAIResponseObjectStreamResponseContentPartDone,
OpenAIResponseObjectStreamResponseCreated,
OpenAIResponseObjectStreamResponseFunctionCallArgumentsDelta,
OpenAIResponseObjectStreamResponseFunctionCallArgumentsDone,
OpenAIResponseObjectStreamResponseMcpCallArgumentsDelta,
OpenAIResponseObjectStreamResponseMcpCallArgumentsDone,
OpenAIResponseObjectStreamResponseMcpCallCompleted,
OpenAIResponseObjectStreamResponseMcpCallFailed,
OpenAIResponseObjectStreamResponseMcpCallInProgress,
OpenAIResponseObjectStreamResponseOutputItemAdded,
OpenAIResponseObjectStreamResponseOutputItemDone,
OpenAIResponseObjectStreamResponseOutputTextDelta,
OpenAIResponseObjectStreamResponseWebSearchCallCompleted,
OpenAIResponseObjectStreamResponseWebSearchCallInProgress,
OpenAIResponseObjectStreamResponseWebSearchCallSearching,
OpenAIResponseOutput,
OpenAIResponseOutputMessageContent,
OpenAIResponseOutputMessageContentOutputText,
@ -87,6 +98,15 @@ logger = get_logger(name=__name__, category="openai_responses")
OPENAI_RESPONSES_PREFIX = "openai_responses:"
class ToolExecutionResult(BaseModel):
"""Result of streaming tool execution."""
stream_event: OpenAIResponseObjectStream | None = None
sequence_number: int
final_output_message: OpenAIResponseOutput | None = None
final_input_message: OpenAIMessageParam | None = None
async def _convert_response_content_to_chat_content(
content: (str | list[OpenAIResponseInputMessageContent] | list[OpenAIResponseOutputMessageContent]),
) -> str | list[OpenAIChatCompletionContentPartParam]:
@ -460,6 +480,8 @@ class OpenAIResponsesImpl:
message_item_id = f"msg_{uuid.uuid4()}"
# Track tool call items for streaming events
tool_call_item_ids: dict[int, str] = {}
# Track content parts for streaming events
content_part_emitted = False
async for chunk in completion_result:
chat_response_id = chunk.id
@ -468,6 +490,18 @@ class OpenAIResponsesImpl:
for chunk_choice in chunk.choices:
# Emit incremental text content as delta events
if chunk_choice.delta.content:
# Emit content_part.added event for first text chunk
if not content_part_emitted:
content_part_emitted = True
sequence_number += 1
yield OpenAIResponseObjectStreamResponseContentPartAdded(
response_id=response_id,
item_id=message_item_id,
part=OpenAIResponseContentPartOutputText(
text="", # Will be filled incrementally via text deltas
),
sequence_number=sequence_number,
)
sequence_number += 1
yield OpenAIResponseObjectStreamResponseOutputTextDelta(
content_index=0,
@ -514,16 +548,33 @@ class OpenAIResponsesImpl:
sequence_number=sequence_number,
)
# Stream function call arguments as they arrive
# Stream tool call arguments as they arrive (differentiate between MCP and function calls)
if tool_call.function and tool_call.function.arguments:
tool_call_item_id = tool_call_item_ids[tool_call.index]
sequence_number += 1
yield OpenAIResponseObjectStreamResponseFunctionCallArgumentsDelta(
delta=tool_call.function.arguments,
item_id=tool_call_item_id,
output_index=len(output_messages),
sequence_number=sequence_number,
# Check if this is an MCP tool call
is_mcp_tool = (
ctx.mcp_tool_to_server
and tool_call.function.name
and tool_call.function.name in ctx.mcp_tool_to_server
)
if is_mcp_tool:
# Emit MCP-specific argument delta event
yield OpenAIResponseObjectStreamResponseMcpCallArgumentsDelta(
delta=tool_call.function.arguments,
item_id=tool_call_item_id,
output_index=len(output_messages),
sequence_number=sequence_number,
)
else:
# Emit function call argument delta event
yield OpenAIResponseObjectStreamResponseFunctionCallArgumentsDelta(
delta=tool_call.function.arguments,
item_id=tool_call_item_id,
output_index=len(output_messages),
sequence_number=sequence_number,
)
# Accumulate arguments for final response (only for subsequent chunks)
if not is_new_tool_call:
@ -531,27 +582,55 @@ class OpenAIResponsesImpl:
response_tool_call.function.arguments or ""
) + tool_call.function.arguments
# Emit function_call_arguments.done events for completed tool calls
# Emit arguments.done events for completed tool calls (differentiate between MCP and function calls)
for tool_call_index in sorted(chat_response_tool_calls.keys()):
tool_call_item_id = tool_call_item_ids[tool_call_index]
final_arguments = chat_response_tool_calls[tool_call_index].function.arguments or ""
tool_call_name = chat_response_tool_calls[tool_call_index].function.name
# Check if this is an MCP tool call
is_mcp_tool = ctx.mcp_tool_to_server and tool_call_name and tool_call_name in ctx.mcp_tool_to_server
sequence_number += 1
yield OpenAIResponseObjectStreamResponseFunctionCallArgumentsDone(
arguments=final_arguments,
item_id=tool_call_item_id,
output_index=len(output_messages),
sequence_number=sequence_number,
)
if is_mcp_tool:
# Emit MCP-specific argument done event
yield OpenAIResponseObjectStreamResponseMcpCallArgumentsDone(
arguments=final_arguments,
item_id=tool_call_item_id,
output_index=len(output_messages),
sequence_number=sequence_number,
)
else:
# Emit function call argument done event
yield OpenAIResponseObjectStreamResponseFunctionCallArgumentsDone(
arguments=final_arguments,
item_id=tool_call_item_id,
output_index=len(output_messages),
sequence_number=sequence_number,
)
# Convert collected chunks to complete response
if chat_response_tool_calls:
tool_calls = [chat_response_tool_calls[i] for i in sorted(chat_response_tool_calls.keys())]
# when there are tool calls, we need to clear the content
chat_response_content = []
else:
tool_calls = None
# Emit content_part.done event if text content was streamed (before content gets cleared)
if content_part_emitted:
final_text = "".join(chat_response_content)
sequence_number += 1
yield OpenAIResponseObjectStreamResponseContentPartDone(
response_id=response_id,
item_id=message_item_id,
part=OpenAIResponseContentPartOutputText(
text=final_text,
),
sequence_number=sequence_number,
)
# Clear content when there are tool calls (OpenAI spec behavior)
if chat_response_tool_calls:
chat_response_content = []
assistant_message = OpenAIAssistantMessageParam(
content="".join(chat_response_content),
tool_calls=tool_calls,
@ -587,19 +666,38 @@ class OpenAIResponsesImpl:
# execute non-function tool calls
for tool_call in non_function_tool_calls:
tool_call_log, tool_response_message = await self._execute_tool_call(tool_call, ctx)
# Find the item_id for this tool call
matching_item_id = None
for index, item_id in tool_call_item_ids.items():
response_tool_call = chat_response_tool_calls.get(index)
if response_tool_call and response_tool_call.id == tool_call.id:
matching_item_id = item_id
break
# Use a fallback item_id if not found
if not matching_item_id:
matching_item_id = f"tc_{uuid.uuid4()}"
# Execute tool call with streaming
tool_call_log = None
tool_response_message = None
async for result in self._execute_tool_call(
tool_call, ctx, sequence_number, response_id, len(output_messages), matching_item_id
):
if result.stream_event:
# Forward streaming events
sequence_number = result.sequence_number
yield result.stream_event
if result.final_output_message is not None:
tool_call_log = result.final_output_message
tool_response_message = result.final_input_message
sequence_number = result.sequence_number
if tool_call_log:
output_messages.append(tool_call_log)
# Emit output_item.done event for completed non-function tool call
# Find the item_id for this tool call
matching_item_id = None
for index, item_id in tool_call_item_ids.items():
response_tool_call = chat_response_tool_calls.get(index)
if response_tool_call and response_tool_call.id == tool_call.id:
matching_item_id = item_id
break
if matching_item_id:
sequence_number += 1
yield OpenAIResponseObjectStreamResponseOutputItemDone(
@ -848,7 +946,11 @@ class OpenAIResponsesImpl:
self,
tool_call: OpenAIChatCompletionToolCall,
ctx: ChatCompletionContext,
) -> tuple[OpenAIResponseOutput | None, OpenAIMessageParam | None]:
sequence_number: int,
response_id: str,
output_index: int,
item_id: str,
) -> AsyncIterator[ToolExecutionResult]:
from llama_stack.providers.utils.inference.prompt_adapter import (
interleaved_content_as_str,
)
@ -858,8 +960,41 @@ class OpenAIResponsesImpl:
tool_kwargs = json.loads(function.arguments) if function.arguments else {}
if not function or not tool_call_id or not function.name:
return None, None
yield ToolExecutionResult(sequence_number=sequence_number)
return
# Emit in_progress event based on tool type (only for tools with specific streaming events)
progress_event = None
if ctx.mcp_tool_to_server and function.name in ctx.mcp_tool_to_server:
sequence_number += 1
progress_event = OpenAIResponseObjectStreamResponseMcpCallInProgress(
item_id=item_id,
output_index=output_index,
sequence_number=sequence_number,
)
elif function.name == "web_search":
sequence_number += 1
progress_event = OpenAIResponseObjectStreamResponseWebSearchCallInProgress(
item_id=item_id,
output_index=output_index,
sequence_number=sequence_number,
)
# Note: knowledge_search and other custom tools don't have specific streaming events in OpenAI spec
if progress_event:
yield ToolExecutionResult(stream_event=progress_event, sequence_number=sequence_number)
# For web search, emit searching event
if function.name == "web_search":
sequence_number += 1
searching_event = OpenAIResponseObjectStreamResponseWebSearchCallSearching(
item_id=item_id,
output_index=output_index,
sequence_number=sequence_number,
)
yield ToolExecutionResult(stream_event=searching_event, sequence_number=sequence_number)
# Execute the actual tool call
error_exc = None
result = None
try:
@ -894,6 +1029,33 @@ class OpenAIResponsesImpl:
except Exception as e:
error_exc = e
# Emit completion or failure event based on result (only for tools with specific streaming events)
has_error = error_exc or (result and ((result.error_code and result.error_code > 0) or result.error_message))
completion_event = None
if ctx.mcp_tool_to_server and function.name in ctx.mcp_tool_to_server:
sequence_number += 1
if has_error:
completion_event = OpenAIResponseObjectStreamResponseMcpCallFailed(
sequence_number=sequence_number,
)
else:
completion_event = OpenAIResponseObjectStreamResponseMcpCallCompleted(
sequence_number=sequence_number,
)
elif function.name == "web_search":
sequence_number += 1
completion_event = OpenAIResponseObjectStreamResponseWebSearchCallCompleted(
item_id=item_id,
output_index=output_index,
sequence_number=sequence_number,
)
# Note: knowledge_search and other custom tools don't have specific completion events in OpenAI spec
if completion_event:
yield ToolExecutionResult(stream_event=completion_event, sequence_number=sequence_number)
# Build the result message and input message
if function.name in ctx.mcp_tool_to_server:
from llama_stack.apis.agents.openai_responses import (
OpenAIResponseOutputMessageMCPCall,
@ -907,9 +1069,9 @@ class OpenAIResponsesImpl:
)
if error_exc:
message.error = str(error_exc)
elif (result.error_code and result.error_code > 0) or result.error_message:
elif (result and result.error_code and result.error_code > 0) or (result and result.error_message):
message.error = f"Error (code {result.error_code}): {result.error_message}"
elif result.content:
elif result and result.content:
message.output = interleaved_content_as_str(result.content)
else:
if function.name == "web_search":
@ -917,7 +1079,7 @@ class OpenAIResponsesImpl:
id=tool_call_id,
status="completed",
)
if error_exc or (result.error_code and result.error_code > 0) or result.error_message:
if has_error:
message.status = "failed"
elif function.name == "knowledge_search":
message = OpenAIResponseOutputMessageFileSearchToolCall(
@ -925,7 +1087,7 @@ class OpenAIResponsesImpl:
queries=[tool_kwargs.get("query", "")],
status="completed",
)
if "document_ids" in result.metadata:
if result and "document_ids" in result.metadata:
message.results = []
for i, doc_id in enumerate(result.metadata["document_ids"]):
text = result.metadata["chunks"][i] if "chunks" in result.metadata else None
@ -939,7 +1101,7 @@ class OpenAIResponsesImpl:
attributes={},
)
)
if error_exc or (result.error_code and result.error_code > 0) or result.error_message:
if has_error:
message.status = "failed"
else:
raise ValueError(f"Unknown tool {function.name} called")
@ -971,10 +1133,13 @@ class OpenAIResponsesImpl:
raise ValueError(f"Unknown result content type: {type(result.content)}")
input_message = OpenAIToolMessageParam(content=content, tool_call_id=tool_call_id)
else:
text = str(error_exc)
text = str(error_exc) if error_exc else "Tool execution failed"
input_message = OpenAIToolMessageParam(content=text, tool_call_id=tool_call_id)
return message, input_message
# Yield the final result
yield ToolExecutionResult(
sequence_number=sequence_number, final_output_message=message, final_input_message=input_message
)
def _is_function_tool_call(