mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-08-15 14:08:00 +00:00
feat(responses): stream progress of tool calls
This commit is contained in:
parent
5b312a80b9
commit
8159a9d757
2 changed files with 141 additions and 18 deletions
|
@ -35,9 +35,15 @@ from llama_stack.apis.agents.openai_responses import (
|
||||||
OpenAIResponseObjectStreamResponseCreated,
|
OpenAIResponseObjectStreamResponseCreated,
|
||||||
OpenAIResponseObjectStreamResponseFunctionCallArgumentsDelta,
|
OpenAIResponseObjectStreamResponseFunctionCallArgumentsDelta,
|
||||||
OpenAIResponseObjectStreamResponseFunctionCallArgumentsDone,
|
OpenAIResponseObjectStreamResponseFunctionCallArgumentsDone,
|
||||||
|
OpenAIResponseObjectStreamResponseMcpCallCompleted,
|
||||||
|
OpenAIResponseObjectStreamResponseMcpCallFailed,
|
||||||
|
OpenAIResponseObjectStreamResponseMcpCallInProgress,
|
||||||
OpenAIResponseObjectStreamResponseOutputItemAdded,
|
OpenAIResponseObjectStreamResponseOutputItemAdded,
|
||||||
OpenAIResponseObjectStreamResponseOutputItemDone,
|
OpenAIResponseObjectStreamResponseOutputItemDone,
|
||||||
OpenAIResponseObjectStreamResponseOutputTextDelta,
|
OpenAIResponseObjectStreamResponseOutputTextDelta,
|
||||||
|
OpenAIResponseObjectStreamResponseWebSearchCallCompleted,
|
||||||
|
OpenAIResponseObjectStreamResponseWebSearchCallInProgress,
|
||||||
|
OpenAIResponseObjectStreamResponseWebSearchCallSearching,
|
||||||
OpenAIResponseOutput,
|
OpenAIResponseOutput,
|
||||||
OpenAIResponseOutputMessageContent,
|
OpenAIResponseOutputMessageContent,
|
||||||
OpenAIResponseOutputMessageContentOutputText,
|
OpenAIResponseOutputMessageContentOutputText,
|
||||||
|
@ -87,6 +93,15 @@ logger = get_logger(name=__name__, category="openai_responses")
|
||||||
OPENAI_RESPONSES_PREFIX = "openai_responses:"
|
OPENAI_RESPONSES_PREFIX = "openai_responses:"
|
||||||
|
|
||||||
|
|
||||||
|
class ToolExecutionResult(BaseModel):
|
||||||
|
"""Result of streaming tool execution."""
|
||||||
|
|
||||||
|
stream_event: OpenAIResponseObjectStream | None = None
|
||||||
|
sequence_number: int
|
||||||
|
final_output_message: OpenAIResponseOutput | None = None
|
||||||
|
final_input_message: OpenAIMessageParam | None = None
|
||||||
|
|
||||||
|
|
||||||
async def _convert_response_content_to_chat_content(
|
async def _convert_response_content_to_chat_content(
|
||||||
content: (str | list[OpenAIResponseInputMessageContent] | list[OpenAIResponseOutputMessageContent]),
|
content: (str | list[OpenAIResponseInputMessageContent] | list[OpenAIResponseOutputMessageContent]),
|
||||||
) -> str | list[OpenAIChatCompletionContentPartParam]:
|
) -> str | list[OpenAIChatCompletionContentPartParam]:
|
||||||
|
@ -587,19 +602,38 @@ class OpenAIResponsesImpl:
|
||||||
|
|
||||||
# execute non-function tool calls
|
# execute non-function tool calls
|
||||||
for tool_call in non_function_tool_calls:
|
for tool_call in non_function_tool_calls:
|
||||||
tool_call_log, tool_response_message = await self._execute_tool_call(tool_call, ctx)
|
# Find the item_id for this tool call
|
||||||
|
matching_item_id = None
|
||||||
|
for index, item_id in tool_call_item_ids.items():
|
||||||
|
response_tool_call = chat_response_tool_calls.get(index)
|
||||||
|
if response_tool_call and response_tool_call.id == tool_call.id:
|
||||||
|
matching_item_id = item_id
|
||||||
|
break
|
||||||
|
|
||||||
|
# Use a fallback item_id if not found
|
||||||
|
if not matching_item_id:
|
||||||
|
matching_item_id = f"tc_{uuid.uuid4()}"
|
||||||
|
|
||||||
|
# Execute tool call with streaming
|
||||||
|
tool_call_log = None
|
||||||
|
tool_response_message = None
|
||||||
|
async for result in self._execute_tool_call(
|
||||||
|
tool_call, ctx, sequence_number, response_id, len(output_messages), matching_item_id
|
||||||
|
):
|
||||||
|
if result.stream_event:
|
||||||
|
# Forward streaming events
|
||||||
|
sequence_number = result.sequence_number
|
||||||
|
yield result.stream_event
|
||||||
|
|
||||||
|
if result.final_output_message is not None:
|
||||||
|
tool_call_log = result.final_output_message
|
||||||
|
tool_response_message = result.final_input_message
|
||||||
|
sequence_number = result.sequence_number
|
||||||
|
|
||||||
if tool_call_log:
|
if tool_call_log:
|
||||||
output_messages.append(tool_call_log)
|
output_messages.append(tool_call_log)
|
||||||
|
|
||||||
# Emit output_item.done event for completed non-function tool call
|
# Emit output_item.done event for completed non-function tool call
|
||||||
# Find the item_id for this tool call
|
|
||||||
matching_item_id = None
|
|
||||||
for index, item_id in tool_call_item_ids.items():
|
|
||||||
response_tool_call = chat_response_tool_calls.get(index)
|
|
||||||
if response_tool_call and response_tool_call.id == tool_call.id:
|
|
||||||
matching_item_id = item_id
|
|
||||||
break
|
|
||||||
|
|
||||||
if matching_item_id:
|
if matching_item_id:
|
||||||
sequence_number += 1
|
sequence_number += 1
|
||||||
yield OpenAIResponseObjectStreamResponseOutputItemDone(
|
yield OpenAIResponseObjectStreamResponseOutputItemDone(
|
||||||
|
@ -848,7 +882,11 @@ class OpenAIResponsesImpl:
|
||||||
self,
|
self,
|
||||||
tool_call: OpenAIChatCompletionToolCall,
|
tool_call: OpenAIChatCompletionToolCall,
|
||||||
ctx: ChatCompletionContext,
|
ctx: ChatCompletionContext,
|
||||||
) -> tuple[OpenAIResponseOutput | None, OpenAIMessageParam | None]:
|
sequence_number: int,
|
||||||
|
response_id: str,
|
||||||
|
output_index: int,
|
||||||
|
item_id: str,
|
||||||
|
) -> AsyncIterator[ToolExecutionResult]:
|
||||||
from llama_stack.providers.utils.inference.prompt_adapter import (
|
from llama_stack.providers.utils.inference.prompt_adapter import (
|
||||||
interleaved_content_as_str,
|
interleaved_content_as_str,
|
||||||
)
|
)
|
||||||
|
@ -858,8 +896,41 @@ class OpenAIResponsesImpl:
|
||||||
tool_kwargs = json.loads(function.arguments) if function.arguments else {}
|
tool_kwargs = json.loads(function.arguments) if function.arguments else {}
|
||||||
|
|
||||||
if not function or not tool_call_id or not function.name:
|
if not function or not tool_call_id or not function.name:
|
||||||
return None, None
|
yield ToolExecutionResult(sequence_number=sequence_number)
|
||||||
|
return
|
||||||
|
|
||||||
|
# Emit in_progress event based on tool type (only for tools with specific streaming events)
|
||||||
|
progress_event = None
|
||||||
|
if ctx.mcp_tool_to_server and function.name in ctx.mcp_tool_to_server:
|
||||||
|
sequence_number += 1
|
||||||
|
progress_event = OpenAIResponseObjectStreamResponseMcpCallInProgress(
|
||||||
|
item_id=item_id,
|
||||||
|
output_index=output_index,
|
||||||
|
sequence_number=sequence_number,
|
||||||
|
)
|
||||||
|
elif function.name == "web_search":
|
||||||
|
sequence_number += 1
|
||||||
|
progress_event = OpenAIResponseObjectStreamResponseWebSearchCallInProgress(
|
||||||
|
item_id=item_id,
|
||||||
|
output_index=output_index,
|
||||||
|
sequence_number=sequence_number,
|
||||||
|
)
|
||||||
|
# Note: knowledge_search and other custom tools don't have specific streaming events in OpenAI spec
|
||||||
|
|
||||||
|
if progress_event:
|
||||||
|
yield ToolExecutionResult(stream_event=progress_event, sequence_number=sequence_number)
|
||||||
|
|
||||||
|
# For web search, emit searching event
|
||||||
|
if function.name == "web_search":
|
||||||
|
sequence_number += 1
|
||||||
|
searching_event = OpenAIResponseObjectStreamResponseWebSearchCallSearching(
|
||||||
|
item_id=item_id,
|
||||||
|
output_index=output_index,
|
||||||
|
sequence_number=sequence_number,
|
||||||
|
)
|
||||||
|
yield ToolExecutionResult(stream_event=searching_event, sequence_number=sequence_number)
|
||||||
|
|
||||||
|
# Execute the actual tool call
|
||||||
error_exc = None
|
error_exc = None
|
||||||
result = None
|
result = None
|
||||||
try:
|
try:
|
||||||
|
@ -894,6 +965,33 @@ class OpenAIResponsesImpl:
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
error_exc = e
|
error_exc = e
|
||||||
|
|
||||||
|
# Emit completion or failure event based on result (only for tools with specific streaming events)
|
||||||
|
has_error = error_exc or (result and ((result.error_code and result.error_code > 0) or result.error_message))
|
||||||
|
completion_event = None
|
||||||
|
|
||||||
|
if ctx.mcp_tool_to_server and function.name in ctx.mcp_tool_to_server:
|
||||||
|
sequence_number += 1
|
||||||
|
if has_error:
|
||||||
|
completion_event = OpenAIResponseObjectStreamResponseMcpCallFailed(
|
||||||
|
sequence_number=sequence_number,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
completion_event = OpenAIResponseObjectStreamResponseMcpCallCompleted(
|
||||||
|
sequence_number=sequence_number,
|
||||||
|
)
|
||||||
|
elif function.name == "web_search":
|
||||||
|
sequence_number += 1
|
||||||
|
completion_event = OpenAIResponseObjectStreamResponseWebSearchCallCompleted(
|
||||||
|
item_id=item_id,
|
||||||
|
output_index=output_index,
|
||||||
|
sequence_number=sequence_number,
|
||||||
|
)
|
||||||
|
# Note: knowledge_search and other custom tools don't have specific completion events in OpenAI spec
|
||||||
|
|
||||||
|
if completion_event:
|
||||||
|
yield ToolExecutionResult(stream_event=completion_event, sequence_number=sequence_number)
|
||||||
|
|
||||||
|
# Build the result message and input message
|
||||||
if function.name in ctx.mcp_tool_to_server:
|
if function.name in ctx.mcp_tool_to_server:
|
||||||
from llama_stack.apis.agents.openai_responses import (
|
from llama_stack.apis.agents.openai_responses import (
|
||||||
OpenAIResponseOutputMessageMCPCall,
|
OpenAIResponseOutputMessageMCPCall,
|
||||||
|
@ -907,9 +1005,9 @@ class OpenAIResponsesImpl:
|
||||||
)
|
)
|
||||||
if error_exc:
|
if error_exc:
|
||||||
message.error = str(error_exc)
|
message.error = str(error_exc)
|
||||||
elif (result.error_code and result.error_code > 0) or result.error_message:
|
elif (result and result.error_code and result.error_code > 0) or (result and result.error_message):
|
||||||
message.error = f"Error (code {result.error_code}): {result.error_message}"
|
message.error = f"Error (code {result.error_code}): {result.error_message}"
|
||||||
elif result.content:
|
elif result and result.content:
|
||||||
message.output = interleaved_content_as_str(result.content)
|
message.output = interleaved_content_as_str(result.content)
|
||||||
else:
|
else:
|
||||||
if function.name == "web_search":
|
if function.name == "web_search":
|
||||||
|
@ -917,7 +1015,7 @@ class OpenAIResponsesImpl:
|
||||||
id=tool_call_id,
|
id=tool_call_id,
|
||||||
status="completed",
|
status="completed",
|
||||||
)
|
)
|
||||||
if error_exc or (result.error_code and result.error_code > 0) or result.error_message:
|
if has_error:
|
||||||
message.status = "failed"
|
message.status = "failed"
|
||||||
elif function.name == "knowledge_search":
|
elif function.name == "knowledge_search":
|
||||||
message = OpenAIResponseOutputMessageFileSearchToolCall(
|
message = OpenAIResponseOutputMessageFileSearchToolCall(
|
||||||
|
@ -925,7 +1023,7 @@ class OpenAIResponsesImpl:
|
||||||
queries=[tool_kwargs.get("query", "")],
|
queries=[tool_kwargs.get("query", "")],
|
||||||
status="completed",
|
status="completed",
|
||||||
)
|
)
|
||||||
if "document_ids" in result.metadata:
|
if result and "document_ids" in result.metadata:
|
||||||
message.results = []
|
message.results = []
|
||||||
for i, doc_id in enumerate(result.metadata["document_ids"]):
|
for i, doc_id in enumerate(result.metadata["document_ids"]):
|
||||||
text = result.metadata["chunks"][i] if "chunks" in result.metadata else None
|
text = result.metadata["chunks"][i] if "chunks" in result.metadata else None
|
||||||
|
@ -939,7 +1037,7 @@ class OpenAIResponsesImpl:
|
||||||
attributes={},
|
attributes={},
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
if error_exc or (result.error_code and result.error_code > 0) or result.error_message:
|
if has_error:
|
||||||
message.status = "failed"
|
message.status = "failed"
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Unknown tool {function.name} called")
|
raise ValueError(f"Unknown tool {function.name} called")
|
||||||
|
@ -971,10 +1069,13 @@ class OpenAIResponsesImpl:
|
||||||
raise ValueError(f"Unknown result content type: {type(result.content)}")
|
raise ValueError(f"Unknown result content type: {type(result.content)}")
|
||||||
input_message = OpenAIToolMessageParam(content=content, tool_call_id=tool_call_id)
|
input_message = OpenAIToolMessageParam(content=content, tool_call_id=tool_call_id)
|
||||||
else:
|
else:
|
||||||
text = str(error_exc)
|
text = str(error_exc) if error_exc else "Tool execution failed"
|
||||||
input_message = OpenAIToolMessageParam(content=text, tool_call_id=tool_call_id)
|
input_message = OpenAIToolMessageParam(content=text, tool_call_id=tool_call_id)
|
||||||
|
|
||||||
return message, input_message
|
# Yield the final result
|
||||||
|
yield ToolExecutionResult(
|
||||||
|
sequence_number=sequence_number, final_output_message=message, final_input_message=input_message
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def _is_function_tool_call(
|
def _is_function_tool_call(
|
||||||
|
|
|
@ -598,6 +598,10 @@ def test_response_streaming_multi_turn_tool_execution(compat_client, text_model_
|
||||||
item_added_events = [chunk for chunk in chunks if chunk.type == "response.output_item.added"]
|
item_added_events = [chunk for chunk in chunks if chunk.type == "response.output_item.added"]
|
||||||
item_done_events = [chunk for chunk in chunks if chunk.type == "response.output_item.done"]
|
item_done_events = [chunk for chunk in chunks if chunk.type == "response.output_item.done"]
|
||||||
|
|
||||||
|
# Should have tool execution progress events
|
||||||
|
mcp_in_progress_events = [chunk for chunk in chunks if chunk.type == "response.mcp_call.in_progress"]
|
||||||
|
mcp_completed_events = [chunk for chunk in chunks if chunk.type == "response.mcp_call.completed"]
|
||||||
|
|
||||||
# Verify we have substantial streaming activity (not just batch events)
|
# Verify we have substantial streaming activity (not just batch events)
|
||||||
assert len(chunks) > 10, f"Expected rich streaming with many events, got only {len(chunks)} chunks"
|
assert len(chunks) > 10, f"Expected rich streaming with many events, got only {len(chunks)} chunks"
|
||||||
|
|
||||||
|
@ -609,6 +613,24 @@ def test_response_streaming_multi_turn_tool_execution(compat_client, text_model_
|
||||||
assert len(item_added_events) > 0, f"Expected response.output_item.added events, got chunk types: {chunk_types}"
|
assert len(item_added_events) > 0, f"Expected response.output_item.added events, got chunk types: {chunk_types}"
|
||||||
assert len(item_done_events) > 0, f"Expected response.output_item.done events, got chunk types: {chunk_types}"
|
assert len(item_done_events) > 0, f"Expected response.output_item.done events, got chunk types: {chunk_types}"
|
||||||
|
|
||||||
|
# Should have tool execution progress events
|
||||||
|
assert len(mcp_in_progress_events) > 0, (
|
||||||
|
f"Expected response.mcp_call.in_progress events, got chunk types: {chunk_types}"
|
||||||
|
)
|
||||||
|
assert len(mcp_completed_events) > 0, (
|
||||||
|
f"Expected response.mcp_call.completed events, got chunk types: {chunk_types}"
|
||||||
|
)
|
||||||
|
# MCP failed events are optional (only if errors occur)
|
||||||
|
|
||||||
|
# Verify progress events have proper structure
|
||||||
|
for progress_event in mcp_in_progress_events:
|
||||||
|
assert hasattr(progress_event, "item_id"), "Progress event should have 'item_id' field"
|
||||||
|
assert hasattr(progress_event, "output_index"), "Progress event should have 'output_index' field"
|
||||||
|
assert hasattr(progress_event, "sequence_number"), "Progress event should have 'sequence_number' field"
|
||||||
|
|
||||||
|
for completed_event in mcp_completed_events:
|
||||||
|
assert hasattr(completed_event, "sequence_number"), "Completed event should have 'sequence_number' field"
|
||||||
|
|
||||||
# Verify delta events have proper structure
|
# Verify delta events have proper structure
|
||||||
for delta_event in delta_events:
|
for delta_event in delta_events:
|
||||||
assert hasattr(delta_event, "delta"), "Delta event should have 'delta' field"
|
assert hasattr(delta_event, "delta"), "Delta event should have 'delta' field"
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue