diff --git a/docs/_static/llama-stack-spec.html b/docs/_static/llama-stack-spec.html
index f338aeea0..dbfe65960 100644
--- a/docs/_static/llama-stack-spec.html
+++ b/docs/_static/llama-stack-spec.html
@@ -7283,6 +7283,9 @@
"items": {
"$ref": "#/components/schemas/OpenAIResponseInputTool"
}
+ },
+ "max_infer_iters": {
+ "type": "integer"
}
},
"additionalProperties": false,
diff --git a/docs/_static/llama-stack-spec.yaml b/docs/_static/llama-stack-spec.yaml
index a87c6a80b..c185488b4 100644
--- a/docs/_static/llama-stack-spec.yaml
+++ b/docs/_static/llama-stack-spec.yaml
@@ -5149,6 +5149,8 @@ components:
type: array
items:
$ref: '#/components/schemas/OpenAIResponseInputTool'
+ max_infer_iters:
+ type: integer
additionalProperties: false
required:
- input
diff --git a/llama_stack/apis/agents/agents.py b/llama_stack/apis/agents/agents.py
index b79c512b8..956f4a614 100644
--- a/llama_stack/apis/agents/agents.py
+++ b/llama_stack/apis/agents/agents.py
@@ -604,6 +604,7 @@ class Agents(Protocol):
stream: bool | None = False,
temperature: float | None = None,
tools: list[OpenAIResponseInputTool] | None = None,
+ max_infer_iters: int | None = 10, # this is an extension to the OpenAI API
) -> OpenAIResponseObject | AsyncIterator[OpenAIResponseObjectStream]:
"""Create a new OpenAI response.
diff --git a/llama_stack/providers/inline/agents/meta_reference/agents.py b/llama_stack/providers/inline/agents/meta_reference/agents.py
index bcbfcbe31..854f8b285 100644
--- a/llama_stack/providers/inline/agents/meta_reference/agents.py
+++ b/llama_stack/providers/inline/agents/meta_reference/agents.py
@@ -325,9 +325,10 @@ class MetaReferenceAgentsImpl(Agents):
stream: bool | None = False,
temperature: float | None = None,
tools: list[OpenAIResponseInputTool] | None = None,
+ max_infer_iters: int | None = 10,
) -> OpenAIResponseObject:
return await self.openai_responses_impl.create_openai_response(
- input, model, instructions, previous_response_id, store, stream, temperature, tools
+ input, model, instructions, previous_response_id, store, stream, temperature, tools, max_infer_iters
)
async def list_openai_responses(
diff --git a/llama_stack/providers/inline/agents/meta_reference/openai_responses.py b/llama_stack/providers/inline/agents/meta_reference/openai_responses.py
index 19d7ea56f..f4f1bac43 100644
--- a/llama_stack/providers/inline/agents/meta_reference/openai_responses.py
+++ b/llama_stack/providers/inline/agents/meta_reference/openai_responses.py
@@ -258,6 +258,18 @@ class OpenAIResponsesImpl:
"""
return await self.responses_store.list_response_input_items(response_id, after, before, include, limit, order)
+ def _is_function_tool_call(
+ self,
+ tool_call: OpenAIChatCompletionToolCall,
+ tools: list[OpenAIResponseInputTool],
+ ) -> bool:
+ if not tool_call.function:
+ return False
+ for t in tools:
+ if t.type == "function" and t.name == tool_call.function.name:
+ return True
+ return False
+
async def _process_response_choices(
self,
chat_response: OpenAIChatCompletion,
@@ -270,7 +282,7 @@ class OpenAIResponsesImpl:
for choice in chat_response.choices:
if choice.message.tool_calls and tools:
# Assume if the first tool is a function, all tools are functions
- if tools[0].type == "function":
+ if self._is_function_tool_call(choice.message.tool_calls[0], tools):
for tool_call in choice.message.tool_calls:
output_messages.append(
OpenAIResponseOutputMessageFunctionToolCall(
@@ -332,6 +344,7 @@ class OpenAIResponsesImpl:
stream: bool | None = False,
temperature: float | None = None,
tools: list[OpenAIResponseInputTool] | None = None,
+ max_infer_iters: int | None = 10,
):
stream = False if stream is None else stream
@@ -358,58 +371,100 @@ class OpenAIResponsesImpl:
temperature=temperature,
)
- inference_result = await self.inference_api.openai_chat_completion(
- model=model,
- messages=messages,
- tools=chat_tools,
- stream=stream,
- temperature=temperature,
- )
-
+ # Fork to streaming vs non-streaming - let each handle ALL inference rounds
if stream:
return self._create_streaming_response(
- inference_result=inference_result,
ctx=ctx,
output_messages=output_messages,
input=input,
model=model,
store=store,
tools=tools,
+ max_infer_iters=max_infer_iters,
)
else:
return await self._create_non_streaming_response(
- inference_result=inference_result,
ctx=ctx,
output_messages=output_messages,
input=input,
model=model,
store=store,
tools=tools,
+ max_infer_iters=max_infer_iters,
)
async def _create_non_streaming_response(
self,
- inference_result: Any,
ctx: ChatCompletionContext,
output_messages: list[OpenAIResponseOutput],
input: str | list[OpenAIResponseInput],
model: str,
store: bool | None,
tools: list[OpenAIResponseInputTool] | None,
+ max_infer_iters: int | None,
) -> OpenAIResponseObject:
- chat_response = OpenAIChatCompletion(**inference_result.model_dump())
+ # Implement tool execution loop - handle ALL inference rounds including the first
+ n_iter = 0
+ messages = ctx.messages.copy()
+ current_response = None
- # Process response choices (tool execution and message creation)
- output_messages.extend(
- await self._process_response_choices(
- chat_response=chat_response,
- ctx=ctx,
- tools=tools,
+ while True:
+ # Do inference (including the first one)
+ inference_result = await self.inference_api.openai_chat_completion(
+ model=ctx.model,
+ messages=messages,
+ tools=ctx.tools,
+ stream=False,
+ temperature=ctx.temperature,
)
- )
+ current_response = OpenAIChatCompletion(**inference_result.model_dump())
+
+ # Separate function vs non-function tool calls
+ function_tool_calls = []
+ non_function_tool_calls = []
+
+ for choice in current_response.choices:
+ if choice.message.tool_calls and tools:
+ for tool_call in choice.message.tool_calls:
+ if self._is_function_tool_call(tool_call, tools):
+ function_tool_calls.append(tool_call)
+ else:
+ non_function_tool_calls.append(tool_call)
+
+ # Process response choices based on tool call types
+ if function_tool_calls:
+ # For function tool calls, use existing logic and return immediately
+ current_output_messages = await self._process_response_choices(
+ chat_response=current_response,
+ ctx=ctx,
+ tools=tools,
+ )
+ output_messages.extend(current_output_messages)
+ break
+ elif non_function_tool_calls:
+ # For non-function tool calls, execute them and continue loop
+ for choice in current_response.choices:
+ tool_outputs, tool_response_messages = await self._execute_tool_calls_only(choice, ctx)
+ output_messages.extend(tool_outputs)
+
+ # Add assistant message and tool responses to messages for next iteration
+ messages.append(choice.message)
+ messages.extend(tool_response_messages)
+
+ n_iter += 1
+ if n_iter >= (max_infer_iters or 10):
+ break
+
+ # Continue with next iteration of the loop
+ continue
+ else:
+ # No tool calls - convert response to message and we're done
+ for choice in current_response.choices:
+ output_messages.append(await _convert_chat_choice_to_response_message(choice))
+ break
response = OpenAIResponseObject(
- created_at=chat_response.created,
+ created_at=current_response.created,
id=f"resp-{uuid.uuid4()}",
model=model,
object="response",
@@ -429,13 +484,13 @@ class OpenAIResponsesImpl:
async def _create_streaming_response(
self,
- inference_result: Any,
ctx: ChatCompletionContext,
output_messages: list[OpenAIResponseOutput],
input: str | list[OpenAIResponseInput],
model: str,
store: bool | None,
tools: list[OpenAIResponseInputTool] | None,
+ max_infer_iters: int | None,
) -> AsyncIterator[OpenAIResponseObjectStream]:
# Create initial response and emit response.created immediately
response_id = f"resp-{uuid.uuid4()}"
@@ -453,87 +508,135 @@ class OpenAIResponsesImpl:
# Emit response.created immediately
yield OpenAIResponseObjectStreamResponseCreated(response=initial_response)
- # For streaming, inference_result is an async iterator of chunks
- # Stream chunks and emit delta events as they arrive
- chat_response_id = ""
- chat_response_content = []
- chat_response_tool_calls: dict[int, OpenAIChatCompletionToolCall] = {}
- chunk_created = 0
- chunk_model = ""
- chunk_finish_reason = ""
- sequence_number = 0
+ # Implement tool execution loop for streaming - handle ALL inference rounds including the first
+ n_iter = 0
+ messages = ctx.messages.copy()
- # Create a placeholder message item for delta events
- message_item_id = f"msg_{uuid.uuid4()}"
-
- async for chunk in inference_result:
- chat_response_id = chunk.id
- chunk_created = chunk.created
- chunk_model = chunk.model
- for chunk_choice in chunk.choices:
- # Emit incremental text content as delta events
- if chunk_choice.delta.content:
- sequence_number += 1
- yield OpenAIResponseObjectStreamResponseOutputTextDelta(
- content_index=0,
- delta=chunk_choice.delta.content,
- item_id=message_item_id,
- output_index=0,
- sequence_number=sequence_number,
- )
-
- # Collect content for final response
- chat_response_content.append(chunk_choice.delta.content or "")
- if chunk_choice.finish_reason:
- chunk_finish_reason = chunk_choice.finish_reason
-
- # Aggregate tool call arguments across chunks, using their index as the aggregation key
- if chunk_choice.delta.tool_calls:
- for tool_call in chunk_choice.delta.tool_calls:
- response_tool_call = chat_response_tool_calls.get(tool_call.index, None)
- if response_tool_call:
- # Don't attempt to concatenate arguments if we don't have any new arguments
- if tool_call.function.arguments:
- # Guard against an initial None argument before we concatenate
- response_tool_call.function.arguments = (
- response_tool_call.function.arguments or ""
- ) + tool_call.function.arguments
- else:
- tool_call_dict: dict[str, Any] = tool_call.model_dump()
- tool_call_dict.pop("type", None)
- response_tool_call = OpenAIChatCompletionToolCall(**tool_call_dict)
- chat_response_tool_calls[tool_call.index] = response_tool_call
-
- # Convert collected chunks to complete response
- if chat_response_tool_calls:
- tool_calls = [chat_response_tool_calls[i] for i in sorted(chat_response_tool_calls.keys())]
- else:
- tool_calls = None
- assistant_message = OpenAIAssistantMessageParam(
- content="".join(chat_response_content),
- tool_calls=tool_calls,
- )
- chat_response_obj = OpenAIChatCompletion(
- id=chat_response_id,
- choices=[
- OpenAIChoice(
- message=assistant_message,
- finish_reason=chunk_finish_reason,
- index=0,
- )
- ],
- created=chunk_created,
- model=chunk_model,
- )
-
- # Process response choices (tool execution and message creation)
- output_messages.extend(
- await self._process_response_choices(
- chat_response=chat_response_obj,
- ctx=ctx,
- tools=tools,
+ while True:
+ # Do inference (including the first one) - streaming
+ current_inference_result = await self.inference_api.openai_chat_completion(
+ model=ctx.model,
+ messages=messages,
+ tools=ctx.tools,
+ stream=True,
+ temperature=ctx.temperature,
)
- )
+
+ # Process streaming chunks and build complete response
+ chat_response_id = ""
+ chat_response_content = []
+ chat_response_tool_calls: dict[int, OpenAIChatCompletionToolCall] = {}
+ chunk_created = 0
+ chunk_model = ""
+ chunk_finish_reason = ""
+ sequence_number = 0
+
+ # Create a placeholder message item for delta events
+ message_item_id = f"msg_{uuid.uuid4()}"
+
+ async for chunk in current_inference_result:
+ chat_response_id = chunk.id
+ chunk_created = chunk.created
+ chunk_model = chunk.model
+ for chunk_choice in chunk.choices:
+ # Emit incremental text content as delta events
+ if chunk_choice.delta.content:
+ sequence_number += 1
+ yield OpenAIResponseObjectStreamResponseOutputTextDelta(
+ content_index=0,
+ delta=chunk_choice.delta.content,
+ item_id=message_item_id,
+ output_index=0,
+ sequence_number=sequence_number,
+ )
+
+ # Collect content for final response
+ chat_response_content.append(chunk_choice.delta.content or "")
+ if chunk_choice.finish_reason:
+ chunk_finish_reason = chunk_choice.finish_reason
+
+ # Aggregate tool call arguments across chunks
+ if chunk_choice.delta.tool_calls:
+ for tool_call in chunk_choice.delta.tool_calls:
+ response_tool_call = chat_response_tool_calls.get(tool_call.index, None)
+ if response_tool_call:
+ # Don't attempt to concatenate arguments if we don't have any new argumentsAdd commentMore actions
+ if tool_call.function.arguments:
+ # Guard against an initial None argument before we concatenate
+ response_tool_call.function.arguments = (
+ response_tool_call.function.arguments or ""
+ ) + tool_call.function.arguments
+ else:
+ tool_call_dict: dict[str, Any] = tool_call.model_dump()
+ tool_call_dict.pop("type", None)
+ response_tool_call = OpenAIChatCompletionToolCall(**tool_call_dict)
+ chat_response_tool_calls[tool_call.index] = response_tool_call
+
+ # Convert collected chunks to complete response
+ if chat_response_tool_calls:
+ tool_calls = [chat_response_tool_calls[i] for i in sorted(chat_response_tool_calls.keys())]
+ else:
+ tool_calls = None
+ assistant_message = OpenAIAssistantMessageParam(
+ content="".join(chat_response_content),
+ tool_calls=tool_calls,
+ )
+ current_response = OpenAIChatCompletion(
+ id=chat_response_id,
+ choices=[
+ OpenAIChoice(
+ message=assistant_message,
+ finish_reason=chunk_finish_reason,
+ index=0,
+ )
+ ],
+ created=chunk_created,
+ model=chunk_model,
+ )
+
+ # Separate function vs non-function tool calls
+ function_tool_calls = []
+ non_function_tool_calls = []
+
+ for choice in current_response.choices:
+ if choice.message.tool_calls and tools:
+ for tool_call in choice.message.tool_calls:
+ if self._is_function_tool_call(tool_call, tools):
+ function_tool_calls.append(tool_call)
+ else:
+ non_function_tool_calls.append(tool_call)
+
+ # Process response choices based on tool call types
+ if function_tool_calls:
+ # For function tool calls, use existing logic and break
+ current_output_messages = await self._process_response_choices(
+ chat_response=current_response,
+ ctx=ctx,
+ tools=tools,
+ )
+ output_messages.extend(current_output_messages)
+ break
+ elif non_function_tool_calls:
+ # For non-function tool calls, execute them and continue loop
+ for choice in current_response.choices:
+ tool_outputs, tool_response_messages = await self._execute_tool_calls_only(choice, ctx)
+ output_messages.extend(tool_outputs)
+
+ # Add assistant message and tool responses to messages for next iteration
+ messages.append(choice.message)
+ messages.extend(tool_response_messages)
+
+ n_iter += 1
+ if n_iter >= (max_infer_iters or 10):
+ break
+
+ # Continue with next iteration of the loop
+ continue
+ else:
+ # No tool calls - convert response to message and we're done
+ for choice in current_response.choices:
+ output_messages.append(await _convert_chat_choice_to_response_message(choice))
+ break
# Create final response
final_response = OpenAIResponseObject(
@@ -646,6 +749,30 @@ class OpenAIResponsesImpl:
raise ValueError(f"Llama Stack OpenAI Responses does not yet support tool type: {input_tool.type}")
return chat_tools, mcp_tool_to_server, mcp_list_message
+ async def _execute_tool_calls_only(
+ self,
+ choice: OpenAIChoice,
+ ctx: ChatCompletionContext,
+ ) -> tuple[list[OpenAIResponseOutput], list[OpenAIMessageParam]]:
+ """Execute tool calls and return output messages and tool response messages for next inference."""
+ output_messages: list[OpenAIResponseOutput] = []
+ tool_response_messages: list[OpenAIMessageParam] = []
+
+ if not isinstance(choice.message, OpenAIAssistantMessageParam):
+ return output_messages, tool_response_messages
+
+ if not choice.message.tool_calls:
+ return output_messages, tool_response_messages
+
+ for tool_call in choice.message.tool_calls:
+ tool_call_log, further_input = await self._execute_tool_call(tool_call, ctx)
+ if tool_call_log:
+ output_messages.append(tool_call_log)
+ if further_input:
+ tool_response_messages.append(further_input)
+
+ return output_messages, tool_response_messages
+
async def _execute_tool_and_return_final_output(
self,
choice: OpenAIChoice,
@@ -772,5 +899,8 @@ class OpenAIResponsesImpl:
else:
raise ValueError(f"Unknown result content type: {type(result.content)}")
input_message = OpenAIToolMessageParam(content=content, tool_call_id=tool_call_id)
+ else:
+ text = str(error_exc)
+ input_message = OpenAIToolMessageParam(content=text, tool_call_id=tool_call_id)
return message, input_message
diff --git a/tests/common/mcp.py b/tests/common/mcp.py
index fd7040c6c..775e38295 100644
--- a/tests/common/mcp.py
+++ b/tests/common/mcp.py
@@ -5,6 +5,7 @@
# the root directory of this source tree.
# we want the mcp server to be authenticated OR not, depends
+from collections.abc import Callable
from contextlib import contextmanager
# Unfortunately the toolgroup id must be tied to the tool names because the registry
@@ -13,15 +14,158 @@ from contextlib import contextmanager
MCP_TOOLGROUP_ID = "mcp::localmcp"
+def default_tools():
+ """Default tools for backward compatibility."""
+ from mcp import types
+ from mcp.server.fastmcp import Context
+
+ async def greet_everyone(
+ url: str, ctx: Context
+ ) -> list[types.TextContent | types.ImageContent | types.EmbeddedResource]:
+ return [types.TextContent(type="text", text="Hello, world!")]
+
+ async def get_boiling_point(liquid_name: str, celsius: bool = True) -> int:
+ """
+ Returns the boiling point of a liquid in Celsius or Fahrenheit.
+
+ :param liquid_name: The name of the liquid
+ :param celsius: Whether to return the boiling point in Celsius
+ :return: The boiling point of the liquid in Celcius or Fahrenheit
+ """
+ if liquid_name.lower() == "myawesomeliquid":
+ if celsius:
+ return -100
+ else:
+ return -212
+ else:
+ return -1
+
+ return {"greet_everyone": greet_everyone, "get_boiling_point": get_boiling_point}
+
+
+def dependency_tools():
+ """Tools with natural dependencies for multi-turn testing."""
+ from mcp import types
+ from mcp.server.fastmcp import Context
+
+ async def get_user_id(username: str, ctx: Context) -> str:
+ """
+ Get the user ID for a given username. This ID is needed for other operations.
+
+ :param username: The username to look up
+ :return: The user ID for the username
+ """
+ # Simple mapping for testing
+ user_mapping = {"alice": "user_12345", "bob": "user_67890", "charlie": "user_11111", "admin": "user_00000"}
+ return user_mapping.get(username.lower(), "user_99999")
+
+ async def get_user_permissions(user_id: str, ctx: Context) -> str:
+ """
+ Get the permissions for a user ID. Requires a valid user ID from get_user_id.
+
+ :param user_id: The user ID to check permissions for
+ :return: The permissions for the user
+ """
+ # Permission mapping based on user IDs
+ permission_mapping = {
+ "user_12345": "read,write", # alice
+ "user_67890": "read", # bob
+ "user_11111": "admin", # charlie
+ "user_00000": "superadmin", # admin
+ "user_99999": "none", # unknown users
+ }
+ return permission_mapping.get(user_id, "none")
+
+ async def check_file_access(user_id: str, filename: str, ctx: Context) -> str:
+ """
+ Check if a user can access a specific file. Requires a valid user ID.
+
+ :param user_id: The user ID to check access for
+ :param filename: The filename to check access to
+ :return: Whether the user can access the file (yes/no)
+ """
+ # Get permissions first
+ permission_mapping = {
+ "user_12345": "read,write", # alice
+ "user_67890": "read", # bob
+ "user_11111": "admin", # charlie
+ "user_00000": "superadmin", # admin
+ "user_99999": "none", # unknown users
+ }
+ permissions = permission_mapping.get(user_id, "none")
+
+ # Check file access based on permissions and filename
+ if permissions == "superadmin":
+ access = "yes"
+ elif permissions == "admin":
+ access = "yes" if not filename.startswith("secret_") else "no"
+ elif "write" in permissions:
+ access = "yes" if filename.endswith(".txt") else "no"
+ elif "read" in permissions:
+ access = "yes" if filename.endswith(".txt") or filename.endswith(".md") else "no"
+ else:
+ access = "no"
+
+ return [types.TextContent(type="text", text=access)]
+
+ async def get_experiment_id(experiment_name: str, ctx: Context) -> str:
+ """
+ Get the experiment ID for a given experiment name. This ID is needed to get results.
+
+ :param experiment_name: The name of the experiment
+ :return: The experiment ID
+ """
+ # Simple mapping for testing
+ experiment_mapping = {
+ "temperature_test": "exp_001",
+ "pressure_test": "exp_002",
+ "chemical_reaction": "exp_003",
+ "boiling_point": "exp_004",
+ }
+ exp_id = experiment_mapping.get(experiment_name.lower(), "exp_999")
+ return exp_id
+
+ async def get_experiment_results(experiment_id: str, ctx: Context) -> str:
+ """
+ Get the results for an experiment ID. Requires a valid experiment ID from get_experiment_id.
+
+ :param experiment_id: The experiment ID to get results for
+ :return: The experiment results
+ """
+ # Results mapping based on experiment IDs
+ results_mapping = {
+ "exp_001": "Temperature: 25°C, Status: Success",
+ "exp_002": "Pressure: 1.2 atm, Status: Success",
+ "exp_003": "Yield: 85%, Status: Complete",
+ "exp_004": "Boiling Point: 100°C, Status: Verified",
+ "exp_999": "No results found",
+ }
+ results = results_mapping.get(experiment_id, "Invalid experiment ID")
+ return results
+
+ return {
+ "get_user_id": get_user_id,
+ "get_user_permissions": get_user_permissions,
+ "check_file_access": check_file_access,
+ "get_experiment_id": get_experiment_id,
+ "get_experiment_results": get_experiment_results,
+ }
+
+
@contextmanager
-def make_mcp_server(required_auth_token: str | None = None):
+def make_mcp_server(required_auth_token: str | None = None, tools: dict[str, Callable] | None = None):
+ """
+ Create an MCP server with the specified tools.
+
+ :param required_auth_token: Optional auth token required for access
+ :param tools: Dictionary of tool_name -> tool_function. If None, uses default tools.
+ """
import threading
import time
import httpx
import uvicorn
- from mcp import types
- from mcp.server.fastmcp import Context, FastMCP
+ from mcp.server.fastmcp import FastMCP
from mcp.server.sse import SseServerTransport
from starlette.applications import Starlette
from starlette.responses import Response
@@ -29,35 +173,18 @@ def make_mcp_server(required_auth_token: str | None = None):
server = FastMCP("FastMCP Test Server", log_level="WARNING")
- @server.tool()
- async def greet_everyone(
- url: str, ctx: Context
- ) -> list[types.TextContent | types.ImageContent | types.EmbeddedResource]:
- return [types.TextContent(type="text", text="Hello, world!")]
+ tools = tools or default_tools()
- @server.tool()
- async def get_boiling_point(liquid_name: str, celcius: bool = True) -> int:
- """
- Returns the boiling point of a liquid in Celcius or Fahrenheit.
-
- :param liquid_name: The name of the liquid
- :param celcius: Whether to return the boiling point in Celcius
- :return: The boiling point of the liquid in Celcius or Fahrenheit
- """
- if liquid_name.lower() == "polyjuice":
- if celcius:
- return -100
- else:
- return -212
- else:
- return -1
+ # Register all tools with the server
+ for tool_func in tools.values():
+ server.tool()(tool_func)
sse = SseServerTransport("/messages/")
async def handle_sse(request):
from starlette.exceptions import HTTPException
- auth_header = request.headers.get("Authorization")
+ auth_header: str | None = request.headers.get("Authorization")
auth_token = None
if auth_header and auth_header.startswith("Bearer "):
auth_token = auth_header.split(" ")[1]
diff --git a/tests/unit/providers/agents/meta_reference/test_openai_responses.py b/tests/unit/providers/agents/meta_reference/test_openai_responses.py
index 5b6cee0ec..7a367e394 100644
--- a/tests/unit/providers/agents/meta_reference/test_openai_responses.py
+++ b/tests/unit/providers/agents/meta_reference/test_openai_responses.py
@@ -224,16 +224,16 @@ async def test_create_openai_response_with_tool_call_type_none(openai_responses_
],
)
- # Verify
+ # Check that we got the content from our mocked tool execution result
+ chunks = [chunk async for chunk in result]
+ assert len(chunks) == 2 # Should have response.created and response.completed
+
+ # Verify inference API was called correctly (after iterating over result)
first_call = mock_inference_api.openai_chat_completion.call_args_list[0]
assert first_call.kwargs["messages"][0].content == input_text
assert first_call.kwargs["tools"] is not None
assert first_call.kwargs["temperature"] == 0.1
- # Check that we got the content from our mocked tool execution result
- chunks = [chunk async for chunk in result]
- assert len(chunks) == 2 # Should have response.created and response.completed
-
# Check response.created event (should have empty output)
assert chunks[0].type == "response.created"
assert len(chunks[0].response.output) == 0
diff --git a/tests/verifications/openai_api/fixtures/test_cases/responses.yaml b/tests/verifications/openai_api/fixtures/test_cases/responses.yaml
index 51c7814a3..4d6c19b59 100644
--- a/tests/verifications/openai_api/fixtures/test_cases/responses.yaml
+++ b/tests/verifications/openai_api/fixtures/test_cases/responses.yaml
@@ -36,7 +36,7 @@ test_response_mcp_tool:
test_params:
case:
- case_id: "boiling_point_tool"
- input: "What is the boiling point of polyjuice?"
+ input: "What is the boiling point of myawesomeliquid in Celsius?"
tools:
- type: mcp
server_label: "localmcp"
@@ -94,3 +94,43 @@ test_response_multi_turn_image:
output: "llama"
- input: "What country do you find this animal primarily in? What continent?"
output: "peru"
+
+test_response_multi_turn_tool_execution:
+ test_name: test_response_multi_turn_tool_execution
+ test_params:
+ case:
+ - case_id: "user_file_access_check"
+ input: "I need to check if user 'alice' can access the file 'document.txt'. First, get alice's user ID, then check if that user ID can access the file 'document.txt'. Do this as a series of steps, where each step is a separate message. Return only one tool call per step. Summarize the final result with a single 'yes' or 'no' response."
+ tools:
+ - type: mcp
+ server_label: "localmcp"
+ server_url: ""
+ output: "yes"
+ - case_id: "experiment_results_lookup"
+ input: "I need to get the results for the 'boiling_point' experiment. First, get the experiment ID for 'boiling_point', then use that ID to get the experiment results. Tell me what you found."
+ tools:
+ - type: mcp
+ server_label: "localmcp"
+ server_url: ""
+ output: "100°C"
+
+test_response_multi_turn_tool_execution_streaming:
+ test_name: test_response_multi_turn_tool_execution_streaming
+ test_params:
+ case:
+ - case_id: "user_permissions_workflow"
+ input: "Help me with this security check: First, get the user ID for 'charlie', then get the permissions for that user ID, and finally check if that user can access 'secret_file.txt'. Stream your progress as you work through each step."
+ tools:
+ - type: mcp
+ server_label: "localmcp"
+ server_url: ""
+ stream: true
+ output: "no"
+ - case_id: "experiment_analysis_streaming"
+ input: "I need a complete analysis: First, get the experiment ID for 'chemical_reaction', then get the results for that experiment, and tell me if the yield was above 80%. Please stream your analysis process."
+ tools:
+ - type: mcp
+ server_label: "localmcp"
+ server_url: ""
+ stream: true
+ output: "85%"
diff --git a/tests/verifications/openai_api/test_responses.py b/tests/verifications/openai_api/test_responses.py
index 2ce0a3e9c..c9b190e62 100644
--- a/tests/verifications/openai_api/test_responses.py
+++ b/tests/verifications/openai_api/test_responses.py
@@ -12,7 +12,7 @@ import pytest
from llama_stack import LlamaStackAsLibraryClient
from llama_stack.distribution.datatypes import AuthenticationRequiredError
-from tests.common.mcp import make_mcp_server
+from tests.common.mcp import dependency_tools, make_mcp_server
from tests.verifications.openai_api.fixtures.fixtures import (
case_id_generator,
get_base_test_name,
@@ -280,6 +280,7 @@ def test_response_non_streaming_mcp_tool(request, openai_client, model, provider
tools=tools,
stream=False,
)
+
assert len(response.output) >= 3
list_tools = response.output[0]
assert list_tools.type == "mcp_list_tools"
@@ -290,11 +291,12 @@ def test_response_non_streaming_mcp_tool(request, openai_client, model, provider
call = response.output[1]
assert call.type == "mcp_call"
assert call.name == "get_boiling_point"
- assert json.loads(call.arguments) == {"liquid_name": "polyjuice", "celcius": True}
+ assert json.loads(call.arguments) == {"liquid_name": "myawesomeliquid", "celsius": True}
assert call.error is None
assert "-100" in call.output
- message = response.output[2]
+ # sometimes the model will call the tool again, so we need to get the last message
+ message = response.output[-1]
text_content = message.content[0].text
assert "boiling point" in text_content.lower()
@@ -393,3 +395,154 @@ def test_response_non_streaming_multi_turn_image(request, openai_client, model,
previous_response_id = response.id
output_text = response.output_text.lower()
assert turn["output"].lower() in output_text
+
+
+@pytest.mark.parametrize(
+ "case",
+ responses_test_cases["test_response_multi_turn_tool_execution"]["test_params"]["case"],
+ ids=case_id_generator,
+)
+def test_response_non_streaming_multi_turn_tool_execution(
+ request, openai_client, model, provider, verification_config, case
+):
+ """Test multi-turn tool execution where multiple MCP tool calls are performed in sequence."""
+ test_name_base = get_base_test_name(request)
+ if should_skip_test(verification_config, provider, model, test_name_base):
+ pytest.skip(f"Skipping {test_name_base} for model {model} on provider {provider} based on config.")
+
+ with make_mcp_server(tools=dependency_tools()) as mcp_server_info:
+ tools = case["tools"]
+ # Replace the placeholder URL with the actual server URL
+ for tool in tools:
+ if tool["type"] == "mcp" and tool["server_url"] == "":
+ tool["server_url"] = mcp_server_info["server_url"]
+
+ response = openai_client.responses.create(
+ input=case["input"],
+ model=model,
+ tools=tools,
+ )
+
+ # Verify we have MCP tool calls in the output
+ mcp_list_tools = [output for output in response.output if output.type == "mcp_list_tools"]
+ mcp_calls = [output for output in response.output if output.type == "mcp_call"]
+ message_outputs = [output for output in response.output if output.type == "message"]
+
+ # Should have exactly 1 MCP list tools message (at the beginning)
+ assert len(mcp_list_tools) == 1, f"Expected exactly 1 mcp_list_tools, got {len(mcp_list_tools)}"
+ assert mcp_list_tools[0].server_label == "localmcp"
+ assert len(mcp_list_tools[0].tools) == 5 # Updated for dependency tools
+ expected_tool_names = {
+ "get_user_id",
+ "get_user_permissions",
+ "check_file_access",
+ "get_experiment_id",
+ "get_experiment_results",
+ }
+ assert {t["name"] for t in mcp_list_tools[0].tools} == expected_tool_names
+
+ assert len(mcp_calls) >= 1, f"Expected at least 1 mcp_call, got {len(mcp_calls)}"
+ for mcp_call in mcp_calls:
+ assert mcp_call.error is None, f"MCP call should not have errors, got: {mcp_call.error}"
+
+ assert len(message_outputs) >= 1, f"Expected at least 1 message output, got {len(message_outputs)}"
+
+ final_message = message_outputs[-1]
+ assert final_message.role == "assistant", f"Final message should be from assistant, got {final_message.role}"
+ assert final_message.status == "completed", f"Final message should be completed, got {final_message.status}"
+ assert len(final_message.content) > 0, "Final message should have content"
+
+ expected_output = case["output"]
+ assert expected_output.lower() in response.output_text.lower(), (
+ f"Expected '{expected_output}' to appear in response: {response.output_text}"
+ )
+
+
+@pytest.mark.parametrize(
+ "case",
+ responses_test_cases["test_response_multi_turn_tool_execution_streaming"]["test_params"]["case"],
+ ids=case_id_generator,
+)
+async def test_response_streaming_multi_turn_tool_execution(
+ request, openai_client, model, provider, verification_config, case
+):
+ """Test streaming multi-turn tool execution where multiple MCP tool calls are performed in sequence."""
+ test_name_base = get_base_test_name(request)
+ if should_skip_test(verification_config, provider, model, test_name_base):
+ pytest.skip(f"Skipping {test_name_base} for model {model} on provider {provider} based on config.")
+
+ with make_mcp_server(tools=dependency_tools()) as mcp_server_info:
+ tools = case["tools"]
+ # Replace the placeholder URL with the actual server URL
+ for tool in tools:
+ if tool["type"] == "mcp" and tool["server_url"] == "":
+ tool["server_url"] = mcp_server_info["server_url"]
+
+ stream = openai_client.responses.create(
+ input=case["input"],
+ model=model,
+ tools=tools,
+ stream=True,
+ )
+
+ chunks = []
+ async for chunk in stream:
+ chunks.append(chunk)
+
+ # Should have at least response.created and response.completed
+ assert len(chunks) >= 2, f"Expected at least 2 chunks (created + completed), got {len(chunks)}"
+
+ # First chunk should be response.created
+ assert chunks[0].type == "response.created", f"First chunk should be response.created, got {chunks[0].type}"
+
+ # Last chunk should be response.completed
+ assert chunks[-1].type == "response.completed", (
+ f"Last chunk should be response.completed, got {chunks[-1].type}"
+ )
+
+ # Get the final response from the last chunk
+ final_chunk = chunks[-1]
+ if hasattr(final_chunk, "response"):
+ final_response = final_chunk.response
+
+ # Verify multi-turn MCP tool execution results
+ mcp_list_tools = [output for output in final_response.output if output.type == "mcp_list_tools"]
+ mcp_calls = [output for output in final_response.output if output.type == "mcp_call"]
+ message_outputs = [output for output in final_response.output if output.type == "message"]
+
+ # Should have exactly 1 MCP list tools message (at the beginning)
+ assert len(mcp_list_tools) == 1, f"Expected exactly 1 mcp_list_tools, got {len(mcp_list_tools)}"
+ assert mcp_list_tools[0].server_label == "localmcp"
+ assert len(mcp_list_tools[0].tools) == 5 # Updated for dependency tools
+ expected_tool_names = {
+ "get_user_id",
+ "get_user_permissions",
+ "check_file_access",
+ "get_experiment_id",
+ "get_experiment_results",
+ }
+ assert {t["name"] for t in mcp_list_tools[0].tools} == expected_tool_names
+
+ # Should have at least 1 MCP call (the model should call at least one tool)
+ assert len(mcp_calls) >= 1, f"Expected at least 1 mcp_call, got {len(mcp_calls)}"
+
+ # All MCP calls should be completed (verifies our tool execution works)
+ for mcp_call in mcp_calls:
+ assert mcp_call.error is None, f"MCP call should not have errors, got: {mcp_call.error}"
+
+ # Should have at least one final message response
+ assert len(message_outputs) >= 1, f"Expected at least 1 message output, got {len(message_outputs)}"
+
+ # Final message should be from assistant and completed
+ final_message = message_outputs[-1]
+ assert final_message.role == "assistant", (
+ f"Final message should be from assistant, got {final_message.role}"
+ )
+ assert final_message.status == "completed", f"Final message should be completed, got {final_message.status}"
+ assert len(final_message.content) > 0, "Final message should have content"
+
+ # Check that the expected output appears in the response
+ expected_output = case["output"]
+ assert expected_output.lower() in final_response.output_text.lower(), (
+ f"Expected '{expected_output}' to appear in response: {final_response.output_text}"
+ )