mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-10-04 04:04:14 +00:00
feat(tools)!: substantial clean up of "Tool" related datatypes (#3627)
This is a sweeping change to clean up some gunk around our "Tool" definitions. First, we had two types `Tool` and `ToolDef`. The first of these was a "Resource" type for the registry but we had stopped registering tools inside the Registry long back (and only registered ToolGroups.) The latter was for specifying tools for the Agents API. This PR removes the former and adds an optional `toolgroup_id` field to the latter. Secondly, as pointed out by @bbrowning in https://github.com/llamastack/llama-stack/pull/3003#issuecomment-3245270132, we were doing a lossy conversion from a full JSON schema from the MCP tool specification into our ToolDefinition to send it to the model. There is no necessity to do this -- we ourselves aren't doing any execution at all but merely passing it to the chat completions API which supports this. By doing this (and by doing it poorly), we encountered limitations like not supporting array items, or not resolving $refs, etc. To fix this, we replaced the `parameters` field by `{ input_schema, output_schema }` which can be full blown JSON schemas. Finally, there were some types in our llama-related chat format conversion which needed some cleanup. We are taking this opportunity to clean those up. This PR is a substantial breaking change to the API. However, given our window for introducing breaking changes, this suits us just fine. I will be landing a concurrent `llama-stack-client` change as well since API shapes are changing.
This commit is contained in:
parent
1f5003d50e
commit
ef0736527d
179 changed files with 34186 additions and 9171 deletions
|
@ -16,7 +16,7 @@ from llama_stack.apis.datasets.datasets import Dataset, DatasetPurpose, URIDataS
|
|||
from llama_stack.apis.datatypes import Api
|
||||
from llama_stack.apis.models import Model, ModelType
|
||||
from llama_stack.apis.shields.shields import Shield
|
||||
from llama_stack.apis.tools import ListToolDefsResponse, ToolDef, ToolGroup, ToolParameter
|
||||
from llama_stack.apis.tools import ListToolDefsResponse, ToolDef, ToolGroup
|
||||
from llama_stack.apis.vector_dbs import VectorDB
|
||||
from llama_stack.core.datatypes import RegistryEntrySource
|
||||
from llama_stack.core.routing_tables.benchmarks import BenchmarksRoutingTable
|
||||
|
@ -137,7 +137,10 @@ class ToolGroupsImpl(Impl):
|
|||
ToolDef(
|
||||
name="test-tool",
|
||||
description="Test tool",
|
||||
parameters=[ToolParameter(name="test-param", description="Test param", parameter_type="string")],
|
||||
input_schema={
|
||||
"type": "object",
|
||||
"properties": {"test-param": {"type": "string", "description": "Test param"}},
|
||||
},
|
||||
)
|
||||
]
|
||||
)
|
||||
|
|
|
@ -18,7 +18,6 @@ from llama_stack.apis.inference import (
|
|||
from llama_stack.models.llama.datatypes import (
|
||||
BuiltinTool,
|
||||
ToolDefinition,
|
||||
ToolParamDefinition,
|
||||
ToolPromptFormat,
|
||||
)
|
||||
from llama_stack.providers.utils.inference.prompt_adapter import (
|
||||
|
@ -75,12 +74,15 @@ async def test_system_custom_only():
|
|||
ToolDefinition(
|
||||
tool_name="custom1",
|
||||
description="custom1 tool",
|
||||
parameters={
|
||||
"param1": ToolParamDefinition(
|
||||
param_type="str",
|
||||
description="param1 description",
|
||||
required=True,
|
||||
),
|
||||
input_schema={
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"param1": {
|
||||
"type": "str",
|
||||
"description": "param1 description",
|
||||
},
|
||||
},
|
||||
"required": ["param1"],
|
||||
},
|
||||
)
|
||||
],
|
||||
|
@ -107,12 +109,15 @@ async def test_system_custom_and_builtin():
|
|||
ToolDefinition(
|
||||
tool_name="custom1",
|
||||
description="custom1 tool",
|
||||
parameters={
|
||||
"param1": ToolParamDefinition(
|
||||
param_type="str",
|
||||
description="param1 description",
|
||||
required=True,
|
||||
),
|
||||
input_schema={
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"param1": {
|
||||
"type": "str",
|
||||
"description": "param1 description",
|
||||
},
|
||||
},
|
||||
"required": ["param1"],
|
||||
},
|
||||
),
|
||||
],
|
||||
|
@ -138,7 +143,7 @@ async def test_completion_message_encoding():
|
|||
tool_calls=[
|
||||
ToolCall(
|
||||
tool_name="custom1",
|
||||
arguments={"param1": "value1"},
|
||||
arguments='{"param1": "value1"}', # arguments must be a JSON string
|
||||
call_id="123",
|
||||
)
|
||||
],
|
||||
|
@ -148,12 +153,15 @@ async def test_completion_message_encoding():
|
|||
ToolDefinition(
|
||||
tool_name="custom1",
|
||||
description="custom1 tool",
|
||||
parameters={
|
||||
"param1": ToolParamDefinition(
|
||||
param_type="str",
|
||||
description="param1 description",
|
||||
required=True,
|
||||
),
|
||||
input_schema={
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"param1": {
|
||||
"type": "str",
|
||||
"description": "param1 description",
|
||||
},
|
||||
},
|
||||
"required": ["param1"],
|
||||
},
|
||||
),
|
||||
],
|
||||
|
@ -227,12 +235,15 @@ async def test_replace_system_message_behavior_custom_tools():
|
|||
ToolDefinition(
|
||||
tool_name="custom1",
|
||||
description="custom1 tool",
|
||||
parameters={
|
||||
"param1": ToolParamDefinition(
|
||||
param_type="str",
|
||||
description="param1 description",
|
||||
required=True,
|
||||
),
|
||||
input_schema={
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"param1": {
|
||||
"type": "str",
|
||||
"description": "param1 description",
|
||||
},
|
||||
},
|
||||
"required": ["param1"],
|
||||
},
|
||||
),
|
||||
],
|
||||
|
@ -264,12 +275,15 @@ async def test_replace_system_message_behavior_custom_tools_with_template():
|
|||
ToolDefinition(
|
||||
tool_name="custom1",
|
||||
description="custom1 tool",
|
||||
parameters={
|
||||
"param1": ToolParamDefinition(
|
||||
param_type="str",
|
||||
description="param1 description",
|
||||
required=True,
|
||||
),
|
||||
input_schema={
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"param1": {
|
||||
"type": "str",
|
||||
"description": "param1 description",
|
||||
},
|
||||
},
|
||||
"required": ["param1"],
|
||||
},
|
||||
),
|
||||
],
|
||||
|
|
|
@ -16,9 +16,8 @@ from llama_stack.apis.agents import (
|
|||
)
|
||||
from llama_stack.apis.common.responses import PaginatedResponse
|
||||
from llama_stack.apis.inference import Inference
|
||||
from llama_stack.apis.resource import ResourceType
|
||||
from llama_stack.apis.safety import Safety
|
||||
from llama_stack.apis.tools import ListToolsResponse, Tool, ToolGroups, ToolParameter, ToolRuntime
|
||||
from llama_stack.apis.tools import ListToolDefsResponse, ToolDef, ToolGroups, ToolRuntime
|
||||
from llama_stack.apis.vector_io import VectorIO
|
||||
from llama_stack.providers.inline.agents.meta_reference.agent_instance import ChatAgent
|
||||
from llama_stack.providers.inline.agents.meta_reference.agents import MetaReferenceAgentsImpl
|
||||
|
@ -232,32 +231,26 @@ async def test_delete_agent(agents_impl, sample_agent_config):
|
|||
|
||||
async def test__initialize_tools(agents_impl, sample_agent_config):
|
||||
# Mock tool_groups_api.list_tools()
|
||||
agents_impl.tool_groups_api.list_tools.return_value = ListToolsResponse(
|
||||
agents_impl.tool_groups_api.list_tools.return_value = ListToolDefsResponse(
|
||||
data=[
|
||||
Tool(
|
||||
identifier="story_maker",
|
||||
provider_id="model-context-protocol",
|
||||
type=ResourceType.tool,
|
||||
ToolDef(
|
||||
name="story_maker",
|
||||
toolgroup_id="mcp::my_mcp_server",
|
||||
description="Make a story",
|
||||
parameters=[
|
||||
ToolParameter(
|
||||
name="story_title",
|
||||
parameter_type="string",
|
||||
description="Title of the story",
|
||||
required=True,
|
||||
title="Story Title",
|
||||
),
|
||||
ToolParameter(
|
||||
name="input_words",
|
||||
parameter_type="array",
|
||||
description="Input words",
|
||||
required=False,
|
||||
items={"type": "string"},
|
||||
title="Input Words",
|
||||
default=[],
|
||||
),
|
||||
],
|
||||
input_schema={
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"story_title": {"type": "string", "description": "Title of the story", "title": "Story Title"},
|
||||
"input_words": {
|
||||
"type": "array",
|
||||
"description": "Input words",
|
||||
"items": {"type": "string"},
|
||||
"title": "Input Words",
|
||||
"default": [],
|
||||
},
|
||||
},
|
||||
"required": ["story_title"],
|
||||
},
|
||||
)
|
||||
]
|
||||
)
|
||||
|
@ -284,27 +277,27 @@ async def test__initialize_tools(agents_impl, sample_agent_config):
|
|||
assert second_tool.tool_name == "story_maker"
|
||||
assert second_tool.description == "Make a story"
|
||||
|
||||
parameters = second_tool.parameters
|
||||
assert len(parameters) == 2
|
||||
# Verify the input schema
|
||||
input_schema = second_tool.input_schema
|
||||
assert input_schema is not None
|
||||
assert input_schema["type"] == "object"
|
||||
|
||||
properties = input_schema["properties"]
|
||||
assert len(properties) == 2
|
||||
|
||||
# Verify a string property
|
||||
story_title = parameters.get("story_title")
|
||||
assert story_title is not None
|
||||
assert story_title.param_type == "string"
|
||||
assert story_title.description == "Title of the story"
|
||||
assert story_title.required
|
||||
assert story_title.items is None
|
||||
assert story_title.title == "Story Title"
|
||||
assert story_title.default is None
|
||||
story_title = properties["story_title"]
|
||||
assert story_title["type"] == "string"
|
||||
assert story_title["description"] == "Title of the story"
|
||||
assert story_title["title"] == "Story Title"
|
||||
|
||||
# Verify an array property
|
||||
input_words = parameters.get("input_words")
|
||||
assert input_words is not None
|
||||
assert input_words.param_type == "array"
|
||||
assert input_words.description == "Input words"
|
||||
assert not input_words.required
|
||||
assert input_words.items is not None
|
||||
assert len(input_words.items) == 1
|
||||
assert input_words.items.get("type") == "string"
|
||||
assert input_words.title == "Input Words"
|
||||
assert input_words.default == []
|
||||
input_words = properties["input_words"]
|
||||
assert input_words["type"] == "array"
|
||||
assert input_words["description"] == "Input words"
|
||||
assert input_words["items"]["type"] == "string"
|
||||
assert input_words["title"] == "Input Words"
|
||||
assert input_words["default"] == []
|
||||
|
||||
# Verify required fields
|
||||
assert input_schema["required"] == ["story_title"]
|
||||
|
|
|
@ -39,7 +39,7 @@ from llama_stack.apis.inference import (
|
|||
OpenAIResponseFormatJSONSchema,
|
||||
OpenAIUserMessageParam,
|
||||
)
|
||||
from llama_stack.apis.tools.tools import Tool, ToolGroups, ToolInvocationResult, ToolParameter, ToolRuntime
|
||||
from llama_stack.apis.tools.tools import ToolDef, ToolGroups, ToolInvocationResult, ToolRuntime
|
||||
from llama_stack.core.access_control.access_control import default_policy
|
||||
from llama_stack.core.datatypes import ResponsesStoreConfig
|
||||
from llama_stack.providers.inline.agents.meta_reference.responses.openai_responses import (
|
||||
|
@ -186,14 +186,15 @@ async def test_create_openai_response_with_string_input_with_tools(openai_respon
|
|||
input_text = "What is the capital of Ireland?"
|
||||
model = "meta-llama/Llama-3.1-8B-Instruct"
|
||||
|
||||
openai_responses_impl.tool_groups_api.get_tool.return_value = Tool(
|
||||
identifier="web_search",
|
||||
provider_id="client",
|
||||
openai_responses_impl.tool_groups_api.get_tool.return_value = ToolDef(
|
||||
name="web_search",
|
||||
toolgroup_id="web_search",
|
||||
description="Search the web for information",
|
||||
parameters=[
|
||||
ToolParameter(name="query", parameter_type="string", description="The query to search for", required=True)
|
||||
],
|
||||
input_schema={
|
||||
"type": "object",
|
||||
"properties": {"query": {"type": "string", "description": "The query to search for"}},
|
||||
"required": ["query"],
|
||||
},
|
||||
)
|
||||
|
||||
openai_responses_impl.tool_runtime_api.invoke_tool.return_value = ToolInvocationResult(
|
||||
|
|
|
@ -138,8 +138,7 @@ async def test_tool_call_response(vllm_inference_adapter):
|
|||
ToolCall(
|
||||
call_id="foo",
|
||||
tool_name="knowledge_search",
|
||||
arguments={"query": "How many?"},
|
||||
arguments_json='{"query": "How many?"}',
|
||||
arguments='{"query": "How many?"}',
|
||||
)
|
||||
],
|
||||
),
|
||||
|
@ -263,7 +262,7 @@ async def test_tool_call_delta_streaming_arguments_dict():
|
|||
assert chunks[1].event.event_type.value == "progress"
|
||||
assert chunks[1].event.delta.type == "tool_call"
|
||||
assert chunks[1].event.delta.parse_status.value == "succeeded"
|
||||
assert chunks[1].event.delta.tool_call.arguments_json == '{"number": 28, "power": 3}'
|
||||
assert chunks[1].event.delta.tool_call.arguments == '{"number": 28, "power": 3}'
|
||||
assert chunks[2].event.event_type.value == "complete"
|
||||
|
||||
|
||||
|
@ -339,11 +338,11 @@ async def test_multiple_tool_calls():
|
|||
assert chunks[1].event.event_type.value == "progress"
|
||||
assert chunks[1].event.delta.type == "tool_call"
|
||||
assert chunks[1].event.delta.parse_status.value == "succeeded"
|
||||
assert chunks[1].event.delta.tool_call.arguments_json == '{"number": 28, "power": 3}'
|
||||
assert chunks[1].event.delta.tool_call.arguments == '{"number": 28, "power": 3}'
|
||||
assert chunks[2].event.event_type.value == "progress"
|
||||
assert chunks[2].event.delta.type == "tool_call"
|
||||
assert chunks[2].event.delta.parse_status.value == "succeeded"
|
||||
assert chunks[2].event.delta.tool_call.arguments_json == '{"first_number": 4, "second_number": 7}'
|
||||
assert chunks[2].event.delta.tool_call.arguments == '{"first_number": 4, "second_number": 7}'
|
||||
assert chunks[3].event.event_type.value == "complete"
|
||||
|
||||
|
||||
|
@ -456,7 +455,7 @@ async def test_process_vllm_chat_completion_stream_response_tool_call_args_last_
|
|||
assert chunks[-1].event.event_type == ChatCompletionResponseEventType.complete
|
||||
assert chunks[-2].event.delta.type == "tool_call"
|
||||
assert chunks[-2].event.delta.tool_call.tool_name == mock_tool_name
|
||||
assert chunks[-2].event.delta.tool_call.arguments == mock_tool_arguments
|
||||
assert chunks[-2].event.delta.tool_call.arguments == mock_tool_arguments_str
|
||||
|
||||
|
||||
async def test_process_vllm_chat_completion_stream_response_no_finish_reason():
|
||||
|
@ -468,7 +467,7 @@ async def test_process_vllm_chat_completion_stream_response_no_finish_reason():
|
|||
|
||||
mock_tool_name = "mock_tool"
|
||||
mock_tool_arguments = {"arg1": 0, "arg2": 100}
|
||||
mock_tool_arguments_str = '"{\\"arg1\\": 0, \\"arg2\\": 100}"'
|
||||
mock_tool_arguments_str = json.dumps(mock_tool_arguments)
|
||||
|
||||
async def mock_stream():
|
||||
mock_chunks = [
|
||||
|
@ -508,7 +507,7 @@ async def test_process_vllm_chat_completion_stream_response_no_finish_reason():
|
|||
assert chunks[-1].event.event_type == ChatCompletionResponseEventType.complete
|
||||
assert chunks[-2].event.delta.type == "tool_call"
|
||||
assert chunks[-2].event.delta.tool_call.tool_name == mock_tool_name
|
||||
assert chunks[-2].event.delta.tool_call.arguments == mock_tool_arguments
|
||||
assert chunks[-2].event.delta.tool_call.arguments == mock_tool_arguments_str
|
||||
|
||||
|
||||
async def test_process_vllm_chat_completion_stream_response_tool_without_args():
|
||||
|
@ -556,7 +555,7 @@ async def test_process_vllm_chat_completion_stream_response_tool_without_args():
|
|||
assert chunks[-1].event.event_type == ChatCompletionResponseEventType.complete
|
||||
assert chunks[-2].event.delta.type == "tool_call"
|
||||
assert chunks[-2].event.delta.tool_call.tool_name == mock_tool_name
|
||||
assert chunks[-2].event.delta.tool_call.arguments == {}
|
||||
assert chunks[-2].event.delta.tool_call.arguments == "{}"
|
||||
|
||||
|
||||
async def test_health_status_success(vllm_inference_adapter):
|
||||
|
|
|
@ -4,7 +4,7 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from llama_stack.apis.tools import ToolDef, ToolParameter
|
||||
from llama_stack.apis.tools import ToolDef
|
||||
from llama_stack.providers.inline.agents.meta_reference.responses.streaming import (
|
||||
convert_tooldef_to_chat_tool,
|
||||
)
|
||||
|
@ -20,15 +20,11 @@ def test_convert_tooldef_to_chat_tool_preserves_items_field():
|
|||
tool_def = ToolDef(
|
||||
name="test_tool",
|
||||
description="A test tool with array parameter",
|
||||
parameters=[
|
||||
ToolParameter(
|
||||
name="tags",
|
||||
parameter_type="array",
|
||||
description="List of tags",
|
||||
required=True,
|
||||
items={"type": "string"},
|
||||
)
|
||||
],
|
||||
input_schema={
|
||||
"type": "object",
|
||||
"properties": {"tags": {"type": "array", "description": "List of tags", "items": {"type": "string"}}},
|
||||
"required": ["tags"],
|
||||
},
|
||||
)
|
||||
|
||||
result = convert_tooldef_to_chat_tool(tool_def)
|
||||
|
|
|
@ -41,9 +41,7 @@ async def test_convert_message_to_openai_dict():
|
|||
async def test_convert_message_to_openai_dict_with_tool_call():
|
||||
message = CompletionMessage(
|
||||
content="",
|
||||
tool_calls=[
|
||||
ToolCall(call_id="123", tool_name="test_tool", arguments_json='{"foo": "bar"}', arguments={"foo": "bar"})
|
||||
],
|
||||
tool_calls=[ToolCall(call_id="123", tool_name="test_tool", arguments='{"foo": "bar"}')],
|
||||
stop_reason=StopReason.end_of_turn,
|
||||
)
|
||||
|
||||
|
@ -65,8 +63,7 @@ async def test_convert_message_to_openai_dict_with_builtin_tool_call():
|
|||
ToolCall(
|
||||
call_id="123",
|
||||
tool_name=BuiltinTool.brave_search,
|
||||
arguments_json='{"foo": "bar"}',
|
||||
arguments={"foo": "bar"},
|
||||
arguments='{"foo": "bar"}',
|
||||
)
|
||||
],
|
||||
stop_reason=StopReason.end_of_turn,
|
||||
|
@ -202,8 +199,7 @@ async def test_convert_message_to_openai_dict_new_completion_message_with_tool_c
|
|||
ToolCall(
|
||||
call_id="call_123",
|
||||
tool_name="get_weather",
|
||||
arguments={"city": "Sligo"},
|
||||
arguments_json='{"city": "Sligo"}',
|
||||
arguments='{"city": "Sligo"}',
|
||||
)
|
||||
],
|
||||
stop_reason=StopReason.end_of_turn,
|
||||
|
|
381
tests/unit/providers/utils/test_openai_compat_conversion.py
Normal file
381
tests/unit/providers/utils/test_openai_compat_conversion.py
Normal file
|
@ -0,0 +1,381 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
"""
|
||||
Unit tests for OpenAI compatibility tool conversion.
|
||||
Tests convert_tooldef_to_openai_tool with new JSON Schema approach.
|
||||
"""
|
||||
|
||||
from llama_stack.models.llama.datatypes import BuiltinTool, ToolDefinition
|
||||
from llama_stack.providers.utils.inference.openai_compat import convert_tooldef_to_openai_tool
|
||||
|
||||
|
||||
class TestSimpleSchemaConversion:
|
||||
"""Test basic schema conversions to OpenAI format."""
|
||||
|
||||
def test_simple_tool_conversion(self):
|
||||
"""Test conversion of simple tool with basic input schema."""
|
||||
tool = ToolDefinition(
|
||||
tool_name="get_weather",
|
||||
description="Get weather information",
|
||||
input_schema={
|
||||
"type": "object",
|
||||
"properties": {"location": {"type": "string", "description": "City name"}},
|
||||
"required": ["location"],
|
||||
},
|
||||
)
|
||||
|
||||
result = convert_tooldef_to_openai_tool(tool)
|
||||
|
||||
# Check OpenAI structure
|
||||
assert result["type"] == "function"
|
||||
assert "function" in result
|
||||
|
||||
function = result["function"]
|
||||
assert function["name"] == "get_weather"
|
||||
assert function["description"] == "Get weather information"
|
||||
|
||||
# Check parameters are passed through
|
||||
assert "parameters" in function
|
||||
assert function["parameters"] == tool.input_schema
|
||||
assert function["parameters"]["type"] == "object"
|
||||
assert "location" in function["parameters"]["properties"]
|
||||
|
||||
def test_tool_without_description(self):
|
||||
"""Test tool conversion without description."""
|
||||
tool = ToolDefinition(tool_name="test_tool", input_schema={"type": "object", "properties": {}})
|
||||
|
||||
result = convert_tooldef_to_openai_tool(tool)
|
||||
|
||||
assert result["function"]["name"] == "test_tool"
|
||||
assert "description" not in result["function"]
|
||||
assert "parameters" in result["function"]
|
||||
|
||||
def test_builtin_tool_conversion(self):
|
||||
"""Test conversion of BuiltinTool enum."""
|
||||
tool = ToolDefinition(
|
||||
tool_name=BuiltinTool.code_interpreter,
|
||||
description="Run Python code",
|
||||
input_schema={"type": "object", "properties": {"code": {"type": "string"}}},
|
||||
)
|
||||
|
||||
result = convert_tooldef_to_openai_tool(tool)
|
||||
|
||||
# BuiltinTool should be converted to its value
|
||||
assert result["function"]["name"] == "code_interpreter"
|
||||
|
||||
|
||||
class TestComplexSchemaConversion:
|
||||
"""Test conversion of complex JSON Schema features."""
|
||||
|
||||
def test_schema_with_refs_and_defs(self):
|
||||
"""Test that $ref and $defs are passed through to OpenAI."""
|
||||
tool = ToolDefinition(
|
||||
tool_name="book_flight",
|
||||
description="Book a flight",
|
||||
input_schema={
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"flight": {"$ref": "#/$defs/FlightInfo"},
|
||||
"passengers": {"type": "array", "items": {"$ref": "#/$defs/Passenger"}},
|
||||
"payment": {"$ref": "#/$defs/Payment"},
|
||||
},
|
||||
"required": ["flight", "passengers", "payment"],
|
||||
"$defs": {
|
||||
"FlightInfo": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"from": {"type": "string", "description": "Departure airport"},
|
||||
"to": {"type": "string", "description": "Arrival airport"},
|
||||
"date": {"type": "string", "format": "date"},
|
||||
},
|
||||
"required": ["from", "to", "date"],
|
||||
},
|
||||
"Passenger": {
|
||||
"type": "object",
|
||||
"properties": {"name": {"type": "string"}, "age": {"type": "integer", "minimum": 0}},
|
||||
"required": ["name", "age"],
|
||||
},
|
||||
"Payment": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"method": {"type": "string", "enum": ["credit_card", "debit_card"]},
|
||||
"amount": {"type": "number", "minimum": 0},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
)
|
||||
|
||||
result = convert_tooldef_to_openai_tool(tool)
|
||||
|
||||
params = result["function"]["parameters"]
|
||||
|
||||
# Verify $defs are preserved
|
||||
assert "$defs" in params
|
||||
assert "FlightInfo" in params["$defs"]
|
||||
assert "Passenger" in params["$defs"]
|
||||
assert "Payment" in params["$defs"]
|
||||
|
||||
# Verify $ref are preserved
|
||||
assert params["properties"]["flight"]["$ref"] == "#/$defs/FlightInfo"
|
||||
assert params["properties"]["passengers"]["items"]["$ref"] == "#/$defs/Passenger"
|
||||
assert params["properties"]["payment"]["$ref"] == "#/$defs/Payment"
|
||||
|
||||
# Verify nested schema details are preserved
|
||||
assert params["$defs"]["FlightInfo"]["properties"]["date"]["format"] == "date"
|
||||
assert params["$defs"]["Passenger"]["properties"]["age"]["minimum"] == 0
|
||||
assert params["$defs"]["Payment"]["properties"]["method"]["enum"] == ["credit_card", "debit_card"]
|
||||
|
||||
def test_anyof_schema_conversion(self):
|
||||
"""Test conversion of anyOf schemas."""
|
||||
tool = ToolDefinition(
|
||||
tool_name="flexible_input",
|
||||
input_schema={
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"contact": {
|
||||
"anyOf": [
|
||||
{"type": "string", "format": "email"},
|
||||
{"type": "string", "pattern": "^\\+?[0-9]{10,15}$"},
|
||||
],
|
||||
"description": "Email or phone number",
|
||||
}
|
||||
},
|
||||
},
|
||||
)
|
||||
|
||||
result = convert_tooldef_to_openai_tool(tool)
|
||||
|
||||
contact_schema = result["function"]["parameters"]["properties"]["contact"]
|
||||
assert "anyOf" in contact_schema
|
||||
assert len(contact_schema["anyOf"]) == 2
|
||||
assert contact_schema["anyOf"][0]["format"] == "email"
|
||||
assert "pattern" in contact_schema["anyOf"][1]
|
||||
|
||||
def test_nested_objects_conversion(self):
|
||||
"""Test conversion of deeply nested objects."""
|
||||
tool = ToolDefinition(
|
||||
tool_name="nested_data",
|
||||
input_schema={
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"user": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"profile": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"name": {"type": "string"},
|
||||
"settings": {
|
||||
"type": "object",
|
||||
"properties": {"theme": {"type": "string", "enum": ["light", "dark"]}},
|
||||
},
|
||||
},
|
||||
}
|
||||
},
|
||||
}
|
||||
},
|
||||
},
|
||||
)
|
||||
|
||||
result = convert_tooldef_to_openai_tool(tool)
|
||||
|
||||
# Navigate deep structure
|
||||
user_schema = result["function"]["parameters"]["properties"]["user"]
|
||||
profile_schema = user_schema["properties"]["profile"]
|
||||
settings_schema = profile_schema["properties"]["settings"]
|
||||
theme_schema = settings_schema["properties"]["theme"]
|
||||
|
||||
assert theme_schema["enum"] == ["light", "dark"]
|
||||
|
||||
def test_array_schemas_with_constraints(self):
|
||||
"""Test conversion of array schemas with constraints."""
|
||||
tool = ToolDefinition(
|
||||
tool_name="list_processor",
|
||||
input_schema={
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"items": {
|
||||
"type": "array",
|
||||
"items": {
|
||||
"type": "object",
|
||||
"properties": {"id": {"type": "integer"}, "name": {"type": "string"}},
|
||||
"required": ["id"],
|
||||
},
|
||||
"minItems": 1,
|
||||
"maxItems": 100,
|
||||
"uniqueItems": True,
|
||||
}
|
||||
},
|
||||
},
|
||||
)
|
||||
|
||||
result = convert_tooldef_to_openai_tool(tool)
|
||||
|
||||
items_schema = result["function"]["parameters"]["properties"]["items"]
|
||||
assert items_schema["type"] == "array"
|
||||
assert items_schema["minItems"] == 1
|
||||
assert items_schema["maxItems"] == 100
|
||||
assert items_schema["uniqueItems"] is True
|
||||
assert items_schema["items"]["type"] == "object"
|
||||
|
||||
|
||||
class TestOutputSchemaHandling:
|
||||
"""Test that output_schema is correctly handled (or dropped) for OpenAI."""
|
||||
|
||||
def test_output_schema_is_dropped(self):
|
||||
"""Test that output_schema is NOT included in OpenAI format (API limitation)."""
|
||||
tool = ToolDefinition(
|
||||
tool_name="calculator",
|
||||
description="Perform calculation",
|
||||
input_schema={"type": "object", "properties": {"x": {"type": "number"}, "y": {"type": "number"}}},
|
||||
output_schema={"type": "object", "properties": {"result": {"type": "number"}}, "required": ["result"]},
|
||||
)
|
||||
|
||||
result = convert_tooldef_to_openai_tool(tool)
|
||||
|
||||
# OpenAI doesn't support output schema
|
||||
assert "outputSchema" not in result["function"]
|
||||
assert "responseSchema" not in result["function"]
|
||||
assert "output_schema" not in result["function"]
|
||||
|
||||
# But input schema should be present
|
||||
assert "parameters" in result["function"]
|
||||
assert result["function"]["parameters"] == tool.input_schema
|
||||
|
||||
def test_only_output_schema_no_input(self):
|
||||
"""Test tool with only output_schema (unusual but valid)."""
|
||||
tool = ToolDefinition(
|
||||
tool_name="no_input_tool",
|
||||
description="Tool with no inputs",
|
||||
output_schema={"type": "object", "properties": {"timestamp": {"type": "string"}}},
|
||||
)
|
||||
|
||||
result = convert_tooldef_to_openai_tool(tool)
|
||||
|
||||
# No parameters should be set if input_schema is None
|
||||
# (or we might set an empty object schema - implementation detail)
|
||||
assert "outputSchema" not in result["function"]
|
||||
|
||||
|
||||
class TestEdgeCases:
|
||||
"""Test edge cases and error conditions."""
|
||||
|
||||
def test_tool_with_no_schemas(self):
|
||||
"""Test tool with neither input nor output schema."""
|
||||
tool = ToolDefinition(tool_name="schemaless_tool", description="Tool without schemas")
|
||||
|
||||
result = convert_tooldef_to_openai_tool(tool)
|
||||
|
||||
assert result["function"]["name"] == "schemaless_tool"
|
||||
assert result["function"]["description"] == "Tool without schemas"
|
||||
# Implementation detail: might have no parameters or empty object
|
||||
|
||||
def test_empty_input_schema(self):
|
||||
"""Test tool with empty object schema."""
|
||||
tool = ToolDefinition(tool_name="no_params", input_schema={"type": "object", "properties": {}})
|
||||
|
||||
result = convert_tooldef_to_openai_tool(tool)
|
||||
|
||||
assert result["function"]["parameters"]["type"] == "object"
|
||||
assert result["function"]["parameters"]["properties"] == {}
|
||||
|
||||
def test_schema_with_additional_properties(self):
|
||||
"""Test that additionalProperties is preserved."""
|
||||
tool = ToolDefinition(
|
||||
tool_name="flexible_tool",
|
||||
input_schema={
|
||||
"type": "object",
|
||||
"properties": {"known_field": {"type": "string"}},
|
||||
"additionalProperties": True,
|
||||
},
|
||||
)
|
||||
|
||||
result = convert_tooldef_to_openai_tool(tool)
|
||||
|
||||
assert result["function"]["parameters"]["additionalProperties"] is True
|
||||
|
||||
def test_schema_with_pattern_properties(self):
|
||||
"""Test that patternProperties is preserved."""
|
||||
tool = ToolDefinition(
|
||||
tool_name="pattern_tool",
|
||||
input_schema={"type": "object", "patternProperties": {"^[a-z]+$": {"type": "string"}}},
|
||||
)
|
||||
|
||||
result = convert_tooldef_to_openai_tool(tool)
|
||||
|
||||
assert "patternProperties" in result["function"]["parameters"]
|
||||
|
||||
def test_schema_identity(self):
|
||||
"""Test that converted schema is identical to input (no lossy conversion)."""
|
||||
original_schema = {
|
||||
"type": "object",
|
||||
"properties": {"complex": {"$ref": "#/$defs/Complex"}},
|
||||
"$defs": {
|
||||
"Complex": {
|
||||
"type": "object",
|
||||
"properties": {"nested": {"anyOf": [{"type": "string"}, {"type": "number"}]}},
|
||||
}
|
||||
},
|
||||
"required": ["complex"],
|
||||
"additionalProperties": False,
|
||||
}
|
||||
|
||||
tool = ToolDefinition(tool_name="test", input_schema=original_schema)
|
||||
|
||||
result = convert_tooldef_to_openai_tool(tool)
|
||||
|
||||
# Converted parameters should be EXACTLY the same as input
|
||||
assert result["function"]["parameters"] == original_schema
|
||||
|
||||
|
||||
class TestConversionConsistency:
|
||||
"""Test consistency across multiple conversions."""
|
||||
|
||||
def test_multiple_tools_with_shared_defs(self):
|
||||
"""Test converting multiple tools that could share definitions."""
|
||||
tool1 = ToolDefinition(
|
||||
tool_name="tool1",
|
||||
input_schema={
|
||||
"type": "object",
|
||||
"properties": {"data": {"$ref": "#/$defs/Data"}},
|
||||
"$defs": {"Data": {"type": "object", "properties": {"x": {"type": "number"}}}},
|
||||
},
|
||||
)
|
||||
|
||||
tool2 = ToolDefinition(
|
||||
tool_name="tool2",
|
||||
input_schema={
|
||||
"type": "object",
|
||||
"properties": {"info": {"$ref": "#/$defs/Data"}},
|
||||
"$defs": {"Data": {"type": "object", "properties": {"y": {"type": "string"}}}},
|
||||
},
|
||||
)
|
||||
|
||||
result1 = convert_tooldef_to_openai_tool(tool1)
|
||||
result2 = convert_tooldef_to_openai_tool(tool2)
|
||||
|
||||
# Each tool maintains its own $defs independently
|
||||
assert result1["function"]["parameters"]["$defs"]["Data"]["properties"]["x"]["type"] == "number"
|
||||
assert result2["function"]["parameters"]["$defs"]["Data"]["properties"]["y"]["type"] == "string"
|
||||
|
||||
def test_conversion_is_pure(self):
|
||||
"""Test that conversion doesn't modify the original tool."""
|
||||
original_schema = {
|
||||
"type": "object",
|
||||
"properties": {"x": {"type": "string"}},
|
||||
"$defs": {"T": {"type": "number"}},
|
||||
}
|
||||
|
||||
tool = ToolDefinition(tool_name="test", input_schema=original_schema.copy())
|
||||
|
||||
# Convert
|
||||
convert_tooldef_to_openai_tool(tool)
|
||||
|
||||
# Original tool should be unchanged
|
||||
assert tool.input_schema == original_schema
|
||||
assert "$defs" in tool.input_schema
|
297
tests/unit/tools/test_tools_json_schema.py
Normal file
297
tests/unit/tools/test_tools_json_schema.py
Normal file
|
@ -0,0 +1,297 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
"""
|
||||
Unit tests for JSON Schema-based tool definitions.
|
||||
Tests the new input_schema and output_schema fields.
|
||||
"""
|
||||
|
||||
from pydantic import ValidationError
|
||||
|
||||
from llama_stack.apis.tools import ToolDef
|
||||
from llama_stack.models.llama.datatypes import BuiltinTool, ToolDefinition
|
||||
|
||||
|
||||
class TestToolDefValidation:
|
||||
"""Test ToolDef validation with JSON Schema."""
|
||||
|
||||
def test_simple_input_schema(self):
|
||||
"""Test ToolDef with simple input schema."""
|
||||
tool = ToolDef(
|
||||
name="get_weather",
|
||||
description="Get weather information",
|
||||
input_schema={
|
||||
"type": "object",
|
||||
"properties": {"location": {"type": "string", "description": "City name"}},
|
||||
"required": ["location"],
|
||||
},
|
||||
)
|
||||
|
||||
assert tool.name == "get_weather"
|
||||
assert tool.input_schema["type"] == "object"
|
||||
assert "location" in tool.input_schema["properties"]
|
||||
assert tool.output_schema is None
|
||||
|
||||
def test_input_and_output_schema(self):
|
||||
"""Test ToolDef with both input and output schemas."""
|
||||
tool = ToolDef(
|
||||
name="calculate",
|
||||
description="Perform calculation",
|
||||
input_schema={
|
||||
"type": "object",
|
||||
"properties": {"x": {"type": "number"}, "y": {"type": "number"}},
|
||||
"required": ["x", "y"],
|
||||
},
|
||||
output_schema={"type": "object", "properties": {"result": {"type": "number"}}, "required": ["result"]},
|
||||
)
|
||||
|
||||
assert tool.input_schema is not None
|
||||
assert tool.output_schema is not None
|
||||
assert "result" in tool.output_schema["properties"]
|
||||
|
||||
def test_schema_with_refs_and_defs(self):
|
||||
"""Test that $ref and $defs are preserved in schemas."""
|
||||
tool = ToolDef(
|
||||
name="book_flight",
|
||||
description="Book a flight",
|
||||
input_schema={
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"flight": {"$ref": "#/$defs/FlightInfo"},
|
||||
"passengers": {"type": "array", "items": {"$ref": "#/$defs/Passenger"}},
|
||||
},
|
||||
"$defs": {
|
||||
"FlightInfo": {
|
||||
"type": "object",
|
||||
"properties": {"from": {"type": "string"}, "to": {"type": "string"}},
|
||||
},
|
||||
"Passenger": {
|
||||
"type": "object",
|
||||
"properties": {"name": {"type": "string"}, "age": {"type": "integer"}},
|
||||
},
|
||||
},
|
||||
},
|
||||
)
|
||||
|
||||
# Verify $defs are preserved
|
||||
assert "$defs" in tool.input_schema
|
||||
assert "FlightInfo" in tool.input_schema["$defs"]
|
||||
assert "Passenger" in tool.input_schema["$defs"]
|
||||
|
||||
# Verify $ref are preserved
|
||||
assert tool.input_schema["properties"]["flight"]["$ref"] == "#/$defs/FlightInfo"
|
||||
assert tool.input_schema["properties"]["passengers"]["items"]["$ref"] == "#/$defs/Passenger"
|
||||
|
||||
def test_output_schema_with_refs(self):
|
||||
"""Test that output_schema also supports $ref and $defs."""
|
||||
tool = ToolDef(
|
||||
name="search",
|
||||
description="Search for items",
|
||||
input_schema={"type": "object", "properties": {"query": {"type": "string"}}},
|
||||
output_schema={
|
||||
"type": "object",
|
||||
"properties": {"results": {"type": "array", "items": {"$ref": "#/$defs/SearchResult"}}},
|
||||
"$defs": {
|
||||
"SearchResult": {
|
||||
"type": "object",
|
||||
"properties": {"title": {"type": "string"}, "score": {"type": "number"}},
|
||||
}
|
||||
},
|
||||
},
|
||||
)
|
||||
|
||||
assert "$defs" in tool.output_schema
|
||||
assert "SearchResult" in tool.output_schema["$defs"]
|
||||
|
||||
def test_complex_json_schema_features(self):
|
||||
"""Test various JSON Schema features are preserved."""
|
||||
tool = ToolDef(
|
||||
name="complex_tool",
|
||||
description="Tool with complex schema",
|
||||
input_schema={
|
||||
"type": "object",
|
||||
"properties": {
|
||||
# anyOf
|
||||
"contact": {
|
||||
"anyOf": [
|
||||
{"type": "string", "format": "email"},
|
||||
{"type": "string", "pattern": "^\\+?[0-9]{10,15}$"},
|
||||
]
|
||||
},
|
||||
# enum
|
||||
"status": {"type": "string", "enum": ["pending", "approved", "rejected"]},
|
||||
# nested objects
|
||||
"address": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"street": {"type": "string"},
|
||||
"city": {"type": "string"},
|
||||
"zipcode": {"type": "string", "pattern": "^[0-9]{5}$"},
|
||||
},
|
||||
"required": ["street", "city"],
|
||||
},
|
||||
# array with constraints
|
||||
"tags": {
|
||||
"type": "array",
|
||||
"items": {"type": "string"},
|
||||
"minItems": 1,
|
||||
"maxItems": 10,
|
||||
"uniqueItems": True,
|
||||
},
|
||||
},
|
||||
},
|
||||
)
|
||||
|
||||
# Verify anyOf
|
||||
assert "anyOf" in tool.input_schema["properties"]["contact"]
|
||||
|
||||
# Verify enum
|
||||
assert tool.input_schema["properties"]["status"]["enum"] == ["pending", "approved", "rejected"]
|
||||
|
||||
# Verify nested object
|
||||
assert tool.input_schema["properties"]["address"]["type"] == "object"
|
||||
assert "zipcode" in tool.input_schema["properties"]["address"]["properties"]
|
||||
|
||||
# Verify array constraints
|
||||
tags_schema = tool.input_schema["properties"]["tags"]
|
||||
assert tags_schema["minItems"] == 1
|
||||
assert tags_schema["maxItems"] == 10
|
||||
assert tags_schema["uniqueItems"] is True
|
||||
|
||||
def test_invalid_json_schema_raises_error(self):
|
||||
"""Test that invalid JSON Schema raises validation error."""
|
||||
# TODO: This test will pass once we add schema validation
|
||||
# For now, Pydantic accepts any dict, so this is a placeholder
|
||||
|
||||
# This should eventually raise an error due to invalid schema
|
||||
try:
|
||||
ToolDef(
|
||||
name="bad_tool",
|
||||
input_schema={
|
||||
"type": "invalid_type", # Not a valid JSON Schema type
|
||||
"properties": "not_an_object", # Should be an object
|
||||
},
|
||||
)
|
||||
# For now this passes, but shouldn't after we add validation
|
||||
except ValidationError:
|
||||
pass # Expected once validation is added
|
||||
|
||||
|
||||
class TestToolDefinitionValidation:
|
||||
"""Test ToolDefinition (internal) validation with JSON Schema."""
|
||||
|
||||
def test_simple_tool_definition(self):
|
||||
"""Test ToolDefinition with simple schema."""
|
||||
tool = ToolDefinition(
|
||||
tool_name="get_time",
|
||||
description="Get current time",
|
||||
input_schema={"type": "object", "properties": {"timezone": {"type": "string"}}},
|
||||
)
|
||||
|
||||
assert tool.tool_name == "get_time"
|
||||
assert tool.input_schema is not None
|
||||
|
||||
def test_builtin_tool_with_schema(self):
|
||||
"""Test ToolDefinition with BuiltinTool enum."""
|
||||
tool = ToolDefinition(
|
||||
tool_name=BuiltinTool.code_interpreter,
|
||||
description="Run Python code",
|
||||
input_schema={"type": "object", "properties": {"code": {"type": "string"}}, "required": ["code"]},
|
||||
output_schema={"type": "object", "properties": {"output": {"type": "string"}, "error": {"type": "string"}}},
|
||||
)
|
||||
|
||||
assert isinstance(tool.tool_name, BuiltinTool)
|
||||
assert tool.input_schema is not None
|
||||
assert tool.output_schema is not None
|
||||
|
||||
def test_tool_definition_with_refs(self):
|
||||
"""Test ToolDefinition preserves $ref/$defs."""
|
||||
tool = ToolDefinition(
|
||||
tool_name="process_data",
|
||||
input_schema={
|
||||
"type": "object",
|
||||
"properties": {"data": {"$ref": "#/$defs/DataObject"}},
|
||||
"$defs": {
|
||||
"DataObject": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"id": {"type": "integer"},
|
||||
"values": {"type": "array", "items": {"type": "number"}},
|
||||
},
|
||||
}
|
||||
},
|
||||
},
|
||||
)
|
||||
|
||||
assert "$defs" in tool.input_schema
|
||||
assert tool.input_schema["properties"]["data"]["$ref"] == "#/$defs/DataObject"
|
||||
|
||||
|
||||
class TestSchemaEquivalence:
|
||||
"""Test that schemas remain unchanged through serialization."""
|
||||
|
||||
def test_schema_roundtrip(self):
|
||||
"""Test that schemas survive model_dump/model_validate roundtrip."""
|
||||
original = ToolDef(
|
||||
name="test",
|
||||
input_schema={
|
||||
"type": "object",
|
||||
"properties": {"x": {"$ref": "#/$defs/X"}},
|
||||
"$defs": {"X": {"type": "string"}},
|
||||
},
|
||||
)
|
||||
|
||||
# Serialize and deserialize
|
||||
dumped = original.model_dump()
|
||||
restored = ToolDef(**dumped)
|
||||
|
||||
# Schemas should be identical
|
||||
assert restored.input_schema == original.input_schema
|
||||
assert "$defs" in restored.input_schema
|
||||
assert restored.input_schema["properties"]["x"]["$ref"] == "#/$defs/X"
|
||||
|
||||
def test_json_serialization(self):
|
||||
"""Test JSON serialization preserves schema."""
|
||||
import json
|
||||
|
||||
tool = ToolDef(
|
||||
name="test",
|
||||
input_schema={
|
||||
"type": "object",
|
||||
"properties": {"a": {"type": "string"}},
|
||||
"$defs": {"T": {"type": "number"}},
|
||||
},
|
||||
output_schema={"type": "object", "properties": {"b": {"$ref": "#/$defs/T"}}},
|
||||
)
|
||||
|
||||
# Serialize to JSON and back
|
||||
json_str = tool.model_dump_json()
|
||||
parsed = json.loads(json_str)
|
||||
restored = ToolDef(**parsed)
|
||||
|
||||
assert restored.input_schema == tool.input_schema
|
||||
assert restored.output_schema == tool.output_schema
|
||||
assert "$defs" in restored.input_schema
|
||||
|
||||
|
||||
class TestBackwardsCompatibility:
|
||||
"""Test handling of legacy code patterns."""
|
||||
|
||||
def test_none_schemas(self):
|
||||
"""Test tools with no schemas (legacy case)."""
|
||||
tool = ToolDef(name="legacy_tool", description="Tool without schemas", input_schema=None, output_schema=None)
|
||||
|
||||
assert tool.input_schema is None
|
||||
assert tool.output_schema is None
|
||||
|
||||
def test_metadata_preserved(self):
|
||||
"""Test that metadata field still works."""
|
||||
tool = ToolDef(
|
||||
name="test", input_schema={"type": "object"}, metadata={"endpoint": "http://example.com", "version": "1.0"}
|
||||
)
|
||||
|
||||
assert tool.metadata["endpoint"] == "http://example.com"
|
||||
assert tool.metadata["version"] == "1.0"
|
Loading…
Add table
Add a link
Reference in a new issue