Merge branch 'llamastack:main' into model_unregisteration_error_message

This commit is contained in:
Omar Abdelwahab 2025-10-06 10:14:30 -07:00 committed by GitHub
commit aa09a44c94
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
1036 changed files with 314835 additions and 114394 deletions

View file

@ -0,0 +1,60 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from llama_stack.apis.conversations.conversations import (
Conversation,
ConversationCreateRequest,
ConversationItem,
ConversationItemList,
)
def test_conversation_create_request_defaults():
request = ConversationCreateRequest()
assert request.items == []
assert request.metadata == {}
def test_conversation_model_defaults():
conversation = Conversation(
id="conv_123456789",
created_at=1234567890,
metadata=None,
object="conversation",
)
assert conversation.id == "conv_123456789"
assert conversation.object == "conversation"
assert conversation.metadata is None
def test_openai_client_compatibility():
from openai.types.conversations.message import Message
from pydantic import TypeAdapter
openai_message = Message(
id="msg_123",
content=[{"type": "input_text", "text": "Hello"}],
role="user",
status="in_progress",
type="message",
object="message",
)
adapter = TypeAdapter(ConversationItem)
validated_item = adapter.validate_python(openai_message.model_dump())
assert validated_item.id == "msg_123"
assert validated_item.type == "message"
def test_conversation_item_list():
item_list = ConversationItemList(data=[])
assert item_list.object == "list"
assert item_list.data == []
assert item_list.first_id is None
assert item_list.last_id is None
assert item_list.has_more is False

View file

@ -0,0 +1,132 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import tempfile
from pathlib import Path
import pytest
from openai.types.conversations.conversation import Conversation as OpenAIConversation
from openai.types.conversations.conversation_item import ConversationItem as OpenAIConversationItem
from pydantic import TypeAdapter
from llama_stack.apis.agents.openai_responses import (
OpenAIResponseInputMessageContentText,
OpenAIResponseMessage,
)
from llama_stack.core.conversations.conversations import (
ConversationServiceConfig,
ConversationServiceImpl,
)
from llama_stack.providers.utils.sqlstore.sqlstore import SqliteSqlStoreConfig
@pytest.fixture
async def service():
with tempfile.TemporaryDirectory() as tmpdir:
db_path = Path(tmpdir) / "test_conversations.db"
config = ConversationServiceConfig(conversations_store=SqliteSqlStoreConfig(db_path=str(db_path)), policy=[])
service = ConversationServiceImpl(config, {})
await service.initialize()
yield service
async def test_conversation_lifecycle(service):
conversation = await service.create_conversation(metadata={"test": "data"})
assert conversation.id.startswith("conv_")
assert conversation.metadata == {"test": "data"}
retrieved = await service.get_conversation(conversation.id)
assert retrieved.id == conversation.id
deleted = await service.openai_delete_conversation(conversation.id)
assert deleted.id == conversation.id
async def test_conversation_items(service):
conversation = await service.create_conversation()
items = [
OpenAIResponseMessage(
type="message",
role="user",
content=[OpenAIResponseInputMessageContentText(type="input_text", text="Hello")],
id="msg_test123",
status="completed",
)
]
item_list = await service.add_items(conversation.id, items)
assert len(item_list.data) == 1
assert item_list.data[0].id == "msg_test123"
items = await service.list(conversation.id)
assert len(items.data) == 1
async def test_invalid_conversation_id(service):
with pytest.raises(ValueError, match="Expected an ID that begins with 'conv_'"):
await service._get_validated_conversation("invalid_id")
async def test_empty_parameter_validation(service):
with pytest.raises(ValueError, match="Expected a non-empty value"):
await service.retrieve("", "item_123")
async def test_openai_type_compatibility(service):
conversation = await service.create_conversation(metadata={"test": "value"})
conversation_dict = conversation.model_dump()
openai_conversation = OpenAIConversation.model_validate(conversation_dict)
for attr in ["id", "object", "created_at", "metadata"]:
assert getattr(openai_conversation, attr) == getattr(conversation, attr)
items = [
OpenAIResponseMessage(
type="message",
role="user",
content=[OpenAIResponseInputMessageContentText(type="input_text", text="Hello")],
id="msg_test456",
status="completed",
)
]
item_list = await service.add_items(conversation.id, items)
for attr in ["object", "data", "first_id", "last_id", "has_more"]:
assert hasattr(item_list, attr)
assert item_list.object == "list"
items = await service.list(conversation.id)
item = await service.retrieve(conversation.id, items.data[0].id)
item_dict = item.model_dump()
openai_item_adapter = TypeAdapter(OpenAIConversationItem)
openai_item_adapter.validate_python(item_dict)
async def test_policy_configuration():
from llama_stack.core.access_control.datatypes import Action, Scope
from llama_stack.core.datatypes import AccessRule
with tempfile.TemporaryDirectory() as tmpdir:
db_path = Path(tmpdir) / "test_conversations_policy.db"
restrictive_policy = [
AccessRule(forbid=Scope(principal="test_user", actions=[Action.CREATE, Action.READ], resource="*"))
]
config = ConversationServiceConfig(
conversations_store=SqliteSqlStoreConfig(db_path=str(db_path)), policy=restrictive_policy
)
service = ConversationServiceImpl(config, {})
await service.initialize()
assert service.policy == restrictive_policy
assert len(service.policy) == 1
assert service.policy[0].forbid is not None

View file

@ -16,7 +16,7 @@ from llama_stack.apis.datasets.datasets import Dataset, DatasetPurpose, URIDataS
from llama_stack.apis.datatypes import Api
from llama_stack.apis.models import Model, ModelType
from llama_stack.apis.shields.shields import Shield
from llama_stack.apis.tools import ListToolDefsResponse, ToolDef, ToolGroup, ToolParameter
from llama_stack.apis.tools import ListToolDefsResponse, ToolDef, ToolGroup
from llama_stack.apis.vector_dbs import VectorDB
from llama_stack.core.datatypes import RegistryEntrySource
from llama_stack.core.routing_tables.benchmarks import BenchmarksRoutingTable
@ -137,7 +137,10 @@ class ToolGroupsImpl(Impl):
ToolDef(
name="test-tool",
description="Test tool",
parameters=[ToolParameter(name="test-param", description="Test param", parameter_type="string")],
input_schema={
"type": "object",
"properties": {"test-param": {"type": "string", "description": "Test param"}},
},
)
]
)

View file

@ -390,3 +390,467 @@ pip_packages:
assert provider.is_external is True
# config_class is empty string in partial spec
assert provider.config_class == ""
class TestGetExternalProvidersFromModule:
"""Test suite for installing external providers from module."""
def test_stackrunconfig_provider_without_module(self, mock_providers):
"""Test that providers without module attribute are skipped."""
from llama_stack.core.datatypes import Provider, StackRunConfig
from llama_stack.core.distribution import get_external_providers_from_module
import_module_side_effect = make_import_module_side_effect()
with patch("importlib.import_module", side_effect=import_module_side_effect):
config = StackRunConfig(
image_name="test_image",
providers={
"inference": [
Provider(
provider_id="no_module",
provider_type="no_module",
config={},
)
]
},
)
registry = {Api.inference: {}}
result = get_external_providers_from_module(registry, config, building=False)
# Should not add anything to registry
assert len(result[Api.inference]) == 0
def test_stackrunconfig_with_version_spec(self, mock_providers):
"""Test provider with module containing version spec (e.g., package==1.0.0)."""
from types import SimpleNamespace
from llama_stack.core.datatypes import Provider, StackRunConfig
from llama_stack.core.distribution import get_external_providers_from_module
from llama_stack.providers.datatypes import ProviderSpec
fake_spec = ProviderSpec(
api=Api.inference,
provider_type="versioned_test",
config_class="versioned_test.config.VersionedTestConfig",
module="versioned_test==1.0.0",
)
fake_module = SimpleNamespace(get_provider_spec=lambda: fake_spec)
def import_side_effect(name):
if name == "versioned_test.provider":
return fake_module
raise ModuleNotFoundError(name)
with patch("importlib.import_module", side_effect=import_side_effect):
config = StackRunConfig(
image_name="test_image",
providers={
"inference": [
Provider(
provider_id="versioned",
provider_type="versioned_test",
config={},
module="versioned_test==1.0.0",
)
]
},
)
registry = {Api.inference: {}}
result = get_external_providers_from_module(registry, config, building=False)
assert "versioned_test" in result[Api.inference]
assert result[Api.inference]["versioned_test"].module == "versioned_test==1.0.0"
def test_buildconfig_does_not_import_module(self, mock_providers):
"""Test that BuildConfig does not import the module (building=True)."""
from llama_stack.core.datatypes import BuildConfig, BuildProvider, DistributionSpec
from llama_stack.core.distribution import get_external_providers_from_module
build_config = BuildConfig(
version=2,
image_type="container",
image_name="test_image",
distribution_spec=DistributionSpec(
description="test",
providers={
"inference": [
BuildProvider(
provider_type="build_test",
module="build_test==1.0.0",
)
]
},
),
)
# Should not call import_module at all when building
with patch("importlib.import_module") as mock_import:
registry = {Api.inference: {}}
result = get_external_providers_from_module(registry, build_config, building=True)
# Verify module was NOT imported
mock_import.assert_not_called()
# Verify partial spec was created
assert "build_test" in result[Api.inference]
provider = result[Api.inference]["build_test"]
assert provider.module == "build_test==1.0.0"
assert provider.is_external is True
assert provider.config_class == ""
assert provider.api == Api.inference
def test_buildconfig_multiple_providers(self, mock_providers):
"""Test BuildConfig with multiple providers for the same API."""
from llama_stack.core.datatypes import BuildConfig, BuildProvider, DistributionSpec
from llama_stack.core.distribution import get_external_providers_from_module
build_config = BuildConfig(
version=2,
image_type="container",
image_name="test_image",
distribution_spec=DistributionSpec(
description="test",
providers={
"inference": [
BuildProvider(provider_type="provider1", module="provider1"),
BuildProvider(provider_type="provider2", module="provider2"),
]
},
),
)
with patch("importlib.import_module") as mock_import:
registry = {Api.inference: {}}
result = get_external_providers_from_module(registry, build_config, building=True)
mock_import.assert_not_called()
assert "provider1" in result[Api.inference]
assert "provider2" in result[Api.inference]
def test_distributionspec_does_not_import_module(self, mock_providers):
"""Test that DistributionSpec does not import the module (building=True)."""
from llama_stack.core.datatypes import BuildProvider, DistributionSpec
from llama_stack.core.distribution import get_external_providers_from_module
dist_spec = DistributionSpec(
description="test distribution",
providers={
"inference": [
BuildProvider(
provider_type="dist_test",
module="dist_test==2.0.0",
)
]
},
)
# Should not call import_module at all when building
with patch("importlib.import_module") as mock_import:
registry = {Api.inference: {}}
result = get_external_providers_from_module(registry, dist_spec, building=True)
# Verify module was NOT imported
mock_import.assert_not_called()
# Verify partial spec was created
assert "dist_test" in result[Api.inference]
provider = result[Api.inference]["dist_test"]
assert provider.module == "dist_test==2.0.0"
assert provider.is_external is True
assert provider.config_class == ""
def test_list_return_from_get_provider_spec(self, mock_providers):
"""Test when get_provider_spec returns a list of specs."""
from types import SimpleNamespace
from llama_stack.core.datatypes import Provider, StackRunConfig
from llama_stack.core.distribution import get_external_providers_from_module
from llama_stack.providers.datatypes import ProviderSpec
spec1 = ProviderSpec(
api=Api.inference,
provider_type="list_test",
config_class="list_test.config.Config1",
module="list_test",
)
spec2 = ProviderSpec(
api=Api.inference,
provider_type="list_test_remote",
config_class="list_test.config.Config2",
module="list_test",
)
fake_module = SimpleNamespace(get_provider_spec=lambda: [spec1, spec2])
def import_side_effect(name):
if name == "list_test.provider":
return fake_module
raise ModuleNotFoundError(name)
with patch("importlib.import_module", side_effect=import_side_effect):
config = StackRunConfig(
image_name="test_image",
providers={
"inference": [
Provider(
provider_id="list_test",
provider_type="list_test",
config={},
module="list_test",
)
]
},
)
registry = {Api.inference: {}}
result = get_external_providers_from_module(registry, config, building=False)
# Only the matching provider_type should be added
assert "list_test" in result[Api.inference]
assert result[Api.inference]["list_test"].config_class == "list_test.config.Config1"
def test_list_return_filters_by_provider_type(self, mock_providers):
"""Test that list return filters specs by provider_type."""
from types import SimpleNamespace
from llama_stack.core.datatypes import Provider, StackRunConfig
from llama_stack.core.distribution import get_external_providers_from_module
from llama_stack.providers.datatypes import ProviderSpec
spec1 = ProviderSpec(
api=Api.inference,
provider_type="wanted",
config_class="test.Config1",
module="test",
)
spec2 = ProviderSpec(
api=Api.inference,
provider_type="unwanted",
config_class="test.Config2",
module="test",
)
fake_module = SimpleNamespace(get_provider_spec=lambda: [spec1, spec2])
def import_side_effect(name):
if name == "test.provider":
return fake_module
raise ModuleNotFoundError(name)
with patch("importlib.import_module", side_effect=import_side_effect):
config = StackRunConfig(
image_name="test_image",
providers={
"inference": [
Provider(
provider_id="wanted",
provider_type="wanted",
config={},
module="test",
)
]
},
)
registry = {Api.inference: {}}
result = get_external_providers_from_module(registry, config, building=False)
# Only the matching provider_type should be added
assert "wanted" in result[Api.inference]
assert "unwanted" not in result[Api.inference]
def test_list_return_adds_multiple_provider_types(self, mock_providers):
"""Test that list return adds multiple different provider_types when config requests them."""
from types import SimpleNamespace
from llama_stack.core.datatypes import Provider, StackRunConfig
from llama_stack.core.distribution import get_external_providers_from_module
from llama_stack.providers.datatypes import ProviderSpec
# Module returns both inline and remote variants
spec1 = ProviderSpec(
api=Api.inference,
provider_type="remote::ollama",
config_class="test.RemoteConfig",
module="test",
)
spec2 = ProviderSpec(
api=Api.inference,
provider_type="inline::ollama",
config_class="test.InlineConfig",
module="test",
)
fake_module = SimpleNamespace(get_provider_spec=lambda: [spec1, spec2])
def import_side_effect(name):
if name == "test.provider":
return fake_module
raise ModuleNotFoundError(name)
with patch("importlib.import_module", side_effect=import_side_effect):
config = StackRunConfig(
image_name="test_image",
providers={
"inference": [
Provider(
provider_id="remote_ollama",
provider_type="remote::ollama",
config={},
module="test",
),
Provider(
provider_id="inline_ollama",
provider_type="inline::ollama",
config={},
module="test",
),
]
},
)
registry = {Api.inference: {}}
result = get_external_providers_from_module(registry, config, building=False)
# Both provider types should be added to registry
assert "remote::ollama" in result[Api.inference]
assert "inline::ollama" in result[Api.inference]
assert result[Api.inference]["remote::ollama"].config_class == "test.RemoteConfig"
assert result[Api.inference]["inline::ollama"].config_class == "test.InlineConfig"
def test_module_not_found_raises_value_error(self, mock_providers):
"""Test that ModuleNotFoundError raises ValueError with helpful message."""
from llama_stack.core.datatypes import Provider, StackRunConfig
from llama_stack.core.distribution import get_external_providers_from_module
def import_side_effect(name):
if name == "missing_module.provider":
raise ModuleNotFoundError(name)
raise ModuleNotFoundError(name)
with patch("importlib.import_module", side_effect=import_side_effect):
config = StackRunConfig(
image_name="test_image",
providers={
"inference": [
Provider(
provider_id="missing",
provider_type="missing",
config={},
module="missing_module",
)
]
},
)
registry = {Api.inference: {}}
with pytest.raises(ValueError) as exc_info:
get_external_providers_from_module(registry, config, building=False)
assert "get_provider_spec not found" in str(exc_info.value)
def test_generic_exception_is_raised(self, mock_providers):
"""Test that generic exceptions are properly raised."""
from types import SimpleNamespace
from llama_stack.core.datatypes import Provider, StackRunConfig
from llama_stack.core.distribution import get_external_providers_from_module
def bad_spec():
raise RuntimeError("Something went wrong")
fake_module = SimpleNamespace(get_provider_spec=bad_spec)
def import_side_effect(name):
if name == "error_module.provider":
return fake_module
raise ModuleNotFoundError(name)
with patch("importlib.import_module", side_effect=import_side_effect):
config = StackRunConfig(
image_name="test_image",
providers={
"inference": [
Provider(
provider_id="error",
provider_type="error",
config={},
module="error_module",
)
]
},
)
registry = {Api.inference: {}}
with pytest.raises(RuntimeError) as exc_info:
get_external_providers_from_module(registry, config, building=False)
assert "Something went wrong" in str(exc_info.value)
def test_empty_provider_list(self, mock_providers):
"""Test with empty provider list."""
from llama_stack.core.datatypes import StackRunConfig
from llama_stack.core.distribution import get_external_providers_from_module
config = StackRunConfig(
image_name="test_image",
providers={},
)
registry = {Api.inference: {}}
result = get_external_providers_from_module(registry, config, building=False)
# Should return registry unchanged
assert result == registry
assert len(result[Api.inference]) == 0
def test_multiple_apis_with_providers(self, mock_providers):
"""Test multiple APIs with providers."""
from types import SimpleNamespace
from llama_stack.core.datatypes import Provider, StackRunConfig
from llama_stack.core.distribution import get_external_providers_from_module
from llama_stack.providers.datatypes import ProviderSpec
inference_spec = ProviderSpec(
api=Api.inference,
provider_type="inf_test",
config_class="inf.Config",
module="inf_test",
)
safety_spec = ProviderSpec(
api=Api.safety,
provider_type="safe_test",
config_class="safe.Config",
module="safe_test",
)
def import_side_effect(name):
if name == "inf_test.provider":
return SimpleNamespace(get_provider_spec=lambda: inference_spec)
elif name == "safe_test.provider":
return SimpleNamespace(get_provider_spec=lambda: safety_spec)
raise ModuleNotFoundError(name)
with patch("importlib.import_module", side_effect=import_side_effect):
config = StackRunConfig(
image_name="test_image",
providers={
"inference": [
Provider(
provider_id="inf",
provider_type="inf_test",
config={},
module="inf_test",
)
],
"safety": [
Provider(
provider_id="safe",
provider_type="safe_test",
config={},
module="safe_test",
)
],
},
)
registry = {Api.inference: {}, Api.safety: {}}
result = get_external_providers_from_module(registry, config, building=False)
assert "inf_test" in result[Api.inference]
assert "safe_test" in result[Api.safety]

View file

@ -131,10 +131,6 @@ class TestInferenceRecording:
temp_storage_dir = temp_storage_dir / "test_response_storage"
storage = ResponseStorage(temp_storage_dir)
# Test directory creation
assert storage.test_dir.exists()
assert storage.responses_dir.exists()
# Test storing and retrieving a recording
request_hash = "test_hash_123"
request_data = {
@ -174,7 +170,8 @@ class TestInferenceRecording:
# Verify recording was stored
storage = ResponseStorage(temp_storage_dir)
assert storage.responses_dir.exists()
dir = storage._get_test_dir()
assert dir.exists()
async def test_replay_mode(self, temp_storage_dir, real_openai_chat_response):
"""Test that replay mode returns stored responses without making real calls."""

View file

@ -18,7 +18,6 @@ from llama_stack.apis.inference import (
from llama_stack.models.llama.datatypes import (
BuiltinTool,
ToolDefinition,
ToolParamDefinition,
ToolPromptFormat,
)
from llama_stack.providers.utils.inference.prompt_adapter import (
@ -75,12 +74,15 @@ async def test_system_custom_only():
ToolDefinition(
tool_name="custom1",
description="custom1 tool",
parameters={
"param1": ToolParamDefinition(
param_type="str",
description="param1 description",
required=True,
),
input_schema={
"type": "object",
"properties": {
"param1": {
"type": "str",
"description": "param1 description",
},
},
"required": ["param1"],
},
)
],
@ -107,12 +109,15 @@ async def test_system_custom_and_builtin():
ToolDefinition(
tool_name="custom1",
description="custom1 tool",
parameters={
"param1": ToolParamDefinition(
param_type="str",
description="param1 description",
required=True,
),
input_schema={
"type": "object",
"properties": {
"param1": {
"type": "str",
"description": "param1 description",
},
},
"required": ["param1"],
},
),
],
@ -138,7 +143,7 @@ async def test_completion_message_encoding():
tool_calls=[
ToolCall(
tool_name="custom1",
arguments={"param1": "value1"},
arguments='{"param1": "value1"}', # arguments must be a JSON string
call_id="123",
)
],
@ -148,12 +153,15 @@ async def test_completion_message_encoding():
ToolDefinition(
tool_name="custom1",
description="custom1 tool",
parameters={
"param1": ToolParamDefinition(
param_type="str",
description="param1 description",
required=True,
),
input_schema={
"type": "object",
"properties": {
"param1": {
"type": "str",
"description": "param1 description",
},
},
"required": ["param1"],
},
),
],
@ -227,12 +235,15 @@ async def test_replace_system_message_behavior_custom_tools():
ToolDefinition(
tool_name="custom1",
description="custom1 tool",
parameters={
"param1": ToolParamDefinition(
param_type="str",
description="param1 description",
required=True,
),
input_schema={
"type": "object",
"properties": {
"param1": {
"type": "str",
"description": "param1 description",
},
},
"required": ["param1"],
},
),
],
@ -264,12 +275,15 @@ async def test_replace_system_message_behavior_custom_tools_with_template():
ToolDefinition(
tool_name="custom1",
description="custom1 tool",
parameters={
"param1": ToolParamDefinition(
param_type="str",
description="param1 description",
required=True,
),
input_schema={
"type": "object",
"properties": {
"param1": {
"type": "str",
"description": "param1 description",
},
},
"required": ["param1"],
},
),
],

View file

@ -16,9 +16,8 @@ from llama_stack.apis.agents import (
)
from llama_stack.apis.common.responses import PaginatedResponse
from llama_stack.apis.inference import Inference
from llama_stack.apis.resource import ResourceType
from llama_stack.apis.safety import Safety
from llama_stack.apis.tools import ListToolsResponse, Tool, ToolGroups, ToolParameter, ToolRuntime
from llama_stack.apis.tools import ListToolDefsResponse, ToolDef, ToolGroups, ToolRuntime
from llama_stack.apis.vector_io import VectorIO
from llama_stack.providers.inline.agents.meta_reference.agent_instance import ChatAgent
from llama_stack.providers.inline.agents.meta_reference.agents import MetaReferenceAgentsImpl
@ -232,32 +231,26 @@ async def test_delete_agent(agents_impl, sample_agent_config):
async def test__initialize_tools(agents_impl, sample_agent_config):
# Mock tool_groups_api.list_tools()
agents_impl.tool_groups_api.list_tools.return_value = ListToolsResponse(
agents_impl.tool_groups_api.list_tools.return_value = ListToolDefsResponse(
data=[
Tool(
identifier="story_maker",
provider_id="model-context-protocol",
type=ResourceType.tool,
ToolDef(
name="story_maker",
toolgroup_id="mcp::my_mcp_server",
description="Make a story",
parameters=[
ToolParameter(
name="story_title",
parameter_type="string",
description="Title of the story",
required=True,
title="Story Title",
),
ToolParameter(
name="input_words",
parameter_type="array",
description="Input words",
required=False,
items={"type": "string"},
title="Input Words",
default=[],
),
],
input_schema={
"type": "object",
"properties": {
"story_title": {"type": "string", "description": "Title of the story", "title": "Story Title"},
"input_words": {
"type": "array",
"description": "Input words",
"items": {"type": "string"},
"title": "Input Words",
"default": [],
},
},
"required": ["story_title"],
},
)
]
)
@ -284,27 +277,27 @@ async def test__initialize_tools(agents_impl, sample_agent_config):
assert second_tool.tool_name == "story_maker"
assert second_tool.description == "Make a story"
parameters = second_tool.parameters
assert len(parameters) == 2
# Verify the input schema
input_schema = second_tool.input_schema
assert input_schema is not None
assert input_schema["type"] == "object"
properties = input_schema["properties"]
assert len(properties) == 2
# Verify a string property
story_title = parameters.get("story_title")
assert story_title is not None
assert story_title.param_type == "string"
assert story_title.description == "Title of the story"
assert story_title.required
assert story_title.items is None
assert story_title.title == "Story Title"
assert story_title.default is None
story_title = properties["story_title"]
assert story_title["type"] == "string"
assert story_title["description"] == "Title of the story"
assert story_title["title"] == "Story Title"
# Verify an array property
input_words = parameters.get("input_words")
assert input_words is not None
assert input_words.param_type == "array"
assert input_words.description == "Input words"
assert not input_words.required
assert input_words.items is not None
assert len(input_words.items) == 1
assert input_words.items.get("type") == "string"
assert input_words.title == "Input Words"
assert input_words.default == []
input_words = properties["input_words"]
assert input_words["type"] == "array"
assert input_words["description"] == "Input words"
assert input_words["items"]["type"] == "string"
assert input_words["title"] == "Input Words"
assert input_words["default"] == []
# Verify required fields
assert input_schema["required"] == ["story_title"]

View file

@ -22,7 +22,6 @@ from llama_stack.apis.agents.openai_responses import (
OpenAIResponseInputToolFunction,
OpenAIResponseInputToolWebSearch,
OpenAIResponseMessage,
OpenAIResponseObjectWithInput,
OpenAIResponseOutputMessageContentOutputText,
OpenAIResponseOutputMessageMCPCall,
OpenAIResponseOutputMessageWebSearchToolCall,
@ -37,16 +36,18 @@ from llama_stack.apis.inference import (
OpenAIJSONSchema,
OpenAIResponseFormatJSONObject,
OpenAIResponseFormatJSONSchema,
OpenAIResponseFormatText,
OpenAIUserMessageParam,
)
from llama_stack.apis.tools.tools import Tool, ToolGroups, ToolInvocationResult, ToolParameter, ToolRuntime
from llama_stack.apis.tools.tools import ToolDef, ToolGroups, ToolInvocationResult, ToolRuntime
from llama_stack.core.access_control.access_control import default_policy
from llama_stack.core.datatypes import ResponsesStoreConfig
from llama_stack.providers.inline.agents.meta_reference.responses.openai_responses import (
OpenAIResponsesImpl,
)
from llama_stack.providers.utils.responses.responses_store import ResponsesStore
from llama_stack.providers.utils.responses.responses_store import (
ResponsesStore,
_OpenAIResponseObjectWithInputAndMessages,
)
from llama_stack.providers.utils.sqlstore.sqlstore import SqliteSqlStoreConfig
from tests.unit.providers.agents.meta_reference.fixtures import load_chat_completion_fixture
@ -148,7 +149,7 @@ async def test_create_openai_response_with_string_input(openai_responses_impl, m
mock_inference_api.openai_chat_completion.assert_called_once_with(
model=model,
messages=[OpenAIUserMessageParam(role="user", content="What is the capital of Ireland?", name=None)],
response_format=OpenAIResponseFormatText(),
response_format=None,
tools=None,
stream=True,
temperature=0.1,
@ -187,14 +188,15 @@ async def test_create_openai_response_with_string_input_with_tools(openai_respon
input_text = "What is the capital of Ireland?"
model = "meta-llama/Llama-3.1-8B-Instruct"
openai_responses_impl.tool_groups_api.get_tool.return_value = Tool(
identifier="web_search",
provider_id="client",
openai_responses_impl.tool_groups_api.get_tool.return_value = ToolDef(
name="web_search",
toolgroup_id="web_search",
description="Search the web for information",
parameters=[
ToolParameter(name="query", parameter_type="string", description="The query to search for", required=True)
],
input_schema={
"type": "object",
"properties": {"query": {"type": "string", "description": "The query to search for"}},
"required": ["query"],
},
)
openai_responses_impl.tool_runtime_api.invoke_tool.return_value = ToolInvocationResult(
@ -328,6 +330,132 @@ async def test_create_openai_response_with_tool_call_type_none(openai_responses_
assert chunks[5].response.output[0].name == "get_weather"
async def test_create_openai_response_with_tool_call_function_arguments_none(openai_responses_impl, mock_inference_api):
"""Test creating an OpenAI response with a tool call response that has a function that does not accept arguments, or arguments set to None when they are not mandatory."""
# Setup
input_text = "What is the time right now?"
model = "meta-llama/Llama-3.1-8B-Instruct"
async def fake_stream_toolcall():
yield ChatCompletionChunk(
id="123",
choices=[
Choice(
index=0,
delta=ChoiceDelta(
tool_calls=[
ChoiceDeltaToolCall(
index=0,
id="tc_123",
function=ChoiceDeltaToolCallFunction(name="get_current_time", arguments=None),
type=None,
)
]
),
),
],
created=1,
model=model,
object="chat.completion.chunk",
)
mock_inference_api.openai_chat_completion.return_value = fake_stream_toolcall()
# Function does not accept arguments
result = await openai_responses_impl.create_openai_response(
input=input_text,
model=model,
stream=True,
temperature=0.1,
tools=[
OpenAIResponseInputToolFunction(
name="get_current_time",
description="Get current time for system's timezone",
parameters={},
)
],
)
# Check that we got the content from our mocked tool execution result
chunks = [chunk async for chunk in result]
# Verify event types
# Should have: response.created, output_item.added, function_call_arguments.delta,
# function_call_arguments.done, output_item.done, response.completed
assert len(chunks) == 5
# Verify inference API was called correctly (after iterating over result)
first_call = mock_inference_api.openai_chat_completion.call_args_list[0]
assert first_call.kwargs["messages"][0].content == input_text
assert first_call.kwargs["tools"] is not None
assert first_call.kwargs["temperature"] == 0.1
# Check response.created event (should have empty output)
assert chunks[0].type == "response.created"
assert len(chunks[0].response.output) == 0
# Check streaming events
assert chunks[1].type == "response.output_item.added"
assert chunks[2].type == "response.function_call_arguments.done"
assert chunks[3].type == "response.output_item.done"
# Check response.completed event (should have the tool call with arguments set to "{}")
assert chunks[4].type == "response.completed"
assert len(chunks[4].response.output) == 1
assert chunks[4].response.output[0].type == "function_call"
assert chunks[4].response.output[0].name == "get_current_time"
assert chunks[4].response.output[0].arguments == "{}"
mock_inference_api.openai_chat_completion.return_value = fake_stream_toolcall()
# Function accepts optional arguments
result = await openai_responses_impl.create_openai_response(
input=input_text,
model=model,
stream=True,
temperature=0.1,
tools=[
OpenAIResponseInputToolFunction(
name="get_current_time",
description="Get current time for system's timezone",
parameters={
"timezone": "string",
},
)
],
)
# Check that we got the content from our mocked tool execution result
chunks = [chunk async for chunk in result]
# Verify event types
# Should have: response.created, output_item.added, function_call_arguments.delta,
# function_call_arguments.done, output_item.done, response.completed
assert len(chunks) == 5
# Verify inference API was called correctly (after iterating over result)
first_call = mock_inference_api.openai_chat_completion.call_args_list[0]
assert first_call.kwargs["messages"][0].content == input_text
assert first_call.kwargs["tools"] is not None
assert first_call.kwargs["temperature"] == 0.1
# Check response.created event (should have empty output)
assert chunks[0].type == "response.created"
assert len(chunks[0].response.output) == 0
# Check streaming events
assert chunks[1].type == "response.output_item.added"
assert chunks[2].type == "response.function_call_arguments.done"
assert chunks[3].type == "response.output_item.done"
# Check response.completed event (should have the tool call with arguments set to "{}")
assert chunks[4].type == "response.completed"
assert len(chunks[4].response.output) == 1
assert chunks[4].response.output[0].type == "function_call"
assert chunks[4].response.output[0].name == "get_current_time"
assert chunks[4].response.output[0].arguments == "{}"
async def test_create_openai_response_with_multiple_messages(openai_responses_impl, mock_inference_api):
"""Test creating an OpenAI response with multiple messages."""
# Setup
@ -373,13 +501,6 @@ async def test_create_openai_response_with_multiple_messages(openai_responses_im
assert isinstance(inference_messages[i], OpenAIDeveloperMessageParam)
async def test_prepend_previous_response_none(openai_responses_impl):
"""Test prepending no previous response to a new response."""
input = await openai_responses_impl._prepend_previous_response("fake_input", None)
assert input == "fake_input"
async def test_prepend_previous_response_basic(openai_responses_impl, mock_responses_store):
"""Test prepending a basic previous response to a new response."""
@ -394,7 +515,7 @@ async def test_prepend_previous_response_basic(openai_responses_impl, mock_respo
status="completed",
role="assistant",
)
previous_response = OpenAIResponseObjectWithInput(
previous_response = _OpenAIResponseObjectWithInputAndMessages(
created_at=1,
id="resp_123",
model="fake_model",
@ -402,10 +523,11 @@ async def test_prepend_previous_response_basic(openai_responses_impl, mock_respo
status="completed",
text=OpenAIResponseText(format=OpenAIResponseTextFormat(type="text")),
input=[input_item_message],
messages=[OpenAIUserMessageParam(content="fake_previous_input")],
)
mock_responses_store.get_response_object.return_value = previous_response
input = await openai_responses_impl._prepend_previous_response("fake_input", "resp_123")
input = await openai_responses_impl._prepend_previous_response("fake_input", previous_response)
assert len(input) == 3
# Check for previous input
@ -436,7 +558,7 @@ async def test_prepend_previous_response_web_search(openai_responses_impl, mock_
status="completed",
role="assistant",
)
response = OpenAIResponseObjectWithInput(
response = _OpenAIResponseObjectWithInputAndMessages(
created_at=1,
id="resp_123",
model="fake_model",
@ -444,11 +566,12 @@ async def test_prepend_previous_response_web_search(openai_responses_impl, mock_
status="completed",
text=OpenAIResponseText(format=OpenAIResponseTextFormat(type="text")),
input=[input_item_message],
messages=[OpenAIUserMessageParam(content="test input")],
)
mock_responses_store.get_response_object.return_value = response
input_messages = [OpenAIResponseMessage(content="fake_input", role="user")]
input = await openai_responses_impl._prepend_previous_response(input_messages, "resp_123")
input = await openai_responses_impl._prepend_previous_response(input_messages, response)
assert len(input) == 4
# Check for previous input
@ -483,7 +606,7 @@ async def test_prepend_previous_response_mcp_tool_call(openai_responses_impl, mo
status="completed",
role="assistant",
)
response = OpenAIResponseObjectWithInput(
response = _OpenAIResponseObjectWithInputAndMessages(
created_at=1,
id="resp_123",
model="fake_model",
@ -491,11 +614,12 @@ async def test_prepend_previous_response_mcp_tool_call(openai_responses_impl, mo
status="completed",
text=OpenAIResponseText(format=OpenAIResponseTextFormat(type="text")),
input=[input_item_message],
messages=[OpenAIUserMessageParam(content="test input")],
)
mock_responses_store.get_response_object.return_value = response
input_messages = [OpenAIResponseMessage(content="fake_input", role="user")]
input = await openai_responses_impl._prepend_previous_response(input_messages, "resp_123")
input = await openai_responses_impl._prepend_previous_response(input_messages, response)
assert len(input) == 4
# Check for previous input
@ -599,7 +723,7 @@ async def test_create_openai_response_with_instructions_and_previous_response(
status="completed",
role="assistant",
)
response = OpenAIResponseObjectWithInput(
response = _OpenAIResponseObjectWithInputAndMessages(
created_at=1,
id="resp_123",
model="fake_model",
@ -607,6 +731,10 @@ async def test_create_openai_response_with_instructions_and_previous_response(
status="completed",
text=OpenAIResponseText(format=OpenAIResponseTextFormat(type="text")),
input=[input_item_message],
messages=[
OpenAIUserMessageParam(content="Name some towns in Ireland"),
OpenAIAssistantMessageParam(content="Galway, Longford, Sligo"),
],
)
mock_responses_store.get_response_object.return_value = response
@ -692,7 +820,7 @@ async def test_responses_store_list_input_items_logic():
OpenAIResponseMessage(id="msg_4", content="Fourth message", role="user"),
]
response_with_input = OpenAIResponseObjectWithInput(
response_with_input = _OpenAIResponseObjectWithInputAndMessages(
id="resp_123",
model="test_model",
created_at=1234567890,
@ -701,6 +829,7 @@ async def test_responses_store_list_input_items_logic():
output=[],
text=OpenAIResponseText(format=(OpenAIResponseTextFormat(type="text"))),
input=input_items,
messages=[OpenAIUserMessageParam(content="First message")],
)
# Mock the get_response_object method to return our test data
@ -761,7 +890,7 @@ async def test_store_response_uses_rehydrated_input_with_previous_response(
rather than just the original input when previous_response_id is provided."""
# Setup - Create a previous response that should be included in the stored input
previous_response = OpenAIResponseObjectWithInput(
previous_response = _OpenAIResponseObjectWithInputAndMessages(
id="resp-previous-123",
object="response",
created_at=1234567890,
@ -780,6 +909,10 @@ async def test_store_response_uses_rehydrated_input_with_previous_response(
content=[OpenAIResponseOutputMessageContentOutputText(text="2+2 equals 4.")],
)
],
messages=[
OpenAIUserMessageParam(content="What is 2+2?"),
OpenAIAssistantMessageParam(content="2+2 equals 4."),
],
)
mock_responses_store.get_response_object.return_value = previous_response
@ -823,16 +956,16 @@ async def test_store_response_uses_rehydrated_input_with_previous_response(
@pytest.mark.parametrize(
"text_format, response_format",
[
(OpenAIResponseText(format=OpenAIResponseTextFormat(type="text")), OpenAIResponseFormatText()),
(OpenAIResponseText(format=OpenAIResponseTextFormat(type="text")), None),
(
OpenAIResponseText(format=OpenAIResponseTextFormat(name="Test", schema={"foo": "bar"}, type="json_schema")),
OpenAIResponseFormatJSONSchema(json_schema=OpenAIJSONSchema(name="Test", schema={"foo": "bar"})),
),
(OpenAIResponseText(format=OpenAIResponseTextFormat(type="json_object")), OpenAIResponseFormatJSONObject()),
# ensure text param with no format specified defaults to text
(OpenAIResponseText(format=None), OpenAIResponseFormatText()),
# ensure text param of None defaults to text
(None, OpenAIResponseFormatText()),
# ensure text param with no format specified defaults to None
(OpenAIResponseText(format=None), None),
# ensure text param of None defaults to None
(None, None),
],
)
async def test_create_openai_response_with_text_format(
@ -855,7 +988,6 @@ async def test_create_openai_response_with_text_format(
# Verify
first_call = mock_inference_api.openai_chat_completion.call_args_list[0]
assert first_call.kwargs["messages"][0].content == input_text
assert first_call.kwargs["response_format"] is not None
assert first_call.kwargs["response_format"] == response_format

View file

@ -7,6 +7,8 @@
import json
from unittest.mock import MagicMock
import pytest
from llama_stack.core.request_headers import request_provider_data_context
from llama_stack.providers.remote.inference.groq.config import GroqConfig
from llama_stack.providers.remote.inference.groq.groq import GroqInferenceAdapter
@ -18,72 +20,41 @@ from llama_stack.providers.remote.inference.together.config import TogetherImplC
from llama_stack.providers.remote.inference.together.together import TogetherInferenceAdapter
def test_groq_provider_openai_client_caching():
"""Ensure the Groq provider does not cache api keys across client requests"""
config = GroqConfig()
inference_adapter = GroqInferenceAdapter(config)
inference_adapter.__provider_spec__ = MagicMock()
inference_adapter.__provider_spec__.provider_data_validator = (
"llama_stack.providers.remote.inference.groq.config.GroqProviderDataValidator"
)
for api_key in ["test1", "test2"]:
with request_provider_data_context(
{"x-llamastack-provider-data": json.dumps({inference_adapter.provider_data_api_key_field: api_key})}
):
assert inference_adapter.client.api_key == api_key
def test_openai_provider_openai_client_caching():
@pytest.mark.parametrize(
"config_cls,adapter_cls,provider_data_validator",
[
(
GroqConfig,
GroqInferenceAdapter,
"llama_stack.providers.remote.inference.groq.config.GroqProviderDataValidator",
),
(
OpenAIConfig,
OpenAIInferenceAdapter,
"llama_stack.providers.remote.inference.openai.config.OpenAIProviderDataValidator",
),
(
TogetherImplConfig,
TogetherInferenceAdapter,
"llama_stack.providers.remote.inference.together.TogetherProviderDataValidator",
),
(
LlamaCompatConfig,
LlamaCompatInferenceAdapter,
"llama_stack.providers.remote.inference.llama_openai_compat.config.LlamaProviderDataValidator",
),
],
)
def test_openai_provider_data_used(config_cls, adapter_cls, provider_data_validator: str):
"""Ensure the OpenAI provider does not cache api keys across client requests"""
config = OpenAIConfig()
inference_adapter = OpenAIInferenceAdapter(config)
inference_adapter = adapter_cls(config=config_cls())
inference_adapter.__provider_spec__ = MagicMock()
inference_adapter.__provider_spec__.provider_data_validator = (
"llama_stack.providers.remote.inference.openai.config.OpenAIProviderDataValidator"
)
inference_adapter.__provider_spec__.provider_data_validator = provider_data_validator
for api_key in ["test1", "test2"]:
with request_provider_data_context(
{"x-llamastack-provider-data": json.dumps({inference_adapter.provider_data_api_key_field: api_key})}
):
openai_client = inference_adapter.client
assert openai_client.api_key == api_key
def test_together_provider_openai_client_caching():
"""Ensure the Together provider does not cache api keys across client requests"""
config = TogetherImplConfig()
inference_adapter = TogetherInferenceAdapter(config)
inference_adapter.__provider_spec__ = MagicMock()
inference_adapter.__provider_spec__.provider_data_validator = (
"llama_stack.providers.remote.inference.together.TogetherProviderDataValidator"
)
for api_key in ["test1", "test2"]:
with request_provider_data_context({"x-llamastack-provider-data": json.dumps({"together_api_key": api_key})}):
together_client = inference_adapter._get_client()
assert together_client.client.api_key == api_key
openai_client = inference_adapter._get_openai_client()
assert openai_client.api_key == api_key
def test_llama_compat_provider_openai_client_caching():
"""Ensure the LlamaCompat provider does not cache api keys across client requests"""
config = LlamaCompatConfig()
inference_adapter = LlamaCompatInferenceAdapter(config)
inference_adapter.__provider_spec__ = MagicMock()
inference_adapter.__provider_spec__.provider_data_validator = (
"llama_stack.providers.remote.inference.llama_openai_compat.config.LlamaProviderDataValidator"
)
for api_key in ["test1", "test2"]:
with request_provider_data_context({"x-llamastack-provider-data": json.dumps({"llama_api_key": api_key})}):
assert inference_adapter.client.api_key == api_key

View file

@ -18,7 +18,8 @@ class TestOpenAIBaseURLConfig:
def test_default_base_url_without_env_var(self):
"""Test that the adapter uses the default OpenAI base URL when no environment variable is set."""
config = OpenAIConfig(api_key="test-key")
adapter = OpenAIInferenceAdapter(config)
adapter = OpenAIInferenceAdapter(config=config)
adapter.provider_data_api_key_field = None # Disable provider data for this test
assert adapter.get_base_url() == "https://api.openai.com/v1"
@ -26,7 +27,8 @@ class TestOpenAIBaseURLConfig:
"""Test that the adapter uses a custom base URL when provided in config."""
custom_url = "https://custom.openai.com/v1"
config = OpenAIConfig(api_key="test-key", base_url=custom_url)
adapter = OpenAIInferenceAdapter(config)
adapter = OpenAIInferenceAdapter(config=config)
adapter.provider_data_api_key_field = None # Disable provider data for this test
assert adapter.get_base_url() == custom_url
@ -37,7 +39,8 @@ class TestOpenAIBaseURLConfig:
config_data = OpenAIConfig.sample_run_config(api_key="test-key")
processed_config = replace_env_vars(config_data)
config = OpenAIConfig.model_validate(processed_config)
adapter = OpenAIInferenceAdapter(config)
adapter = OpenAIInferenceAdapter(config=config)
adapter.provider_data_api_key_field = None # Disable provider data for this test
assert adapter.get_base_url() == "https://env.openai.com/v1"
@ -46,7 +49,8 @@ class TestOpenAIBaseURLConfig:
"""Test that explicit config value overrides environment variable."""
custom_url = "https://config.openai.com/v1"
config = OpenAIConfig(api_key="test-key", base_url=custom_url)
adapter = OpenAIInferenceAdapter(config)
adapter = OpenAIInferenceAdapter(config=config)
adapter.provider_data_api_key_field = None # Disable provider data for this test
# Config should take precedence over environment variable
assert adapter.get_base_url() == custom_url
@ -56,7 +60,8 @@ class TestOpenAIBaseURLConfig:
"""Test that the OpenAI client is initialized with the configured base URL."""
custom_url = "https://test.openai.com/v1"
config = OpenAIConfig(api_key="test-key", base_url=custom_url)
adapter = OpenAIInferenceAdapter(config)
adapter = OpenAIInferenceAdapter(config=config)
adapter.provider_data_api_key_field = None # Disable provider data for this test
# Mock the get_api_key method since it's delegated to LiteLLMOpenAIMixin
adapter.get_api_key = MagicMock(return_value="test-key")
@ -75,7 +80,8 @@ class TestOpenAIBaseURLConfig:
"""Test that check_model_availability uses the configured base URL."""
custom_url = "https://test.openai.com/v1"
config = OpenAIConfig(api_key="test-key", base_url=custom_url)
adapter = OpenAIInferenceAdapter(config)
adapter = OpenAIInferenceAdapter(config=config)
adapter.provider_data_api_key_field = None # Disable provider data for this test
# Mock the get_api_key method
adapter.get_api_key = MagicMock(return_value="test-key")
@ -116,7 +122,8 @@ class TestOpenAIBaseURLConfig:
config_data = OpenAIConfig.sample_run_config(api_key="test-key")
processed_config = replace_env_vars(config_data)
config = OpenAIConfig.model_validate(processed_config)
adapter = OpenAIInferenceAdapter(config)
adapter = OpenAIInferenceAdapter(config=config)
adapter.provider_data_api_key_field = None # Disable provider data for this test
# Mock the get_api_key method
adapter.get_api_key = MagicMock(return_value="test-key")

View file

@ -5,49 +5,21 @@
# the root directory of this source tree.
import asyncio
import json
import time
from unittest.mock import AsyncMock, MagicMock, PropertyMock, patch
import pytest
from openai.types.chat.chat_completion_chunk import (
ChatCompletionChunk as OpenAIChatCompletionChunk,
)
from openai.types.chat.chat_completion_chunk import (
Choice as OpenAIChoiceChunk,
)
from openai.types.chat.chat_completion_chunk import (
ChoiceDelta as OpenAIChoiceDelta,
)
from openai.types.chat.chat_completion_chunk import (
ChoiceDeltaToolCall as OpenAIChoiceDeltaToolCall,
)
from openai.types.chat.chat_completion_chunk import (
ChoiceDeltaToolCallFunction as OpenAIChoiceDeltaToolCallFunction,
)
from openai.types.model import Model as OpenAIModel
from llama_stack.apis.inference import (
ChatCompletionRequest,
ChatCompletionResponseEventType,
CompletionMessage,
OpenAIAssistantMessageParam,
OpenAIChatCompletion,
OpenAIChoice,
SystemMessage,
ToolChoice,
ToolConfig,
ToolResponseMessage,
UserMessage,
)
from llama_stack.apis.models import Model
from llama_stack.models.llama.datatypes import StopReason, ToolCall
from llama_stack.providers.datatypes import HealthStatus
from llama_stack.providers.remote.inference.vllm.config import VLLMInferenceAdapterConfig
from llama_stack.providers.remote.inference.vllm.vllm import (
VLLMInferenceAdapter,
_process_vllm_chat_completion_stream_response,
)
from llama_stack.providers.remote.inference.vllm.vllm import VLLMInferenceAdapter
# These are unit test for the remote vllm provider
# implementation. This should only contain tests which are specific to
@ -60,37 +32,15 @@ from llama_stack.providers.remote.inference.vllm.vllm import (
# -v -s --tb=short --disable-warnings
@pytest.fixture(scope="module")
def mock_openai_models_list():
with patch("openai.resources.models.AsyncModels.list") as mock_list:
yield mock_list
@pytest.fixture(scope="function")
async def vllm_inference_adapter():
config = VLLMInferenceAdapterConfig(url="http://mocked.localhost:12345")
inference_adapter = VLLMInferenceAdapter(config)
inference_adapter = VLLMInferenceAdapter(config=config)
inference_adapter.model_store = AsyncMock()
# Mock the __provider_spec__ attribute that would normally be set by the resolver
inference_adapter.__provider_spec__ = MagicMock()
inference_adapter.__provider_spec__.provider_type = "vllm-inference"
inference_adapter.__provider_spec__.provider_data_validator = MagicMock()
await inference_adapter.initialize()
return inference_adapter
async def test_register_model_checks_vllm(mock_openai_models_list, vllm_inference_adapter):
async def mock_openai_models():
yield OpenAIModel(id="foo", created=1, object="model", owned_by="test")
mock_openai_models_list.return_value = mock_openai_models()
foo_model = Model(identifier="foo", provider_resource_id="foo", provider_id="vllm-inference")
await vllm_inference_adapter.register_model(foo_model)
mock_openai_models_list.assert_called()
async def test_old_vllm_tool_choice(vllm_inference_adapter):
"""
Test that we set tool_choice to none when no tools are in use
@ -99,464 +49,24 @@ async def test_old_vllm_tool_choice(vllm_inference_adapter):
mock_model = Model(identifier="mock-model", provider_resource_id="mock-model", provider_id="vllm-inference")
vllm_inference_adapter.model_store.get_model.return_value = mock_model
with patch.object(vllm_inference_adapter, "_nonstream_chat_completion") as mock_nonstream_completion:
# Patch the client property to avoid instantiating a real AsyncOpenAI client
with patch.object(VLLMInferenceAdapter, "client", new_callable=PropertyMock) as mock_client_property:
mock_client = MagicMock()
mock_client.chat.completions.create = AsyncMock()
mock_client_property.return_value = mock_client
# No tools but auto tool choice
await vllm_inference_adapter.chat_completion(
await vllm_inference_adapter.openai_chat_completion(
"mock-model",
[],
stream=False,
tools=None,
tool_config=ToolConfig(tool_choice=ToolChoice.auto),
tool_choice=ToolChoice.auto.value,
)
mock_nonstream_completion.assert_called()
request = mock_nonstream_completion.call_args.args[0]
mock_client.chat.completions.create.assert_called()
call_args = mock_client.chat.completions.create.call_args
# Ensure tool_choice gets converted to none for older vLLM versions
assert request.tool_config.tool_choice == ToolChoice.none
async def test_tool_call_response(vllm_inference_adapter):
"""Verify that tool call arguments from a CompletionMessage are correctly converted
into the expected JSON format."""
# Patch the client property to avoid instantiating a real AsyncOpenAI client
with patch.object(VLLMInferenceAdapter, "client", new_callable=PropertyMock) as mock_create_client:
mock_client = MagicMock()
mock_client.chat.completions.create = AsyncMock()
mock_create_client.return_value = mock_client
# Mock the model to return a proper provider_resource_id
mock_model = Model(identifier="mock-model", provider_resource_id="mock-model", provider_id="vllm-inference")
vllm_inference_adapter.model_store.get_model.return_value = mock_model
messages = [
SystemMessage(content="You are a helpful assistant"),
UserMessage(content="How many?"),
CompletionMessage(
content="",
stop_reason=StopReason.end_of_turn,
tool_calls=[
ToolCall(
call_id="foo",
tool_name="knowledge_search",
arguments={"query": "How many?"},
arguments_json='{"query": "How many?"}',
)
],
),
ToolResponseMessage(call_id="foo", content="knowledge_search found 5...."),
]
await vllm_inference_adapter.chat_completion(
"mock-model",
messages,
stream=False,
tools=[],
tool_config=ToolConfig(tool_choice=ToolChoice.auto),
)
assert mock_client.chat.completions.create.call_args.kwargs["messages"][2]["tool_calls"] == [
{
"id": "foo",
"type": "function",
"function": {"name": "knowledge_search", "arguments": '{"query": "How many?"}'},
}
]
async def test_tool_call_delta_empty_tool_call_buf():
"""
Test that we don't generate extra chunks when processing a
tool call response that didn't call any tools. Previously we would
emit chunks with spurious ToolCallParseStatus.succeeded or
ToolCallParseStatus.failed when processing chunks that didn't
actually make any tool calls.
"""
async def mock_stream():
delta = OpenAIChoiceDelta(content="", tool_calls=None)
choices = [OpenAIChoiceChunk(delta=delta, finish_reason="stop", index=0)]
mock_chunk = OpenAIChatCompletionChunk(
id="chunk-1",
created=1,
model="foo",
object="chat.completion.chunk",
choices=choices,
)
for chunk in [mock_chunk]:
yield chunk
chunks = [chunk async for chunk in _process_vllm_chat_completion_stream_response(mock_stream())]
assert len(chunks) == 2
assert chunks[0].event.event_type.value == "start"
assert chunks[1].event.event_type.value == "complete"
assert chunks[1].event.stop_reason == StopReason.end_of_turn
async def test_tool_call_delta_streaming_arguments_dict():
async def mock_stream():
mock_chunk_1 = OpenAIChatCompletionChunk(
id="chunk-1",
created=1,
model="foo",
object="chat.completion.chunk",
choices=[
OpenAIChoiceChunk(
delta=OpenAIChoiceDelta(
content="",
tool_calls=[
OpenAIChoiceDeltaToolCall(
id="tc_1",
index=1,
function=OpenAIChoiceDeltaToolCallFunction(
name="power",
arguments="",
),
)
],
),
finish_reason=None,
index=0,
)
],
)
mock_chunk_2 = OpenAIChatCompletionChunk(
id="chunk-2",
created=1,
model="foo",
object="chat.completion.chunk",
choices=[
OpenAIChoiceChunk(
delta=OpenAIChoiceDelta(
content="",
tool_calls=[
OpenAIChoiceDeltaToolCall(
id="tc_1",
index=1,
function=OpenAIChoiceDeltaToolCallFunction(
name="power",
arguments='{"number": 28, "power": 3}',
),
)
],
),
finish_reason=None,
index=0,
)
],
)
mock_chunk_3 = OpenAIChatCompletionChunk(
id="chunk-3",
created=1,
model="foo",
object="chat.completion.chunk",
choices=[
OpenAIChoiceChunk(
delta=OpenAIChoiceDelta(content="", tool_calls=None), finish_reason="tool_calls", index=0
)
],
)
for chunk in [mock_chunk_1, mock_chunk_2, mock_chunk_3]:
yield chunk
chunks = [chunk async for chunk in _process_vllm_chat_completion_stream_response(mock_stream())]
assert len(chunks) == 3
assert chunks[0].event.event_type.value == "start"
assert chunks[1].event.event_type.value == "progress"
assert chunks[1].event.delta.type == "tool_call"
assert chunks[1].event.delta.parse_status.value == "succeeded"
assert chunks[1].event.delta.tool_call.arguments_json == '{"number": 28, "power": 3}'
assert chunks[2].event.event_type.value == "complete"
async def test_multiple_tool_calls():
async def mock_stream():
mock_chunk_1 = OpenAIChatCompletionChunk(
id="chunk-1",
created=1,
model="foo",
object="chat.completion.chunk",
choices=[
OpenAIChoiceChunk(
delta=OpenAIChoiceDelta(
content="",
tool_calls=[
OpenAIChoiceDeltaToolCall(
id="",
index=1,
function=OpenAIChoiceDeltaToolCallFunction(
name="power",
arguments='{"number": 28, "power": 3}',
),
),
],
),
finish_reason=None,
index=0,
)
],
)
mock_chunk_2 = OpenAIChatCompletionChunk(
id="chunk-2",
created=1,
model="foo",
object="chat.completion.chunk",
choices=[
OpenAIChoiceChunk(
delta=OpenAIChoiceDelta(
content="",
tool_calls=[
OpenAIChoiceDeltaToolCall(
id="",
index=2,
function=OpenAIChoiceDeltaToolCallFunction(
name="multiple",
arguments='{"first_number": 4, "second_number": 7}',
),
),
],
),
finish_reason=None,
index=0,
)
],
)
mock_chunk_3 = OpenAIChatCompletionChunk(
id="chunk-3",
created=1,
model="foo",
object="chat.completion.chunk",
choices=[
OpenAIChoiceChunk(
delta=OpenAIChoiceDelta(content="", tool_calls=None), finish_reason="tool_calls", index=0
)
],
)
for chunk in [mock_chunk_1, mock_chunk_2, mock_chunk_3]:
yield chunk
chunks = [chunk async for chunk in _process_vllm_chat_completion_stream_response(mock_stream())]
assert len(chunks) == 4
assert chunks[0].event.event_type.value == "start"
assert chunks[1].event.event_type.value == "progress"
assert chunks[1].event.delta.type == "tool_call"
assert chunks[1].event.delta.parse_status.value == "succeeded"
assert chunks[1].event.delta.tool_call.arguments_json == '{"number": 28, "power": 3}'
assert chunks[2].event.event_type.value == "progress"
assert chunks[2].event.delta.type == "tool_call"
assert chunks[2].event.delta.parse_status.value == "succeeded"
assert chunks[2].event.delta.tool_call.arguments_json == '{"first_number": 4, "second_number": 7}'
assert chunks[3].event.event_type.value == "complete"
async def test_process_vllm_chat_completion_stream_response_no_choices():
"""
Test that we don't error out when vLLM returns no choices for a
completion request. This can happen when there's an error thrown
in vLLM for example.
"""
async def mock_stream():
choices = []
mock_chunk = OpenAIChatCompletionChunk(
id="chunk-1",
created=1,
model="foo",
object="chat.completion.chunk",
choices=choices,
)
for chunk in [mock_chunk]:
yield chunk
chunks = [chunk async for chunk in _process_vllm_chat_completion_stream_response(mock_stream())]
assert len(chunks) == 1
assert chunks[0].event.event_type.value == "start"
async def test_get_params_empty_tools(vllm_inference_adapter):
request = ChatCompletionRequest(
tools=[],
model="test_model",
messages=[UserMessage(content="test")],
)
params = await vllm_inference_adapter._get_params(request)
assert "tools" not in params
async def test_process_vllm_chat_completion_stream_response_tool_call_args_last_chunk():
"""
Tests the edge case where the model returns the arguments for the tool call in the same chunk that
contains the finish reason (i.e., the last one).
We want to make sure the tool call is executed in this case, and the parameters are passed correctly.
"""
mock_tool_name = "mock_tool"
mock_tool_arguments = {"arg1": 0, "arg2": 100}
mock_tool_arguments_str = json.dumps(mock_tool_arguments)
async def mock_stream():
mock_chunks = [
OpenAIChatCompletionChunk(
id="chunk-1",
created=1,
model="foo",
object="chat.completion.chunk",
choices=[
{
"delta": {
"content": None,
"tool_calls": [
{
"index": 0,
"id": "mock_id",
"type": "function",
"function": {
"name": mock_tool_name,
"arguments": None,
},
}
],
},
"finish_reason": None,
"logprobs": None,
"index": 0,
}
],
),
OpenAIChatCompletionChunk(
id="chunk-1",
created=1,
model="foo",
object="chat.completion.chunk",
choices=[
{
"delta": {
"content": None,
"tool_calls": [
{
"index": 0,
"id": None,
"function": {
"name": None,
"arguments": mock_tool_arguments_str,
},
}
],
},
"finish_reason": "tool_calls",
"logprobs": None,
"index": 0,
}
],
),
]
for chunk in mock_chunks:
yield chunk
chunks = [chunk async for chunk in _process_vllm_chat_completion_stream_response(mock_stream())]
assert len(chunks) == 3
assert chunks[-1].event.event_type == ChatCompletionResponseEventType.complete
assert chunks[-2].event.delta.type == "tool_call"
assert chunks[-2].event.delta.tool_call.tool_name == mock_tool_name
assert chunks[-2].event.delta.tool_call.arguments == mock_tool_arguments
async def test_process_vllm_chat_completion_stream_response_no_finish_reason():
"""
Tests the edge case where the model requests a tool call and stays idle without explicitly providing the
finish reason.
We want to make sure that this case is recognized and handled correctly, i.e., as a valid end of message.
"""
mock_tool_name = "mock_tool"
mock_tool_arguments = {"arg1": 0, "arg2": 100}
mock_tool_arguments_str = '"{\\"arg1\\": 0, \\"arg2\\": 100}"'
async def mock_stream():
mock_chunks = [
OpenAIChatCompletionChunk(
id="chunk-1",
created=1,
model="foo",
object="chat.completion.chunk",
choices=[
{
"delta": {
"content": None,
"tool_calls": [
{
"index": 0,
"id": "mock_id",
"type": "function",
"function": {
"name": mock_tool_name,
"arguments": mock_tool_arguments_str,
},
}
],
},
"finish_reason": None,
"logprobs": None,
"index": 0,
}
],
),
]
for chunk in mock_chunks:
yield chunk
chunks = [chunk async for chunk in _process_vllm_chat_completion_stream_response(mock_stream())]
assert len(chunks) == 3
assert chunks[-1].event.event_type == ChatCompletionResponseEventType.complete
assert chunks[-2].event.delta.type == "tool_call"
assert chunks[-2].event.delta.tool_call.tool_name == mock_tool_name
assert chunks[-2].event.delta.tool_call.arguments == mock_tool_arguments
async def test_process_vllm_chat_completion_stream_response_tool_without_args():
"""
Tests the edge case where no arguments are provided for the tool call.
Tool calls with no arguments should be treated as regular tool calls, which was not the case until now.
"""
mock_tool_name = "mock_tool"
async def mock_stream():
mock_chunks = [
OpenAIChatCompletionChunk(
id="chunk-1",
created=1,
model="foo",
object="chat.completion.chunk",
choices=[
{
"delta": {
"content": None,
"tool_calls": [
{
"index": 0,
"id": "mock_id",
"type": "function",
"function": {
"name": mock_tool_name,
"arguments": "",
},
}
],
},
"finish_reason": None,
"logprobs": None,
"index": 0,
}
],
),
]
for chunk in mock_chunks:
yield chunk
chunks = [chunk async for chunk in _process_vllm_chat_completion_stream_response(mock_stream())]
assert len(chunks) == 3
assert chunks[-1].event.event_type == ChatCompletionResponseEventType.complete
assert chunks[-2].event.delta.type == "tool_call"
assert chunks[-2].event.delta.tool_call.tool_name == mock_tool_name
assert chunks[-2].event.delta.tool_call.arguments == {}
assert call_args.kwargs["tool_choice"] == ToolChoice.none.value
async def test_health_status_success(vllm_inference_adapter):
@ -689,96 +199,30 @@ async def test_should_refresh_models():
# Test case 1: refresh_models is True, api_token is None
config1 = VLLMInferenceAdapterConfig(url="http://test.localhost", api_token=None, refresh_models=True)
adapter1 = VLLMInferenceAdapter(config1)
adapter1 = VLLMInferenceAdapter(config=config1)
result1 = await adapter1.should_refresh_models()
assert result1 is True, "should_refresh_models should return True when refresh_models is True"
# Test case 2: refresh_models is True, api_token is empty string
config2 = VLLMInferenceAdapterConfig(url="http://test.localhost", api_token="", refresh_models=True)
adapter2 = VLLMInferenceAdapter(config2)
adapter2 = VLLMInferenceAdapter(config=config2)
result2 = await adapter2.should_refresh_models()
assert result2 is True, "should_refresh_models should return True when refresh_models is True"
# Test case 3: refresh_models is True, api_token is "fake" (default)
config3 = VLLMInferenceAdapterConfig(url="http://test.localhost", api_token="fake", refresh_models=True)
adapter3 = VLLMInferenceAdapter(config3)
adapter3 = VLLMInferenceAdapter(config=config3)
result3 = await adapter3.should_refresh_models()
assert result3 is True, "should_refresh_models should return True when refresh_models is True"
# Test case 4: refresh_models is True, api_token is real token
config4 = VLLMInferenceAdapterConfig(url="http://test.localhost", api_token="real-token-123", refresh_models=True)
adapter4 = VLLMInferenceAdapter(config4)
adapter4 = VLLMInferenceAdapter(config=config4)
result4 = await adapter4.should_refresh_models()
assert result4 is True, "should_refresh_models should return True when refresh_models is True"
# Test case 5: refresh_models is False, api_token is real token
config5 = VLLMInferenceAdapterConfig(url="http://test.localhost", api_token="real-token-456", refresh_models=False)
adapter5 = VLLMInferenceAdapter(config5)
adapter5 = VLLMInferenceAdapter(config=config5)
result5 = await adapter5.should_refresh_models()
assert result5 is False, "should_refresh_models should return False when refresh_models is False"
async def test_provider_data_var_context_propagation(vllm_inference_adapter):
"""
Test that PROVIDER_DATA_VAR context is properly propagated through the vLLM inference adapter.
This ensures that dynamic provider data (like API tokens) can be passed through context.
Note: The base URL is always taken from config.url, not from provider data.
"""
# Mock the AsyncOpenAI class to capture provider data
with (
patch("llama_stack.providers.utils.inference.openai_mixin.AsyncOpenAI") as mock_openai_class,
patch.object(vllm_inference_adapter, "get_request_provider_data") as mock_get_provider_data,
):
mock_client = AsyncMock()
mock_client.chat.completions.create = AsyncMock()
mock_openai_class.return_value = mock_client
# Mock provider data to return test data
mock_provider_data = MagicMock()
mock_provider_data.vllm_api_token = "test-token-123"
mock_provider_data.vllm_url = "http://test-server:8000/v1"
mock_get_provider_data.return_value = mock_provider_data
# Mock the model
mock_model = Model(identifier="test-model", provider_resource_id="test-model", provider_id="vllm-inference")
vllm_inference_adapter.model_store.get_model.return_value = mock_model
try:
# Execute chat completion
await vllm_inference_adapter.chat_completion(
"test-model",
[UserMessage(content="Hello")],
stream=False,
tools=None,
tool_config=ToolConfig(tool_choice=ToolChoice.auto),
)
# Verify that ALL client calls were made with the correct parameters
calls = mock_openai_class.call_args_list
incorrect_calls = []
for i, call in enumerate(calls):
api_key = call[1]["api_key"]
base_url = call[1]["base_url"]
if api_key != "test-token-123" or base_url != "http://mocked.localhost:12345":
incorrect_calls.append({"call_index": i, "api_key": api_key, "base_url": base_url})
if incorrect_calls:
error_msg = (
f"Found {len(incorrect_calls)} calls with incorrect parameters out of {len(calls)} total calls:\n"
)
for incorrect_call in incorrect_calls:
error_msg += f" Call {incorrect_call['call_index']}: api_key='{incorrect_call['api_key']}', base_url='{incorrect_call['base_url']}'\n"
error_msg += "Expected: api_key='test-token-123', base_url='http://mocked.localhost:12345'"
raise AssertionError(error_msg)
# Ensure at least one call was made
assert len(calls) >= 1, "No AsyncOpenAI client calls were made"
# Verify that chat completion was called
mock_client.chat.completions.create.assert_called_once()
finally:
# Clean up context
pass

View file

@ -4,7 +4,7 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from llama_stack.apis.tools import ToolDef, ToolParameter
from llama_stack.apis.tools import ToolDef
from llama_stack.providers.inline.agents.meta_reference.responses.streaming import (
convert_tooldef_to_chat_tool,
)
@ -20,15 +20,11 @@ def test_convert_tooldef_to_chat_tool_preserves_items_field():
tool_def = ToolDef(
name="test_tool",
description="A test tool with array parameter",
parameters=[
ToolParameter(
name="tags",
parameter_type="array",
description="List of tags",
required=True,
items={"type": "string"},
)
],
input_schema={
"type": "object",
"properties": {"tags": {"type": "array", "description": "List of tags", "items": {"type": "string"}}},
"required": ["tags"],
},
)
result = convert_tooldef_to_chat_tool(tool_def)

View file

@ -41,9 +41,7 @@ async def test_convert_message_to_openai_dict():
async def test_convert_message_to_openai_dict_with_tool_call():
message = CompletionMessage(
content="",
tool_calls=[
ToolCall(call_id="123", tool_name="test_tool", arguments_json='{"foo": "bar"}', arguments={"foo": "bar"})
],
tool_calls=[ToolCall(call_id="123", tool_name="test_tool", arguments='{"foo": "bar"}')],
stop_reason=StopReason.end_of_turn,
)
@ -65,8 +63,7 @@ async def test_convert_message_to_openai_dict_with_builtin_tool_call():
ToolCall(
call_id="123",
tool_name=BuiltinTool.brave_search,
arguments_json='{"foo": "bar"}',
arguments={"foo": "bar"},
arguments='{"foo": "bar"}',
)
],
stop_reason=StopReason.end_of_turn,
@ -202,8 +199,7 @@ async def test_convert_message_to_openai_dict_new_completion_message_with_tool_c
ToolCall(
call_id="call_123",
tool_name="get_weather",
arguments={"city": "Sligo"},
arguments_json='{"city": "Sligo"}',
arguments='{"city": "Sligo"}',
)
],
stop_reason=StopReason.end_of_turn,

View file

@ -4,18 +4,23 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from unittest.mock import AsyncMock, MagicMock, PropertyMock, patch
import json
from collections.abc import Iterable
from typing import Any
from unittest.mock import AsyncMock, MagicMock, Mock, PropertyMock, patch
import pytest
from pydantic import BaseModel, Field
from llama_stack.apis.inference import Model, OpenAIUserMessageParam
from llama_stack.apis.models import ModelType
from llama_stack.core.request_headers import request_provider_data_context
from llama_stack.providers.utils.inference.model_registry import RemoteInferenceProviderConfig
from llama_stack.providers.utils.inference.openai_mixin import OpenAIMixin
class OpenAIMixinImpl(OpenAIMixin):
def __init__(self):
self.__provider_id__ = "test-provider"
__provider_id__: str = "test-provider"
def get_api_key(self) -> str:
raise NotImplementedError("This method should be mocked in tests")
@ -24,27 +29,20 @@ class OpenAIMixinImpl(OpenAIMixin):
raise NotImplementedError("This method should be mocked in tests")
class OpenAIMixinWithEmbeddingsImpl(OpenAIMixin):
class OpenAIMixinWithEmbeddingsImpl(OpenAIMixinImpl):
"""Test implementation with embedding model metadata"""
embedding_model_metadata = {
embedding_model_metadata: dict[str, dict[str, int]] = {
"text-embedding-3-small": {"embedding_dimension": 1536, "context_length": 8192},
"text-embedding-ada-002": {"embedding_dimension": 1536, "context_length": 8192},
}
__provider_id__ = "test-provider"
def get_api_key(self) -> str:
raise NotImplementedError("This method should be mocked in tests")
def get_base_url(self) -> str:
raise NotImplementedError("This method should be mocked in tests")
@pytest.fixture
def mixin():
"""Create a test instance of OpenAIMixin with mocked model_store"""
mixin_instance = OpenAIMixinImpl()
config = RemoteInferenceProviderConfig()
mixin_instance = OpenAIMixinImpl(config=config)
# just enough to satisfy _get_provider_model_id calls
mock_model_store = MagicMock()
@ -59,7 +57,8 @@ def mixin():
@pytest.fixture
def mixin_with_embeddings():
"""Create a test instance of OpenAIMixin with embedding model metadata"""
return OpenAIMixinWithEmbeddingsImpl()
config = RemoteInferenceProviderConfig()
return OpenAIMixinWithEmbeddingsImpl(config=config)
@pytest.fixture
@ -366,3 +365,328 @@ class TestOpenAIMixinAllowedModels:
assert await mixin.check_model_availability("final-mock-model-id")
assert not await mixin.check_model_availability("some-mock-model-id")
assert not await mixin.check_model_availability("another-mock-model-id")
class TestOpenAIMixinModelRegistration:
"""Test cases for model registration functionality"""
async def test_register_model_success(self, mixin, mock_client_with_models, mock_client_context):
"""Test successful model registration when model is available"""
model = Model(
provider_id="test-provider",
provider_resource_id="some-mock-model-id",
identifier="test-model",
model_type=ModelType.llm,
)
with mock_client_context(mixin, mock_client_with_models):
result = await mixin.register_model(model)
assert result == model
assert result.provider_id == "test-provider"
assert result.provider_resource_id == "some-mock-model-id"
assert result.identifier == "test-model"
assert result.model_type == ModelType.llm
mock_client_with_models.models.list.assert_called_once()
async def test_register_model_not_available(self, mixin, mock_client_with_models, mock_client_context):
"""Test model registration failure when model is not available from provider"""
model = Model(
provider_id="test-provider",
provider_resource_id="non-existent-model",
identifier="test-model",
model_type=ModelType.llm,
)
with mock_client_context(mixin, mock_client_with_models):
with pytest.raises(
ValueError, match="Model non-existent-model is not available from provider test-provider"
):
await mixin.register_model(model)
mock_client_with_models.models.list.assert_called_once()
async def test_register_model_with_allowed_models_filter(self, mixin, mock_client_with_models, mock_client_context):
"""Test model registration with allowed_models filtering"""
mixin.allowed_models = {"some-mock-model-id"}
# Test with allowed model
allowed_model = Model(
provider_id="test-provider",
provider_resource_id="some-mock-model-id",
identifier="allowed-model",
model_type=ModelType.llm,
)
# Test with disallowed model
disallowed_model = Model(
provider_id="test-provider",
provider_resource_id="final-mock-model-id",
identifier="disallowed-model",
model_type=ModelType.llm,
)
with mock_client_context(mixin, mock_client_with_models):
result = await mixin.register_model(allowed_model)
assert result == allowed_model
with pytest.raises(
ValueError, match="Model final-mock-model-id is not available from provider test-provider"
):
await mixin.register_model(disallowed_model)
mock_client_with_models.models.list.assert_called_once()
async def test_register_embedding_model(self, mixin_with_embeddings, mock_client_context):
"""Test registration of embedding models with metadata"""
mock_embedding_model = MagicMock(id="text-embedding-3-small")
mock_models = [mock_embedding_model]
mock_client = MagicMock()
async def mock_models_list():
for model in mock_models:
yield model
mock_client.models.list.return_value = mock_models_list()
embedding_model = Model(
provider_id="test-provider",
provider_resource_id="text-embedding-3-small",
identifier="embedding-test",
model_type=ModelType.embedding,
)
with mock_client_context(mixin_with_embeddings, mock_client):
result = await mixin_with_embeddings.register_model(embedding_model)
assert result == embedding_model
assert result.model_type == ModelType.embedding
async def test_unregister_model(self, mixin):
"""Test model unregistration (should be no-op)"""
# unregister_model should not raise any exceptions and return None
result = await mixin.unregister_model("any-model-id")
assert result is None
async def test_should_refresh_models(self, mixin):
"""Test should_refresh_models method (should always return False)"""
result = await mixin.should_refresh_models()
assert result is False
async def test_register_model_error_propagation(self, mixin, mock_client_with_exception, mock_client_context):
"""Test that errors from provider API are properly propagated during registration"""
model = Model(
provider_id="test-provider",
provider_resource_id="some-model",
identifier="test-model",
model_type=ModelType.llm,
)
with mock_client_context(mixin, mock_client_with_exception):
# The exception from the API should be propagated
with pytest.raises(Exception, match="API Error"):
await mixin.register_model(model)
class ProviderDataValidator(BaseModel):
"""Validator for provider data in tests"""
test_api_key: str | None = Field(default=None)
class OpenAIMixinWithProviderData(OpenAIMixinImpl):
"""Test implementation that supports provider data API key field"""
provider_data_api_key_field: str = "test_api_key"
def get_api_key(self) -> str:
return "default-api-key"
def get_base_url(self):
return "default-base-url"
class CustomListProviderModelIdsImplementation(OpenAIMixinImpl):
"""Test implementation with custom list_provider_model_ids override"""
custom_model_ids: Any
async def list_provider_model_ids(self) -> Iterable[str]:
"""Return custom model IDs list"""
return self.custom_model_ids
class TestOpenAIMixinCustomListProviderModelIds:
"""Test cases for custom list_provider_model_ids() implementation functionality"""
@pytest.fixture
def custom_model_ids_list(self):
"""Create a list of custom model ID strings"""
return ["custom-model-1", "custom-model-2", "custom-embedding"]
@pytest.fixture
def config(self):
"""Create RemoteInferenceProviderConfig instance"""
return RemoteInferenceProviderConfig()
@pytest.fixture
def adapter(self, custom_model_ids_list, config):
"""Create mixin instance with custom list_provider_model_ids implementation"""
mixin = CustomListProviderModelIdsImplementation(config=config, custom_model_ids=custom_model_ids_list)
mixin.embedding_model_metadata = {"custom-embedding": {"embedding_dimension": 768, "context_length": 512}}
return mixin
async def test_is_used(self, adapter, custom_model_ids_list):
"""Test that custom list_provider_model_ids() implementation is used instead of client.models.list()"""
result = await adapter.list_models()
assert result is not None
assert len(result) == 3
assert set(custom_model_ids_list) == {m.identifier for m in result}
async def test_populates_cache(self, adapter, custom_model_ids_list):
"""Test that custom list_provider_model_ids() results are cached"""
assert len(adapter._model_cache) == 0
await adapter.list_models()
assert set(custom_model_ids_list) == set(adapter._model_cache.keys())
async def test_respects_allowed_models(self, config):
"""Test that custom list_provider_model_ids() respects allowed_models filtering"""
mixin = CustomListProviderModelIdsImplementation(
config=config, custom_model_ids=["model-1", "model-2", "model-3"]
)
mixin.allowed_models = ["model-1"]
result = await mixin.list_models()
assert result is not None
assert len(result) == 1
assert result[0].identifier == "model-1"
async def test_with_empty_list(self, config):
"""Test that custom list_provider_model_ids() handles empty list correctly"""
mixin = CustomListProviderModelIdsImplementation(config=config, custom_model_ids=[])
result = await mixin.list_models()
assert result is not None
assert len(result) == 0
assert len(mixin._model_cache) == 0
async def test_wrong_type_raises_error(self, config):
"""Test that list_provider_model_ids() returning unhashable items results in an error"""
mixin = CustomListProviderModelIdsImplementation(
config=config, custom_model_ids=["valid-model", ["nested", "list"]]
)
with pytest.raises(Exception, match="is not a string"):
await mixin.list_models()
mixin = CustomListProviderModelIdsImplementation(
config=config, custom_model_ids=[{"key": "value"}, "valid-model"]
)
with pytest.raises(Exception, match="is not a string"):
await mixin.list_models()
mixin = CustomListProviderModelIdsImplementation(config=config, custom_model_ids=["valid-model", 42.0])
with pytest.raises(Exception, match="is not a string"):
await mixin.list_models()
mixin = CustomListProviderModelIdsImplementation(config=config, custom_model_ids=[None])
with pytest.raises(Exception, match="is not a string"):
await mixin.list_models()
async def test_non_iterable_raises_error(self, config):
"""Test that list_provider_model_ids() returning non-iterable type raises error"""
mixin = CustomListProviderModelIdsImplementation(config=config, custom_model_ids=42)
with pytest.raises(
TypeError,
match=r"Failed to list models: CustomListProviderModelIdsImplementation\.list_provider_model_ids\(\) must return an iterable.*but returned int",
):
await mixin.list_models()
async def test_accepts_various_iterables(self, config):
"""Test that list_provider_model_ids() accepts tuples, sets, generators, etc."""
tuples = CustomListProviderModelIdsImplementation(
config=config, custom_model_ids=("model-1", "model-2", "model-3")
)
result = await tuples.list_models()
assert result is not None
assert len(result) == 3
class GeneratorAdapter(OpenAIMixinImpl):
async def list_provider_model_ids(self) -> Iterable[str]:
def gen():
yield "gen-model-1"
yield "gen-model-2"
return gen()
mixin = GeneratorAdapter(config=config)
result = await mixin.list_models()
assert result is not None
assert len(result) == 2
sets = CustomListProviderModelIdsImplementation(config=config, custom_model_ids={"set-model-1", "set-model-2"})
result = await sets.list_models()
assert result is not None
assert len(result) == 2
class TestOpenAIMixinProviderDataApiKey:
"""Test cases for provider_data_api_key_field functionality"""
@pytest.fixture
def mixin_with_provider_data_field(self):
"""Mixin instance with provider_data_api_key_field set"""
config = RemoteInferenceProviderConfig()
mixin_instance = OpenAIMixinWithProviderData(config=config)
# Mock provider_spec for provider data validation
mock_provider_spec = MagicMock()
mock_provider_spec.provider_type = "test-provider-with-data"
mock_provider_spec.provider_data_validator = (
"tests.unit.providers.utils.inference.test_openai_mixin.ProviderDataValidator"
)
mixin_instance.__provider_spec__ = mock_provider_spec
return mixin_instance
@pytest.fixture
def mixin_with_provider_data_field_and_none_api_key(self, mixin_with_provider_data_field):
mixin_with_provider_data_field.get_api_key = Mock(return_value=None)
return mixin_with_provider_data_field
def test_no_provider_data(self, mixin_with_provider_data_field):
"""Test that client uses config API key when no provider data is available"""
assert mixin_with_provider_data_field.client.api_key == "default-api-key"
def test_with_provider_data(self, mixin_with_provider_data_field):
"""Test that provider data API key overrides config API key"""
with request_provider_data_context(
{"x-llamastack-provider-data": json.dumps({"test_api_key": "provider-data-key"})}
):
assert mixin_with_provider_data_field.client.api_key == "provider-data-key"
def test_with_wrong_key(self, mixin_with_provider_data_field):
"""Test fallback to config when provider data doesn't have the required key"""
with request_provider_data_context({"x-llamastack-provider-data": json.dumps({"wrong_key": "some-value"})}):
assert mixin_with_provider_data_field.client.api_key == "default-api-key"
def test_error_when_no_config_and_provider_data_has_wrong_key(
self, mixin_with_provider_data_field_and_none_api_key
):
"""Test that ValueError is raised when provider data exists but doesn't have required key"""
with request_provider_data_context({"x-llamastack-provider-data": json.dumps({"wrong_key": "some-value"})}):
with pytest.raises(ValueError, match="API key is not set"):
_ = mixin_with_provider_data_field_and_none_api_key.client
def test_error_message_includes_correct_field_names(self, mixin_with_provider_data_field_and_none_api_key):
"""Test that error message includes correct field name and header information"""
with pytest.raises(ValueError) as exc_info:
_ = mixin_with_provider_data_field_and_none_api_key.client
error_message = str(exc_info.value)
assert "test_api_key" in error_message
assert "x-llamastack-provider-data" in error_message

View file

@ -0,0 +1,381 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
"""
Unit tests for OpenAI compatibility tool conversion.
Tests convert_tooldef_to_openai_tool with new JSON Schema approach.
"""
from llama_stack.models.llama.datatypes import BuiltinTool, ToolDefinition
from llama_stack.providers.utils.inference.openai_compat import convert_tooldef_to_openai_tool
class TestSimpleSchemaConversion:
"""Test basic schema conversions to OpenAI format."""
def test_simple_tool_conversion(self):
"""Test conversion of simple tool with basic input schema."""
tool = ToolDefinition(
tool_name="get_weather",
description="Get weather information",
input_schema={
"type": "object",
"properties": {"location": {"type": "string", "description": "City name"}},
"required": ["location"],
},
)
result = convert_tooldef_to_openai_tool(tool)
# Check OpenAI structure
assert result["type"] == "function"
assert "function" in result
function = result["function"]
assert function["name"] == "get_weather"
assert function["description"] == "Get weather information"
# Check parameters are passed through
assert "parameters" in function
assert function["parameters"] == tool.input_schema
assert function["parameters"]["type"] == "object"
assert "location" in function["parameters"]["properties"]
def test_tool_without_description(self):
"""Test tool conversion without description."""
tool = ToolDefinition(tool_name="test_tool", input_schema={"type": "object", "properties": {}})
result = convert_tooldef_to_openai_tool(tool)
assert result["function"]["name"] == "test_tool"
assert "description" not in result["function"]
assert "parameters" in result["function"]
def test_builtin_tool_conversion(self):
"""Test conversion of BuiltinTool enum."""
tool = ToolDefinition(
tool_name=BuiltinTool.code_interpreter,
description="Run Python code",
input_schema={"type": "object", "properties": {"code": {"type": "string"}}},
)
result = convert_tooldef_to_openai_tool(tool)
# BuiltinTool should be converted to its value
assert result["function"]["name"] == "code_interpreter"
class TestComplexSchemaConversion:
"""Test conversion of complex JSON Schema features."""
def test_schema_with_refs_and_defs(self):
"""Test that $ref and $defs are passed through to OpenAI."""
tool = ToolDefinition(
tool_name="book_flight",
description="Book a flight",
input_schema={
"type": "object",
"properties": {
"flight": {"$ref": "#/$defs/FlightInfo"},
"passengers": {"type": "array", "items": {"$ref": "#/$defs/Passenger"}},
"payment": {"$ref": "#/$defs/Payment"},
},
"required": ["flight", "passengers", "payment"],
"$defs": {
"FlightInfo": {
"type": "object",
"properties": {
"from": {"type": "string", "description": "Departure airport"},
"to": {"type": "string", "description": "Arrival airport"},
"date": {"type": "string", "format": "date"},
},
"required": ["from", "to", "date"],
},
"Passenger": {
"type": "object",
"properties": {"name": {"type": "string"}, "age": {"type": "integer", "minimum": 0}},
"required": ["name", "age"],
},
"Payment": {
"type": "object",
"properties": {
"method": {"type": "string", "enum": ["credit_card", "debit_card"]},
"amount": {"type": "number", "minimum": 0},
},
},
},
},
)
result = convert_tooldef_to_openai_tool(tool)
params = result["function"]["parameters"]
# Verify $defs are preserved
assert "$defs" in params
assert "FlightInfo" in params["$defs"]
assert "Passenger" in params["$defs"]
assert "Payment" in params["$defs"]
# Verify $ref are preserved
assert params["properties"]["flight"]["$ref"] == "#/$defs/FlightInfo"
assert params["properties"]["passengers"]["items"]["$ref"] == "#/$defs/Passenger"
assert params["properties"]["payment"]["$ref"] == "#/$defs/Payment"
# Verify nested schema details are preserved
assert params["$defs"]["FlightInfo"]["properties"]["date"]["format"] == "date"
assert params["$defs"]["Passenger"]["properties"]["age"]["minimum"] == 0
assert params["$defs"]["Payment"]["properties"]["method"]["enum"] == ["credit_card", "debit_card"]
def test_anyof_schema_conversion(self):
"""Test conversion of anyOf schemas."""
tool = ToolDefinition(
tool_name="flexible_input",
input_schema={
"type": "object",
"properties": {
"contact": {
"anyOf": [
{"type": "string", "format": "email"},
{"type": "string", "pattern": "^\\+?[0-9]{10,15}$"},
],
"description": "Email or phone number",
}
},
},
)
result = convert_tooldef_to_openai_tool(tool)
contact_schema = result["function"]["parameters"]["properties"]["contact"]
assert "anyOf" in contact_schema
assert len(contact_schema["anyOf"]) == 2
assert contact_schema["anyOf"][0]["format"] == "email"
assert "pattern" in contact_schema["anyOf"][1]
def test_nested_objects_conversion(self):
"""Test conversion of deeply nested objects."""
tool = ToolDefinition(
tool_name="nested_data",
input_schema={
"type": "object",
"properties": {
"user": {
"type": "object",
"properties": {
"profile": {
"type": "object",
"properties": {
"name": {"type": "string"},
"settings": {
"type": "object",
"properties": {"theme": {"type": "string", "enum": ["light", "dark"]}},
},
},
}
},
}
},
},
)
result = convert_tooldef_to_openai_tool(tool)
# Navigate deep structure
user_schema = result["function"]["parameters"]["properties"]["user"]
profile_schema = user_schema["properties"]["profile"]
settings_schema = profile_schema["properties"]["settings"]
theme_schema = settings_schema["properties"]["theme"]
assert theme_schema["enum"] == ["light", "dark"]
def test_array_schemas_with_constraints(self):
"""Test conversion of array schemas with constraints."""
tool = ToolDefinition(
tool_name="list_processor",
input_schema={
"type": "object",
"properties": {
"items": {
"type": "array",
"items": {
"type": "object",
"properties": {"id": {"type": "integer"}, "name": {"type": "string"}},
"required": ["id"],
},
"minItems": 1,
"maxItems": 100,
"uniqueItems": True,
}
},
},
)
result = convert_tooldef_to_openai_tool(tool)
items_schema = result["function"]["parameters"]["properties"]["items"]
assert items_schema["type"] == "array"
assert items_schema["minItems"] == 1
assert items_schema["maxItems"] == 100
assert items_schema["uniqueItems"] is True
assert items_schema["items"]["type"] == "object"
class TestOutputSchemaHandling:
"""Test that output_schema is correctly handled (or dropped) for OpenAI."""
def test_output_schema_is_dropped(self):
"""Test that output_schema is NOT included in OpenAI format (API limitation)."""
tool = ToolDefinition(
tool_name="calculator",
description="Perform calculation",
input_schema={"type": "object", "properties": {"x": {"type": "number"}, "y": {"type": "number"}}},
output_schema={"type": "object", "properties": {"result": {"type": "number"}}, "required": ["result"]},
)
result = convert_tooldef_to_openai_tool(tool)
# OpenAI doesn't support output schema
assert "outputSchema" not in result["function"]
assert "responseSchema" not in result["function"]
assert "output_schema" not in result["function"]
# But input schema should be present
assert "parameters" in result["function"]
assert result["function"]["parameters"] == tool.input_schema
def test_only_output_schema_no_input(self):
"""Test tool with only output_schema (unusual but valid)."""
tool = ToolDefinition(
tool_name="no_input_tool",
description="Tool with no inputs",
output_schema={"type": "object", "properties": {"timestamp": {"type": "string"}}},
)
result = convert_tooldef_to_openai_tool(tool)
# No parameters should be set if input_schema is None
# (or we might set an empty object schema - implementation detail)
assert "outputSchema" not in result["function"]
class TestEdgeCases:
"""Test edge cases and error conditions."""
def test_tool_with_no_schemas(self):
"""Test tool with neither input nor output schema."""
tool = ToolDefinition(tool_name="schemaless_tool", description="Tool without schemas")
result = convert_tooldef_to_openai_tool(tool)
assert result["function"]["name"] == "schemaless_tool"
assert result["function"]["description"] == "Tool without schemas"
# Implementation detail: might have no parameters or empty object
def test_empty_input_schema(self):
"""Test tool with empty object schema."""
tool = ToolDefinition(tool_name="no_params", input_schema={"type": "object", "properties": {}})
result = convert_tooldef_to_openai_tool(tool)
assert result["function"]["parameters"]["type"] == "object"
assert result["function"]["parameters"]["properties"] == {}
def test_schema_with_additional_properties(self):
"""Test that additionalProperties is preserved."""
tool = ToolDefinition(
tool_name="flexible_tool",
input_schema={
"type": "object",
"properties": {"known_field": {"type": "string"}},
"additionalProperties": True,
},
)
result = convert_tooldef_to_openai_tool(tool)
assert result["function"]["parameters"]["additionalProperties"] is True
def test_schema_with_pattern_properties(self):
"""Test that patternProperties is preserved."""
tool = ToolDefinition(
tool_name="pattern_tool",
input_schema={"type": "object", "patternProperties": {"^[a-z]+$": {"type": "string"}}},
)
result = convert_tooldef_to_openai_tool(tool)
assert "patternProperties" in result["function"]["parameters"]
def test_schema_identity(self):
"""Test that converted schema is identical to input (no lossy conversion)."""
original_schema = {
"type": "object",
"properties": {"complex": {"$ref": "#/$defs/Complex"}},
"$defs": {
"Complex": {
"type": "object",
"properties": {"nested": {"anyOf": [{"type": "string"}, {"type": "number"}]}},
}
},
"required": ["complex"],
"additionalProperties": False,
}
tool = ToolDefinition(tool_name="test", input_schema=original_schema)
result = convert_tooldef_to_openai_tool(tool)
# Converted parameters should be EXACTLY the same as input
assert result["function"]["parameters"] == original_schema
class TestConversionConsistency:
"""Test consistency across multiple conversions."""
def test_multiple_tools_with_shared_defs(self):
"""Test converting multiple tools that could share definitions."""
tool1 = ToolDefinition(
tool_name="tool1",
input_schema={
"type": "object",
"properties": {"data": {"$ref": "#/$defs/Data"}},
"$defs": {"Data": {"type": "object", "properties": {"x": {"type": "number"}}}},
},
)
tool2 = ToolDefinition(
tool_name="tool2",
input_schema={
"type": "object",
"properties": {"info": {"$ref": "#/$defs/Data"}},
"$defs": {"Data": {"type": "object", "properties": {"y": {"type": "string"}}}},
},
)
result1 = convert_tooldef_to_openai_tool(tool1)
result2 = convert_tooldef_to_openai_tool(tool2)
# Each tool maintains its own $defs independently
assert result1["function"]["parameters"]["$defs"]["Data"]["properties"]["x"]["type"] == "number"
assert result2["function"]["parameters"]["$defs"]["Data"]["properties"]["y"]["type"] == "string"
def test_conversion_is_pure(self):
"""Test that conversion doesn't modify the original tool."""
original_schema = {
"type": "object",
"properties": {"x": {"type": "string"}},
"$defs": {"T": {"type": "number"}},
}
tool = ToolDefinition(tool_name="test", input_schema=original_schema.copy())
# Convert
convert_tooldef_to_openai_tool(tool)
# Original tool should be unchanged
assert tool.input_schema == original_schema
assert "$defs" in tool.input_schema

View file

@ -26,13 +26,15 @@ from llama_stack.providers.remote.vector_io.milvus.milvus import MilvusIndex, Mi
from llama_stack.providers.remote.vector_io.pgvector.config import PGVectorVectorIOConfig
from llama_stack.providers.remote.vector_io.pgvector.pgvector import PGVectorIndex, PGVectorVectorIOAdapter
from llama_stack.providers.remote.vector_io.qdrant.qdrant import QdrantVectorIOAdapter
from llama_stack.providers.remote.vector_io.weaviate.config import WeaviateVectorIOConfig
from llama_stack.providers.remote.vector_io.weaviate.weaviate import WeaviateIndex, WeaviateVectorIOAdapter
EMBEDDING_DIMENSION = 384
COLLECTION_PREFIX = "test_collection"
MILVUS_ALIAS = "test_milvus"
@pytest.fixture(params=["milvus", "sqlite_vec", "faiss", "chroma", "pgvector"])
@pytest.fixture(params=["milvus", "sqlite_vec", "faiss", "chroma", "pgvector", "weaviate"])
def vector_provider(request):
return request.param
@ -448,6 +450,71 @@ async def pgvector_vec_adapter(mock_inference_api, embedding_dimension):
await adapter.shutdown()
@pytest.fixture(scope="session")
def weaviate_vec_db_path(tmp_path_factory):
db_path = str(tmp_path_factory.getbasetemp() / "test_weaviate.db")
return db_path
@pytest.fixture
async def weaviate_vec_index(weaviate_vec_db_path):
import pytest_socket
import weaviate
pytest_socket.enable_socket()
client = weaviate.connect_to_embedded(
hostname="localhost",
port=8080,
grpc_port=50051,
persistence_data_path=weaviate_vec_db_path,
)
index = WeaviateIndex(client=client, collection_name="Testcollection")
await index.initialize()
yield index
await index.delete()
client.close()
@pytest.fixture
async def weaviate_vec_adapter(weaviate_vec_db_path, mock_inference_api, embedding_dimension):
import pytest_socket
import weaviate
pytest_socket.enable_socket()
client = weaviate.connect_to_embedded(
hostname="localhost",
port=8080,
grpc_port=50051,
persistence_data_path=weaviate_vec_db_path,
)
config = WeaviateVectorIOConfig(
weaviate_cluster_url="localhost:8080",
weaviate_api_key=None,
kvstore=SqliteKVStoreConfig(),
)
adapter = WeaviateVectorIOAdapter(
config=config,
inference_api=mock_inference_api,
files_api=None,
)
collection_id = f"weaviate_test_collection_{random.randint(1, 1_000_000)}"
await adapter.initialize()
await adapter.register_vector_db(
VectorDB(
identifier=collection_id,
provider_id="test_provider",
embedding_model="test_model",
embedding_dimension=embedding_dimension,
)
)
adapter.test_collection_id = collection_id
yield adapter
await adapter.shutdown()
client.close()
@pytest.fixture
def vector_io_adapter(vector_provider, request):
vector_provider_dict = {
@ -457,6 +524,7 @@ def vector_io_adapter(vector_provider, request):
"chroma": "chroma_vec_adapter",
"qdrant": "qdrant_vec_adapter",
"pgvector": "pgvector_vec_adapter",
"weaviate": "weaviate_vec_adapter",
}
return request.getfixturevalue(vector_provider_dict[vector_provider])

View file

@ -0,0 +1,297 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
"""
Unit tests for JSON Schema-based tool definitions.
Tests the new input_schema and output_schema fields.
"""
from pydantic import ValidationError
from llama_stack.apis.tools import ToolDef
from llama_stack.models.llama.datatypes import BuiltinTool, ToolDefinition
class TestToolDefValidation:
"""Test ToolDef validation with JSON Schema."""
def test_simple_input_schema(self):
"""Test ToolDef with simple input schema."""
tool = ToolDef(
name="get_weather",
description="Get weather information",
input_schema={
"type": "object",
"properties": {"location": {"type": "string", "description": "City name"}},
"required": ["location"],
},
)
assert tool.name == "get_weather"
assert tool.input_schema["type"] == "object"
assert "location" in tool.input_schema["properties"]
assert tool.output_schema is None
def test_input_and_output_schema(self):
"""Test ToolDef with both input and output schemas."""
tool = ToolDef(
name="calculate",
description="Perform calculation",
input_schema={
"type": "object",
"properties": {"x": {"type": "number"}, "y": {"type": "number"}},
"required": ["x", "y"],
},
output_schema={"type": "object", "properties": {"result": {"type": "number"}}, "required": ["result"]},
)
assert tool.input_schema is not None
assert tool.output_schema is not None
assert "result" in tool.output_schema["properties"]
def test_schema_with_refs_and_defs(self):
"""Test that $ref and $defs are preserved in schemas."""
tool = ToolDef(
name="book_flight",
description="Book a flight",
input_schema={
"type": "object",
"properties": {
"flight": {"$ref": "#/$defs/FlightInfo"},
"passengers": {"type": "array", "items": {"$ref": "#/$defs/Passenger"}},
},
"$defs": {
"FlightInfo": {
"type": "object",
"properties": {"from": {"type": "string"}, "to": {"type": "string"}},
},
"Passenger": {
"type": "object",
"properties": {"name": {"type": "string"}, "age": {"type": "integer"}},
},
},
},
)
# Verify $defs are preserved
assert "$defs" in tool.input_schema
assert "FlightInfo" in tool.input_schema["$defs"]
assert "Passenger" in tool.input_schema["$defs"]
# Verify $ref are preserved
assert tool.input_schema["properties"]["flight"]["$ref"] == "#/$defs/FlightInfo"
assert tool.input_schema["properties"]["passengers"]["items"]["$ref"] == "#/$defs/Passenger"
def test_output_schema_with_refs(self):
"""Test that output_schema also supports $ref and $defs."""
tool = ToolDef(
name="search",
description="Search for items",
input_schema={"type": "object", "properties": {"query": {"type": "string"}}},
output_schema={
"type": "object",
"properties": {"results": {"type": "array", "items": {"$ref": "#/$defs/SearchResult"}}},
"$defs": {
"SearchResult": {
"type": "object",
"properties": {"title": {"type": "string"}, "score": {"type": "number"}},
}
},
},
)
assert "$defs" in tool.output_schema
assert "SearchResult" in tool.output_schema["$defs"]
def test_complex_json_schema_features(self):
"""Test various JSON Schema features are preserved."""
tool = ToolDef(
name="complex_tool",
description="Tool with complex schema",
input_schema={
"type": "object",
"properties": {
# anyOf
"contact": {
"anyOf": [
{"type": "string", "format": "email"},
{"type": "string", "pattern": "^\\+?[0-9]{10,15}$"},
]
},
# enum
"status": {"type": "string", "enum": ["pending", "approved", "rejected"]},
# nested objects
"address": {
"type": "object",
"properties": {
"street": {"type": "string"},
"city": {"type": "string"},
"zipcode": {"type": "string", "pattern": "^[0-9]{5}$"},
},
"required": ["street", "city"],
},
# array with constraints
"tags": {
"type": "array",
"items": {"type": "string"},
"minItems": 1,
"maxItems": 10,
"uniqueItems": True,
},
},
},
)
# Verify anyOf
assert "anyOf" in tool.input_schema["properties"]["contact"]
# Verify enum
assert tool.input_schema["properties"]["status"]["enum"] == ["pending", "approved", "rejected"]
# Verify nested object
assert tool.input_schema["properties"]["address"]["type"] == "object"
assert "zipcode" in tool.input_schema["properties"]["address"]["properties"]
# Verify array constraints
tags_schema = tool.input_schema["properties"]["tags"]
assert tags_schema["minItems"] == 1
assert tags_schema["maxItems"] == 10
assert tags_schema["uniqueItems"] is True
def test_invalid_json_schema_raises_error(self):
"""Test that invalid JSON Schema raises validation error."""
# TODO: This test will pass once we add schema validation
# For now, Pydantic accepts any dict, so this is a placeholder
# This should eventually raise an error due to invalid schema
try:
ToolDef(
name="bad_tool",
input_schema={
"type": "invalid_type", # Not a valid JSON Schema type
"properties": "not_an_object", # Should be an object
},
)
# For now this passes, but shouldn't after we add validation
except ValidationError:
pass # Expected once validation is added
class TestToolDefinitionValidation:
"""Test ToolDefinition (internal) validation with JSON Schema."""
def test_simple_tool_definition(self):
"""Test ToolDefinition with simple schema."""
tool = ToolDefinition(
tool_name="get_time",
description="Get current time",
input_schema={"type": "object", "properties": {"timezone": {"type": "string"}}},
)
assert tool.tool_name == "get_time"
assert tool.input_schema is not None
def test_builtin_tool_with_schema(self):
"""Test ToolDefinition with BuiltinTool enum."""
tool = ToolDefinition(
tool_name=BuiltinTool.code_interpreter,
description="Run Python code",
input_schema={"type": "object", "properties": {"code": {"type": "string"}}, "required": ["code"]},
output_schema={"type": "object", "properties": {"output": {"type": "string"}, "error": {"type": "string"}}},
)
assert isinstance(tool.tool_name, BuiltinTool)
assert tool.input_schema is not None
assert tool.output_schema is not None
def test_tool_definition_with_refs(self):
"""Test ToolDefinition preserves $ref/$defs."""
tool = ToolDefinition(
tool_name="process_data",
input_schema={
"type": "object",
"properties": {"data": {"$ref": "#/$defs/DataObject"}},
"$defs": {
"DataObject": {
"type": "object",
"properties": {
"id": {"type": "integer"},
"values": {"type": "array", "items": {"type": "number"}},
},
}
},
},
)
assert "$defs" in tool.input_schema
assert tool.input_schema["properties"]["data"]["$ref"] == "#/$defs/DataObject"
class TestSchemaEquivalence:
"""Test that schemas remain unchanged through serialization."""
def test_schema_roundtrip(self):
"""Test that schemas survive model_dump/model_validate roundtrip."""
original = ToolDef(
name="test",
input_schema={
"type": "object",
"properties": {"x": {"$ref": "#/$defs/X"}},
"$defs": {"X": {"type": "string"}},
},
)
# Serialize and deserialize
dumped = original.model_dump()
restored = ToolDef(**dumped)
# Schemas should be identical
assert restored.input_schema == original.input_schema
assert "$defs" in restored.input_schema
assert restored.input_schema["properties"]["x"]["$ref"] == "#/$defs/X"
def test_json_serialization(self):
"""Test JSON serialization preserves schema."""
import json
tool = ToolDef(
name="test",
input_schema={
"type": "object",
"properties": {"a": {"type": "string"}},
"$defs": {"T": {"type": "number"}},
},
output_schema={"type": "object", "properties": {"b": {"$ref": "#/$defs/T"}}},
)
# Serialize to JSON and back
json_str = tool.model_dump_json()
parsed = json.loads(json_str)
restored = ToolDef(**parsed)
assert restored.input_schema == tool.input_schema
assert restored.output_schema == tool.output_schema
assert "$defs" in restored.input_schema
class TestBackwardsCompatibility:
"""Test handling of legacy code patterns."""
def test_none_schemas(self):
"""Test tools with no schemas (legacy case)."""
tool = ToolDef(name="legacy_tool", description="Tool without schemas", input_schema=None, output_schema=None)
assert tool.input_schema is None
assert tool.output_schema is None
def test_metadata_preserved(self):
"""Test that metadata field still works."""
tool = ToolDef(
name="test", input_schema={"type": "object"}, metadata={"endpoint": "http://example.com", "version": "1.0"}
)
assert tool.metadata["endpoint"] == "http://example.com"
assert tool.metadata["version"] == "1.0"

View file

@ -14,6 +14,7 @@ from llama_stack.apis.agents.openai_responses import (
OpenAIResponseInput,
OpenAIResponseObject,
)
from llama_stack.apis.inference import OpenAIMessageParam, OpenAIUserMessageParam
from llama_stack.providers.utils.responses.responses_store import ResponsesStore
from llama_stack.providers.utils.sqlstore.sqlstore import SqliteSqlStoreConfig
@ -44,6 +45,11 @@ def create_test_response_input(content: str, input_id: str) -> OpenAIResponseInp
)
def create_test_messages(content: str) -> list[OpenAIMessageParam]:
"""Helper to create test messages for chat completion."""
return [OpenAIUserMessageParam(content=content)]
async def test_responses_store_pagination_basic():
"""Test basic pagination functionality for responses store."""
with TemporaryDirectory() as tmp_dir:
@ -65,7 +71,8 @@ async def test_responses_store_pagination_basic():
for response_id, timestamp in test_data:
response = create_test_response_object(response_id, timestamp)
input_list = [create_test_response_input(f"Input for {response_id}", f"input-{response_id}")]
await store.store_response_object(response, input_list)
messages = create_test_messages(f"Input for {response_id}")
await store.store_response_object(response, input_list, messages)
# Wait for all queued writes to complete
await store.flush()
@ -111,7 +118,8 @@ async def test_responses_store_pagination_ascending():
for response_id, timestamp in test_data:
response = create_test_response_object(response_id, timestamp)
input_list = [create_test_response_input(f"Input for {response_id}", f"input-{response_id}")]
await store.store_response_object(response, input_list)
messages = create_test_messages(f"Input for {response_id}")
await store.store_response_object(response, input_list, messages)
# Wait for all queued writes to complete
await store.flush()
@ -149,7 +157,8 @@ async def test_responses_store_pagination_with_model_filter():
for response_id, timestamp, model in test_data:
response = create_test_response_object(response_id, timestamp, model)
input_list = [create_test_response_input(f"Input for {response_id}", f"input-{response_id}")]
await store.store_response_object(response, input_list)
messages = create_test_messages(f"Input for {response_id}")
await store.store_response_object(response, input_list, messages)
# Wait for all queued writes to complete
await store.flush()
@ -199,7 +208,8 @@ async def test_responses_store_pagination_no_limit():
for response_id, timestamp in test_data:
response = create_test_response_object(response_id, timestamp)
input_list = [create_test_response_input(f"Input for {response_id}", f"input-{response_id}")]
await store.store_response_object(response, input_list)
messages = create_test_messages(f"Input for {response_id}")
await store.store_response_object(response, input_list, messages)
# Wait for all queued writes to complete
await store.flush()
@ -222,7 +232,8 @@ async def test_responses_store_get_response_object():
# Store a test response
response = create_test_response_object("test-resp", int(time.time()))
input_list = [create_test_response_input("Test input content", "input-test-resp")]
await store.store_response_object(response, input_list)
messages = create_test_messages("Test input content")
await store.store_response_object(response, input_list, messages)
# Wait for all queued writes to complete
await store.flush()
@ -255,7 +266,8 @@ async def test_responses_store_input_items_pagination():
create_test_response_input("Fourth input", "input-4"),
create_test_response_input("Fifth input", "input-5"),
]
await store.store_response_object(response, input_list)
messages = create_test_messages("First input")
await store.store_response_object(response, input_list, messages)
# Wait for all queued writes to complete
await store.flush()
@ -335,7 +347,8 @@ async def test_responses_store_input_items_before_pagination():
create_test_response_input("Fourth input", "before-4"),
create_test_response_input("Fifth input", "before-5"),
]
await store.store_response_object(response, input_list)
messages = create_test_messages("First input")
await store.store_response_object(response, input_list, messages)
# Wait for all queued writes to complete
await store.flush()

View file

@ -368,6 +368,32 @@ async def test_where_operator_gt_and_update_delete():
assert {r["id"] for r in rows_after} == {1, 3}
async def test_batch_insert():
with TemporaryDirectory() as tmp_dir:
db_path = tmp_dir + "/test.db"
store = SqlAlchemySqlStoreImpl(SqliteSqlStoreConfig(db_path=db_path))
await store.create_table(
"batch_test",
{
"id": ColumnType.INTEGER,
"name": ColumnType.STRING,
"value": ColumnType.INTEGER,
},
)
batch_data = [
{"id": 1, "name": "first", "value": 10},
{"id": 2, "name": "second", "value": 20},
{"id": 3, "name": "third", "value": 30},
]
await store.insert("batch_test", batch_data)
result = await store.fetch_all("batch_test", order_by=[("id", "asc")])
assert result.data == batch_data
async def test_where_operator_edge_cases():
with TemporaryDirectory() as tmp_dir:
db_path = tmp_dir + "/test.db"