mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-08-15 14:08:00 +00:00
feat(responses): stream progress of tool calls
This commit is contained in:
parent
5b312a80b9
commit
8159a9d757
2 changed files with 141 additions and 18 deletions
|
@ -35,9 +35,15 @@ from llama_stack.apis.agents.openai_responses import (
|
|||
OpenAIResponseObjectStreamResponseCreated,
|
||||
OpenAIResponseObjectStreamResponseFunctionCallArgumentsDelta,
|
||||
OpenAIResponseObjectStreamResponseFunctionCallArgumentsDone,
|
||||
OpenAIResponseObjectStreamResponseMcpCallCompleted,
|
||||
OpenAIResponseObjectStreamResponseMcpCallFailed,
|
||||
OpenAIResponseObjectStreamResponseMcpCallInProgress,
|
||||
OpenAIResponseObjectStreamResponseOutputItemAdded,
|
||||
OpenAIResponseObjectStreamResponseOutputItemDone,
|
||||
OpenAIResponseObjectStreamResponseOutputTextDelta,
|
||||
OpenAIResponseObjectStreamResponseWebSearchCallCompleted,
|
||||
OpenAIResponseObjectStreamResponseWebSearchCallInProgress,
|
||||
OpenAIResponseObjectStreamResponseWebSearchCallSearching,
|
||||
OpenAIResponseOutput,
|
||||
OpenAIResponseOutputMessageContent,
|
||||
OpenAIResponseOutputMessageContentOutputText,
|
||||
|
@ -87,6 +93,15 @@ logger = get_logger(name=__name__, category="openai_responses")
|
|||
OPENAI_RESPONSES_PREFIX = "openai_responses:"
|
||||
|
||||
|
||||
class ToolExecutionResult(BaseModel):
|
||||
"""Result of streaming tool execution."""
|
||||
|
||||
stream_event: OpenAIResponseObjectStream | None = None
|
||||
sequence_number: int
|
||||
final_output_message: OpenAIResponseOutput | None = None
|
||||
final_input_message: OpenAIMessageParam | None = None
|
||||
|
||||
|
||||
async def _convert_response_content_to_chat_content(
|
||||
content: (str | list[OpenAIResponseInputMessageContent] | list[OpenAIResponseOutputMessageContent]),
|
||||
) -> str | list[OpenAIChatCompletionContentPartParam]:
|
||||
|
@ -587,19 +602,38 @@ class OpenAIResponsesImpl:
|
|||
|
||||
# execute non-function tool calls
|
||||
for tool_call in non_function_tool_calls:
|
||||
tool_call_log, tool_response_message = await self._execute_tool_call(tool_call, ctx)
|
||||
# Find the item_id for this tool call
|
||||
matching_item_id = None
|
||||
for index, item_id in tool_call_item_ids.items():
|
||||
response_tool_call = chat_response_tool_calls.get(index)
|
||||
if response_tool_call and response_tool_call.id == tool_call.id:
|
||||
matching_item_id = item_id
|
||||
break
|
||||
|
||||
# Use a fallback item_id if not found
|
||||
if not matching_item_id:
|
||||
matching_item_id = f"tc_{uuid.uuid4()}"
|
||||
|
||||
# Execute tool call with streaming
|
||||
tool_call_log = None
|
||||
tool_response_message = None
|
||||
async for result in self._execute_tool_call(
|
||||
tool_call, ctx, sequence_number, response_id, len(output_messages), matching_item_id
|
||||
):
|
||||
if result.stream_event:
|
||||
# Forward streaming events
|
||||
sequence_number = result.sequence_number
|
||||
yield result.stream_event
|
||||
|
||||
if result.final_output_message is not None:
|
||||
tool_call_log = result.final_output_message
|
||||
tool_response_message = result.final_input_message
|
||||
sequence_number = result.sequence_number
|
||||
|
||||
if tool_call_log:
|
||||
output_messages.append(tool_call_log)
|
||||
|
||||
# Emit output_item.done event for completed non-function tool call
|
||||
# Find the item_id for this tool call
|
||||
matching_item_id = None
|
||||
for index, item_id in tool_call_item_ids.items():
|
||||
response_tool_call = chat_response_tool_calls.get(index)
|
||||
if response_tool_call and response_tool_call.id == tool_call.id:
|
||||
matching_item_id = item_id
|
||||
break
|
||||
|
||||
if matching_item_id:
|
||||
sequence_number += 1
|
||||
yield OpenAIResponseObjectStreamResponseOutputItemDone(
|
||||
|
@ -848,7 +882,11 @@ class OpenAIResponsesImpl:
|
|||
self,
|
||||
tool_call: OpenAIChatCompletionToolCall,
|
||||
ctx: ChatCompletionContext,
|
||||
) -> tuple[OpenAIResponseOutput | None, OpenAIMessageParam | None]:
|
||||
sequence_number: int,
|
||||
response_id: str,
|
||||
output_index: int,
|
||||
item_id: str,
|
||||
) -> AsyncIterator[ToolExecutionResult]:
|
||||
from llama_stack.providers.utils.inference.prompt_adapter import (
|
||||
interleaved_content_as_str,
|
||||
)
|
||||
|
@ -858,8 +896,41 @@ class OpenAIResponsesImpl:
|
|||
tool_kwargs = json.loads(function.arguments) if function.arguments else {}
|
||||
|
||||
if not function or not tool_call_id or not function.name:
|
||||
return None, None
|
||||
yield ToolExecutionResult(sequence_number=sequence_number)
|
||||
return
|
||||
|
||||
# Emit in_progress event based on tool type (only for tools with specific streaming events)
|
||||
progress_event = None
|
||||
if ctx.mcp_tool_to_server and function.name in ctx.mcp_tool_to_server:
|
||||
sequence_number += 1
|
||||
progress_event = OpenAIResponseObjectStreamResponseMcpCallInProgress(
|
||||
item_id=item_id,
|
||||
output_index=output_index,
|
||||
sequence_number=sequence_number,
|
||||
)
|
||||
elif function.name == "web_search":
|
||||
sequence_number += 1
|
||||
progress_event = OpenAIResponseObjectStreamResponseWebSearchCallInProgress(
|
||||
item_id=item_id,
|
||||
output_index=output_index,
|
||||
sequence_number=sequence_number,
|
||||
)
|
||||
# Note: knowledge_search and other custom tools don't have specific streaming events in OpenAI spec
|
||||
|
||||
if progress_event:
|
||||
yield ToolExecutionResult(stream_event=progress_event, sequence_number=sequence_number)
|
||||
|
||||
# For web search, emit searching event
|
||||
if function.name == "web_search":
|
||||
sequence_number += 1
|
||||
searching_event = OpenAIResponseObjectStreamResponseWebSearchCallSearching(
|
||||
item_id=item_id,
|
||||
output_index=output_index,
|
||||
sequence_number=sequence_number,
|
||||
)
|
||||
yield ToolExecutionResult(stream_event=searching_event, sequence_number=sequence_number)
|
||||
|
||||
# Execute the actual tool call
|
||||
error_exc = None
|
||||
result = None
|
||||
try:
|
||||
|
@ -894,6 +965,33 @@ class OpenAIResponsesImpl:
|
|||
except Exception as e:
|
||||
error_exc = e
|
||||
|
||||
# Emit completion or failure event based on result (only for tools with specific streaming events)
|
||||
has_error = error_exc or (result and ((result.error_code and result.error_code > 0) or result.error_message))
|
||||
completion_event = None
|
||||
|
||||
if ctx.mcp_tool_to_server and function.name in ctx.mcp_tool_to_server:
|
||||
sequence_number += 1
|
||||
if has_error:
|
||||
completion_event = OpenAIResponseObjectStreamResponseMcpCallFailed(
|
||||
sequence_number=sequence_number,
|
||||
)
|
||||
else:
|
||||
completion_event = OpenAIResponseObjectStreamResponseMcpCallCompleted(
|
||||
sequence_number=sequence_number,
|
||||
)
|
||||
elif function.name == "web_search":
|
||||
sequence_number += 1
|
||||
completion_event = OpenAIResponseObjectStreamResponseWebSearchCallCompleted(
|
||||
item_id=item_id,
|
||||
output_index=output_index,
|
||||
sequence_number=sequence_number,
|
||||
)
|
||||
# Note: knowledge_search and other custom tools don't have specific completion events in OpenAI spec
|
||||
|
||||
if completion_event:
|
||||
yield ToolExecutionResult(stream_event=completion_event, sequence_number=sequence_number)
|
||||
|
||||
# Build the result message and input message
|
||||
if function.name in ctx.mcp_tool_to_server:
|
||||
from llama_stack.apis.agents.openai_responses import (
|
||||
OpenAIResponseOutputMessageMCPCall,
|
||||
|
@ -907,9 +1005,9 @@ class OpenAIResponsesImpl:
|
|||
)
|
||||
if error_exc:
|
||||
message.error = str(error_exc)
|
||||
elif (result.error_code and result.error_code > 0) or result.error_message:
|
||||
elif (result and result.error_code and result.error_code > 0) or (result and result.error_message):
|
||||
message.error = f"Error (code {result.error_code}): {result.error_message}"
|
||||
elif result.content:
|
||||
elif result and result.content:
|
||||
message.output = interleaved_content_as_str(result.content)
|
||||
else:
|
||||
if function.name == "web_search":
|
||||
|
@ -917,7 +1015,7 @@ class OpenAIResponsesImpl:
|
|||
id=tool_call_id,
|
||||
status="completed",
|
||||
)
|
||||
if error_exc or (result.error_code and result.error_code > 0) or result.error_message:
|
||||
if has_error:
|
||||
message.status = "failed"
|
||||
elif function.name == "knowledge_search":
|
||||
message = OpenAIResponseOutputMessageFileSearchToolCall(
|
||||
|
@ -925,7 +1023,7 @@ class OpenAIResponsesImpl:
|
|||
queries=[tool_kwargs.get("query", "")],
|
||||
status="completed",
|
||||
)
|
||||
if "document_ids" in result.metadata:
|
||||
if result and "document_ids" in result.metadata:
|
||||
message.results = []
|
||||
for i, doc_id in enumerate(result.metadata["document_ids"]):
|
||||
text = result.metadata["chunks"][i] if "chunks" in result.metadata else None
|
||||
|
@ -939,7 +1037,7 @@ class OpenAIResponsesImpl:
|
|||
attributes={},
|
||||
)
|
||||
)
|
||||
if error_exc or (result.error_code and result.error_code > 0) or result.error_message:
|
||||
if has_error:
|
||||
message.status = "failed"
|
||||
else:
|
||||
raise ValueError(f"Unknown tool {function.name} called")
|
||||
|
@ -971,10 +1069,13 @@ class OpenAIResponsesImpl:
|
|||
raise ValueError(f"Unknown result content type: {type(result.content)}")
|
||||
input_message = OpenAIToolMessageParam(content=content, tool_call_id=tool_call_id)
|
||||
else:
|
||||
text = str(error_exc)
|
||||
text = str(error_exc) if error_exc else "Tool execution failed"
|
||||
input_message = OpenAIToolMessageParam(content=text, tool_call_id=tool_call_id)
|
||||
|
||||
return message, input_message
|
||||
# Yield the final result
|
||||
yield ToolExecutionResult(
|
||||
sequence_number=sequence_number, final_output_message=message, final_input_message=input_message
|
||||
)
|
||||
|
||||
|
||||
def _is_function_tool_call(
|
||||
|
|
|
@ -598,6 +598,10 @@ def test_response_streaming_multi_turn_tool_execution(compat_client, text_model_
|
|||
item_added_events = [chunk for chunk in chunks if chunk.type == "response.output_item.added"]
|
||||
item_done_events = [chunk for chunk in chunks if chunk.type == "response.output_item.done"]
|
||||
|
||||
# Should have tool execution progress events
|
||||
mcp_in_progress_events = [chunk for chunk in chunks if chunk.type == "response.mcp_call.in_progress"]
|
||||
mcp_completed_events = [chunk for chunk in chunks if chunk.type == "response.mcp_call.completed"]
|
||||
|
||||
# Verify we have substantial streaming activity (not just batch events)
|
||||
assert len(chunks) > 10, f"Expected rich streaming with many events, got only {len(chunks)} chunks"
|
||||
|
||||
|
@ -609,6 +613,24 @@ def test_response_streaming_multi_turn_tool_execution(compat_client, text_model_
|
|||
assert len(item_added_events) > 0, f"Expected response.output_item.added events, got chunk types: {chunk_types}"
|
||||
assert len(item_done_events) > 0, f"Expected response.output_item.done events, got chunk types: {chunk_types}"
|
||||
|
||||
# Should have tool execution progress events
|
||||
assert len(mcp_in_progress_events) > 0, (
|
||||
f"Expected response.mcp_call.in_progress events, got chunk types: {chunk_types}"
|
||||
)
|
||||
assert len(mcp_completed_events) > 0, (
|
||||
f"Expected response.mcp_call.completed events, got chunk types: {chunk_types}"
|
||||
)
|
||||
# MCP failed events are optional (only if errors occur)
|
||||
|
||||
# Verify progress events have proper structure
|
||||
for progress_event in mcp_in_progress_events:
|
||||
assert hasattr(progress_event, "item_id"), "Progress event should have 'item_id' field"
|
||||
assert hasattr(progress_event, "output_index"), "Progress event should have 'output_index' field"
|
||||
assert hasattr(progress_event, "sequence_number"), "Progress event should have 'sequence_number' field"
|
||||
|
||||
for completed_event in mcp_completed_events:
|
||||
assert hasattr(completed_event, "sequence_number"), "Completed event should have 'sequence_number' field"
|
||||
|
||||
# Verify delta events have proper structure
|
||||
for delta_event in delta_events:
|
||||
assert hasattr(delta_event, "delta"), "Delta event should have 'delta' field"
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue