mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-06-28 02:53:30 +00:00
feat(responses): implement full multi-turn support
This commit is contained in:
parent
6bb174bb05
commit
fd15a6832c
5 changed files with 577 additions and 125 deletions
|
@ -258,6 +258,19 @@ class OpenAIResponsesImpl:
|
|||
"""
|
||||
return await self.responses_store.list_response_input_items(response_id, after, before, include, limit, order)
|
||||
|
||||
def _is_function_tool_call(
|
||||
self,
|
||||
tool_call: OpenAIChatCompletionToolCall,
|
||||
tools: list[OpenAIResponseInputTool] | None,
|
||||
) -> bool:
|
||||
"""Check if a tool call is a function tool call (client-side) vs non-function (server-side)."""
|
||||
if not tools:
|
||||
return False
|
||||
|
||||
# If the first tool is a function, assume all tools are functions
|
||||
# This matches the logic in _process_response_choices
|
||||
return tools[0].type == "function"
|
||||
|
||||
async def _process_response_choices(
|
||||
self,
|
||||
chat_response: OpenAIChatCompletion,
|
||||
|
@ -270,7 +283,7 @@ class OpenAIResponsesImpl:
|
|||
for choice in chat_response.choices:
|
||||
if choice.message.tool_calls and tools:
|
||||
# Assume if the first tool is a function, all tools are functions
|
||||
if tools[0].type == "function":
|
||||
if self._is_function_tool_call(choice.message.tool_calls[0], tools):
|
||||
for tool_call in choice.message.tool_calls:
|
||||
output_messages.append(
|
||||
OpenAIResponseOutputMessageFunctionToolCall(
|
||||
|
@ -332,6 +345,7 @@ class OpenAIResponsesImpl:
|
|||
stream: bool | None = False,
|
||||
temperature: float | None = None,
|
||||
tools: list[OpenAIResponseInputTool] | None = None,
|
||||
max_infer_iters: int | None = 10,
|
||||
):
|
||||
stream = False if stream is None else stream
|
||||
|
||||
|
@ -358,58 +372,100 @@ class OpenAIResponsesImpl:
|
|||
temperature=temperature,
|
||||
)
|
||||
|
||||
inference_result = await self.inference_api.openai_chat_completion(
|
||||
model=model,
|
||||
messages=messages,
|
||||
tools=chat_tools,
|
||||
stream=stream,
|
||||
temperature=temperature,
|
||||
)
|
||||
|
||||
# Fork to streaming vs non-streaming - let each handle ALL inference rounds
|
||||
if stream:
|
||||
return self._create_streaming_response(
|
||||
inference_result=inference_result,
|
||||
ctx=ctx,
|
||||
output_messages=output_messages,
|
||||
input=input,
|
||||
model=model,
|
||||
store=store,
|
||||
tools=tools,
|
||||
max_infer_iters=max_infer_iters,
|
||||
)
|
||||
else:
|
||||
return await self._create_non_streaming_response(
|
||||
inference_result=inference_result,
|
||||
ctx=ctx,
|
||||
output_messages=output_messages,
|
||||
input=input,
|
||||
model=model,
|
||||
store=store,
|
||||
tools=tools,
|
||||
max_infer_iters=max_infer_iters,
|
||||
)
|
||||
|
||||
async def _create_non_streaming_response(
|
||||
self,
|
||||
inference_result: Any,
|
||||
ctx: ChatCompletionContext,
|
||||
output_messages: list[OpenAIResponseOutput],
|
||||
input: str | list[OpenAIResponseInput],
|
||||
model: str,
|
||||
store: bool | None,
|
||||
tools: list[OpenAIResponseInputTool] | None,
|
||||
max_infer_iters: int | None,
|
||||
) -> OpenAIResponseObject:
|
||||
chat_response = OpenAIChatCompletion(**inference_result.model_dump())
|
||||
# Implement tool execution loop - handle ALL inference rounds including the first
|
||||
n_iter = 0
|
||||
messages = ctx.messages.copy()
|
||||
current_response = None
|
||||
|
||||
# Process response choices (tool execution and message creation)
|
||||
output_messages.extend(
|
||||
await self._process_response_choices(
|
||||
chat_response=chat_response,
|
||||
ctx=ctx,
|
||||
tools=tools,
|
||||
while True:
|
||||
# Do inference (including the first one)
|
||||
inference_result = await self.inference_api.openai_chat_completion(
|
||||
model=ctx.model,
|
||||
messages=messages,
|
||||
tools=ctx.tools,
|
||||
stream=False,
|
||||
temperature=ctx.temperature,
|
||||
)
|
||||
)
|
||||
current_response = OpenAIChatCompletion(**inference_result.model_dump())
|
||||
|
||||
# Separate function vs non-function tool calls
|
||||
function_tool_calls = []
|
||||
non_function_tool_calls = []
|
||||
|
||||
for choice in current_response.choices:
|
||||
if choice.message.tool_calls and tools:
|
||||
for tool_call in choice.message.tool_calls:
|
||||
if self._is_function_tool_call(tool_call, tools):
|
||||
function_tool_calls.append(tool_call)
|
||||
else:
|
||||
non_function_tool_calls.append(tool_call)
|
||||
|
||||
# Process response choices based on tool call types
|
||||
if function_tool_calls:
|
||||
# For function tool calls, use existing logic and return immediately
|
||||
current_output_messages = await self._process_response_choices(
|
||||
chat_response=current_response,
|
||||
ctx=ctx,
|
||||
tools=tools,
|
||||
)
|
||||
output_messages.extend(current_output_messages)
|
||||
break
|
||||
elif non_function_tool_calls:
|
||||
# For non-function tool calls, execute them and continue loop
|
||||
for choice in current_response.choices:
|
||||
tool_outputs, tool_response_messages = await self._execute_tool_calls_only(choice, ctx)
|
||||
output_messages.extend(tool_outputs)
|
||||
|
||||
# Add assistant message and tool responses to messages for next iteration
|
||||
messages.append(choice.message)
|
||||
messages.extend(tool_response_messages)
|
||||
|
||||
n_iter += 1
|
||||
if n_iter >= (max_infer_iters or 10):
|
||||
break
|
||||
|
||||
# Continue with next iteration of the loop
|
||||
continue
|
||||
else:
|
||||
# No tool calls - convert response to message and we're done
|
||||
for choice in current_response.choices:
|
||||
output_messages.append(await _convert_chat_choice_to_response_message(choice))
|
||||
break
|
||||
|
||||
response = OpenAIResponseObject(
|
||||
created_at=chat_response.created,
|
||||
created_at=current_response.created,
|
||||
id=f"resp-{uuid.uuid4()}",
|
||||
model=model,
|
||||
object="response",
|
||||
|
@ -429,13 +485,13 @@ class OpenAIResponsesImpl:
|
|||
|
||||
async def _create_streaming_response(
|
||||
self,
|
||||
inference_result: Any,
|
||||
ctx: ChatCompletionContext,
|
||||
output_messages: list[OpenAIResponseOutput],
|
||||
input: str | list[OpenAIResponseInput],
|
||||
model: str,
|
||||
store: bool | None,
|
||||
tools: list[OpenAIResponseInputTool] | None,
|
||||
max_infer_iters: int | None,
|
||||
) -> AsyncIterator[OpenAIResponseObjectStream]:
|
||||
# Create initial response and emit response.created immediately
|
||||
response_id = f"resp-{uuid.uuid4()}"
|
||||
|
@ -453,87 +509,135 @@ class OpenAIResponsesImpl:
|
|||
# Emit response.created immediately
|
||||
yield OpenAIResponseObjectStreamResponseCreated(response=initial_response)
|
||||
|
||||
# For streaming, inference_result is an async iterator of chunks
|
||||
# Stream chunks and emit delta events as they arrive
|
||||
chat_response_id = ""
|
||||
chat_response_content = []
|
||||
chat_response_tool_calls: dict[int, OpenAIChatCompletionToolCall] = {}
|
||||
chunk_created = 0
|
||||
chunk_model = ""
|
||||
chunk_finish_reason = ""
|
||||
sequence_number = 0
|
||||
# Implement tool execution loop for streaming - handle ALL inference rounds including the first
|
||||
n_iter = 0
|
||||
messages = ctx.messages.copy()
|
||||
|
||||
# Create a placeholder message item for delta events
|
||||
message_item_id = f"msg_{uuid.uuid4()}"
|
||||
|
||||
async for chunk in inference_result:
|
||||
chat_response_id = chunk.id
|
||||
chunk_created = chunk.created
|
||||
chunk_model = chunk.model
|
||||
for chunk_choice in chunk.choices:
|
||||
# Emit incremental text content as delta events
|
||||
if chunk_choice.delta.content:
|
||||
sequence_number += 1
|
||||
yield OpenAIResponseObjectStreamResponseOutputTextDelta(
|
||||
content_index=0,
|
||||
delta=chunk_choice.delta.content,
|
||||
item_id=message_item_id,
|
||||
output_index=0,
|
||||
sequence_number=sequence_number,
|
||||
)
|
||||
|
||||
# Collect content for final response
|
||||
chat_response_content.append(chunk_choice.delta.content or "")
|
||||
if chunk_choice.finish_reason:
|
||||
chunk_finish_reason = chunk_choice.finish_reason
|
||||
|
||||
# Aggregate tool call arguments across chunks, using their index as the aggregation key
|
||||
if chunk_choice.delta.tool_calls:
|
||||
for tool_call in chunk_choice.delta.tool_calls:
|
||||
response_tool_call = chat_response_tool_calls.get(tool_call.index, None)
|
||||
if response_tool_call:
|
||||
# Don't attempt to concatenate arguments if we don't have any new arguments
|
||||
if tool_call.function.arguments:
|
||||
# Guard against an initial None argument before we concatenate
|
||||
response_tool_call.function.arguments = (
|
||||
response_tool_call.function.arguments or ""
|
||||
) + tool_call.function.arguments
|
||||
else:
|
||||
tool_call_dict: dict[str, Any] = tool_call.model_dump()
|
||||
tool_call_dict.pop("type", None)
|
||||
response_tool_call = OpenAIChatCompletionToolCall(**tool_call_dict)
|
||||
chat_response_tool_calls[tool_call.index] = response_tool_call
|
||||
|
||||
# Convert collected chunks to complete response
|
||||
if chat_response_tool_calls:
|
||||
tool_calls = [chat_response_tool_calls[i] for i in sorted(chat_response_tool_calls.keys())]
|
||||
else:
|
||||
tool_calls = None
|
||||
assistant_message = OpenAIAssistantMessageParam(
|
||||
content="".join(chat_response_content),
|
||||
tool_calls=tool_calls,
|
||||
)
|
||||
chat_response_obj = OpenAIChatCompletion(
|
||||
id=chat_response_id,
|
||||
choices=[
|
||||
OpenAIChoice(
|
||||
message=assistant_message,
|
||||
finish_reason=chunk_finish_reason,
|
||||
index=0,
|
||||
)
|
||||
],
|
||||
created=chunk_created,
|
||||
model=chunk_model,
|
||||
)
|
||||
|
||||
# Process response choices (tool execution and message creation)
|
||||
output_messages.extend(
|
||||
await self._process_response_choices(
|
||||
chat_response=chat_response_obj,
|
||||
ctx=ctx,
|
||||
tools=tools,
|
||||
while True:
|
||||
# Do inference (including the first one) - streaming
|
||||
current_inference_result = await self.inference_api.openai_chat_completion(
|
||||
model=ctx.model,
|
||||
messages=messages,
|
||||
tools=ctx.tools,
|
||||
stream=True,
|
||||
temperature=ctx.temperature,
|
||||
)
|
||||
)
|
||||
|
||||
# Process streaming chunks and build complete response
|
||||
chat_response_id = ""
|
||||
chat_response_content = []
|
||||
chat_response_tool_calls: dict[int, OpenAIChatCompletionToolCall] = {}
|
||||
chunk_created = 0
|
||||
chunk_model = ""
|
||||
chunk_finish_reason = ""
|
||||
sequence_number = 0
|
||||
|
||||
# Create a placeholder message item for delta events
|
||||
message_item_id = f"msg_{uuid.uuid4()}"
|
||||
|
||||
async for chunk in current_inference_result:
|
||||
chat_response_id = chunk.id
|
||||
chunk_created = chunk.created
|
||||
chunk_model = chunk.model
|
||||
for chunk_choice in chunk.choices:
|
||||
# Emit incremental text content as delta events
|
||||
if chunk_choice.delta.content:
|
||||
sequence_number += 1
|
||||
yield OpenAIResponseObjectStreamResponseOutputTextDelta(
|
||||
content_index=0,
|
||||
delta=chunk_choice.delta.content,
|
||||
item_id=message_item_id,
|
||||
output_index=0,
|
||||
sequence_number=sequence_number,
|
||||
)
|
||||
|
||||
# Collect content for final response
|
||||
chat_response_content.append(chunk_choice.delta.content or "")
|
||||
if chunk_choice.finish_reason:
|
||||
chunk_finish_reason = chunk_choice.finish_reason
|
||||
|
||||
# Aggregate tool call arguments across chunks
|
||||
if chunk_choice.delta.tool_calls:
|
||||
for tool_call in chunk_choice.delta.tool_calls:
|
||||
response_tool_call = chat_response_tool_calls.get(tool_call.index, None)
|
||||
if response_tool_call:
|
||||
# Don't attempt to concatenate arguments if we don't have any new argumentsAdd commentMore actions
|
||||
if tool_call.function.arguments:
|
||||
# Guard against an initial None argument before we concatenate
|
||||
response_tool_call.function.arguments = (
|
||||
response_tool_call.function.arguments or ""
|
||||
) + tool_call.function.arguments
|
||||
else:
|
||||
tool_call_dict: dict[str, Any] = tool_call.model_dump()
|
||||
tool_call_dict.pop("type", None)
|
||||
response_tool_call = OpenAIChatCompletionToolCall(**tool_call_dict)
|
||||
chat_response_tool_calls[tool_call.index] = response_tool_call
|
||||
|
||||
# Convert collected chunks to complete response
|
||||
if chat_response_tool_calls:
|
||||
tool_calls = [chat_response_tool_calls[i] for i in sorted(chat_response_tool_calls.keys())]
|
||||
else:
|
||||
tool_calls = None
|
||||
assistant_message = OpenAIAssistantMessageParam(
|
||||
content="".join(chat_response_content),
|
||||
tool_calls=tool_calls,
|
||||
)
|
||||
current_response = OpenAIChatCompletion(
|
||||
id=chat_response_id,
|
||||
choices=[
|
||||
OpenAIChoice(
|
||||
message=assistant_message,
|
||||
finish_reason=chunk_finish_reason,
|
||||
index=0,
|
||||
)
|
||||
],
|
||||
created=chunk_created,
|
||||
model=chunk_model,
|
||||
)
|
||||
|
||||
# Separate function vs non-function tool calls
|
||||
function_tool_calls = []
|
||||
non_function_tool_calls = []
|
||||
|
||||
for choice in current_response.choices:
|
||||
if choice.message.tool_calls and tools:
|
||||
for tool_call in choice.message.tool_calls:
|
||||
if self._is_function_tool_call(tool_call, tools):
|
||||
function_tool_calls.append(tool_call)
|
||||
else:
|
||||
non_function_tool_calls.append(tool_call)
|
||||
|
||||
# Process response choices based on tool call types
|
||||
if function_tool_calls:
|
||||
# For function tool calls, use existing logic and break
|
||||
current_output_messages = await self._process_response_choices(
|
||||
chat_response=current_response,
|
||||
ctx=ctx,
|
||||
tools=tools,
|
||||
)
|
||||
output_messages.extend(current_output_messages)
|
||||
break
|
||||
elif non_function_tool_calls:
|
||||
# For non-function tool calls, execute them and continue loop
|
||||
for choice in current_response.choices:
|
||||
tool_outputs, tool_response_messages = await self._execute_tool_calls_only(choice, ctx)
|
||||
output_messages.extend(tool_outputs)
|
||||
|
||||
# Add assistant message and tool responses to messages for next iteration
|
||||
messages.append(choice.message)
|
||||
messages.extend(tool_response_messages)
|
||||
|
||||
n_iter += 1
|
||||
if n_iter >= (max_infer_iters or 10):
|
||||
break
|
||||
|
||||
# Continue with next iteration of the loop
|
||||
continue
|
||||
else:
|
||||
# No tool calls - convert response to message and we're done
|
||||
for choice in current_response.choices:
|
||||
output_messages.append(await _convert_chat_choice_to_response_message(choice))
|
||||
break
|
||||
|
||||
# Create final response
|
||||
final_response = OpenAIResponseObject(
|
||||
|
@ -646,6 +750,30 @@ class OpenAIResponsesImpl:
|
|||
raise ValueError(f"Llama Stack OpenAI Responses does not yet support tool type: {input_tool.type}")
|
||||
return chat_tools, mcp_tool_to_server, mcp_list_message
|
||||
|
||||
async def _execute_tool_calls_only(
|
||||
self,
|
||||
choice: OpenAIChoice,
|
||||
ctx: ChatCompletionContext,
|
||||
) -> tuple[list[OpenAIResponseOutput], list[OpenAIMessageParam]]:
|
||||
"""Execute tool calls and return output messages and tool response messages for next inference."""
|
||||
output_messages: list[OpenAIResponseOutput] = []
|
||||
tool_response_messages: list[OpenAIMessageParam] = []
|
||||
|
||||
if not isinstance(choice.message, OpenAIAssistantMessageParam):
|
||||
return output_messages, tool_response_messages
|
||||
|
||||
if not choice.message.tool_calls:
|
||||
return output_messages, tool_response_messages
|
||||
|
||||
for tool_call in choice.message.tool_calls:
|
||||
tool_call_log, further_input = await self._execute_tool_call(tool_call, ctx)
|
||||
if tool_call_log:
|
||||
output_messages.append(tool_call_log)
|
||||
if further_input:
|
||||
tool_response_messages.append(further_input)
|
||||
|
||||
return output_messages, tool_response_messages
|
||||
|
||||
async def _execute_tool_and_return_final_output(
|
||||
self,
|
||||
choice: OpenAIChoice,
|
||||
|
|
|
@ -5,6 +5,7 @@
|
|||
# the root directory of this source tree.
|
||||
|
||||
# we want the mcp server to be authenticated OR not, depends
|
||||
from collections.abc import Callable
|
||||
from contextlib import contextmanager
|
||||
|
||||
# Unfortunately the toolgroup id must be tied to the tool names because the registry
|
||||
|
@ -13,29 +14,16 @@ from contextlib import contextmanager
|
|||
MCP_TOOLGROUP_ID = "mcp::localmcp"
|
||||
|
||||
|
||||
@contextmanager
|
||||
def make_mcp_server(required_auth_token: str | None = None):
|
||||
import threading
|
||||
import time
|
||||
|
||||
import httpx
|
||||
import uvicorn
|
||||
def default_tools():
|
||||
"""Default tools for backward compatibility."""
|
||||
from mcp import types
|
||||
from mcp.server.fastmcp import Context, FastMCP
|
||||
from mcp.server.sse import SseServerTransport
|
||||
from starlette.applications import Starlette
|
||||
from starlette.responses import Response
|
||||
from starlette.routing import Mount, Route
|
||||
from mcp.server.fastmcp import Context
|
||||
|
||||
server = FastMCP("FastMCP Test Server", log_level="WARNING")
|
||||
|
||||
@server.tool()
|
||||
async def greet_everyone(
|
||||
url: str, ctx: Context
|
||||
) -> list[types.TextContent | types.ImageContent | types.EmbeddedResource]:
|
||||
return [types.TextContent(type="text", text="Hello, world!")]
|
||||
|
||||
@server.tool()
|
||||
async def get_boiling_point(liquid_name: str, celcius: bool = True) -> int:
|
||||
"""
|
||||
Returns the boiling point of a liquid in Celcius or Fahrenheit.
|
||||
|
@ -52,12 +40,151 @@ def make_mcp_server(required_auth_token: str | None = None):
|
|||
else:
|
||||
return -1
|
||||
|
||||
return {"greet_everyone": greet_everyone, "get_boiling_point": get_boiling_point}
|
||||
|
||||
|
||||
def dependency_tools():
|
||||
"""Tools with natural dependencies for multi-turn testing."""
|
||||
from mcp import types
|
||||
from mcp.server.fastmcp import Context
|
||||
|
||||
async def get_user_id(username: str, ctx: Context) -> str:
|
||||
"""
|
||||
Get the user ID for a given username. This ID is needed for other operations.
|
||||
|
||||
:param username: The username to look up
|
||||
:return: The user ID for the username
|
||||
"""
|
||||
# Simple mapping for testing
|
||||
user_mapping = {"alice": "user_12345", "bob": "user_67890", "charlie": "user_11111", "admin": "user_00000"}
|
||||
return user_mapping.get(username.lower(), "user_99999")
|
||||
|
||||
async def get_user_permissions(user_id: str, ctx: Context) -> str:
|
||||
"""
|
||||
Get the permissions for a user ID. Requires a valid user ID from get_user_id.
|
||||
|
||||
:param user_id: The user ID to check permissions for
|
||||
:return: The permissions for the user
|
||||
"""
|
||||
# Permission mapping based on user IDs
|
||||
permission_mapping = {
|
||||
"user_12345": "read,write", # alice
|
||||
"user_67890": "read", # bob
|
||||
"user_11111": "admin", # charlie
|
||||
"user_00000": "superadmin", # admin
|
||||
"user_99999": "none", # unknown users
|
||||
}
|
||||
return permission_mapping.get(user_id, "none")
|
||||
|
||||
async def check_file_access(user_id: str, filename: str, ctx: Context) -> str:
|
||||
"""
|
||||
Check if a user can access a specific file. Requires a valid user ID.
|
||||
|
||||
:param user_id: The user ID to check access for
|
||||
:param filename: The filename to check access to
|
||||
:return: Whether the user can access the file (yes/no)
|
||||
"""
|
||||
# Get permissions first
|
||||
permission_mapping = {
|
||||
"user_12345": "read,write", # alice
|
||||
"user_67890": "read", # bob
|
||||
"user_11111": "admin", # charlie
|
||||
"user_00000": "superadmin", # admin
|
||||
"user_99999": "none", # unknown users
|
||||
}
|
||||
permissions = permission_mapping.get(user_id, "none")
|
||||
|
||||
# Check file access based on permissions and filename
|
||||
if permissions == "superadmin":
|
||||
access = "yes"
|
||||
elif permissions == "admin":
|
||||
access = "yes" if not filename.startswith("secret_") else "no"
|
||||
elif "write" in permissions:
|
||||
access = "yes" if filename.endswith(".txt") else "no"
|
||||
elif "read" in permissions:
|
||||
access = "yes" if filename.endswith(".txt") or filename.endswith(".md") else "no"
|
||||
else:
|
||||
access = "no"
|
||||
|
||||
return [types.TextContent(type="text", text=access)]
|
||||
|
||||
async def get_experiment_id(experiment_name: str, ctx: Context) -> str:
|
||||
"""
|
||||
Get the experiment ID for a given experiment name. This ID is needed to get results.
|
||||
|
||||
:param experiment_name: The name of the experiment
|
||||
:return: The experiment ID
|
||||
"""
|
||||
# Simple mapping for testing
|
||||
experiment_mapping = {
|
||||
"temperature_test": "exp_001",
|
||||
"pressure_test": "exp_002",
|
||||
"chemical_reaction": "exp_003",
|
||||
"boiling_point": "exp_004",
|
||||
}
|
||||
exp_id = experiment_mapping.get(experiment_name.lower(), "exp_999")
|
||||
return exp_id
|
||||
|
||||
async def get_experiment_results(experiment_id: str, ctx: Context) -> str:
|
||||
"""
|
||||
Get the results for an experiment ID. Requires a valid experiment ID from get_experiment_id.
|
||||
|
||||
:param experiment_id: The experiment ID to get results for
|
||||
:return: The experiment results
|
||||
"""
|
||||
# Results mapping based on experiment IDs
|
||||
results_mapping = {
|
||||
"exp_001": "Temperature: 25°C, Status: Success",
|
||||
"exp_002": "Pressure: 1.2 atm, Status: Success",
|
||||
"exp_003": "Yield: 85%, Status: Complete",
|
||||
"exp_004": "Boiling Point: 100°C, Status: Verified",
|
||||
"exp_999": "No results found",
|
||||
}
|
||||
results = results_mapping.get(experiment_id, "Invalid experiment ID")
|
||||
return results
|
||||
|
||||
return {
|
||||
"get_user_id": get_user_id,
|
||||
"get_user_permissions": get_user_permissions,
|
||||
"check_file_access": check_file_access,
|
||||
"get_experiment_id": get_experiment_id,
|
||||
"get_experiment_results": get_experiment_results,
|
||||
}
|
||||
|
||||
|
||||
@contextmanager
|
||||
def make_mcp_server(required_auth_token: str | None = None, tools: dict[str, Callable] | None = None):
|
||||
"""
|
||||
Create an MCP server with the specified tools.
|
||||
|
||||
:param required_auth_token: Optional auth token required for access
|
||||
:param tools: Dictionary of tool_name -> tool_function. If None, uses default tools.
|
||||
"""
|
||||
import threading
|
||||
import time
|
||||
|
||||
import httpx
|
||||
import uvicorn
|
||||
from mcp.server.fastmcp import FastMCP
|
||||
from mcp.server.sse import SseServerTransport
|
||||
from starlette.applications import Starlette
|
||||
from starlette.responses import Response
|
||||
from starlette.routing import Mount, Route
|
||||
|
||||
server = FastMCP("FastMCP Test Server", log_level="WARNING")
|
||||
|
||||
tools = tools or default_tools()
|
||||
|
||||
# Register all tools with the server
|
||||
for tool_func in tools.values():
|
||||
server.tool()(tool_func)
|
||||
|
||||
sse = SseServerTransport("/messages/")
|
||||
|
||||
async def handle_sse(request):
|
||||
from starlette.exceptions import HTTPException
|
||||
|
||||
auth_header = request.headers.get("Authorization")
|
||||
auth_header: str | None = request.headers.get("Authorization")
|
||||
auth_token = None
|
||||
if auth_header and auth_header.startswith("Bearer "):
|
||||
auth_token = auth_header.split(" ")[1]
|
||||
|
|
|
@ -224,16 +224,16 @@ async def test_create_openai_response_with_tool_call_type_none(openai_responses_
|
|||
],
|
||||
)
|
||||
|
||||
# Verify
|
||||
# Check that we got the content from our mocked tool execution result
|
||||
chunks = [chunk async for chunk in result]
|
||||
assert len(chunks) == 2 # Should have response.created and response.completed
|
||||
|
||||
# Verify inference API was called correctly (after iterating over result)
|
||||
first_call = mock_inference_api.openai_chat_completion.call_args_list[0]
|
||||
assert first_call.kwargs["messages"][0].content == input_text
|
||||
assert first_call.kwargs["tools"] is not None
|
||||
assert first_call.kwargs["temperature"] == 0.1
|
||||
|
||||
# Check that we got the content from our mocked tool execution result
|
||||
chunks = [chunk async for chunk in result]
|
||||
assert len(chunks) == 2 # Should have response.created and response.completed
|
||||
|
||||
# Check response.created event (should have empty output)
|
||||
assert chunks[0].type == "response.created"
|
||||
assert len(chunks[0].response.output) == 0
|
||||
|
|
|
@ -94,3 +94,43 @@ test_response_multi_turn_image:
|
|||
output: "llama"
|
||||
- input: "What country do you find this animal primarily in? What continent?"
|
||||
output: "peru"
|
||||
|
||||
test_response_multi_turn_tool_execution:
|
||||
test_name: test_response_multi_turn_tool_execution
|
||||
test_params:
|
||||
case:
|
||||
- case_id: "user_file_access_check"
|
||||
input: "I need to check if user 'alice' can access the file 'document.txt'. First, get alice's user ID, then check if that user ID can access the file 'document.txt'. Tell me the final result."
|
||||
tools:
|
||||
- type: mcp
|
||||
server_label: "localmcp"
|
||||
server_url: "<FILLED_BY_TEST_RUNNER>"
|
||||
output: "yes"
|
||||
- case_id: "experiment_results_lookup"
|
||||
input: "I need to get the results for the 'boiling_point' experiment. First, get the experiment ID for 'boiling_point', then use that ID to get the experiment results. Tell me what you found."
|
||||
tools:
|
||||
- type: mcp
|
||||
server_label: "localmcp"
|
||||
server_url: "<FILLED_BY_TEST_RUNNER>"
|
||||
output: "100°C"
|
||||
|
||||
test_response_multi_turn_tool_execution_streaming:
|
||||
test_name: test_response_multi_turn_tool_execution_streaming
|
||||
test_params:
|
||||
case:
|
||||
- case_id: "user_permissions_workflow"
|
||||
input: "Help me with this security check: First, get the user ID for 'charlie', then get the permissions for that user ID, and finally check if that user can access 'secret_file.txt'. Stream your progress as you work through each step."
|
||||
tools:
|
||||
- type: mcp
|
||||
server_label: "localmcp"
|
||||
server_url: "<FILLED_BY_TEST_RUNNER>"
|
||||
stream: true
|
||||
output: "no"
|
||||
- case_id: "experiment_analysis_streaming"
|
||||
input: "I need a complete analysis: First, get the experiment ID for 'chemical_reaction', then get the results for that experiment, and tell me if the yield was above 80%. Please stream your analysis process."
|
||||
tools:
|
||||
- type: mcp
|
||||
server_label: "localmcp"
|
||||
server_url: "<FILLED_BY_TEST_RUNNER>"
|
||||
stream: true
|
||||
output: "85%"
|
||||
|
|
|
@ -12,7 +12,7 @@ import pytest
|
|||
|
||||
from llama_stack import LlamaStackAsLibraryClient
|
||||
from llama_stack.distribution.datatypes import AuthenticationRequiredError
|
||||
from tests.common.mcp import make_mcp_server
|
||||
from tests.common.mcp import dependency_tools, make_mcp_server
|
||||
from tests.verifications.openai_api.fixtures.fixtures import (
|
||||
case_id_generator,
|
||||
get_base_test_name,
|
||||
|
@ -280,6 +280,7 @@ def test_response_non_streaming_mcp_tool(request, openai_client, model, provider
|
|||
tools=tools,
|
||||
stream=False,
|
||||
)
|
||||
|
||||
assert len(response.output) >= 3
|
||||
list_tools = response.output[0]
|
||||
assert list_tools.type == "mcp_list_tools"
|
||||
|
@ -294,7 +295,12 @@ def test_response_non_streaming_mcp_tool(request, openai_client, model, provider
|
|||
assert call.error is None
|
||||
assert "-100" in call.output
|
||||
|
||||
message = response.output[2]
|
||||
from rich.pretty import pprint
|
||||
|
||||
pprint(response)
|
||||
|
||||
# sometimes the model will call the tool again, so we need to get the last message
|
||||
message = response.output[-1]
|
||||
text_content = message.content[0].text
|
||||
assert "boiling point" in text_content.lower()
|
||||
|
||||
|
@ -393,3 +399,154 @@ def test_response_non_streaming_multi_turn_image(request, openai_client, model,
|
|||
previous_response_id = response.id
|
||||
output_text = response.output_text.lower()
|
||||
assert turn["output"].lower() in output_text
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"case",
|
||||
responses_test_cases["test_response_multi_turn_tool_execution"]["test_params"]["case"],
|
||||
ids=case_id_generator,
|
||||
)
|
||||
def test_response_non_streaming_multi_turn_tool_execution(
|
||||
request, openai_client, model, provider, verification_config, case
|
||||
):
|
||||
"""Test multi-turn tool execution where multiple MCP tool calls are performed in sequence."""
|
||||
test_name_base = get_base_test_name(request)
|
||||
if should_skip_test(verification_config, provider, model, test_name_base):
|
||||
pytest.skip(f"Skipping {test_name_base} for model {model} on provider {provider} based on config.")
|
||||
|
||||
with make_mcp_server(tools=dependency_tools()) as mcp_server_info:
|
||||
tools = case["tools"]
|
||||
# Replace the placeholder URL with the actual server URL
|
||||
for tool in tools:
|
||||
if tool["type"] == "mcp" and tool["server_url"] == "<FILLED_BY_TEST_RUNNER>":
|
||||
tool["server_url"] = mcp_server_info["server_url"]
|
||||
|
||||
response = openai_client.responses.create(
|
||||
input=case["input"],
|
||||
model=model,
|
||||
tools=tools,
|
||||
)
|
||||
|
||||
# Verify we have MCP tool calls in the output
|
||||
mcp_list_tools = [output for output in response.output if output.type == "mcp_list_tools"]
|
||||
mcp_calls = [output for output in response.output if output.type == "mcp_call"]
|
||||
message_outputs = [output for output in response.output if output.type == "message"]
|
||||
|
||||
# Should have exactly 1 MCP list tools message (at the beginning)
|
||||
assert len(mcp_list_tools) == 1, f"Expected exactly 1 mcp_list_tools, got {len(mcp_list_tools)}"
|
||||
assert mcp_list_tools[0].server_label == "localmcp"
|
||||
assert len(mcp_list_tools[0].tools) == 5 # Updated for dependency tools
|
||||
expected_tool_names = {
|
||||
"get_user_id",
|
||||
"get_user_permissions",
|
||||
"check_file_access",
|
||||
"get_experiment_id",
|
||||
"get_experiment_results",
|
||||
}
|
||||
assert {t["name"] for t in mcp_list_tools[0].tools} == expected_tool_names
|
||||
|
||||
assert len(mcp_calls) >= 1, f"Expected at least 1 mcp_call, got {len(mcp_calls)}"
|
||||
for mcp_call in mcp_calls:
|
||||
assert mcp_call.error is None, f"MCP call should not have errors, got: {mcp_call.error}"
|
||||
|
||||
assert len(message_outputs) >= 1, f"Expected at least 1 message output, got {len(message_outputs)}"
|
||||
|
||||
final_message = message_outputs[-1]
|
||||
assert final_message.role == "assistant", f"Final message should be from assistant, got {final_message.role}"
|
||||
assert final_message.status == "completed", f"Final message should be completed, got {final_message.status}"
|
||||
assert len(final_message.content) > 0, "Final message should have content"
|
||||
|
||||
expected_output = case["output"]
|
||||
assert expected_output.lower() in response.output_text.lower(), (
|
||||
f"Expected '{expected_output}' to appear in response: {response.output_text}"
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"case",
|
||||
responses_test_cases["test_response_multi_turn_tool_execution_streaming"]["test_params"]["case"],
|
||||
ids=case_id_generator,
|
||||
)
|
||||
async def test_response_streaming_multi_turn_tool_execution(
|
||||
request, openai_client, model, provider, verification_config, case
|
||||
):
|
||||
"""Test streaming multi-turn tool execution where multiple MCP tool calls are performed in sequence."""
|
||||
test_name_base = get_base_test_name(request)
|
||||
if should_skip_test(verification_config, provider, model, test_name_base):
|
||||
pytest.skip(f"Skipping {test_name_base} for model {model} on provider {provider} based on config.")
|
||||
|
||||
with make_mcp_server(tools=dependency_tools()) as mcp_server_info:
|
||||
tools = case["tools"]
|
||||
# Replace the placeholder URL with the actual server URL
|
||||
for tool in tools:
|
||||
if tool["type"] == "mcp" and tool["server_url"] == "<FILLED_BY_TEST_RUNNER>":
|
||||
tool["server_url"] = mcp_server_info["server_url"]
|
||||
|
||||
stream = openai_client.responses.create(
|
||||
input=case["input"],
|
||||
model=model,
|
||||
tools=tools,
|
||||
stream=True,
|
||||
)
|
||||
|
||||
chunks = []
|
||||
async for chunk in stream:
|
||||
chunks.append(chunk)
|
||||
|
||||
# Should have at least response.created and response.completed
|
||||
assert len(chunks) >= 2, f"Expected at least 2 chunks (created + completed), got {len(chunks)}"
|
||||
|
||||
# First chunk should be response.created
|
||||
assert chunks[0].type == "response.created", f"First chunk should be response.created, got {chunks[0].type}"
|
||||
|
||||
# Last chunk should be response.completed
|
||||
assert chunks[-1].type == "response.completed", (
|
||||
f"Last chunk should be response.completed, got {chunks[-1].type}"
|
||||
)
|
||||
|
||||
# Get the final response from the last chunk
|
||||
final_chunk = chunks[-1]
|
||||
if hasattr(final_chunk, "response"):
|
||||
final_response = final_chunk.response
|
||||
|
||||
# Verify multi-turn MCP tool execution results
|
||||
mcp_list_tools = [output for output in final_response.output if output.type == "mcp_list_tools"]
|
||||
mcp_calls = [output for output in final_response.output if output.type == "mcp_call"]
|
||||
message_outputs = [output for output in final_response.output if output.type == "message"]
|
||||
|
||||
# Should have exactly 1 MCP list tools message (at the beginning)
|
||||
assert len(mcp_list_tools) == 1, f"Expected exactly 1 mcp_list_tools, got {len(mcp_list_tools)}"
|
||||
assert mcp_list_tools[0].server_label == "localmcp"
|
||||
assert len(mcp_list_tools[0].tools) == 5 # Updated for dependency tools
|
||||
expected_tool_names = {
|
||||
"get_user_id",
|
||||
"get_user_permissions",
|
||||
"check_file_access",
|
||||
"get_experiment_id",
|
||||
"get_experiment_results",
|
||||
}
|
||||
assert {t["name"] for t in mcp_list_tools[0].tools} == expected_tool_names
|
||||
|
||||
# Should have at least 1 MCP call (the model should call at least one tool)
|
||||
assert len(mcp_calls) >= 1, f"Expected at least 1 mcp_call, got {len(mcp_calls)}"
|
||||
|
||||
# All MCP calls should be completed (verifies our tool execution works)
|
||||
for mcp_call in mcp_calls:
|
||||
assert mcp_call.error is None, f"MCP call should not have errors, got: {mcp_call.error}"
|
||||
|
||||
# Should have at least one final message response
|
||||
assert len(message_outputs) >= 1, f"Expected at least 1 message output, got {len(message_outputs)}"
|
||||
|
||||
# Final message should be from assistant and completed
|
||||
final_message = message_outputs[-1]
|
||||
assert final_message.role == "assistant", (
|
||||
f"Final message should be from assistant, got {final_message.role}"
|
||||
)
|
||||
assert final_message.status == "completed", f"Final message should be completed, got {final_message.status}"
|
||||
assert len(final_message.content) > 0, "Final message should have content"
|
||||
|
||||
# Check that the expected output appears in the response
|
||||
expected_output = case["output"]
|
||||
assert expected_output.lower() in final_response.output_text.lower(), (
|
||||
f"Expected '{expected_output}' to appear in response: {final_response.output_text}"
|
||||
)
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue