mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-12-18 11:39:47 +00:00
Merge branch 'main' into milvus/search-modes
This commit is contained in:
commit
2d0d13b826
7 changed files with 621 additions and 56 deletions
|
|
@ -623,6 +623,62 @@ class OpenAIResponseObjectStreamResponseMcpCallCompleted(BaseModel):
|
|||
type: Literal["response.mcp_call.completed"] = "response.mcp_call.completed"
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class OpenAIResponseContentPartOutputText(BaseModel):
|
||||
type: Literal["output_text"] = "output_text"
|
||||
text: str
|
||||
# TODO: add annotations, logprobs, etc.
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class OpenAIResponseContentPartRefusal(BaseModel):
|
||||
type: Literal["refusal"] = "refusal"
|
||||
refusal: str
|
||||
|
||||
|
||||
OpenAIResponseContentPart = Annotated[
|
||||
OpenAIResponseContentPartOutputText | OpenAIResponseContentPartRefusal,
|
||||
Field(discriminator="type"),
|
||||
]
|
||||
register_schema(OpenAIResponseContentPart, name="OpenAIResponseContentPart")
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class OpenAIResponseObjectStreamResponseContentPartAdded(BaseModel):
|
||||
"""Streaming event for when a new content part is added to a response item.
|
||||
|
||||
:param response_id: Unique identifier of the response containing this content
|
||||
:param item_id: Unique identifier of the output item containing this content part
|
||||
:param part: The content part that was added
|
||||
:param sequence_number: Sequential number for ordering streaming events
|
||||
:param type: Event type identifier, always "response.content_part.added"
|
||||
"""
|
||||
|
||||
response_id: str
|
||||
item_id: str
|
||||
part: OpenAIResponseContentPart
|
||||
sequence_number: int
|
||||
type: Literal["response.content_part.added"] = "response.content_part.added"
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class OpenAIResponseObjectStreamResponseContentPartDone(BaseModel):
|
||||
"""Streaming event for when a content part is completed.
|
||||
|
||||
:param response_id: Unique identifier of the response containing this content
|
||||
:param item_id: Unique identifier of the output item containing this content part
|
||||
:param part: The completed content part
|
||||
:param sequence_number: Sequential number for ordering streaming events
|
||||
:param type: Event type identifier, always "response.content_part.done"
|
||||
"""
|
||||
|
||||
response_id: str
|
||||
item_id: str
|
||||
part: OpenAIResponseContentPart
|
||||
sequence_number: int
|
||||
type: Literal["response.content_part.done"] = "response.content_part.done"
|
||||
|
||||
|
||||
OpenAIResponseObjectStream = Annotated[
|
||||
OpenAIResponseObjectStreamResponseCreated
|
||||
| OpenAIResponseObjectStreamResponseOutputItemAdded
|
||||
|
|
@ -642,6 +698,8 @@ OpenAIResponseObjectStream = Annotated[
|
|||
| OpenAIResponseObjectStreamResponseMcpCallInProgress
|
||||
| OpenAIResponseObjectStreamResponseMcpCallFailed
|
||||
| OpenAIResponseObjectStreamResponseMcpCallCompleted
|
||||
| OpenAIResponseObjectStreamResponseContentPartAdded
|
||||
| OpenAIResponseObjectStreamResponseContentPartDone
|
||||
| OpenAIResponseObjectStreamResponseCompleted,
|
||||
Field(discriminator="type"),
|
||||
]
|
||||
|
|
|
|||
|
|
@ -20,6 +20,7 @@ from llama_stack.apis.agents.openai_responses import (
|
|||
ListOpenAIResponseInputItem,
|
||||
ListOpenAIResponseObject,
|
||||
OpenAIDeleteResponseObject,
|
||||
OpenAIResponseContentPartOutputText,
|
||||
OpenAIResponseInput,
|
||||
OpenAIResponseInputFunctionToolCallOutput,
|
||||
OpenAIResponseInputMessageContent,
|
||||
|
|
@ -32,12 +33,22 @@ from llama_stack.apis.agents.openai_responses import (
|
|||
OpenAIResponseObject,
|
||||
OpenAIResponseObjectStream,
|
||||
OpenAIResponseObjectStreamResponseCompleted,
|
||||
OpenAIResponseObjectStreamResponseContentPartAdded,
|
||||
OpenAIResponseObjectStreamResponseContentPartDone,
|
||||
OpenAIResponseObjectStreamResponseCreated,
|
||||
OpenAIResponseObjectStreamResponseFunctionCallArgumentsDelta,
|
||||
OpenAIResponseObjectStreamResponseFunctionCallArgumentsDone,
|
||||
OpenAIResponseObjectStreamResponseMcpCallArgumentsDelta,
|
||||
OpenAIResponseObjectStreamResponseMcpCallArgumentsDone,
|
||||
OpenAIResponseObjectStreamResponseMcpCallCompleted,
|
||||
OpenAIResponseObjectStreamResponseMcpCallFailed,
|
||||
OpenAIResponseObjectStreamResponseMcpCallInProgress,
|
||||
OpenAIResponseObjectStreamResponseOutputItemAdded,
|
||||
OpenAIResponseObjectStreamResponseOutputItemDone,
|
||||
OpenAIResponseObjectStreamResponseOutputTextDelta,
|
||||
OpenAIResponseObjectStreamResponseWebSearchCallCompleted,
|
||||
OpenAIResponseObjectStreamResponseWebSearchCallInProgress,
|
||||
OpenAIResponseObjectStreamResponseWebSearchCallSearching,
|
||||
OpenAIResponseOutput,
|
||||
OpenAIResponseOutputMessageContent,
|
||||
OpenAIResponseOutputMessageContentOutputText,
|
||||
|
|
@ -87,6 +98,15 @@ logger = get_logger(name=__name__, category="openai_responses")
|
|||
OPENAI_RESPONSES_PREFIX = "openai_responses:"
|
||||
|
||||
|
||||
class ToolExecutionResult(BaseModel):
|
||||
"""Result of streaming tool execution."""
|
||||
|
||||
stream_event: OpenAIResponseObjectStream | None = None
|
||||
sequence_number: int
|
||||
final_output_message: OpenAIResponseOutput | None = None
|
||||
final_input_message: OpenAIMessageParam | None = None
|
||||
|
||||
|
||||
async def _convert_response_content_to_chat_content(
|
||||
content: (str | list[OpenAIResponseInputMessageContent] | list[OpenAIResponseOutputMessageContent]),
|
||||
) -> str | list[OpenAIChatCompletionContentPartParam]:
|
||||
|
|
@ -460,6 +480,8 @@ class OpenAIResponsesImpl:
|
|||
message_item_id = f"msg_{uuid.uuid4()}"
|
||||
# Track tool call items for streaming events
|
||||
tool_call_item_ids: dict[int, str] = {}
|
||||
# Track content parts for streaming events
|
||||
content_part_emitted = False
|
||||
|
||||
async for chunk in completion_result:
|
||||
chat_response_id = chunk.id
|
||||
|
|
@ -468,6 +490,18 @@ class OpenAIResponsesImpl:
|
|||
for chunk_choice in chunk.choices:
|
||||
# Emit incremental text content as delta events
|
||||
if chunk_choice.delta.content:
|
||||
# Emit content_part.added event for first text chunk
|
||||
if not content_part_emitted:
|
||||
content_part_emitted = True
|
||||
sequence_number += 1
|
||||
yield OpenAIResponseObjectStreamResponseContentPartAdded(
|
||||
response_id=response_id,
|
||||
item_id=message_item_id,
|
||||
part=OpenAIResponseContentPartOutputText(
|
||||
text="", # Will be filled incrementally via text deltas
|
||||
),
|
||||
sequence_number=sequence_number,
|
||||
)
|
||||
sequence_number += 1
|
||||
yield OpenAIResponseObjectStreamResponseOutputTextDelta(
|
||||
content_index=0,
|
||||
|
|
@ -514,16 +548,33 @@ class OpenAIResponsesImpl:
|
|||
sequence_number=sequence_number,
|
||||
)
|
||||
|
||||
# Stream function call arguments as they arrive
|
||||
# Stream tool call arguments as they arrive (differentiate between MCP and function calls)
|
||||
if tool_call.function and tool_call.function.arguments:
|
||||
tool_call_item_id = tool_call_item_ids[tool_call.index]
|
||||
sequence_number += 1
|
||||
yield OpenAIResponseObjectStreamResponseFunctionCallArgumentsDelta(
|
||||
delta=tool_call.function.arguments,
|
||||
item_id=tool_call_item_id,
|
||||
output_index=len(output_messages),
|
||||
sequence_number=sequence_number,
|
||||
|
||||
# Check if this is an MCP tool call
|
||||
is_mcp_tool = (
|
||||
ctx.mcp_tool_to_server
|
||||
and tool_call.function.name
|
||||
and tool_call.function.name in ctx.mcp_tool_to_server
|
||||
)
|
||||
if is_mcp_tool:
|
||||
# Emit MCP-specific argument delta event
|
||||
yield OpenAIResponseObjectStreamResponseMcpCallArgumentsDelta(
|
||||
delta=tool_call.function.arguments,
|
||||
item_id=tool_call_item_id,
|
||||
output_index=len(output_messages),
|
||||
sequence_number=sequence_number,
|
||||
)
|
||||
else:
|
||||
# Emit function call argument delta event
|
||||
yield OpenAIResponseObjectStreamResponseFunctionCallArgumentsDelta(
|
||||
delta=tool_call.function.arguments,
|
||||
item_id=tool_call_item_id,
|
||||
output_index=len(output_messages),
|
||||
sequence_number=sequence_number,
|
||||
)
|
||||
|
||||
# Accumulate arguments for final response (only for subsequent chunks)
|
||||
if not is_new_tool_call:
|
||||
|
|
@ -531,27 +582,55 @@ class OpenAIResponsesImpl:
|
|||
response_tool_call.function.arguments or ""
|
||||
) + tool_call.function.arguments
|
||||
|
||||
# Emit function_call_arguments.done events for completed tool calls
|
||||
# Emit arguments.done events for completed tool calls (differentiate between MCP and function calls)
|
||||
for tool_call_index in sorted(chat_response_tool_calls.keys()):
|
||||
tool_call_item_id = tool_call_item_ids[tool_call_index]
|
||||
final_arguments = chat_response_tool_calls[tool_call_index].function.arguments or ""
|
||||
tool_call_name = chat_response_tool_calls[tool_call_index].function.name
|
||||
|
||||
# Check if this is an MCP tool call
|
||||
is_mcp_tool = ctx.mcp_tool_to_server and tool_call_name and tool_call_name in ctx.mcp_tool_to_server
|
||||
sequence_number += 1
|
||||
yield OpenAIResponseObjectStreamResponseFunctionCallArgumentsDone(
|
||||
arguments=final_arguments,
|
||||
item_id=tool_call_item_id,
|
||||
output_index=len(output_messages),
|
||||
sequence_number=sequence_number,
|
||||
)
|
||||
if is_mcp_tool:
|
||||
# Emit MCP-specific argument done event
|
||||
yield OpenAIResponseObjectStreamResponseMcpCallArgumentsDone(
|
||||
arguments=final_arguments,
|
||||
item_id=tool_call_item_id,
|
||||
output_index=len(output_messages),
|
||||
sequence_number=sequence_number,
|
||||
)
|
||||
else:
|
||||
# Emit function call argument done event
|
||||
yield OpenAIResponseObjectStreamResponseFunctionCallArgumentsDone(
|
||||
arguments=final_arguments,
|
||||
item_id=tool_call_item_id,
|
||||
output_index=len(output_messages),
|
||||
sequence_number=sequence_number,
|
||||
)
|
||||
|
||||
# Convert collected chunks to complete response
|
||||
if chat_response_tool_calls:
|
||||
tool_calls = [chat_response_tool_calls[i] for i in sorted(chat_response_tool_calls.keys())]
|
||||
|
||||
# when there are tool calls, we need to clear the content
|
||||
chat_response_content = []
|
||||
else:
|
||||
tool_calls = None
|
||||
|
||||
# Emit content_part.done event if text content was streamed (before content gets cleared)
|
||||
if content_part_emitted:
|
||||
final_text = "".join(chat_response_content)
|
||||
sequence_number += 1
|
||||
yield OpenAIResponseObjectStreamResponseContentPartDone(
|
||||
response_id=response_id,
|
||||
item_id=message_item_id,
|
||||
part=OpenAIResponseContentPartOutputText(
|
||||
text=final_text,
|
||||
),
|
||||
sequence_number=sequence_number,
|
||||
)
|
||||
|
||||
# Clear content when there are tool calls (OpenAI spec behavior)
|
||||
if chat_response_tool_calls:
|
||||
chat_response_content = []
|
||||
|
||||
assistant_message = OpenAIAssistantMessageParam(
|
||||
content="".join(chat_response_content),
|
||||
tool_calls=tool_calls,
|
||||
|
|
@ -587,19 +666,38 @@ class OpenAIResponsesImpl:
|
|||
|
||||
# execute non-function tool calls
|
||||
for tool_call in non_function_tool_calls:
|
||||
tool_call_log, tool_response_message = await self._execute_tool_call(tool_call, ctx)
|
||||
# Find the item_id for this tool call
|
||||
matching_item_id = None
|
||||
for index, item_id in tool_call_item_ids.items():
|
||||
response_tool_call = chat_response_tool_calls.get(index)
|
||||
if response_tool_call and response_tool_call.id == tool_call.id:
|
||||
matching_item_id = item_id
|
||||
break
|
||||
|
||||
# Use a fallback item_id if not found
|
||||
if not matching_item_id:
|
||||
matching_item_id = f"tc_{uuid.uuid4()}"
|
||||
|
||||
# Execute tool call with streaming
|
||||
tool_call_log = None
|
||||
tool_response_message = None
|
||||
async for result in self._execute_tool_call(
|
||||
tool_call, ctx, sequence_number, response_id, len(output_messages), matching_item_id
|
||||
):
|
||||
if result.stream_event:
|
||||
# Forward streaming events
|
||||
sequence_number = result.sequence_number
|
||||
yield result.stream_event
|
||||
|
||||
if result.final_output_message is not None:
|
||||
tool_call_log = result.final_output_message
|
||||
tool_response_message = result.final_input_message
|
||||
sequence_number = result.sequence_number
|
||||
|
||||
if tool_call_log:
|
||||
output_messages.append(tool_call_log)
|
||||
|
||||
# Emit output_item.done event for completed non-function tool call
|
||||
# Find the item_id for this tool call
|
||||
matching_item_id = None
|
||||
for index, item_id in tool_call_item_ids.items():
|
||||
response_tool_call = chat_response_tool_calls.get(index)
|
||||
if response_tool_call and response_tool_call.id == tool_call.id:
|
||||
matching_item_id = item_id
|
||||
break
|
||||
|
||||
if matching_item_id:
|
||||
sequence_number += 1
|
||||
yield OpenAIResponseObjectStreamResponseOutputItemDone(
|
||||
|
|
@ -848,7 +946,11 @@ class OpenAIResponsesImpl:
|
|||
self,
|
||||
tool_call: OpenAIChatCompletionToolCall,
|
||||
ctx: ChatCompletionContext,
|
||||
) -> tuple[OpenAIResponseOutput | None, OpenAIMessageParam | None]:
|
||||
sequence_number: int,
|
||||
response_id: str,
|
||||
output_index: int,
|
||||
item_id: str,
|
||||
) -> AsyncIterator[ToolExecutionResult]:
|
||||
from llama_stack.providers.utils.inference.prompt_adapter import (
|
||||
interleaved_content_as_str,
|
||||
)
|
||||
|
|
@ -858,8 +960,41 @@ class OpenAIResponsesImpl:
|
|||
tool_kwargs = json.loads(function.arguments) if function.arguments else {}
|
||||
|
||||
if not function or not tool_call_id or not function.name:
|
||||
return None, None
|
||||
yield ToolExecutionResult(sequence_number=sequence_number)
|
||||
return
|
||||
|
||||
# Emit in_progress event based on tool type (only for tools with specific streaming events)
|
||||
progress_event = None
|
||||
if ctx.mcp_tool_to_server and function.name in ctx.mcp_tool_to_server:
|
||||
sequence_number += 1
|
||||
progress_event = OpenAIResponseObjectStreamResponseMcpCallInProgress(
|
||||
item_id=item_id,
|
||||
output_index=output_index,
|
||||
sequence_number=sequence_number,
|
||||
)
|
||||
elif function.name == "web_search":
|
||||
sequence_number += 1
|
||||
progress_event = OpenAIResponseObjectStreamResponseWebSearchCallInProgress(
|
||||
item_id=item_id,
|
||||
output_index=output_index,
|
||||
sequence_number=sequence_number,
|
||||
)
|
||||
# Note: knowledge_search and other custom tools don't have specific streaming events in OpenAI spec
|
||||
|
||||
if progress_event:
|
||||
yield ToolExecutionResult(stream_event=progress_event, sequence_number=sequence_number)
|
||||
|
||||
# For web search, emit searching event
|
||||
if function.name == "web_search":
|
||||
sequence_number += 1
|
||||
searching_event = OpenAIResponseObjectStreamResponseWebSearchCallSearching(
|
||||
item_id=item_id,
|
||||
output_index=output_index,
|
||||
sequence_number=sequence_number,
|
||||
)
|
||||
yield ToolExecutionResult(stream_event=searching_event, sequence_number=sequence_number)
|
||||
|
||||
# Execute the actual tool call
|
||||
error_exc = None
|
||||
result = None
|
||||
try:
|
||||
|
|
@ -894,6 +1029,33 @@ class OpenAIResponsesImpl:
|
|||
except Exception as e:
|
||||
error_exc = e
|
||||
|
||||
# Emit completion or failure event based on result (only for tools with specific streaming events)
|
||||
has_error = error_exc or (result and ((result.error_code and result.error_code > 0) or result.error_message))
|
||||
completion_event = None
|
||||
|
||||
if ctx.mcp_tool_to_server and function.name in ctx.mcp_tool_to_server:
|
||||
sequence_number += 1
|
||||
if has_error:
|
||||
completion_event = OpenAIResponseObjectStreamResponseMcpCallFailed(
|
||||
sequence_number=sequence_number,
|
||||
)
|
||||
else:
|
||||
completion_event = OpenAIResponseObjectStreamResponseMcpCallCompleted(
|
||||
sequence_number=sequence_number,
|
||||
)
|
||||
elif function.name == "web_search":
|
||||
sequence_number += 1
|
||||
completion_event = OpenAIResponseObjectStreamResponseWebSearchCallCompleted(
|
||||
item_id=item_id,
|
||||
output_index=output_index,
|
||||
sequence_number=sequence_number,
|
||||
)
|
||||
# Note: knowledge_search and other custom tools don't have specific completion events in OpenAI spec
|
||||
|
||||
if completion_event:
|
||||
yield ToolExecutionResult(stream_event=completion_event, sequence_number=sequence_number)
|
||||
|
||||
# Build the result message and input message
|
||||
if function.name in ctx.mcp_tool_to_server:
|
||||
from llama_stack.apis.agents.openai_responses import (
|
||||
OpenAIResponseOutputMessageMCPCall,
|
||||
|
|
@ -907,9 +1069,9 @@ class OpenAIResponsesImpl:
|
|||
)
|
||||
if error_exc:
|
||||
message.error = str(error_exc)
|
||||
elif (result.error_code and result.error_code > 0) or result.error_message:
|
||||
elif (result and result.error_code and result.error_code > 0) or (result and result.error_message):
|
||||
message.error = f"Error (code {result.error_code}): {result.error_message}"
|
||||
elif result.content:
|
||||
elif result and result.content:
|
||||
message.output = interleaved_content_as_str(result.content)
|
||||
else:
|
||||
if function.name == "web_search":
|
||||
|
|
@ -917,7 +1079,7 @@ class OpenAIResponsesImpl:
|
|||
id=tool_call_id,
|
||||
status="completed",
|
||||
)
|
||||
if error_exc or (result.error_code and result.error_code > 0) or result.error_message:
|
||||
if has_error:
|
||||
message.status = "failed"
|
||||
elif function.name == "knowledge_search":
|
||||
message = OpenAIResponseOutputMessageFileSearchToolCall(
|
||||
|
|
@ -925,7 +1087,7 @@ class OpenAIResponsesImpl:
|
|||
queries=[tool_kwargs.get("query", "")],
|
||||
status="completed",
|
||||
)
|
||||
if "document_ids" in result.metadata:
|
||||
if result and "document_ids" in result.metadata:
|
||||
message.results = []
|
||||
for i, doc_id in enumerate(result.metadata["document_ids"]):
|
||||
text = result.metadata["chunks"][i] if "chunks" in result.metadata else None
|
||||
|
|
@ -939,7 +1101,7 @@ class OpenAIResponsesImpl:
|
|||
attributes={},
|
||||
)
|
||||
)
|
||||
if error_exc or (result.error_code and result.error_code > 0) or result.error_message:
|
||||
if has_error:
|
||||
message.status = "failed"
|
||||
else:
|
||||
raise ValueError(f"Unknown tool {function.name} called")
|
||||
|
|
@ -971,10 +1133,13 @@ class OpenAIResponsesImpl:
|
|||
raise ValueError(f"Unknown result content type: {type(result.content)}")
|
||||
input_message = OpenAIToolMessageParam(content=content, tool_call_id=tool_call_id)
|
||||
else:
|
||||
text = str(error_exc)
|
||||
text = str(error_exc) if error_exc else "Tool execution failed"
|
||||
input_message = OpenAIToolMessageParam(content=text, tool_call_id=tool_call_id)
|
||||
|
||||
return message, input_message
|
||||
# Yield the final result
|
||||
yield ToolExecutionResult(
|
||||
sequence_number=sequence_number, final_output_message=message, final_input_message=input_message
|
||||
)
|
||||
|
||||
|
||||
def _is_function_tool_call(
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue