mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-10-04 20:14:13 +00:00
99 lines
3.1 KiB
Python
99 lines
3.1 KiB
Python
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
# All rights reserved.
|
|
#
|
|
# This source code is licensed under the terms described in the LICENSE file in
|
|
# the root directory of this source tree.
|
|
|
|
import pytest
|
|
from pydantic import TypeAdapter, ValidationError
|
|
|
|
from llama_stack.apis.agents.openai_responses import OpenAIResponsesToolChoice
|
|
from llama_stack.apis.tools.openai_tool_choice import (
|
|
ToolChoiceAllowed,
|
|
ToolChoiceCustom,
|
|
ToolChoiceFunction,
|
|
ToolChoiceMcp,
|
|
ToolChoiceOptions,
|
|
ToolChoiceTypes,
|
|
)
|
|
|
|
|
|
def test_tool_choice_discriminated_options():
|
|
adapter = TypeAdapter(OpenAIResponsesToolChoice)
|
|
|
|
cases = [
|
|
({"type": "function", "name": "search"}, ToolChoiceFunction, "function"),
|
|
({"type": "mcp", "server_label": "deepwiki"}, ToolChoiceMcp, "mcp"),
|
|
({"type": "custom", "name": "my_tool"}, ToolChoiceCustom, "custom"),
|
|
(
|
|
{
|
|
"type": "allowed_tools",
|
|
"mode": "auto",
|
|
"tools": [{"type": "function", "name": "foo"}],
|
|
},
|
|
ToolChoiceAllowed,
|
|
"allowed_tools",
|
|
),
|
|
]
|
|
|
|
for payload, expected_cls, expected_type in cases:
|
|
obj = adapter.validate_python(payload)
|
|
assert isinstance(obj, expected_cls)
|
|
assert obj.type == expected_type
|
|
|
|
dumped = obj.model_dump()
|
|
reparsed = adapter.validate_python(dumped)
|
|
assert isinstance(reparsed, expected_cls)
|
|
assert reparsed.model_dump() == dumped
|
|
|
|
|
|
def test_tool_choice_literal_options():
|
|
adapter = TypeAdapter(OpenAIResponsesToolChoice)
|
|
options_adapter = TypeAdapter(ToolChoiceOptions)
|
|
|
|
for v in ("none", "auto", "required"):
|
|
# Validate via the specific literal adapter
|
|
assert options_adapter.validate_python(v) == v
|
|
# And via the top-level union adapter
|
|
assert adapter.validate_python(v) == v
|
|
|
|
|
|
def test_tool_choice_rejects_invalid_value():
|
|
adapter = TypeAdapter(OpenAIResponsesToolChoice)
|
|
|
|
with pytest.raises(ValidationError):
|
|
adapter.validate_python("invalid")
|
|
with pytest.raises(ValidationError):
|
|
adapter.validate_python({"type": "unknown_variant"})
|
|
|
|
|
|
def test_tool_choice_types_accepts_each_variant_value():
|
|
adapter = TypeAdapter(OpenAIResponsesToolChoice)
|
|
|
|
allowed_values = [
|
|
"file_search",
|
|
"web_search_preview",
|
|
"computer_use_preview",
|
|
"web_search_preview_2025_03_11",
|
|
"image_generation",
|
|
"code_interpreter",
|
|
]
|
|
|
|
for v in allowed_values:
|
|
obj = adapter.validate_python({"type": v})
|
|
assert isinstance(obj, ToolChoiceTypes)
|
|
assert obj.type == v
|
|
assert obj.model_dump() == {"type": v}
|
|
|
|
|
|
def test_tool_choice_rejects_invalid_discriminator_value():
|
|
adapter = TypeAdapter(OpenAIResponsesToolChoice)
|
|
with pytest.raises(ValidationError):
|
|
adapter.validate_python({"type": "unknown_variant"})
|
|
|
|
|
|
def test_tool_choice_rejects_missing_required_fields():
|
|
adapter = TypeAdapter(OpenAIResponsesToolChoice)
|
|
# Missing "name" for function
|
|
with pytest.raises(ValidationError):
|
|
adapter.validate_python({"type": "function"})
|