mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-12-03 01:48:05 +00:00
add unit tests
Signed-off-by: Jaideep Rao <jrao@redhat.com>
This commit is contained in:
parent
d2d2c88921
commit
36d7abd4d5
11 changed files with 392 additions and 85 deletions
|
|
@ -12002,18 +12002,9 @@ components:
|
||||||
- web_search_preview
|
- web_search_preview
|
||||||
- web_search_preview_2025_03_11
|
- web_search_preview_2025_03_11
|
||||||
- web_search_2025_08_26
|
- web_search_2025_08_26
|
||||||
search_context_size:
|
|
||||||
anyOf:
|
|
||||||
- type: string
|
|
||||||
pattern: ^low|medium|high$
|
|
||||||
- type: 'null'
|
|
||||||
default: medium
|
|
||||||
type: object
|
type: object
|
||||||
title: OpenAIResponseInputToolChoiceWebSearch
|
title: OpenAIResponseInputToolChoiceWebSearch
|
||||||
description: |-
|
description: Indicates that the model should use web search to generate a response
|
||||||
Indicates that the model should use web search to generate a response.
|
|
||||||
|
|
||||||
This is an alias for OpenAIResponseInputToolWebSearch used in tool_choice context.
|
|
||||||
OpenAIResponseMessage-Input:
|
OpenAIResponseMessage-Input:
|
||||||
properties:
|
properties:
|
||||||
content:
|
content:
|
||||||
|
|
@ -12420,6 +12411,7 @@ components:
|
||||||
title: AllowedToolsConfig
|
title: AllowedToolsConfig
|
||||||
type: object
|
type: object
|
||||||
CustomToolConfig:
|
CustomToolConfig:
|
||||||
|
description: Custom tool configuration for OpenAI-compatible chat completion requests.
|
||||||
properties:
|
properties:
|
||||||
name:
|
name:
|
||||||
title: Name
|
title: Name
|
||||||
|
|
|
||||||
12
docs/static/deprecated-llama-stack-spec.yaml
vendored
12
docs/static/deprecated-llama-stack-spec.yaml
vendored
|
|
@ -8996,18 +8996,9 @@ components:
|
||||||
- web_search_preview
|
- web_search_preview
|
||||||
- web_search_preview_2025_03_11
|
- web_search_preview_2025_03_11
|
||||||
- web_search_2025_08_26
|
- web_search_2025_08_26
|
||||||
search_context_size:
|
|
||||||
anyOf:
|
|
||||||
- type: string
|
|
||||||
pattern: ^low|medium|high$
|
|
||||||
- type: 'null'
|
|
||||||
default: medium
|
|
||||||
type: object
|
type: object
|
||||||
title: OpenAIResponseInputToolChoiceWebSearch
|
title: OpenAIResponseInputToolChoiceWebSearch
|
||||||
description: |-
|
description: Indicates that the model should use web search to generate a response
|
||||||
Indicates that the model should use web search to generate a response.
|
|
||||||
|
|
||||||
This is an alias for OpenAIResponseInputToolWebSearch used in tool_choice context.
|
|
||||||
OpenAIResponseMessage-Input:
|
OpenAIResponseMessage-Input:
|
||||||
properties:
|
properties:
|
||||||
content:
|
content:
|
||||||
|
|
@ -9414,6 +9405,7 @@ components:
|
||||||
title: AllowedToolsConfig
|
title: AllowedToolsConfig
|
||||||
type: object
|
type: object
|
||||||
CustomToolConfig:
|
CustomToolConfig:
|
||||||
|
description: Custom tool configuration for OpenAI-compatible chat completion requests.
|
||||||
properties:
|
properties:
|
||||||
name:
|
name:
|
||||||
title: Name
|
title: Name
|
||||||
|
|
|
||||||
12
docs/static/experimental-llama-stack-spec.yaml
vendored
12
docs/static/experimental-llama-stack-spec.yaml
vendored
|
|
@ -7799,18 +7799,9 @@ components:
|
||||||
- web_search_preview
|
- web_search_preview
|
||||||
- web_search_preview_2025_03_11
|
- web_search_preview_2025_03_11
|
||||||
- web_search_2025_08_26
|
- web_search_2025_08_26
|
||||||
search_context_size:
|
|
||||||
anyOf:
|
|
||||||
- type: string
|
|
||||||
pattern: ^low|medium|high$
|
|
||||||
- type: 'null'
|
|
||||||
default: medium
|
|
||||||
type: object
|
type: object
|
||||||
title: OpenAIResponseInputToolChoiceWebSearch
|
title: OpenAIResponseInputToolChoiceWebSearch
|
||||||
description: |-
|
description: Indicates that the model should use web search to generate a response
|
||||||
Indicates that the model should use web search to generate a response.
|
|
||||||
|
|
||||||
This is an alias for OpenAIResponseInputToolWebSearch used in tool_choice context.
|
|
||||||
OpenAIResponseMessage-Output:
|
OpenAIResponseMessage-Output:
|
||||||
properties:
|
properties:
|
||||||
content:
|
content:
|
||||||
|
|
@ -8148,6 +8139,7 @@ components:
|
||||||
title: AllowedToolsConfig
|
title: AllowedToolsConfig
|
||||||
type: object
|
type: object
|
||||||
CustomToolConfig:
|
CustomToolConfig:
|
||||||
|
description: Custom tool configuration for OpenAI-compatible chat completion requests.
|
||||||
properties:
|
properties:
|
||||||
name:
|
name:
|
||||||
title: Name
|
title: Name
|
||||||
|
|
|
||||||
12
docs/static/llama-stack-spec.yaml
vendored
12
docs/static/llama-stack-spec.yaml
vendored
|
|
@ -10438,18 +10438,9 @@ components:
|
||||||
- web_search_preview
|
- web_search_preview
|
||||||
- web_search_preview_2025_03_11
|
- web_search_preview_2025_03_11
|
||||||
- web_search_2025_08_26
|
- web_search_2025_08_26
|
||||||
search_context_size:
|
|
||||||
anyOf:
|
|
||||||
- type: string
|
|
||||||
pattern: ^low|medium|high$
|
|
||||||
- type: 'null'
|
|
||||||
default: medium
|
|
||||||
type: object
|
type: object
|
||||||
title: OpenAIResponseInputToolChoiceWebSearch
|
title: OpenAIResponseInputToolChoiceWebSearch
|
||||||
description: |-
|
description: Indicates that the model should use web search to generate a response
|
||||||
Indicates that the model should use web search to generate a response.
|
|
||||||
|
|
||||||
This is an alias for OpenAIResponseInputToolWebSearch used in tool_choice context.
|
|
||||||
OpenAIResponseMessage-Input:
|
OpenAIResponseMessage-Input:
|
||||||
properties:
|
properties:
|
||||||
content:
|
content:
|
||||||
|
|
@ -10856,6 +10847,7 @@ components:
|
||||||
title: AllowedToolsConfig
|
title: AllowedToolsConfig
|
||||||
type: object
|
type: object
|
||||||
CustomToolConfig:
|
CustomToolConfig:
|
||||||
|
description: Custom tool configuration for OpenAI-compatible chat completion requests.
|
||||||
properties:
|
properties:
|
||||||
name:
|
name:
|
||||||
title: Name
|
title: Name
|
||||||
|
|
|
||||||
12
docs/static/stainless-llama-stack-spec.yaml
vendored
12
docs/static/stainless-llama-stack-spec.yaml
vendored
|
|
@ -12002,18 +12002,9 @@ components:
|
||||||
- web_search_preview
|
- web_search_preview
|
||||||
- web_search_preview_2025_03_11
|
- web_search_preview_2025_03_11
|
||||||
- web_search_2025_08_26
|
- web_search_2025_08_26
|
||||||
search_context_size:
|
|
||||||
anyOf:
|
|
||||||
- type: string
|
|
||||||
pattern: ^low|medium|high$
|
|
||||||
- type: 'null'
|
|
||||||
default: medium
|
|
||||||
type: object
|
type: object
|
||||||
title: OpenAIResponseInputToolChoiceWebSearch
|
title: OpenAIResponseInputToolChoiceWebSearch
|
||||||
description: |-
|
description: Indicates that the model should use web search to generate a response
|
||||||
Indicates that the model should use web search to generate a response.
|
|
||||||
|
|
||||||
This is an alias for OpenAIResponseInputToolWebSearch used in tool_choice context.
|
|
||||||
OpenAIResponseMessage-Input:
|
OpenAIResponseMessage-Input:
|
||||||
properties:
|
properties:
|
||||||
content:
|
content:
|
||||||
|
|
@ -12420,6 +12411,7 @@ components:
|
||||||
title: AllowedToolsConfig
|
title: AllowedToolsConfig
|
||||||
type: object
|
type: object
|
||||||
CustomToolConfig:
|
CustomToolConfig:
|
||||||
|
description: Custom tool configuration for OpenAI-compatible chat completion requests.
|
||||||
properties:
|
properties:
|
||||||
name:
|
name:
|
||||||
title: Name
|
title: Name
|
||||||
|
|
|
||||||
|
|
@ -473,7 +473,7 @@ class OpenAIResponsesImpl:
|
||||||
model=model,
|
model=model,
|
||||||
messages=messages,
|
messages=messages,
|
||||||
response_tools=tools,
|
response_tools=tools,
|
||||||
responses_tool_choice=tool_choice,
|
tool_choice=tool_choice,
|
||||||
temperature=temperature,
|
temperature=temperature,
|
||||||
response_format=response_format,
|
response_format=response_format,
|
||||||
tool_context=tool_context,
|
tool_context=tool_context,
|
||||||
|
|
|
||||||
|
|
@ -8,8 +8,8 @@ import uuid
|
||||||
from collections.abc import AsyncIterator
|
from collections.abc import AsyncIterator
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
from opentelemetry import trace
|
|
||||||
from openai.types.chat import ChatCompletionToolParam
|
from openai.types.chat import ChatCompletionToolParam
|
||||||
|
from opentelemetry import trace
|
||||||
|
|
||||||
from llama_stack.log import get_logger
|
from llama_stack.log import get_logger
|
||||||
from llama_stack.providers.utils.inference.prompt_adapter import interleaved_content_as_str
|
from llama_stack.providers.utils.inference.prompt_adapter import interleaved_content_as_str
|
||||||
|
|
@ -217,7 +217,7 @@ class StreamingResponseOrchestrator:
|
||||||
output=self._clone_outputs(outputs),
|
output=self._clone_outputs(outputs),
|
||||||
text=self.text,
|
text=self.text,
|
||||||
tools=self.ctx.available_tools(),
|
tools=self.ctx.available_tools(),
|
||||||
tool_choice=self.ctx.responses_tool_choice,
|
tool_choice=self.ctx.tool_choice,
|
||||||
error=error,
|
error=error,
|
||||||
usage=self.accumulated_usage,
|
usage=self.accumulated_usage,
|
||||||
instructions=self.instructions,
|
instructions=self.instructions,
|
||||||
|
|
@ -253,17 +253,18 @@ class StreamingResponseOrchestrator:
|
||||||
async for stream_event in self._process_tools(output_messages):
|
async for stream_event in self._process_tools(output_messages):
|
||||||
yield stream_event
|
yield stream_event
|
||||||
|
|
||||||
if self.ctx.responses_tool_choice and len(self.ctx.chat_tools) > 0:
|
chat_tool_choice = None
|
||||||
chat_tool_choice = await _process_tool_choice(
|
if self.ctx.tool_choice and len(self.ctx.chat_tools) > 0:
|
||||||
|
processed_tool_choice = await _process_tool_choice(
|
||||||
self.ctx.chat_tools,
|
self.ctx.chat_tools,
|
||||||
self.ctx.responses_tool_choice,
|
self.ctx.tool_choice,
|
||||||
self.server_label_to_tools,
|
self.server_label_to_tools,
|
||||||
)
|
)
|
||||||
# chat_tool_choice can be str, dict-like object, or None
|
# chat_tool_choice can be str, dict-like object, or None
|
||||||
if isinstance(chat_tool_choice, str):
|
if isinstance(processed_tool_choice, str | type(None)):
|
||||||
self.ctx.chat_tool_choice = chat_tool_choice
|
chat_tool_choice = processed_tool_choice
|
||||||
else:
|
else:
|
||||||
self.ctx.chat_tool_choice = chat_tool_choice.model_dump()
|
chat_tool_choice = processed_tool_choice.model_dump()
|
||||||
|
|
||||||
n_iter = 0
|
n_iter = 0
|
||||||
messages = self.ctx.messages.copy()
|
messages = self.ctx.messages.copy()
|
||||||
|
|
@ -284,7 +285,7 @@ class StreamingResponseOrchestrator:
|
||||||
messages=messages,
|
messages=messages,
|
||||||
# Pydantic models are dict-compatible but mypy treats them as distinct types
|
# Pydantic models are dict-compatible but mypy treats them as distinct types
|
||||||
tools=self.ctx.chat_tools, # type: ignore[arg-type]
|
tools=self.ctx.chat_tools, # type: ignore[arg-type]
|
||||||
tool_choice=self.ctx.chat_tool_choice,
|
tool_choice=chat_tool_choice,
|
||||||
stream=True,
|
stream=True,
|
||||||
temperature=self.ctx.temperature,
|
temperature=self.ctx.temperature,
|
||||||
response_format=response_format,
|
response_format=response_format,
|
||||||
|
|
@ -363,8 +364,8 @@ class StreamingResponseOrchestrator:
|
||||||
n_iter += 1
|
n_iter += 1
|
||||||
# After first iteration, reset tool_choice to "auto" to let model decide freely
|
# After first iteration, reset tool_choice to "auto" to let model decide freely
|
||||||
# based on tool results (prevents infinite loops when forcing specific tools)
|
# based on tool results (prevents infinite loops when forcing specific tools)
|
||||||
if n_iter == 1 and self.ctx.chat_tool_choice:
|
if n_iter == 1 and chat_tool_choice:
|
||||||
self.ctx.chat_tool_choice = "auto"
|
chat_tool_choice = "auto"
|
||||||
if n_iter >= self.max_infer_iters:
|
if n_iter >= self.max_infer_iters:
|
||||||
logger.info(
|
logger.info(
|
||||||
f"Exiting inference loop since iteration count({n_iter}) exceeds {self.max_infer_iters=}"
|
f"Exiting inference loop since iteration count({n_iter}) exceeds {self.max_infer_iters=}"
|
||||||
|
|
@ -1332,13 +1333,13 @@ class StreamingResponseOrchestrator:
|
||||||
|
|
||||||
async def _process_tool_choice(
|
async def _process_tool_choice(
|
||||||
chat_tools: list[ChatCompletionToolParam],
|
chat_tools: list[ChatCompletionToolParam],
|
||||||
responses_tool_choice: OpenAIResponseInputToolChoice,
|
tool_choice: OpenAIResponseInputToolChoice,
|
||||||
server_label_to_tools: dict[str, list[str]],
|
server_label_to_tools: dict[str, list[str]],
|
||||||
) -> str | OpenAIChatCompletionToolChoice | None:
|
) -> str | OpenAIChatCompletionToolChoice | None:
|
||||||
"""Process and validate the OpenAI Responses tool choice and return the appropriate chat completion tool choice object.
|
"""Process and validate the OpenAI Responses tool choice and return the appropriate chat completion tool choice object.
|
||||||
|
|
||||||
:param chat_tools: The list of chat tools to enforce tool choice against.
|
:param chat_tools: The list of chat tools to enforce tool choice against.
|
||||||
:param responses_tool_choice: The OpenAI Responses tool choice to process.
|
:param tool_choice: The OpenAI Responses tool choice to process.
|
||||||
:param server_label_to_tools: A dictionary mapping server labels to the list of tools available on that server.
|
:param server_label_to_tools: A dictionary mapping server labels to the list of tools available on that server.
|
||||||
:return: The appropriate chat completion tool choice object.
|
:return: The appropriate chat completion tool choice object.
|
||||||
"""
|
"""
|
||||||
|
|
@ -1347,8 +1348,8 @@ async def _process_tool_choice(
|
||||||
# Note: chat_tools contains dicts, not objects
|
# Note: chat_tools contains dicts, not objects
|
||||||
chat_tool_names = [tool["function"]["name"] for tool in chat_tools if tool["type"] == "function"]
|
chat_tool_names = [tool["function"]["name"] for tool in chat_tools if tool["type"] == "function"]
|
||||||
|
|
||||||
if isinstance(responses_tool_choice, OpenAIResponseInputToolChoiceMode):
|
if isinstance(tool_choice, OpenAIResponseInputToolChoiceMode):
|
||||||
if responses_tool_choice.value == "required":
|
if tool_choice.value == "required":
|
||||||
if len(chat_tool_names) == 0:
|
if len(chat_tool_names) == 0:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
@ -1358,18 +1359,17 @@ async def _process_tool_choice(
|
||||||
mode="required",
|
mode="required",
|
||||||
)
|
)
|
||||||
# return other modes as is
|
# return other modes as is
|
||||||
return responses_tool_choice.value
|
return tool_choice.value
|
||||||
|
|
||||||
elif isinstance(responses_tool_choice, OpenAIResponseInputToolChoiceAllowedTools):
|
elif isinstance(tool_choice, OpenAIResponseInputToolChoiceAllowedTools):
|
||||||
# ensure that specified tool choices are available in the chat tools, if not, remove them from the list
|
# ensure that specified tool choices are available in the chat tools, if not, remove them from the list
|
||||||
final_tools = []
|
final_tools = []
|
||||||
for tool in responses_tool_choice.tools:
|
for tool in tool_choice.tools:
|
||||||
tool_name = tool.get("name")
|
|
||||||
match tool.get("type"):
|
match tool.get("type"):
|
||||||
case "function":
|
case "function":
|
||||||
final_tools.append({"type": "function", "function": {"name": tool_name}})
|
final_tools.append({"type": "function", "function": {"name": tool.get("name")}})
|
||||||
case "custom":
|
case "custom":
|
||||||
final_tools.append({"type": "custom", "custom": {"name": tool_name}})
|
final_tools.append({"type": "custom", "custom": {"name": tool.get("name")}})
|
||||||
case "mcp":
|
case "mcp":
|
||||||
mcp_tools = convert_mcp_tool_choice(
|
mcp_tools = convert_mcp_tool_choice(
|
||||||
chat_tool_names, tool.get("server_label"), server_label_to_tools, None
|
chat_tool_names, tool.get("server_label"), server_label_to_tools, None
|
||||||
|
|
@ -1390,14 +1390,14 @@ async def _process_tool_choice(
|
||||||
|
|
||||||
return OpenAIChatCompletionToolChoiceAllowedTools(
|
return OpenAIChatCompletionToolChoiceAllowedTools(
|
||||||
tools=final_tools,
|
tools=final_tools,
|
||||||
mode=responses_tool_choice.mode,
|
mode=tool_choice.mode,
|
||||||
)
|
)
|
||||||
|
|
||||||
else:
|
else:
|
||||||
# Handle specific tool choice by type
|
# Handle specific tool choice by type
|
||||||
# Each case validates the tool exists in chat_tools before returning
|
# Each case validates the tool exists in chat_tools before returning
|
||||||
tool_name = responses_tool_choice.name if responses_tool_choice.name else None
|
tool_name = getattr(tool_choice, "name", None)
|
||||||
match responses_tool_choice:
|
match tool_choice:
|
||||||
case OpenAIResponseInputToolChoiceCustomTool():
|
case OpenAIResponseInputToolChoiceCustomTool():
|
||||||
if tool_name and tool_name not in chat_tool_names:
|
if tool_name and tool_name not in chat_tool_names:
|
||||||
logger.warning(f"Tool {tool_name} not found in chat tools")
|
logger.warning(f"Tool {tool_name} not found in chat tools")
|
||||||
|
|
@ -1425,7 +1425,7 @@ async def _process_tool_choice(
|
||||||
case OpenAIResponseInputToolChoiceMCPTool():
|
case OpenAIResponseInputToolChoiceMCPTool():
|
||||||
tool_choice = convert_mcp_tool_choice(
|
tool_choice = convert_mcp_tool_choice(
|
||||||
chat_tool_names,
|
chat_tool_names,
|
||||||
responses_tool_choice.server_label,
|
tool_choice.server_label,
|
||||||
server_label_to_tools,
|
server_label_to_tools,
|
||||||
tool_name,
|
tool_name,
|
||||||
)
|
)
|
||||||
|
|
|
||||||
|
|
@ -5,7 +5,7 @@
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import Any, cast
|
from typing import cast
|
||||||
|
|
||||||
from openai.types.chat import ChatCompletionToolParam
|
from openai.types.chat import ChatCompletionToolParam
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
|
@ -161,8 +161,7 @@ class ChatCompletionContext(BaseModel):
|
||||||
temperature: float | None
|
temperature: float | None
|
||||||
response_format: OpenAIResponseFormatParam
|
response_format: OpenAIResponseFormatParam
|
||||||
tool_context: ToolContext | None
|
tool_context: ToolContext | None
|
||||||
responses_tool_choice: OpenAIResponseInputToolChoice | None = None
|
tool_choice: OpenAIResponseInputToolChoice | None = None
|
||||||
chat_tool_choice: str | dict[str, Any] | None = None
|
|
||||||
approval_requests: list[OpenAIResponseMCPApprovalRequest] = []
|
approval_requests: list[OpenAIResponseMCPApprovalRequest] = []
|
||||||
approval_responses: dict[str, OpenAIResponseMCPApprovalResponse] = {}
|
approval_responses: dict[str, OpenAIResponseMCPApprovalResponse] = {}
|
||||||
|
|
||||||
|
|
@ -175,7 +174,7 @@ class ChatCompletionContext(BaseModel):
|
||||||
response_format: OpenAIResponseFormatParam,
|
response_format: OpenAIResponseFormatParam,
|
||||||
tool_context: ToolContext,
|
tool_context: ToolContext,
|
||||||
inputs: list[OpenAIResponseInput] | str,
|
inputs: list[OpenAIResponseInput] | str,
|
||||||
responses_tool_choice: OpenAIResponseInputToolChoice | None = None,
|
tool_choice: OpenAIResponseInputToolChoice | None = None,
|
||||||
):
|
):
|
||||||
super().__init__(
|
super().__init__(
|
||||||
model=model,
|
model=model,
|
||||||
|
|
@ -184,7 +183,7 @@ class ChatCompletionContext(BaseModel):
|
||||||
temperature=temperature,
|
temperature=temperature,
|
||||||
response_format=response_format,
|
response_format=response_format,
|
||||||
tool_context=tool_context,
|
tool_context=tool_context,
|
||||||
responses_tool_choice=responses_tool_choice,
|
tool_choice=tool_choice,
|
||||||
)
|
)
|
||||||
if not isinstance(inputs, str):
|
if not isinstance(inputs, str):
|
||||||
self.approval_requests = [input for input in inputs if input.type == "mcp_approval_request"]
|
self.approval_requests = [input for input in inputs if input.type == "mcp_approval_request"]
|
||||||
|
|
|
||||||
|
|
@ -577,6 +577,11 @@ class OpenAIChatCompletionToolChoiceFunctionTool(BaseModel):
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
class CustomToolConfig(BaseModel):
|
class CustomToolConfig(BaseModel):
|
||||||
|
"""Custom tool configuration for OpenAI-compatible chat completion requests.
|
||||||
|
|
||||||
|
:param name: Name of the custom tool
|
||||||
|
"""
|
||||||
|
|
||||||
name: str
|
name: str
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -565,12 +565,19 @@ class OpenAIResponseInputToolChoiceFileSearch(BaseModel):
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
class OpenAIResponseInputToolChoiceWebSearch(OpenAIResponseInputToolWebSearch):
|
class OpenAIResponseInputToolChoiceWebSearch(BaseModel):
|
||||||
"""Indicates that the model should use web search to generate a response.
|
"""Indicates that the model should use web search to generate a response
|
||||||
|
|
||||||
This is an alias for OpenAIResponseInputToolWebSearch used in tool_choice context.
|
:param type: Web search tool type variant to use
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
type: (
|
||||||
|
Literal["web_search"]
|
||||||
|
| Literal["web_search_preview"]
|
||||||
|
| Literal["web_search_preview_2025_03_11"]
|
||||||
|
| Literal["web_search_2025_08_26"]
|
||||||
|
) = "web_search"
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
class OpenAIResponseInputToolChoiceFunctionTool(BaseModel):
|
class OpenAIResponseInputToolChoiceFunctionTool(BaseModel):
|
||||||
|
|
|
||||||
|
|
@ -5,9 +5,22 @@
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
|
||||||
|
from llama_stack.providers.inline.agents.meta_reference.responses.streaming import (
|
||||||
|
_process_tool_choice,
|
||||||
|
)
|
||||||
from llama_stack.providers.inline.agents.meta_reference.responses.types import ToolContext
|
from llama_stack.providers.inline.agents.meta_reference.responses.types import ToolContext
|
||||||
from llama_stack_api.openai_responses import (
|
from llama_stack_api import (
|
||||||
MCPListToolsTool,
|
MCPListToolsTool,
|
||||||
|
OpenAIChatCompletionToolChoiceAllowedTools,
|
||||||
|
OpenAIChatCompletionToolChoiceCustomTool,
|
||||||
|
OpenAIChatCompletionToolChoiceFunctionTool,
|
||||||
|
OpenAIResponseInputToolChoiceAllowedTools,
|
||||||
|
OpenAIResponseInputToolChoiceCustomTool,
|
||||||
|
OpenAIResponseInputToolChoiceFileSearch,
|
||||||
|
OpenAIResponseInputToolChoiceFunctionTool,
|
||||||
|
OpenAIResponseInputToolChoiceMCPTool,
|
||||||
|
OpenAIResponseInputToolChoiceMode,
|
||||||
|
OpenAIResponseInputToolChoiceWebSearch,
|
||||||
OpenAIResponseInputToolFileSearch,
|
OpenAIResponseInputToolFileSearch,
|
||||||
OpenAIResponseInputToolFunction,
|
OpenAIResponseInputToolFunction,
|
||||||
OpenAIResponseInputToolMCP,
|
OpenAIResponseInputToolMCP,
|
||||||
|
|
@ -181,3 +194,326 @@ class TestToolContext:
|
||||||
assert len(context.previous_tool_listings) == 1
|
assert len(context.previous_tool_listings) == 1
|
||||||
assert len(context.previous_tool_listings[0].tools) == 1
|
assert len(context.previous_tool_listings[0].tools) == 1
|
||||||
assert context.previous_tool_listings[0].server_label == "anotherlabel"
|
assert context.previous_tool_listings[0].server_label == "anotherlabel"
|
||||||
|
|
||||||
|
|
||||||
|
class TestProcessToolChoice:
|
||||||
|
"""Comprehensive test suite for _process_tool_choice function."""
|
||||||
|
|
||||||
|
def setup_method(self):
|
||||||
|
"""Set up common test fixtures."""
|
||||||
|
self.chat_tools = [
|
||||||
|
{"type": "function", "function": {"name": "get_weather"}},
|
||||||
|
{"type": "function", "function": {"name": "calculate"}},
|
||||||
|
{"type": "function", "function": {"name": "file_search"}},
|
||||||
|
{"type": "function", "function": {"name": "web_search"}},
|
||||||
|
]
|
||||||
|
self.server_label_to_tools = {
|
||||||
|
"mcp_server_1": ["mcp_tool_1", "mcp_tool_2"],
|
||||||
|
"mcp_server_2": ["mcp_tool_3"],
|
||||||
|
}
|
||||||
|
|
||||||
|
async def test_mode_auto(self):
|
||||||
|
"""Test auto mode - should return 'auto' string."""
|
||||||
|
tool_choice = OpenAIResponseInputToolChoiceMode.auto
|
||||||
|
result = await _process_tool_choice(self.chat_tools, tool_choice, self.server_label_to_tools)
|
||||||
|
assert result == "auto"
|
||||||
|
|
||||||
|
async def test_mode_none(self):
|
||||||
|
"""Test none mode - should return 'none' string."""
|
||||||
|
tool_choice = OpenAIResponseInputToolChoiceMode.none
|
||||||
|
result = await _process_tool_choice(self.chat_tools, tool_choice, self.server_label_to_tools)
|
||||||
|
assert result == "none"
|
||||||
|
|
||||||
|
async def test_mode_required_with_tools(self):
|
||||||
|
"""Test required mode with available tools - should return AllowedTools with all function tools."""
|
||||||
|
tool_choice = OpenAIResponseInputToolChoiceMode.required
|
||||||
|
result = await _process_tool_choice(self.chat_tools, tool_choice, self.server_label_to_tools)
|
||||||
|
|
||||||
|
assert isinstance(result, OpenAIChatCompletionToolChoiceAllowedTools)
|
||||||
|
assert result.allowed_tools.mode == "required"
|
||||||
|
assert len(result.allowed_tools.tools) == 4
|
||||||
|
tool_names = [tool["function"]["name"] for tool in result.allowed_tools.tools]
|
||||||
|
assert "get_weather" in tool_names
|
||||||
|
assert "calculate" in tool_names
|
||||||
|
assert "file_search" in tool_names
|
||||||
|
assert "web_search" in tool_names
|
||||||
|
|
||||||
|
async def test_mode_required_without_tools(self):
|
||||||
|
"""Test required mode without available tools - should return None."""
|
||||||
|
tool_choice = OpenAIResponseInputToolChoiceMode.required
|
||||||
|
result = await _process_tool_choice([], tool_choice, self.server_label_to_tools)
|
||||||
|
assert result is None
|
||||||
|
|
||||||
|
async def test_allowed_tools_function(self):
|
||||||
|
"""Test allowed_tools with function tool types."""
|
||||||
|
tool_choice = OpenAIResponseInputToolChoiceAllowedTools(
|
||||||
|
mode="required",
|
||||||
|
tools=[
|
||||||
|
{"type": "function", "name": "get_weather"},
|
||||||
|
{"type": "function", "name": "calculate"},
|
||||||
|
],
|
||||||
|
)
|
||||||
|
result = await _process_tool_choice(self.chat_tools, tool_choice, self.server_label_to_tools)
|
||||||
|
|
||||||
|
assert isinstance(result, OpenAIChatCompletionToolChoiceAllowedTools)
|
||||||
|
assert result.allowed_tools.mode == "required"
|
||||||
|
assert len(result.allowed_tools.tools) == 2
|
||||||
|
assert result.allowed_tools.tools[0]["function"]["name"] == "get_weather"
|
||||||
|
assert result.allowed_tools.tools[1]["function"]["name"] == "calculate"
|
||||||
|
|
||||||
|
async def test_allowed_tools_custom(self):
|
||||||
|
"""Test allowed_tools with custom tool types."""
|
||||||
|
chat_tools = [{"type": "function", "function": {"name": "custom_tool_1"}}]
|
||||||
|
tool_choice = OpenAIResponseInputToolChoiceAllowedTools(
|
||||||
|
mode="auto",
|
||||||
|
tools=[{"type": "custom", "name": "custom_tool_1"}],
|
||||||
|
)
|
||||||
|
result = await _process_tool_choice(chat_tools, tool_choice, self.server_label_to_tools)
|
||||||
|
|
||||||
|
assert isinstance(result, OpenAIChatCompletionToolChoiceAllowedTools)
|
||||||
|
assert result.allowed_tools.mode == "auto"
|
||||||
|
assert len(result.allowed_tools.tools) == 1
|
||||||
|
assert result.allowed_tools.tools[0]["type"] == "custom"
|
||||||
|
assert result.allowed_tools.tools[0]["custom"]["name"] == "custom_tool_1"
|
||||||
|
|
||||||
|
async def test_allowed_tools_file_search(self):
|
||||||
|
"""Test allowed_tools with file_search."""
|
||||||
|
tool_choice = OpenAIResponseInputToolChoiceAllowedTools(
|
||||||
|
mode="required",
|
||||||
|
tools=[{"type": "file_search"}],
|
||||||
|
)
|
||||||
|
result = await _process_tool_choice(self.chat_tools, tool_choice, self.server_label_to_tools)
|
||||||
|
|
||||||
|
assert isinstance(result, OpenAIChatCompletionToolChoiceAllowedTools)
|
||||||
|
assert len(result.allowed_tools.tools) == 1
|
||||||
|
assert result.allowed_tools.tools[0]["function"]["name"] == "file_search"
|
||||||
|
|
||||||
|
async def test_allowed_tools_web_search(self):
|
||||||
|
"""Test allowed_tools with web_search."""
|
||||||
|
tool_choice = OpenAIResponseInputToolChoiceAllowedTools(
|
||||||
|
mode="required",
|
||||||
|
tools=[
|
||||||
|
{"type": "web_search_preview_2025_03_11"},
|
||||||
|
{"type": "web_search_2025_08_26"},
|
||||||
|
{"type": "web_search_preview"},
|
||||||
|
],
|
||||||
|
)
|
||||||
|
result = await _process_tool_choice(self.chat_tools, tool_choice, self.server_label_to_tools)
|
||||||
|
|
||||||
|
assert isinstance(result, OpenAIChatCompletionToolChoiceAllowedTools)
|
||||||
|
assert len(result.allowed_tools.tools) == 3
|
||||||
|
assert result.allowed_tools.tools[0]["function"]["name"] == "web_search"
|
||||||
|
assert result.allowed_tools.tools[0]["type"] == "function"
|
||||||
|
assert result.allowed_tools.tools[1]["function"]["name"] == "web_search"
|
||||||
|
assert result.allowed_tools.tools[1]["type"] == "function"
|
||||||
|
assert result.allowed_tools.tools[2]["function"]["name"] == "web_search"
|
||||||
|
assert result.allowed_tools.tools[2]["type"] == "function"
|
||||||
|
|
||||||
|
async def test_allowed_tools_mcp_server_label(self):
|
||||||
|
"""Test allowed_tools with MCP server label (no specific tool name)."""
|
||||||
|
chat_tools = [
|
||||||
|
{"type": "function", "function": {"name": "mcp_tool_1"}},
|
||||||
|
{"type": "function", "function": {"name": "mcp_tool_2"}},
|
||||||
|
]
|
||||||
|
tool_choice = OpenAIResponseInputToolChoiceAllowedTools(
|
||||||
|
mode="required",
|
||||||
|
tools=[{"type": "mcp", "server_label": "mcp_server_1"}],
|
||||||
|
)
|
||||||
|
result = await _process_tool_choice(chat_tools, tool_choice, self.server_label_to_tools)
|
||||||
|
|
||||||
|
assert isinstance(result, OpenAIChatCompletionToolChoiceAllowedTools)
|
||||||
|
assert len(result.allowed_tools.tools) == 2
|
||||||
|
tool_names = [tool["function"]["name"] for tool in result.allowed_tools.tools]
|
||||||
|
assert "mcp_tool_1" in tool_names
|
||||||
|
assert "mcp_tool_2" in tool_names
|
||||||
|
|
||||||
|
async def test_allowed_tools_mixed_types(self):
|
||||||
|
"""Test allowed_tools with mixed tool types."""
|
||||||
|
chat_tools = [
|
||||||
|
{"type": "function", "function": {"name": "get_weather"}},
|
||||||
|
{"type": "function", "function": {"name": "file_search"}},
|
||||||
|
{"type": "function", "function": {"name": "web_search"}},
|
||||||
|
{"type": "function", "function": {"name": "mcp_tool_1"}},
|
||||||
|
]
|
||||||
|
tool_choice = OpenAIResponseInputToolChoiceAllowedTools(
|
||||||
|
mode="auto",
|
||||||
|
tools=[
|
||||||
|
{"type": "function", "name": "get_weather"},
|
||||||
|
{"type": "file_search"},
|
||||||
|
{"type": "web_search"},
|
||||||
|
{"type": "mcp", "server_label": "mcp_server_1"},
|
||||||
|
],
|
||||||
|
)
|
||||||
|
result = await _process_tool_choice(chat_tools, tool_choice, self.server_label_to_tools)
|
||||||
|
|
||||||
|
assert isinstance(result, OpenAIChatCompletionToolChoiceAllowedTools)
|
||||||
|
# Should have: get_weather, file_search, web_search, mcp_tool_1, mcp_tool_2
|
||||||
|
assert len(result.allowed_tools.tools) >= 3
|
||||||
|
|
||||||
|
async def test_allowed_tools_invalid_type(self):
|
||||||
|
"""Test allowed_tools with an unsupported tool type - should skip it."""
|
||||||
|
tool_choice = OpenAIResponseInputToolChoiceAllowedTools(
|
||||||
|
mode="required",
|
||||||
|
tools=[
|
||||||
|
{"type": "function", "name": "get_weather"},
|
||||||
|
{"type": "unsupported_type", "name": "bad_tool"},
|
||||||
|
],
|
||||||
|
)
|
||||||
|
result = await _process_tool_choice(self.chat_tools, tool_choice, self.server_label_to_tools)
|
||||||
|
|
||||||
|
assert isinstance(result, OpenAIChatCompletionToolChoiceAllowedTools)
|
||||||
|
# Should only include the valid function tool
|
||||||
|
assert len(result.allowed_tools.tools) == 1
|
||||||
|
assert result.allowed_tools.tools[0]["function"]["name"] == "get_weather"
|
||||||
|
|
||||||
|
async def test_specific_custom_tool_valid(self):
|
||||||
|
"""Test specific custom tool choice when tool exists."""
|
||||||
|
chat_tools = [{"type": "function", "function": {"name": "custom_tool"}}]
|
||||||
|
tool_choice = OpenAIResponseInputToolChoiceCustomTool(name="custom_tool")
|
||||||
|
result = await _process_tool_choice(chat_tools, tool_choice, self.server_label_to_tools)
|
||||||
|
|
||||||
|
assert isinstance(result, OpenAIChatCompletionToolChoiceCustomTool)
|
||||||
|
assert result.custom.name == "custom_tool"
|
||||||
|
|
||||||
|
async def test_specific_custom_tool_invalid(self):
|
||||||
|
"""Test specific custom tool choice when tool doesn't exist - should return None."""
|
||||||
|
tool_choice = OpenAIResponseInputToolChoiceCustomTool(name="nonexistent_tool")
|
||||||
|
result = await _process_tool_choice(self.chat_tools, tool_choice, self.server_label_to_tools)
|
||||||
|
assert result is None
|
||||||
|
|
||||||
|
async def test_specific_function_tool_valid(self):
|
||||||
|
"""Test specific function tool choice when tool exists."""
|
||||||
|
tool_choice = OpenAIResponseInputToolChoiceFunctionTool(name="get_weather")
|
||||||
|
result = await _process_tool_choice(self.chat_tools, tool_choice, self.server_label_to_tools)
|
||||||
|
|
||||||
|
assert isinstance(result, OpenAIChatCompletionToolChoiceFunctionTool)
|
||||||
|
assert result.function.name == "get_weather"
|
||||||
|
|
||||||
|
async def test_specific_function_tool_invalid(self):
|
||||||
|
"""Test specific function tool choice when tool doesn't exist - should return None."""
|
||||||
|
tool_choice = OpenAIResponseInputToolChoiceFunctionTool(name="nonexistent_function")
|
||||||
|
result = await _process_tool_choice(self.chat_tools, tool_choice, self.server_label_to_tools)
|
||||||
|
assert result is None
|
||||||
|
|
||||||
|
async def test_specific_file_search_valid(self):
|
||||||
|
"""Test file_search tool choice when available."""
|
||||||
|
tool_choice = OpenAIResponseInputToolChoiceFileSearch()
|
||||||
|
result = await _process_tool_choice(self.chat_tools, tool_choice, self.server_label_to_tools)
|
||||||
|
|
||||||
|
assert isinstance(result, OpenAIChatCompletionToolChoiceFunctionTool)
|
||||||
|
assert result.function.name == "file_search"
|
||||||
|
|
||||||
|
async def test_specific_file_search_invalid(self):
|
||||||
|
"""Test file_search tool choice when not available - should return None."""
|
||||||
|
chat_tools = [{"type": "function", "function": {"name": "get_weather"}}]
|
||||||
|
tool_choice = OpenAIResponseInputToolChoiceFileSearch()
|
||||||
|
result = await _process_tool_choice(chat_tools, tool_choice, self.server_label_to_tools)
|
||||||
|
assert result is None
|
||||||
|
|
||||||
|
async def test_specific_web_search_valid(self):
|
||||||
|
"""Test web_search tool choice when available."""
|
||||||
|
tool_choice = OpenAIResponseInputToolChoiceWebSearch()
|
||||||
|
result = await _process_tool_choice(self.chat_tools, tool_choice, self.server_label_to_tools)
|
||||||
|
|
||||||
|
assert isinstance(result, OpenAIChatCompletionToolChoiceFunctionTool)
|
||||||
|
assert result.function.name == "web_search"
|
||||||
|
|
||||||
|
async def test_specific_web_search_invalid(self):
|
||||||
|
"""Test web_search tool choice when not available - should return None."""
|
||||||
|
chat_tools = [{"type": "function", "function": {"name": "get_weather"}}]
|
||||||
|
tool_choice = OpenAIResponseInputToolChoiceWebSearch()
|
||||||
|
result = await _process_tool_choice(chat_tools, tool_choice, self.server_label_to_tools)
|
||||||
|
assert result is None
|
||||||
|
|
||||||
|
async def test_specific_mcp_tool_with_name(self):
|
||||||
|
"""Test MCP tool choice with specific tool name."""
|
||||||
|
chat_tools = [{"type": "function", "function": {"name": "mcp_tool_1"}}]
|
||||||
|
tool_choice = OpenAIResponseInputToolChoiceMCPTool(
|
||||||
|
server_label="mcp_server_1",
|
||||||
|
name="mcp_tool_1",
|
||||||
|
)
|
||||||
|
result = await _process_tool_choice(chat_tools, tool_choice, self.server_label_to_tools)
|
||||||
|
|
||||||
|
assert isinstance(result, OpenAIChatCompletionToolChoiceFunctionTool)
|
||||||
|
assert result.function.name == "mcp_tool_1"
|
||||||
|
|
||||||
|
async def test_specific_mcp_tool_with_name_not_in_chat_tools(self):
|
||||||
|
"""Test MCP tool choice with specific tool name that doesn't exist in chat_tools."""
|
||||||
|
chat_tools = [{"type": "function", "function": {"name": "other_tool"}}]
|
||||||
|
tool_choice = OpenAIResponseInputToolChoiceMCPTool(
|
||||||
|
server_label="mcp_server_1",
|
||||||
|
name="mcp_tool_1",
|
||||||
|
)
|
||||||
|
result = await _process_tool_choice(chat_tools, tool_choice, self.server_label_to_tools)
|
||||||
|
assert result is None
|
||||||
|
|
||||||
|
async def test_specific_mcp_tool_server_label_only(self):
|
||||||
|
"""Test MCP tool choice with only server label (no specific tool name)."""
|
||||||
|
chat_tools = [
|
||||||
|
{"type": "function", "function": {"name": "mcp_tool_1"}},
|
||||||
|
{"type": "function", "function": {"name": "mcp_tool_2"}},
|
||||||
|
]
|
||||||
|
tool_choice = OpenAIResponseInputToolChoiceMCPTool(server_label="mcp_server_1")
|
||||||
|
result = await _process_tool_choice(chat_tools, tool_choice, self.server_label_to_tools)
|
||||||
|
|
||||||
|
assert isinstance(result, OpenAIChatCompletionToolChoiceAllowedTools)
|
||||||
|
assert result.allowed_tools.mode == "required"
|
||||||
|
assert len(result.allowed_tools.tools) == 2
|
||||||
|
tool_names = [tool["function"]["name"] for tool in result.allowed_tools.tools]
|
||||||
|
assert "mcp_tool_1" in tool_names
|
||||||
|
assert "mcp_tool_2" in tool_names
|
||||||
|
|
||||||
|
async def test_specific_mcp_tool_unknown_server(self):
|
||||||
|
"""Test MCP tool choice with unknown server label."""
|
||||||
|
tool_choice = OpenAIResponseInputToolChoiceMCPTool(
|
||||||
|
server_label="unknown_server",
|
||||||
|
name="some_tool",
|
||||||
|
)
|
||||||
|
result = await _process_tool_choice(self.chat_tools, tool_choice, self.server_label_to_tools)
|
||||||
|
# Should return None because server not found
|
||||||
|
assert result is None
|
||||||
|
|
||||||
|
async def test_empty_chat_tools(self):
|
||||||
|
"""Test with empty chat_tools list."""
|
||||||
|
tool_choice = OpenAIResponseInputToolChoiceFunctionTool(name="get_weather")
|
||||||
|
result = await _process_tool_choice([], tool_choice, self.server_label_to_tools)
|
||||||
|
assert result is None
|
||||||
|
|
||||||
|
async def test_empty_server_label_to_tools(self):
|
||||||
|
"""Test with empty server_label_to_tools mapping."""
|
||||||
|
tool_choice = OpenAIResponseInputToolChoiceMCPTool(server_label="mcp_server_1")
|
||||||
|
result = await _process_tool_choice(self.chat_tools, tool_choice, {})
|
||||||
|
# Should handle gracefully
|
||||||
|
assert result is None or isinstance(result, OpenAIChatCompletionToolChoiceAllowedTools)
|
||||||
|
|
||||||
|
async def test_allowed_tools_empty_list(self):
|
||||||
|
"""Test allowed_tools with empty tools list."""
|
||||||
|
tool_choice = OpenAIResponseInputToolChoiceAllowedTools(mode="auto", tools=[])
|
||||||
|
result = await _process_tool_choice(self.chat_tools, tool_choice, self.server_label_to_tools)
|
||||||
|
|
||||||
|
assert isinstance(result, OpenAIChatCompletionToolChoiceAllowedTools)
|
||||||
|
assert len(result.allowed_tools.tools) == 0
|
||||||
|
|
||||||
|
async def test_mcp_tool_multiple_servers(self):
|
||||||
|
"""Test MCP tool choice with multiple server labels."""
|
||||||
|
chat_tools = [
|
||||||
|
{"type": "function", "function": {"name": "mcp_tool_1"}},
|
||||||
|
{"type": "function", "function": {"name": "mcp_tool_2"}},
|
||||||
|
{"type": "function", "function": {"name": "mcp_tool_3"}},
|
||||||
|
]
|
||||||
|
server_label_to_tools = {
|
||||||
|
"server_a": ["mcp_tool_1"],
|
||||||
|
"server_b": ["mcp_tool_2", "mcp_tool_3"],
|
||||||
|
}
|
||||||
|
|
||||||
|
# Test server_a
|
||||||
|
tool_choice_a = OpenAIResponseInputToolChoiceMCPTool(server_label="server_a")
|
||||||
|
result_a = await _process_tool_choice(chat_tools, tool_choice_a, server_label_to_tools)
|
||||||
|
assert isinstance(result_a, OpenAIChatCompletionToolChoiceAllowedTools)
|
||||||
|
assert len(result_a.allowed_tools.tools) == 1
|
||||||
|
|
||||||
|
# Test server_b
|
||||||
|
tool_choice_b = OpenAIResponseInputToolChoiceMCPTool(server_label="server_b")
|
||||||
|
result_b = await _process_tool_choice(chat_tools, tool_choice_b, server_label_to_tools)
|
||||||
|
assert isinstance(result_b, OpenAIChatCompletionToolChoiceAllowedTools)
|
||||||
|
assert len(result_b.allowed_tools.tools) == 2
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue