mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-06-27 10:46:41 +00:00
feat(responses): implement full multi-turn support (#2295)
I think the implementation needs more simplification. Spent way too much time trying to get the tests pass with models not co-operating :( Finally had to switch claude-sonnet to get things to pass reliably. ### Test Plan ``` export TAVILY_SEARCH_API_KEY=... export OPENAI_API_KEY=... uv run pytest -p no:warnings \ -s -v tests/verifications/openai_api/test_responses.py \ --provider=stack:starter \ --model openai/gpt-4o ```
This commit is contained in:
parent
cac7d404a2
commit
dbe4e84aca
9 changed files with 593 additions and 136 deletions
3
docs/_static/llama-stack-spec.html
vendored
3
docs/_static/llama-stack-spec.html
vendored
|
@ -7283,6 +7283,9 @@
|
|||
"items": {
|
||||
"$ref": "#/components/schemas/OpenAIResponseInputTool"
|
||||
}
|
||||
},
|
||||
"max_infer_iters": {
|
||||
"type": "integer"
|
||||
}
|
||||
},
|
||||
"additionalProperties": false,
|
||||
|
|
2
docs/_static/llama-stack-spec.yaml
vendored
2
docs/_static/llama-stack-spec.yaml
vendored
|
@ -5149,6 +5149,8 @@ components:
|
|||
type: array
|
||||
items:
|
||||
$ref: '#/components/schemas/OpenAIResponseInputTool'
|
||||
max_infer_iters:
|
||||
type: integer
|
||||
additionalProperties: false
|
||||
required:
|
||||
- input
|
||||
|
|
|
@ -604,6 +604,7 @@ class Agents(Protocol):
|
|||
stream: bool | None = False,
|
||||
temperature: float | None = None,
|
||||
tools: list[OpenAIResponseInputTool] | None = None,
|
||||
max_infer_iters: int | None = 10, # this is an extension to the OpenAI API
|
||||
) -> OpenAIResponseObject | AsyncIterator[OpenAIResponseObjectStream]:
|
||||
"""Create a new OpenAI response.
|
||||
|
||||
|
|
|
@ -325,9 +325,10 @@ class MetaReferenceAgentsImpl(Agents):
|
|||
stream: bool | None = False,
|
||||
temperature: float | None = None,
|
||||
tools: list[OpenAIResponseInputTool] | None = None,
|
||||
max_infer_iters: int | None = 10,
|
||||
) -> OpenAIResponseObject:
|
||||
return await self.openai_responses_impl.create_openai_response(
|
||||
input, model, instructions, previous_response_id, store, stream, temperature, tools
|
||||
input, model, instructions, previous_response_id, store, stream, temperature, tools, max_infer_iters
|
||||
)
|
||||
|
||||
async def list_openai_responses(
|
||||
|
|
|
@ -258,6 +258,18 @@ class OpenAIResponsesImpl:
|
|||
"""
|
||||
return await self.responses_store.list_response_input_items(response_id, after, before, include, limit, order)
|
||||
|
||||
def _is_function_tool_call(
|
||||
self,
|
||||
tool_call: OpenAIChatCompletionToolCall,
|
||||
tools: list[OpenAIResponseInputTool],
|
||||
) -> bool:
|
||||
if not tool_call.function:
|
||||
return False
|
||||
for t in tools:
|
||||
if t.type == "function" and t.name == tool_call.function.name:
|
||||
return True
|
||||
return False
|
||||
|
||||
async def _process_response_choices(
|
||||
self,
|
||||
chat_response: OpenAIChatCompletion,
|
||||
|
@ -270,7 +282,7 @@ class OpenAIResponsesImpl:
|
|||
for choice in chat_response.choices:
|
||||
if choice.message.tool_calls and tools:
|
||||
# Assume if the first tool is a function, all tools are functions
|
||||
if tools[0].type == "function":
|
||||
if self._is_function_tool_call(choice.message.tool_calls[0], tools):
|
||||
for tool_call in choice.message.tool_calls:
|
||||
output_messages.append(
|
||||
OpenAIResponseOutputMessageFunctionToolCall(
|
||||
|
@ -332,6 +344,7 @@ class OpenAIResponsesImpl:
|
|||
stream: bool | None = False,
|
||||
temperature: float | None = None,
|
||||
tools: list[OpenAIResponseInputTool] | None = None,
|
||||
max_infer_iters: int | None = 10,
|
||||
):
|
||||
stream = False if stream is None else stream
|
||||
|
||||
|
@ -358,58 +371,100 @@ class OpenAIResponsesImpl:
|
|||
temperature=temperature,
|
||||
)
|
||||
|
||||
inference_result = await self.inference_api.openai_chat_completion(
|
||||
model=model,
|
||||
messages=messages,
|
||||
tools=chat_tools,
|
||||
stream=stream,
|
||||
temperature=temperature,
|
||||
)
|
||||
|
||||
# Fork to streaming vs non-streaming - let each handle ALL inference rounds
|
||||
if stream:
|
||||
return self._create_streaming_response(
|
||||
inference_result=inference_result,
|
||||
ctx=ctx,
|
||||
output_messages=output_messages,
|
||||
input=input,
|
||||
model=model,
|
||||
store=store,
|
||||
tools=tools,
|
||||
max_infer_iters=max_infer_iters,
|
||||
)
|
||||
else:
|
||||
return await self._create_non_streaming_response(
|
||||
inference_result=inference_result,
|
||||
ctx=ctx,
|
||||
output_messages=output_messages,
|
||||
input=input,
|
||||
model=model,
|
||||
store=store,
|
||||
tools=tools,
|
||||
max_infer_iters=max_infer_iters,
|
||||
)
|
||||
|
||||
async def _create_non_streaming_response(
|
||||
self,
|
||||
inference_result: Any,
|
||||
ctx: ChatCompletionContext,
|
||||
output_messages: list[OpenAIResponseOutput],
|
||||
input: str | list[OpenAIResponseInput],
|
||||
model: str,
|
||||
store: bool | None,
|
||||
tools: list[OpenAIResponseInputTool] | None,
|
||||
max_infer_iters: int | None,
|
||||
) -> OpenAIResponseObject:
|
||||
chat_response = OpenAIChatCompletion(**inference_result.model_dump())
|
||||
# Implement tool execution loop - handle ALL inference rounds including the first
|
||||
n_iter = 0
|
||||
messages = ctx.messages.copy()
|
||||
current_response = None
|
||||
|
||||
# Process response choices (tool execution and message creation)
|
||||
output_messages.extend(
|
||||
await self._process_response_choices(
|
||||
chat_response=chat_response,
|
||||
ctx=ctx,
|
||||
tools=tools,
|
||||
while True:
|
||||
# Do inference (including the first one)
|
||||
inference_result = await self.inference_api.openai_chat_completion(
|
||||
model=ctx.model,
|
||||
messages=messages,
|
||||
tools=ctx.tools,
|
||||
stream=False,
|
||||
temperature=ctx.temperature,
|
||||
)
|
||||
)
|
||||
current_response = OpenAIChatCompletion(**inference_result.model_dump())
|
||||
|
||||
# Separate function vs non-function tool calls
|
||||
function_tool_calls = []
|
||||
non_function_tool_calls = []
|
||||
|
||||
for choice in current_response.choices:
|
||||
if choice.message.tool_calls and tools:
|
||||
for tool_call in choice.message.tool_calls:
|
||||
if self._is_function_tool_call(tool_call, tools):
|
||||
function_tool_calls.append(tool_call)
|
||||
else:
|
||||
non_function_tool_calls.append(tool_call)
|
||||
|
||||
# Process response choices based on tool call types
|
||||
if function_tool_calls:
|
||||
# For function tool calls, use existing logic and return immediately
|
||||
current_output_messages = await self._process_response_choices(
|
||||
chat_response=current_response,
|
||||
ctx=ctx,
|
||||
tools=tools,
|
||||
)
|
||||
output_messages.extend(current_output_messages)
|
||||
break
|
||||
elif non_function_tool_calls:
|
||||
# For non-function tool calls, execute them and continue loop
|
||||
for choice in current_response.choices:
|
||||
tool_outputs, tool_response_messages = await self._execute_tool_calls_only(choice, ctx)
|
||||
output_messages.extend(tool_outputs)
|
||||
|
||||
# Add assistant message and tool responses to messages for next iteration
|
||||
messages.append(choice.message)
|
||||
messages.extend(tool_response_messages)
|
||||
|
||||
n_iter += 1
|
||||
if n_iter >= (max_infer_iters or 10):
|
||||
break
|
||||
|
||||
# Continue with next iteration of the loop
|
||||
continue
|
||||
else:
|
||||
# No tool calls - convert response to message and we're done
|
||||
for choice in current_response.choices:
|
||||
output_messages.append(await _convert_chat_choice_to_response_message(choice))
|
||||
break
|
||||
|
||||
response = OpenAIResponseObject(
|
||||
created_at=chat_response.created,
|
||||
created_at=current_response.created,
|
||||
id=f"resp-{uuid.uuid4()}",
|
||||
model=model,
|
||||
object="response",
|
||||
|
@ -429,13 +484,13 @@ class OpenAIResponsesImpl:
|
|||
|
||||
async def _create_streaming_response(
|
||||
self,
|
||||
inference_result: Any,
|
||||
ctx: ChatCompletionContext,
|
||||
output_messages: list[OpenAIResponseOutput],
|
||||
input: str | list[OpenAIResponseInput],
|
||||
model: str,
|
||||
store: bool | None,
|
||||
tools: list[OpenAIResponseInputTool] | None,
|
||||
max_infer_iters: int | None,
|
||||
) -> AsyncIterator[OpenAIResponseObjectStream]:
|
||||
# Create initial response and emit response.created immediately
|
||||
response_id = f"resp-{uuid.uuid4()}"
|
||||
|
@ -453,87 +508,135 @@ class OpenAIResponsesImpl:
|
|||
# Emit response.created immediately
|
||||
yield OpenAIResponseObjectStreamResponseCreated(response=initial_response)
|
||||
|
||||
# For streaming, inference_result is an async iterator of chunks
|
||||
# Stream chunks and emit delta events as they arrive
|
||||
chat_response_id = ""
|
||||
chat_response_content = []
|
||||
chat_response_tool_calls: dict[int, OpenAIChatCompletionToolCall] = {}
|
||||
chunk_created = 0
|
||||
chunk_model = ""
|
||||
chunk_finish_reason = ""
|
||||
sequence_number = 0
|
||||
# Implement tool execution loop for streaming - handle ALL inference rounds including the first
|
||||
n_iter = 0
|
||||
messages = ctx.messages.copy()
|
||||
|
||||
# Create a placeholder message item for delta events
|
||||
message_item_id = f"msg_{uuid.uuid4()}"
|
||||
|
||||
async for chunk in inference_result:
|
||||
chat_response_id = chunk.id
|
||||
chunk_created = chunk.created
|
||||
chunk_model = chunk.model
|
||||
for chunk_choice in chunk.choices:
|
||||
# Emit incremental text content as delta events
|
||||
if chunk_choice.delta.content:
|
||||
sequence_number += 1
|
||||
yield OpenAIResponseObjectStreamResponseOutputTextDelta(
|
||||
content_index=0,
|
||||
delta=chunk_choice.delta.content,
|
||||
item_id=message_item_id,
|
||||
output_index=0,
|
||||
sequence_number=sequence_number,
|
||||
)
|
||||
|
||||
# Collect content for final response
|
||||
chat_response_content.append(chunk_choice.delta.content or "")
|
||||
if chunk_choice.finish_reason:
|
||||
chunk_finish_reason = chunk_choice.finish_reason
|
||||
|
||||
# Aggregate tool call arguments across chunks, using their index as the aggregation key
|
||||
if chunk_choice.delta.tool_calls:
|
||||
for tool_call in chunk_choice.delta.tool_calls:
|
||||
response_tool_call = chat_response_tool_calls.get(tool_call.index, None)
|
||||
if response_tool_call:
|
||||
# Don't attempt to concatenate arguments if we don't have any new arguments
|
||||
if tool_call.function.arguments:
|
||||
# Guard against an initial None argument before we concatenate
|
||||
response_tool_call.function.arguments = (
|
||||
response_tool_call.function.arguments or ""
|
||||
) + tool_call.function.arguments
|
||||
else:
|
||||
tool_call_dict: dict[str, Any] = tool_call.model_dump()
|
||||
tool_call_dict.pop("type", None)
|
||||
response_tool_call = OpenAIChatCompletionToolCall(**tool_call_dict)
|
||||
chat_response_tool_calls[tool_call.index] = response_tool_call
|
||||
|
||||
# Convert collected chunks to complete response
|
||||
if chat_response_tool_calls:
|
||||
tool_calls = [chat_response_tool_calls[i] for i in sorted(chat_response_tool_calls.keys())]
|
||||
else:
|
||||
tool_calls = None
|
||||
assistant_message = OpenAIAssistantMessageParam(
|
||||
content="".join(chat_response_content),
|
||||
tool_calls=tool_calls,
|
||||
)
|
||||
chat_response_obj = OpenAIChatCompletion(
|
||||
id=chat_response_id,
|
||||
choices=[
|
||||
OpenAIChoice(
|
||||
message=assistant_message,
|
||||
finish_reason=chunk_finish_reason,
|
||||
index=0,
|
||||
)
|
||||
],
|
||||
created=chunk_created,
|
||||
model=chunk_model,
|
||||
)
|
||||
|
||||
# Process response choices (tool execution and message creation)
|
||||
output_messages.extend(
|
||||
await self._process_response_choices(
|
||||
chat_response=chat_response_obj,
|
||||
ctx=ctx,
|
||||
tools=tools,
|
||||
while True:
|
||||
# Do inference (including the first one) - streaming
|
||||
current_inference_result = await self.inference_api.openai_chat_completion(
|
||||
model=ctx.model,
|
||||
messages=messages,
|
||||
tools=ctx.tools,
|
||||
stream=True,
|
||||
temperature=ctx.temperature,
|
||||
)
|
||||
)
|
||||
|
||||
# Process streaming chunks and build complete response
|
||||
chat_response_id = ""
|
||||
chat_response_content = []
|
||||
chat_response_tool_calls: dict[int, OpenAIChatCompletionToolCall] = {}
|
||||
chunk_created = 0
|
||||
chunk_model = ""
|
||||
chunk_finish_reason = ""
|
||||
sequence_number = 0
|
||||
|
||||
# Create a placeholder message item for delta events
|
||||
message_item_id = f"msg_{uuid.uuid4()}"
|
||||
|
||||
async for chunk in current_inference_result:
|
||||
chat_response_id = chunk.id
|
||||
chunk_created = chunk.created
|
||||
chunk_model = chunk.model
|
||||
for chunk_choice in chunk.choices:
|
||||
# Emit incremental text content as delta events
|
||||
if chunk_choice.delta.content:
|
||||
sequence_number += 1
|
||||
yield OpenAIResponseObjectStreamResponseOutputTextDelta(
|
||||
content_index=0,
|
||||
delta=chunk_choice.delta.content,
|
||||
item_id=message_item_id,
|
||||
output_index=0,
|
||||
sequence_number=sequence_number,
|
||||
)
|
||||
|
||||
# Collect content for final response
|
||||
chat_response_content.append(chunk_choice.delta.content or "")
|
||||
if chunk_choice.finish_reason:
|
||||
chunk_finish_reason = chunk_choice.finish_reason
|
||||
|
||||
# Aggregate tool call arguments across chunks
|
||||
if chunk_choice.delta.tool_calls:
|
||||
for tool_call in chunk_choice.delta.tool_calls:
|
||||
response_tool_call = chat_response_tool_calls.get(tool_call.index, None)
|
||||
if response_tool_call:
|
||||
# Don't attempt to concatenate arguments if we don't have any new argumentsAdd commentMore actions
|
||||
if tool_call.function.arguments:
|
||||
# Guard against an initial None argument before we concatenate
|
||||
response_tool_call.function.arguments = (
|
||||
response_tool_call.function.arguments or ""
|
||||
) + tool_call.function.arguments
|
||||
else:
|
||||
tool_call_dict: dict[str, Any] = tool_call.model_dump()
|
||||
tool_call_dict.pop("type", None)
|
||||
response_tool_call = OpenAIChatCompletionToolCall(**tool_call_dict)
|
||||
chat_response_tool_calls[tool_call.index] = response_tool_call
|
||||
|
||||
# Convert collected chunks to complete response
|
||||
if chat_response_tool_calls:
|
||||
tool_calls = [chat_response_tool_calls[i] for i in sorted(chat_response_tool_calls.keys())]
|
||||
else:
|
||||
tool_calls = None
|
||||
assistant_message = OpenAIAssistantMessageParam(
|
||||
content="".join(chat_response_content),
|
||||
tool_calls=tool_calls,
|
||||
)
|
||||
current_response = OpenAIChatCompletion(
|
||||
id=chat_response_id,
|
||||
choices=[
|
||||
OpenAIChoice(
|
||||
message=assistant_message,
|
||||
finish_reason=chunk_finish_reason,
|
||||
index=0,
|
||||
)
|
||||
],
|
||||
created=chunk_created,
|
||||
model=chunk_model,
|
||||
)
|
||||
|
||||
# Separate function vs non-function tool calls
|
||||
function_tool_calls = []
|
||||
non_function_tool_calls = []
|
||||
|
||||
for choice in current_response.choices:
|
||||
if choice.message.tool_calls and tools:
|
||||
for tool_call in choice.message.tool_calls:
|
||||
if self._is_function_tool_call(tool_call, tools):
|
||||
function_tool_calls.append(tool_call)
|
||||
else:
|
||||
non_function_tool_calls.append(tool_call)
|
||||
|
||||
# Process response choices based on tool call types
|
||||
if function_tool_calls:
|
||||
# For function tool calls, use existing logic and break
|
||||
current_output_messages = await self._process_response_choices(
|
||||
chat_response=current_response,
|
||||
ctx=ctx,
|
||||
tools=tools,
|
||||
)
|
||||
output_messages.extend(current_output_messages)
|
||||
break
|
||||
elif non_function_tool_calls:
|
||||
# For non-function tool calls, execute them and continue loop
|
||||
for choice in current_response.choices:
|
||||
tool_outputs, tool_response_messages = await self._execute_tool_calls_only(choice, ctx)
|
||||
output_messages.extend(tool_outputs)
|
||||
|
||||
# Add assistant message and tool responses to messages for next iteration
|
||||
messages.append(choice.message)
|
||||
messages.extend(tool_response_messages)
|
||||
|
||||
n_iter += 1
|
||||
if n_iter >= (max_infer_iters or 10):
|
||||
break
|
||||
|
||||
# Continue with next iteration of the loop
|
||||
continue
|
||||
else:
|
||||
# No tool calls - convert response to message and we're done
|
||||
for choice in current_response.choices:
|
||||
output_messages.append(await _convert_chat_choice_to_response_message(choice))
|
||||
break
|
||||
|
||||
# Create final response
|
||||
final_response = OpenAIResponseObject(
|
||||
|
@ -646,6 +749,30 @@ class OpenAIResponsesImpl:
|
|||
raise ValueError(f"Llama Stack OpenAI Responses does not yet support tool type: {input_tool.type}")
|
||||
return chat_tools, mcp_tool_to_server, mcp_list_message
|
||||
|
||||
async def _execute_tool_calls_only(
|
||||
self,
|
||||
choice: OpenAIChoice,
|
||||
ctx: ChatCompletionContext,
|
||||
) -> tuple[list[OpenAIResponseOutput], list[OpenAIMessageParam]]:
|
||||
"""Execute tool calls and return output messages and tool response messages for next inference."""
|
||||
output_messages: list[OpenAIResponseOutput] = []
|
||||
tool_response_messages: list[OpenAIMessageParam] = []
|
||||
|
||||
if not isinstance(choice.message, OpenAIAssistantMessageParam):
|
||||
return output_messages, tool_response_messages
|
||||
|
||||
if not choice.message.tool_calls:
|
||||
return output_messages, tool_response_messages
|
||||
|
||||
for tool_call in choice.message.tool_calls:
|
||||
tool_call_log, further_input = await self._execute_tool_call(tool_call, ctx)
|
||||
if tool_call_log:
|
||||
output_messages.append(tool_call_log)
|
||||
if further_input:
|
||||
tool_response_messages.append(further_input)
|
||||
|
||||
return output_messages, tool_response_messages
|
||||
|
||||
async def _execute_tool_and_return_final_output(
|
||||
self,
|
||||
choice: OpenAIChoice,
|
||||
|
@ -772,5 +899,8 @@ class OpenAIResponsesImpl:
|
|||
else:
|
||||
raise ValueError(f"Unknown result content type: {type(result.content)}")
|
||||
input_message = OpenAIToolMessageParam(content=content, tool_call_id=tool_call_id)
|
||||
else:
|
||||
text = str(error_exc)
|
||||
input_message = OpenAIToolMessageParam(content=text, tool_call_id=tool_call_id)
|
||||
|
||||
return message, input_message
|
||||
|
|
|
@ -5,6 +5,7 @@
|
|||
# the root directory of this source tree.
|
||||
|
||||
# we want the mcp server to be authenticated OR not, depends
|
||||
from collections.abc import Callable
|
||||
from contextlib import contextmanager
|
||||
|
||||
# Unfortunately the toolgroup id must be tied to the tool names because the registry
|
||||
|
@ -13,15 +14,158 @@ from contextlib import contextmanager
|
|||
MCP_TOOLGROUP_ID = "mcp::localmcp"
|
||||
|
||||
|
||||
def default_tools():
|
||||
"""Default tools for backward compatibility."""
|
||||
from mcp import types
|
||||
from mcp.server.fastmcp import Context
|
||||
|
||||
async def greet_everyone(
|
||||
url: str, ctx: Context
|
||||
) -> list[types.TextContent | types.ImageContent | types.EmbeddedResource]:
|
||||
return [types.TextContent(type="text", text="Hello, world!")]
|
||||
|
||||
async def get_boiling_point(liquid_name: str, celsius: bool = True) -> int:
|
||||
"""
|
||||
Returns the boiling point of a liquid in Celsius or Fahrenheit.
|
||||
|
||||
:param liquid_name: The name of the liquid
|
||||
:param celsius: Whether to return the boiling point in Celsius
|
||||
:return: The boiling point of the liquid in Celcius or Fahrenheit
|
||||
"""
|
||||
if liquid_name.lower() == "myawesomeliquid":
|
||||
if celsius:
|
||||
return -100
|
||||
else:
|
||||
return -212
|
||||
else:
|
||||
return -1
|
||||
|
||||
return {"greet_everyone": greet_everyone, "get_boiling_point": get_boiling_point}
|
||||
|
||||
|
||||
def dependency_tools():
|
||||
"""Tools with natural dependencies for multi-turn testing."""
|
||||
from mcp import types
|
||||
from mcp.server.fastmcp import Context
|
||||
|
||||
async def get_user_id(username: str, ctx: Context) -> str:
|
||||
"""
|
||||
Get the user ID for a given username. This ID is needed for other operations.
|
||||
|
||||
:param username: The username to look up
|
||||
:return: The user ID for the username
|
||||
"""
|
||||
# Simple mapping for testing
|
||||
user_mapping = {"alice": "user_12345", "bob": "user_67890", "charlie": "user_11111", "admin": "user_00000"}
|
||||
return user_mapping.get(username.lower(), "user_99999")
|
||||
|
||||
async def get_user_permissions(user_id: str, ctx: Context) -> str:
|
||||
"""
|
||||
Get the permissions for a user ID. Requires a valid user ID from get_user_id.
|
||||
|
||||
:param user_id: The user ID to check permissions for
|
||||
:return: The permissions for the user
|
||||
"""
|
||||
# Permission mapping based on user IDs
|
||||
permission_mapping = {
|
||||
"user_12345": "read,write", # alice
|
||||
"user_67890": "read", # bob
|
||||
"user_11111": "admin", # charlie
|
||||
"user_00000": "superadmin", # admin
|
||||
"user_99999": "none", # unknown users
|
||||
}
|
||||
return permission_mapping.get(user_id, "none")
|
||||
|
||||
async def check_file_access(user_id: str, filename: str, ctx: Context) -> str:
|
||||
"""
|
||||
Check if a user can access a specific file. Requires a valid user ID.
|
||||
|
||||
:param user_id: The user ID to check access for
|
||||
:param filename: The filename to check access to
|
||||
:return: Whether the user can access the file (yes/no)
|
||||
"""
|
||||
# Get permissions first
|
||||
permission_mapping = {
|
||||
"user_12345": "read,write", # alice
|
||||
"user_67890": "read", # bob
|
||||
"user_11111": "admin", # charlie
|
||||
"user_00000": "superadmin", # admin
|
||||
"user_99999": "none", # unknown users
|
||||
}
|
||||
permissions = permission_mapping.get(user_id, "none")
|
||||
|
||||
# Check file access based on permissions and filename
|
||||
if permissions == "superadmin":
|
||||
access = "yes"
|
||||
elif permissions == "admin":
|
||||
access = "yes" if not filename.startswith("secret_") else "no"
|
||||
elif "write" in permissions:
|
||||
access = "yes" if filename.endswith(".txt") else "no"
|
||||
elif "read" in permissions:
|
||||
access = "yes" if filename.endswith(".txt") or filename.endswith(".md") else "no"
|
||||
else:
|
||||
access = "no"
|
||||
|
||||
return [types.TextContent(type="text", text=access)]
|
||||
|
||||
async def get_experiment_id(experiment_name: str, ctx: Context) -> str:
|
||||
"""
|
||||
Get the experiment ID for a given experiment name. This ID is needed to get results.
|
||||
|
||||
:param experiment_name: The name of the experiment
|
||||
:return: The experiment ID
|
||||
"""
|
||||
# Simple mapping for testing
|
||||
experiment_mapping = {
|
||||
"temperature_test": "exp_001",
|
||||
"pressure_test": "exp_002",
|
||||
"chemical_reaction": "exp_003",
|
||||
"boiling_point": "exp_004",
|
||||
}
|
||||
exp_id = experiment_mapping.get(experiment_name.lower(), "exp_999")
|
||||
return exp_id
|
||||
|
||||
async def get_experiment_results(experiment_id: str, ctx: Context) -> str:
|
||||
"""
|
||||
Get the results for an experiment ID. Requires a valid experiment ID from get_experiment_id.
|
||||
|
||||
:param experiment_id: The experiment ID to get results for
|
||||
:return: The experiment results
|
||||
"""
|
||||
# Results mapping based on experiment IDs
|
||||
results_mapping = {
|
||||
"exp_001": "Temperature: 25°C, Status: Success",
|
||||
"exp_002": "Pressure: 1.2 atm, Status: Success",
|
||||
"exp_003": "Yield: 85%, Status: Complete",
|
||||
"exp_004": "Boiling Point: 100°C, Status: Verified",
|
||||
"exp_999": "No results found",
|
||||
}
|
||||
results = results_mapping.get(experiment_id, "Invalid experiment ID")
|
||||
return results
|
||||
|
||||
return {
|
||||
"get_user_id": get_user_id,
|
||||
"get_user_permissions": get_user_permissions,
|
||||
"check_file_access": check_file_access,
|
||||
"get_experiment_id": get_experiment_id,
|
||||
"get_experiment_results": get_experiment_results,
|
||||
}
|
||||
|
||||
|
||||
@contextmanager
|
||||
def make_mcp_server(required_auth_token: str | None = None):
|
||||
def make_mcp_server(required_auth_token: str | None = None, tools: dict[str, Callable] | None = None):
|
||||
"""
|
||||
Create an MCP server with the specified tools.
|
||||
|
||||
:param required_auth_token: Optional auth token required for access
|
||||
:param tools: Dictionary of tool_name -> tool_function. If None, uses default tools.
|
||||
"""
|
||||
import threading
|
||||
import time
|
||||
|
||||
import httpx
|
||||
import uvicorn
|
||||
from mcp import types
|
||||
from mcp.server.fastmcp import Context, FastMCP
|
||||
from mcp.server.fastmcp import FastMCP
|
||||
from mcp.server.sse import SseServerTransport
|
||||
from starlette.applications import Starlette
|
||||
from starlette.responses import Response
|
||||
|
@ -29,35 +173,18 @@ def make_mcp_server(required_auth_token: str | None = None):
|
|||
|
||||
server = FastMCP("FastMCP Test Server", log_level="WARNING")
|
||||
|
||||
@server.tool()
|
||||
async def greet_everyone(
|
||||
url: str, ctx: Context
|
||||
) -> list[types.TextContent | types.ImageContent | types.EmbeddedResource]:
|
||||
return [types.TextContent(type="text", text="Hello, world!")]
|
||||
tools = tools or default_tools()
|
||||
|
||||
@server.tool()
|
||||
async def get_boiling_point(liquid_name: str, celcius: bool = True) -> int:
|
||||
"""
|
||||
Returns the boiling point of a liquid in Celcius or Fahrenheit.
|
||||
|
||||
:param liquid_name: The name of the liquid
|
||||
:param celcius: Whether to return the boiling point in Celcius
|
||||
:return: The boiling point of the liquid in Celcius or Fahrenheit
|
||||
"""
|
||||
if liquid_name.lower() == "polyjuice":
|
||||
if celcius:
|
||||
return -100
|
||||
else:
|
||||
return -212
|
||||
else:
|
||||
return -1
|
||||
# Register all tools with the server
|
||||
for tool_func in tools.values():
|
||||
server.tool()(tool_func)
|
||||
|
||||
sse = SseServerTransport("/messages/")
|
||||
|
||||
async def handle_sse(request):
|
||||
from starlette.exceptions import HTTPException
|
||||
|
||||
auth_header = request.headers.get("Authorization")
|
||||
auth_header: str | None = request.headers.get("Authorization")
|
||||
auth_token = None
|
||||
if auth_header and auth_header.startswith("Bearer "):
|
||||
auth_token = auth_header.split(" ")[1]
|
||||
|
|
|
@ -224,16 +224,16 @@ async def test_create_openai_response_with_tool_call_type_none(openai_responses_
|
|||
],
|
||||
)
|
||||
|
||||
# Verify
|
||||
# Check that we got the content from our mocked tool execution result
|
||||
chunks = [chunk async for chunk in result]
|
||||
assert len(chunks) == 2 # Should have response.created and response.completed
|
||||
|
||||
# Verify inference API was called correctly (after iterating over result)
|
||||
first_call = mock_inference_api.openai_chat_completion.call_args_list[0]
|
||||
assert first_call.kwargs["messages"][0].content == input_text
|
||||
assert first_call.kwargs["tools"] is not None
|
||||
assert first_call.kwargs["temperature"] == 0.1
|
||||
|
||||
# Check that we got the content from our mocked tool execution result
|
||||
chunks = [chunk async for chunk in result]
|
||||
assert len(chunks) == 2 # Should have response.created and response.completed
|
||||
|
||||
# Check response.created event (should have empty output)
|
||||
assert chunks[0].type == "response.created"
|
||||
assert len(chunks[0].response.output) == 0
|
||||
|
|
|
@ -36,7 +36,7 @@ test_response_mcp_tool:
|
|||
test_params:
|
||||
case:
|
||||
- case_id: "boiling_point_tool"
|
||||
input: "What is the boiling point of polyjuice?"
|
||||
input: "What is the boiling point of myawesomeliquid in Celsius?"
|
||||
tools:
|
||||
- type: mcp
|
||||
server_label: "localmcp"
|
||||
|
@ -94,3 +94,43 @@ test_response_multi_turn_image:
|
|||
output: "llama"
|
||||
- input: "What country do you find this animal primarily in? What continent?"
|
||||
output: "peru"
|
||||
|
||||
test_response_multi_turn_tool_execution:
|
||||
test_name: test_response_multi_turn_tool_execution
|
||||
test_params:
|
||||
case:
|
||||
- case_id: "user_file_access_check"
|
||||
input: "I need to check if user 'alice' can access the file 'document.txt'. First, get alice's user ID, then check if that user ID can access the file 'document.txt'. Do this as a series of steps, where each step is a separate message. Return only one tool call per step. Summarize the final result with a single 'yes' or 'no' response."
|
||||
tools:
|
||||
- type: mcp
|
||||
server_label: "localmcp"
|
||||
server_url: "<FILLED_BY_TEST_RUNNER>"
|
||||
output: "yes"
|
||||
- case_id: "experiment_results_lookup"
|
||||
input: "I need to get the results for the 'boiling_point' experiment. First, get the experiment ID for 'boiling_point', then use that ID to get the experiment results. Tell me what you found."
|
||||
tools:
|
||||
- type: mcp
|
||||
server_label: "localmcp"
|
||||
server_url: "<FILLED_BY_TEST_RUNNER>"
|
||||
output: "100°C"
|
||||
|
||||
test_response_multi_turn_tool_execution_streaming:
|
||||
test_name: test_response_multi_turn_tool_execution_streaming
|
||||
test_params:
|
||||
case:
|
||||
- case_id: "user_permissions_workflow"
|
||||
input: "Help me with this security check: First, get the user ID for 'charlie', then get the permissions for that user ID, and finally check if that user can access 'secret_file.txt'. Stream your progress as you work through each step."
|
||||
tools:
|
||||
- type: mcp
|
||||
server_label: "localmcp"
|
||||
server_url: "<FILLED_BY_TEST_RUNNER>"
|
||||
stream: true
|
||||
output: "no"
|
||||
- case_id: "experiment_analysis_streaming"
|
||||
input: "I need a complete analysis: First, get the experiment ID for 'chemical_reaction', then get the results for that experiment, and tell me if the yield was above 80%. Please stream your analysis process."
|
||||
tools:
|
||||
- type: mcp
|
||||
server_label: "localmcp"
|
||||
server_url: "<FILLED_BY_TEST_RUNNER>"
|
||||
stream: true
|
||||
output: "85%"
|
||||
|
|
|
@ -12,7 +12,7 @@ import pytest
|
|||
|
||||
from llama_stack import LlamaStackAsLibraryClient
|
||||
from llama_stack.distribution.datatypes import AuthenticationRequiredError
|
||||
from tests.common.mcp import make_mcp_server
|
||||
from tests.common.mcp import dependency_tools, make_mcp_server
|
||||
from tests.verifications.openai_api.fixtures.fixtures import (
|
||||
case_id_generator,
|
||||
get_base_test_name,
|
||||
|
@ -280,6 +280,7 @@ def test_response_non_streaming_mcp_tool(request, openai_client, model, provider
|
|||
tools=tools,
|
||||
stream=False,
|
||||
)
|
||||
|
||||
assert len(response.output) >= 3
|
||||
list_tools = response.output[0]
|
||||
assert list_tools.type == "mcp_list_tools"
|
||||
|
@ -290,11 +291,12 @@ def test_response_non_streaming_mcp_tool(request, openai_client, model, provider
|
|||
call = response.output[1]
|
||||
assert call.type == "mcp_call"
|
||||
assert call.name == "get_boiling_point"
|
||||
assert json.loads(call.arguments) == {"liquid_name": "polyjuice", "celcius": True}
|
||||
assert json.loads(call.arguments) == {"liquid_name": "myawesomeliquid", "celsius": True}
|
||||
assert call.error is None
|
||||
assert "-100" in call.output
|
||||
|
||||
message = response.output[2]
|
||||
# sometimes the model will call the tool again, so we need to get the last message
|
||||
message = response.output[-1]
|
||||
text_content = message.content[0].text
|
||||
assert "boiling point" in text_content.lower()
|
||||
|
||||
|
@ -393,3 +395,154 @@ def test_response_non_streaming_multi_turn_image(request, openai_client, model,
|
|||
previous_response_id = response.id
|
||||
output_text = response.output_text.lower()
|
||||
assert turn["output"].lower() in output_text
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"case",
|
||||
responses_test_cases["test_response_multi_turn_tool_execution"]["test_params"]["case"],
|
||||
ids=case_id_generator,
|
||||
)
|
||||
def test_response_non_streaming_multi_turn_tool_execution(
|
||||
request, openai_client, model, provider, verification_config, case
|
||||
):
|
||||
"""Test multi-turn tool execution where multiple MCP tool calls are performed in sequence."""
|
||||
test_name_base = get_base_test_name(request)
|
||||
if should_skip_test(verification_config, provider, model, test_name_base):
|
||||
pytest.skip(f"Skipping {test_name_base} for model {model} on provider {provider} based on config.")
|
||||
|
||||
with make_mcp_server(tools=dependency_tools()) as mcp_server_info:
|
||||
tools = case["tools"]
|
||||
# Replace the placeholder URL with the actual server URL
|
||||
for tool in tools:
|
||||
if tool["type"] == "mcp" and tool["server_url"] == "<FILLED_BY_TEST_RUNNER>":
|
||||
tool["server_url"] = mcp_server_info["server_url"]
|
||||
|
||||
response = openai_client.responses.create(
|
||||
input=case["input"],
|
||||
model=model,
|
||||
tools=tools,
|
||||
)
|
||||
|
||||
# Verify we have MCP tool calls in the output
|
||||
mcp_list_tools = [output for output in response.output if output.type == "mcp_list_tools"]
|
||||
mcp_calls = [output for output in response.output if output.type == "mcp_call"]
|
||||
message_outputs = [output for output in response.output if output.type == "message"]
|
||||
|
||||
# Should have exactly 1 MCP list tools message (at the beginning)
|
||||
assert len(mcp_list_tools) == 1, f"Expected exactly 1 mcp_list_tools, got {len(mcp_list_tools)}"
|
||||
assert mcp_list_tools[0].server_label == "localmcp"
|
||||
assert len(mcp_list_tools[0].tools) == 5 # Updated for dependency tools
|
||||
expected_tool_names = {
|
||||
"get_user_id",
|
||||
"get_user_permissions",
|
||||
"check_file_access",
|
||||
"get_experiment_id",
|
||||
"get_experiment_results",
|
||||
}
|
||||
assert {t["name"] for t in mcp_list_tools[0].tools} == expected_tool_names
|
||||
|
||||
assert len(mcp_calls) >= 1, f"Expected at least 1 mcp_call, got {len(mcp_calls)}"
|
||||
for mcp_call in mcp_calls:
|
||||
assert mcp_call.error is None, f"MCP call should not have errors, got: {mcp_call.error}"
|
||||
|
||||
assert len(message_outputs) >= 1, f"Expected at least 1 message output, got {len(message_outputs)}"
|
||||
|
||||
final_message = message_outputs[-1]
|
||||
assert final_message.role == "assistant", f"Final message should be from assistant, got {final_message.role}"
|
||||
assert final_message.status == "completed", f"Final message should be completed, got {final_message.status}"
|
||||
assert len(final_message.content) > 0, "Final message should have content"
|
||||
|
||||
expected_output = case["output"]
|
||||
assert expected_output.lower() in response.output_text.lower(), (
|
||||
f"Expected '{expected_output}' to appear in response: {response.output_text}"
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"case",
|
||||
responses_test_cases["test_response_multi_turn_tool_execution_streaming"]["test_params"]["case"],
|
||||
ids=case_id_generator,
|
||||
)
|
||||
async def test_response_streaming_multi_turn_tool_execution(
|
||||
request, openai_client, model, provider, verification_config, case
|
||||
):
|
||||
"""Test streaming multi-turn tool execution where multiple MCP tool calls are performed in sequence."""
|
||||
test_name_base = get_base_test_name(request)
|
||||
if should_skip_test(verification_config, provider, model, test_name_base):
|
||||
pytest.skip(f"Skipping {test_name_base} for model {model} on provider {provider} based on config.")
|
||||
|
||||
with make_mcp_server(tools=dependency_tools()) as mcp_server_info:
|
||||
tools = case["tools"]
|
||||
# Replace the placeholder URL with the actual server URL
|
||||
for tool in tools:
|
||||
if tool["type"] == "mcp" and tool["server_url"] == "<FILLED_BY_TEST_RUNNER>":
|
||||
tool["server_url"] = mcp_server_info["server_url"]
|
||||
|
||||
stream = openai_client.responses.create(
|
||||
input=case["input"],
|
||||
model=model,
|
||||
tools=tools,
|
||||
stream=True,
|
||||
)
|
||||
|
||||
chunks = []
|
||||
async for chunk in stream:
|
||||
chunks.append(chunk)
|
||||
|
||||
# Should have at least response.created and response.completed
|
||||
assert len(chunks) >= 2, f"Expected at least 2 chunks (created + completed), got {len(chunks)}"
|
||||
|
||||
# First chunk should be response.created
|
||||
assert chunks[0].type == "response.created", f"First chunk should be response.created, got {chunks[0].type}"
|
||||
|
||||
# Last chunk should be response.completed
|
||||
assert chunks[-1].type == "response.completed", (
|
||||
f"Last chunk should be response.completed, got {chunks[-1].type}"
|
||||
)
|
||||
|
||||
# Get the final response from the last chunk
|
||||
final_chunk = chunks[-1]
|
||||
if hasattr(final_chunk, "response"):
|
||||
final_response = final_chunk.response
|
||||
|
||||
# Verify multi-turn MCP tool execution results
|
||||
mcp_list_tools = [output for output in final_response.output if output.type == "mcp_list_tools"]
|
||||
mcp_calls = [output for output in final_response.output if output.type == "mcp_call"]
|
||||
message_outputs = [output for output in final_response.output if output.type == "message"]
|
||||
|
||||
# Should have exactly 1 MCP list tools message (at the beginning)
|
||||
assert len(mcp_list_tools) == 1, f"Expected exactly 1 mcp_list_tools, got {len(mcp_list_tools)}"
|
||||
assert mcp_list_tools[0].server_label == "localmcp"
|
||||
assert len(mcp_list_tools[0].tools) == 5 # Updated for dependency tools
|
||||
expected_tool_names = {
|
||||
"get_user_id",
|
||||
"get_user_permissions",
|
||||
"check_file_access",
|
||||
"get_experiment_id",
|
||||
"get_experiment_results",
|
||||
}
|
||||
assert {t["name"] for t in mcp_list_tools[0].tools} == expected_tool_names
|
||||
|
||||
# Should have at least 1 MCP call (the model should call at least one tool)
|
||||
assert len(mcp_calls) >= 1, f"Expected at least 1 mcp_call, got {len(mcp_calls)}"
|
||||
|
||||
# All MCP calls should be completed (verifies our tool execution works)
|
||||
for mcp_call in mcp_calls:
|
||||
assert mcp_call.error is None, f"MCP call should not have errors, got: {mcp_call.error}"
|
||||
|
||||
# Should have at least one final message response
|
||||
assert len(message_outputs) >= 1, f"Expected at least 1 message output, got {len(message_outputs)}"
|
||||
|
||||
# Final message should be from assistant and completed
|
||||
final_message = message_outputs[-1]
|
||||
assert final_message.role == "assistant", (
|
||||
f"Final message should be from assistant, got {final_message.role}"
|
||||
)
|
||||
assert final_message.status == "completed", f"Final message should be completed, got {final_message.status}"
|
||||
assert len(final_message.content) > 0, "Final message should have content"
|
||||
|
||||
# Check that the expected output appears in the response
|
||||
expected_output = case["output"]
|
||||
assert expected_output.lower() in final_response.output_text.lower(), (
|
||||
f"Expected '{expected_output}' to appear in response: {final_response.output_text}"
|
||||
)
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue