Merge branch 'main' into feat/gunicorn-production-server

This commit is contained in:
Roy Belio 2025-11-24 12:08:57 +02:00 committed by GitHub
commit 893d49c59e
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
2086 changed files with 133277 additions and 643859 deletions

View file

@ -5,18 +5,7 @@
# the root directory of this source tree.
from llama_stack.apis.conversations.conversations import (
Conversation,
ConversationCreateRequest,
ConversationItem,
ConversationItemList,
)
def test_conversation_create_request_defaults():
request = ConversationCreateRequest()
assert request.items == []
assert request.metadata == {}
from llama_stack_api import Conversation, ConversationItem, ConversationItemList
def test_conversation_model_defaults():

View file

@ -12,10 +12,6 @@ from openai.types.conversations.conversation import Conversation as OpenAIConver
from openai.types.conversations.conversation_item import ConversationItem as OpenAIConversationItem
from pydantic import TypeAdapter
from llama_stack.apis.agents.openai_responses import (
OpenAIResponseInputMessageContentText,
OpenAIResponseMessage,
)
from llama_stack.core.conversations.conversations import (
ConversationServiceConfig,
ConversationServiceImpl,
@ -27,7 +23,8 @@ from llama_stack.core.storage.datatypes import (
SqlStoreReference,
StorageConfig,
)
from llama_stack.providers.utils.sqlstore.sqlstore import register_sqlstore_backends
from llama_stack.core.storage.sqlstore.sqlstore import register_sqlstore_backends
from llama_stack_api import OpenAIResponseInputMessageContentText, OpenAIResponseMessage
@pytest.fixture
@ -41,6 +38,9 @@ async def service():
},
stores=ServerStoresConfig(
conversations=SqlStoreReference(backend="sql_test", table_name="openai_conversations"),
metadata=None,
inference=None,
prompts=None,
),
)
register_sqlstore_backends({"sql_test": storage.backends["sql_test"]})
@ -145,6 +145,9 @@ async def test_policy_configuration():
},
stores=ServerStoresConfig(
conversations=SqlStoreReference(backend="sql_test", table_name="openai_conversations"),
metadata=None,
inference=None,
prompts=None,
),
)
register_sqlstore_backends({"sql_test": storage.backends["sql_test"]})

View file

@ -6,10 +6,9 @@
from unittest.mock import AsyncMock
from llama_stack.apis.safety.safety import ModerationObject, ModerationObjectResults
from llama_stack.apis.shields import ListShieldsResponse, Shield
from llama_stack.core.datatypes import SafetyConfig
from llama_stack.core.routers.safety import SafetyRouter
from llama_stack_api import ListShieldsResponse, ModerationObject, ModerationObjectResults, Shield
async def test_run_moderation_uses_default_shield_when_model_missing():

View file

@ -8,8 +8,13 @@ from unittest.mock import AsyncMock, Mock
import pytest
from llama_stack.apis.vector_io import OpenAICreateVectorStoreRequestWithExtraBody
from llama_stack.core.routers.vector_io import VectorIORouter
from llama_stack_api import (
ModelNotFoundError,
ModelType,
ModelTypeError,
OpenAICreateVectorStoreRequestWithExtraBody,
)
async def test_single_provider_auto_selection():
@ -21,6 +26,7 @@ async def test_single_provider_auto_selection():
Mock(identifier="all-MiniLM-L6-v2", model_type="embedding", metadata={"embedding_dimension": 384})
]
)
mock_routing_table.get_object_by_identifier = AsyncMock(return_value=Mock(model_type=ModelType.embedding))
mock_routing_table.register_vector_store = AsyncMock(
return_value=Mock(identifier="vs_123", provider_id="inline::faiss", provider_resource_id="vs_123")
)
@ -48,6 +54,7 @@ async def test_create_vector_stores_multiple_providers_missing_provider_id_error
Mock(identifier="all-MiniLM-L6-v2", model_type="embedding", metadata={"embedding_dimension": 384})
]
)
mock_routing_table.get_object_by_identifier = AsyncMock(return_value=Mock(model_type=ModelType.embedding))
router = VectorIORouter(mock_routing_table)
request = OpenAICreateVectorStoreRequestWithExtraBody.model_validate(
{"name": "test_store", "embedding_model": "all-MiniLM-L6-v2"}
@ -55,3 +62,94 @@ async def test_create_vector_stores_multiple_providers_missing_provider_id_error
with pytest.raises(ValueError, match="Multiple vector_io providers available"):
await router.openai_create_vector_store(request)
async def test_update_vector_store_provider_id_change_fails():
"""Test that updating a vector store with a different provider_id fails with clear error."""
mock_routing_table = Mock()
# Mock an existing vector store with provider_id "faiss"
mock_existing_store = Mock()
mock_existing_store.provider_id = "inline::faiss"
mock_existing_store.identifier = "vs_123"
mock_routing_table.get_object_by_identifier = AsyncMock(return_value=mock_existing_store)
mock_routing_table.get_provider_impl = AsyncMock(
return_value=Mock(openai_update_vector_store=AsyncMock(return_value=Mock(id="vs_123")))
)
router = VectorIORouter(mock_routing_table)
# Try to update with different provider_id in metadata - this should fail
with pytest.raises(ValueError, match="provider_id cannot be changed after vector store creation"):
await router.openai_update_vector_store(
vector_store_id="vs_123",
name="updated_name",
metadata={"provider_id": "inline::sqlite"}, # Different provider_id
)
# Verify the existing store was looked up to check provider_id
mock_routing_table.get_object_by_identifier.assert_called_once_with("vector_store", "vs_123")
# Provider should not be called since validation failed
mock_routing_table.get_provider_impl.assert_not_called()
async def test_update_vector_store_same_provider_id_succeeds():
"""Test that updating a vector store with the same provider_id succeeds."""
mock_routing_table = Mock()
# Mock an existing vector store with provider_id "faiss"
mock_existing_store = Mock()
mock_existing_store.provider_id = "inline::faiss"
mock_existing_store.identifier = "vs_123"
mock_routing_table.get_object_by_identifier = AsyncMock(return_value=mock_existing_store)
mock_routing_table.get_provider_impl = AsyncMock(
return_value=Mock(openai_update_vector_store=AsyncMock(return_value=Mock(id="vs_123")))
)
router = VectorIORouter(mock_routing_table)
# Update with same provider_id should succeed
await router.openai_update_vector_store(
vector_store_id="vs_123",
name="updated_name",
metadata={"provider_id": "inline::faiss"}, # Same provider_id
)
# Verify the provider update method was called
mock_routing_table.get_provider_impl.assert_called_once_with("vs_123")
provider = await mock_routing_table.get_provider_impl("vs_123")
provider.openai_update_vector_store.assert_called_once_with(
vector_store_id="vs_123", name="updated_name", expires_after=None, metadata={"provider_id": "inline::faiss"}
)
async def test_create_vector_store_with_unknown_embedding_model_raises_error():
"""Test that creating a vector store with an unknown embedding model raises
FoundError."""
mock_routing_table = Mock(impls_by_provider_id={"provider": "mock"})
mock_routing_table.get_object_by_identifier = AsyncMock(return_value=None)
router = VectorIORouter(mock_routing_table)
request = OpenAICreateVectorStoreRequestWithExtraBody.model_validate(
{"embedding_model": "unknown-model", "embedding_dimension": 384}
)
with pytest.raises(ModelNotFoundError, match="Model 'unknown-model' not found"):
await router.openai_create_vector_store(request)
async def test_create_vector_store_with_wrong_model_type_raises_error():
"""Test that creating a vector store with a non-embedding model raises ModelTypeError."""
mock_routing_table = Mock(impls_by_provider_id={"provider": "mock"})
mock_routing_table.get_object_by_identifier = AsyncMock(return_value=Mock(model_type=ModelType.llm))
router = VectorIORouter(mock_routing_table)
request = OpenAICreateVectorStoreRequestWithExtraBody.model_validate(
{"embedding_model": "text-model", "embedding_dimension": 384}
)
with pytest.raises(ModelTypeError, match="Model 'text-model' is of type"):
await router.openai_create_vector_store(request)

View file

@ -10,11 +10,10 @@ from unittest.mock import AsyncMock
import pytest
from llama_stack.apis.models import ListModelsResponse, Model, ModelType
from llama_stack.apis.shields import ListShieldsResponse, Shield
from llama_stack.core.datatypes import QualifiedModel, SafetyConfig, StackRunConfig, StorageConfig, VectorStoresConfig
from llama_stack.core.datatypes import QualifiedModel, SafetyConfig, StackRunConfig, VectorStoresConfig
from llama_stack.core.stack import validate_safety_config, validate_vector_stores_config
from llama_stack.providers.datatypes import Api
from llama_stack.core.storage.datatypes import ServerStoresConfig, StorageConfig
from llama_stack_api import Api, ListModelsResponse, ListShieldsResponse, Model, ModelType, Shield
class TestVectorStoresValidation:
@ -23,7 +22,15 @@ class TestVectorStoresValidation:
run_config = StackRunConfig(
image_name="test",
providers={},
storage=StorageConfig(backends={}, stores={}),
storage=StorageConfig(
backends={},
stores=ServerStoresConfig(
metadata=None,
inference=None,
conversations=None,
prompts=None,
),
),
vector_stores=VectorStoresConfig(
default_provider_id="faiss",
default_embedding_model=QualifiedModel(
@ -43,7 +50,15 @@ class TestVectorStoresValidation:
run_config = StackRunConfig(
image_name="test",
providers={},
storage=StorageConfig(backends={}, stores={}),
storage=StorageConfig(
backends={},
stores=ServerStoresConfig(
metadata=None,
inference=None,
conversations=None,
prompts=None,
),
),
vector_stores=VectorStoresConfig(
default_provider_id="faiss",
default_embedding_model=QualifiedModel(

View file

@ -10,14 +10,6 @@ from unittest.mock import AsyncMock
import pytest
from llama_stack.apis.common.content_types import URL
from llama_stack.apis.common.errors import ModelNotFoundError
from llama_stack.apis.common.type_system import NumberType
from llama_stack.apis.datasets.datasets import Dataset, DatasetPurpose, URIDataSource
from llama_stack.apis.datatypes import Api
from llama_stack.apis.models import Model, ModelType
from llama_stack.apis.shields.shields import Shield
from llama_stack.apis.tools import ListToolDefsResponse, ToolDef, ToolGroup
from llama_stack.core.datatypes import RegistryEntrySource
from llama_stack.core.routing_tables.benchmarks import BenchmarksRoutingTable
from llama_stack.core.routing_tables.datasets import DatasetsRoutingTable
@ -25,6 +17,21 @@ from llama_stack.core.routing_tables.models import ModelsRoutingTable
from llama_stack.core.routing_tables.scoring_functions import ScoringFunctionsRoutingTable
from llama_stack.core.routing_tables.shields import ShieldsRoutingTable
from llama_stack.core.routing_tables.toolgroups import ToolGroupsRoutingTable
from llama_stack_api import (
URL,
Api,
Dataset,
DatasetPurpose,
ListToolDefsResponse,
Model,
ModelNotFoundError,
ModelType,
NumberType,
Shield,
ToolDef,
ToolGroup,
URIDataSource,
)
class Impl:
@ -130,7 +137,7 @@ class ToolGroupsImpl(Impl):
async def unregister_toolgroup(self, toolgroup_id: str):
return toolgroup_id
async def list_runtime_tools(self, toolgroup_id, mcp_endpoint):
async def list_runtime_tools(self, toolgroup_id, mcp_endpoint, authorization=None):
return ListToolDefsResponse(
data=[
ToolDef(

View file

@ -11,8 +11,15 @@ from unittest.mock import patch
import pytest
from openai import AsyncOpenAI
from llama_stack.testing.api_recorder import (
APIRecordingMode,
ResponseStorage,
api_recording,
normalize_inference_request,
)
# Import the real Pydantic response types instead of using Mocks
from llama_stack.apis.inference import (
from llama_stack_api import (
OpenAIAssistantMessageParam,
OpenAIChatCompletion,
OpenAIChoice,
@ -20,12 +27,6 @@ from llama_stack.apis.inference import (
OpenAIEmbeddingsResponse,
OpenAIEmbeddingUsage,
)
from llama_stack.testing.api_recorder import (
APIRecordingMode,
ResponseStorage,
api_recording,
normalize_inference_request,
)
@pytest.fixture

View file

@ -22,7 +22,7 @@ from llama_stack.core.storage.datatypes import (
SqlStoreReference,
StorageConfig,
)
from llama_stack.providers.datatypes import ProviderSpec
from llama_stack_api import ProviderSpec
class SampleConfig(BaseModel):
@ -312,7 +312,7 @@ pip_packages:
"""Test loading an external provider from a module (success path)."""
from types import SimpleNamespace
from llama_stack.providers.datatypes import Api, ProviderSpec
from llama_stack_api import Api, ProviderSpec
# Simulate a provider module with get_provider_spec
fake_spec = ProviderSpec(
@ -396,7 +396,7 @@ pip_packages:
def test_external_provider_from_module_building(self, mock_providers):
"""Test loading an external provider from a module during build (building=True, partial spec)."""
from llama_stack.core.datatypes import BuildConfig, BuildProvider, DistributionSpec
from llama_stack.providers.datatypes import Api
from llama_stack_api import Api
# No importlib patch needed, should not import module when type of `config` is BuildConfig or DistributionSpec
build_config = BuildConfig(
@ -457,7 +457,7 @@ class TestGetExternalProvidersFromModule:
from types import SimpleNamespace
from llama_stack.core.distribution import get_external_providers_from_module
from llama_stack.providers.datatypes import ProviderSpec
from llama_stack_api import ProviderSpec
fake_spec = ProviderSpec(
api=Api.inference,
@ -594,7 +594,7 @@ class TestGetExternalProvidersFromModule:
from types import SimpleNamespace
from llama_stack.core.distribution import get_external_providers_from_module
from llama_stack.providers.datatypes import ProviderSpec
from llama_stack_api import ProviderSpec
spec1 = ProviderSpec(
api=Api.inference,
@ -642,7 +642,7 @@ class TestGetExternalProvidersFromModule:
from types import SimpleNamespace
from llama_stack.core.distribution import get_external_providers_from_module
from llama_stack.providers.datatypes import ProviderSpec
from llama_stack_api import ProviderSpec
spec1 = ProviderSpec(
api=Api.inference,
@ -690,7 +690,7 @@ class TestGetExternalProvidersFromModule:
from types import SimpleNamespace
from llama_stack.core.distribution import get_external_providers_from_module
from llama_stack.providers.datatypes import ProviderSpec
from llama_stack_api import ProviderSpec
# Module returns both inline and remote variants
spec1 = ProviderSpec(
@ -829,7 +829,7 @@ class TestGetExternalProvidersFromModule:
from types import SimpleNamespace
from llama_stack.core.distribution import get_external_providers_from_module
from llama_stack.providers.datatypes import ProviderSpec
from llama_stack_api import ProviderSpec
inference_spec = ProviderSpec(
api=Api.inference,

View file

@ -0,0 +1,130 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
"""Tests for the llama stack list command."""
import argparse
from unittest.mock import MagicMock, patch
import pytest
from llama_stack.cli.stack.list_stacks import StackListBuilds
@pytest.fixture
def list_stacks_command():
"""Create a StackListBuilds instance for testing."""
parser = argparse.ArgumentParser()
subparsers = parser.add_subparsers()
return StackListBuilds(subparsers)
@pytest.fixture
def mock_distribs_base_dir(tmp_path):
"""Create a mock DISTRIBS_BASE_DIR with some custom distributions."""
custom_dir = tmp_path / "distributions"
custom_dir.mkdir(parents=True, exist_ok=True)
# Create a custom distribution
starter_custom = custom_dir / "starter"
starter_custom.mkdir()
(starter_custom / "starter-build.yaml").write_text("# build config")
(starter_custom / "starter-run.yaml").write_text("# run config")
return custom_dir
@pytest.fixture
def mock_distro_dir(tmp_path):
"""Create a mock distributions directory with built-in distributions."""
distro_dir = tmp_path / "src" / "llama_stack" / "distributions"
distro_dir.mkdir(parents=True, exist_ok=True)
# Create some built-in distributions
for distro_name in ["starter", "nvidia", "dell"]:
distro_path = distro_dir / distro_name
distro_path.mkdir()
(distro_path / "build.yaml").write_text("# build config")
(distro_path / "run.yaml").write_text("# run config")
return distro_dir
def create_path_mock(builtin_dist_dir):
"""Create a properly mocked Path object that returns builtin_dist_dir for the distributions path."""
mock_parent_parent_parent = MagicMock()
mock_parent_parent_parent.__truediv__ = (
lambda self, other: builtin_dist_dir if other == "distributions" else MagicMock()
)
mock_path = MagicMock()
mock_path.parent.parent.parent = mock_parent_parent_parent
return mock_path
class TestStackList:
"""Test suite for llama stack list command."""
def test_builtin_distros_shown_without_running(self, list_stacks_command, mock_distro_dir, tmp_path):
"""Test that built-in distributions are shown even before running them."""
mock_path = create_path_mock(mock_distro_dir)
# Mock DISTRIBS_BASE_DIR to be a non-existent directory (no custom distributions)
with patch("llama_stack.cli.stack.list_stacks.DISTRIBS_BASE_DIR", tmp_path / "nonexistent"):
with patch("llama_stack.cli.stack.list_stacks.Path") as mock_path_class:
mock_path_class.return_value = mock_path
distributions = list_stacks_command._get_distribution_dirs()
# Verify built-in distributions are found
assert len(distributions) > 0, "Should find built-in distributions"
assert all(source_type == "built-in" for _, source_type in distributions.values()), (
"All should be built-in"
)
# Check specific distributions we created
assert "starter" in distributions
assert "nvidia" in distributions
assert "dell" in distributions
def test_custom_distribution_overrides_builtin(self, list_stacks_command, mock_distro_dir, mock_distribs_base_dir):
"""Test that custom distributions override built-in ones with the same name."""
mock_path = create_path_mock(mock_distro_dir)
with patch("llama_stack.cli.stack.list_stacks.DISTRIBS_BASE_DIR", mock_distribs_base_dir):
with patch("llama_stack.cli.stack.list_stacks.Path") as mock_path_class:
mock_path_class.return_value = mock_path
distributions = list_stacks_command._get_distribution_dirs()
# "starter" should exist and be marked as "custom" (not "built-in")
# because the custom version overrides the built-in one
assert "starter" in distributions
_, source_type = distributions["starter"]
assert source_type == "custom", "Custom distribution should override built-in"
def test_hidden_directories_ignored(self, list_stacks_command, mock_distro_dir, tmp_path):
"""Test that hidden directories (starting with .) are ignored."""
# Add a hidden directory
hidden_dir = mock_distro_dir / ".hidden"
hidden_dir.mkdir()
(hidden_dir / "build.yaml").write_text("# build")
# Add a __pycache__ directory
pycache_dir = mock_distro_dir / "__pycache__"
pycache_dir.mkdir()
mock_path = create_path_mock(mock_distro_dir)
with patch("llama_stack.cli.stack.list_stacks.DISTRIBS_BASE_DIR", tmp_path / "nonexistent"):
with patch("llama_stack.cli.stack.list_stacks.Path") as mock_path_class:
mock_path_class.return_value = mock_path
distributions = list_stacks_command._get_distribution_dirs()
assert ".hidden" not in distributions
assert "__pycache__" not in distributions

View file

@ -7,16 +7,14 @@
import pytest
from llama_stack.apis.common.errors import ResourceNotFoundError
from llama_stack.apis.common.responses import Order
from llama_stack.apis.files import OpenAIFilePurpose
from llama_stack.core.access_control.access_control import default_policy
from llama_stack.core.storage.datatypes import SqliteSqlStoreConfig, SqlStoreReference
from llama_stack.core.storage.sqlstore.sqlstore import register_sqlstore_backends
from llama_stack.providers.inline.files.localfs import (
LocalfsFilesImpl,
LocalfsFilesImplConfig,
)
from llama_stack.providers.utils.sqlstore.sqlstore import register_sqlstore_backends
from llama_stack_api import OpenAIFilePurpose, Order, ResourceNotFoundError
class MockUploadFile:

View file

@ -6,9 +6,9 @@
import pytest
from llama_stack.core.storage.kvstore.config import SqliteKVStoreConfig
from llama_stack.core.storage.kvstore.sqlite import SqliteKVStoreImpl
from llama_stack.core.store.registry import CachedDiskDistributionRegistry, DiskDistributionRegistry
from llama_stack.providers.utils.kvstore.config import SqliteKVStoreConfig
from llama_stack.providers.utils.kvstore.sqlite import SqliteKVStoreImpl
@pytest.fixture(scope="function")

View file

@ -1,303 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from llama_stack.apis.inference import (
ChatCompletionRequest,
CompletionMessage,
StopReason,
SystemMessage,
SystemMessageBehavior,
ToolCall,
ToolConfig,
UserMessage,
)
from llama_stack.models.llama.datatypes import (
BuiltinTool,
ToolDefinition,
ToolPromptFormat,
)
from llama_stack.providers.utils.inference.prompt_adapter import (
chat_completion_request_to_messages,
chat_completion_request_to_prompt,
interleaved_content_as_str,
)
MODEL = "Llama3.1-8B-Instruct"
MODEL3_2 = "Llama3.2-3B-Instruct"
async def test_system_default():
content = "Hello !"
request = ChatCompletionRequest(
model=MODEL,
messages=[
UserMessage(content=content),
],
)
messages = chat_completion_request_to_messages(request, MODEL)
assert len(messages) == 2
assert messages[-1].content == content
assert "Cutting Knowledge Date: December 2023" in interleaved_content_as_str(messages[0].content)
async def test_system_builtin_only():
content = "Hello !"
request = ChatCompletionRequest(
model=MODEL,
messages=[
UserMessage(content=content),
],
tools=[
ToolDefinition(tool_name=BuiltinTool.code_interpreter),
ToolDefinition(tool_name=BuiltinTool.brave_search),
],
)
messages = chat_completion_request_to_messages(request, MODEL)
assert len(messages) == 2
assert messages[-1].content == content
assert "Cutting Knowledge Date: December 2023" in interleaved_content_as_str(messages[0].content)
assert "Tools: brave_search" in interleaved_content_as_str(messages[0].content)
async def test_system_custom_only():
content = "Hello !"
request = ChatCompletionRequest(
model=MODEL,
messages=[
UserMessage(content=content),
],
tools=[
ToolDefinition(
tool_name="custom1",
description="custom1 tool",
input_schema={
"type": "object",
"properties": {
"param1": {
"type": "str",
"description": "param1 description",
},
},
"required": ["param1"],
},
)
],
tool_config=ToolConfig(tool_prompt_format=ToolPromptFormat.json),
)
messages = chat_completion_request_to_messages(request, MODEL)
assert len(messages) == 3
assert "Environment: ipython" in interleaved_content_as_str(messages[0].content)
assert "Return function calls in JSON format" in interleaved_content_as_str(messages[1].content)
assert messages[-1].content == content
async def test_system_custom_and_builtin():
content = "Hello !"
request = ChatCompletionRequest(
model=MODEL,
messages=[
UserMessage(content=content),
],
tools=[
ToolDefinition(tool_name=BuiltinTool.code_interpreter),
ToolDefinition(tool_name=BuiltinTool.brave_search),
ToolDefinition(
tool_name="custom1",
description="custom1 tool",
input_schema={
"type": "object",
"properties": {
"param1": {
"type": "str",
"description": "param1 description",
},
},
"required": ["param1"],
},
),
],
)
messages = chat_completion_request_to_messages(request, MODEL)
assert len(messages) == 3
assert "Environment: ipython" in interleaved_content_as_str(messages[0].content)
assert "Tools: brave_search" in interleaved_content_as_str(messages[0].content)
assert "Return function calls in JSON format" in interleaved_content_as_str(messages[1].content)
assert messages[-1].content == content
async def test_completion_message_encoding():
request = ChatCompletionRequest(
model=MODEL3_2,
messages=[
UserMessage(content="hello"),
CompletionMessage(
content="",
stop_reason=StopReason.end_of_turn,
tool_calls=[
ToolCall(
tool_name="custom1",
arguments='{"param1": "value1"}', # arguments must be a JSON string
call_id="123",
)
],
),
],
tools=[
ToolDefinition(
tool_name="custom1",
description="custom1 tool",
input_schema={
"type": "object",
"properties": {
"param1": {
"type": "str",
"description": "param1 description",
},
},
"required": ["param1"],
},
),
],
tool_config=ToolConfig(tool_prompt_format=ToolPromptFormat.python_list),
)
prompt = await chat_completion_request_to_prompt(request, request.model)
assert '[custom1(param1="value1")]' in prompt
request.model = MODEL
request.tool_config = ToolConfig(tool_prompt_format=ToolPromptFormat.json)
prompt = await chat_completion_request_to_prompt(request, request.model)
assert '{"type": "function", "name": "custom1", "parameters": {"param1": "value1"}}' in prompt
async def test_user_provided_system_message():
content = "Hello !"
system_prompt = "You are a pirate"
request = ChatCompletionRequest(
model=MODEL,
messages=[
SystemMessage(content=system_prompt),
UserMessage(content=content),
],
tools=[
ToolDefinition(tool_name=BuiltinTool.code_interpreter),
],
)
messages = chat_completion_request_to_messages(request, MODEL)
assert len(messages) == 2
assert interleaved_content_as_str(messages[0].content).endswith(system_prompt)
assert messages[-1].content == content
async def test_replace_system_message_behavior_builtin_tools():
content = "Hello !"
system_prompt = "You are a pirate"
request = ChatCompletionRequest(
model=MODEL,
messages=[
SystemMessage(content=system_prompt),
UserMessage(content=content),
],
tools=[
ToolDefinition(tool_name=BuiltinTool.code_interpreter),
],
tool_config=ToolConfig(
tool_choice="auto",
tool_prompt_format=ToolPromptFormat.python_list,
system_message_behavior=SystemMessageBehavior.replace,
),
)
messages = chat_completion_request_to_messages(request, MODEL3_2)
assert len(messages) == 2
assert interleaved_content_as_str(messages[0].content).endswith(system_prompt)
assert "Environment: ipython" in interleaved_content_as_str(messages[0].content)
assert messages[-1].content == content
async def test_replace_system_message_behavior_custom_tools():
content = "Hello !"
system_prompt = "You are a pirate"
request = ChatCompletionRequest(
model=MODEL,
messages=[
SystemMessage(content=system_prompt),
UserMessage(content=content),
],
tools=[
ToolDefinition(tool_name=BuiltinTool.code_interpreter),
ToolDefinition(
tool_name="custom1",
description="custom1 tool",
input_schema={
"type": "object",
"properties": {
"param1": {
"type": "str",
"description": "param1 description",
},
},
"required": ["param1"],
},
),
],
tool_config=ToolConfig(
tool_choice="auto",
tool_prompt_format=ToolPromptFormat.python_list,
system_message_behavior=SystemMessageBehavior.replace,
),
)
messages = chat_completion_request_to_messages(request, MODEL3_2)
assert len(messages) == 2
assert interleaved_content_as_str(messages[0].content).endswith(system_prompt)
assert "Environment: ipython" in interleaved_content_as_str(messages[0].content)
assert messages[-1].content == content
async def test_replace_system_message_behavior_custom_tools_with_template():
content = "Hello !"
system_prompt = "You are a pirate {{ function_description }}"
request = ChatCompletionRequest(
model=MODEL,
messages=[
SystemMessage(content=system_prompt),
UserMessage(content=content),
],
tools=[
ToolDefinition(tool_name=BuiltinTool.code_interpreter),
ToolDefinition(
tool_name="custom1",
description="custom1 tool",
input_schema={
"type": "object",
"properties": {
"param1": {
"type": "str",
"description": "param1 description",
},
},
"required": ["param1"],
},
),
],
tool_config=ToolConfig(
tool_choice="auto",
tool_prompt_format=ToolPromptFormat.python_list,
system_message_behavior=SystemMessageBehavior.replace,
),
)
messages = chat_completion_request_to_messages(request, MODEL3_2)
assert len(messages) == 2
assert "Environment: ipython" in interleaved_content_as_str(messages[0].content)
assert "You are a pirate" in interleaved_content_as_str(messages[0].content)
# function description is present in the system prompt
assert '"name": "custom1"' in interleaved_content_as_str(messages[0].content)
assert messages[-1].content == content

View file

@ -18,7 +18,7 @@ from llama_stack.core.storage.datatypes import (
SqlStoreReference,
StorageConfig,
)
from llama_stack.providers.utils.kvstore import register_kvstore_backends
from llama_stack.core.storage.kvstore import register_kvstore_backends
@pytest.fixture

View file

@ -1,347 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import json
from datetime import UTC, datetime
from unittest.mock import AsyncMock, patch
import pytest
from llama_stack.apis.agents import Session
from llama_stack.core.datatypes import User
from llama_stack.providers.inline.agents.meta_reference.persistence import (
AgentPersistence,
AgentSessionInfo,
)
from llama_stack.providers.utils.kvstore import KVStore
@pytest.fixture
def mock_kvstore():
return AsyncMock(spec=KVStore)
@pytest.fixture
def mock_policy():
return []
@pytest.fixture
def agent_persistence(mock_kvstore, mock_policy):
return AgentPersistence(agent_id="test-agent-123", kvstore=mock_kvstore, policy=mock_policy)
@pytest.fixture
def sample_session():
return AgentSessionInfo(
session_id="session-123",
session_name="Test Session",
started_at=datetime.now(UTC),
owner=User(principal="user-123", attributes=None),
turns=[],
identifier="test-session",
type="session",
)
@pytest.fixture
def sample_session_json(sample_session):
return sample_session.model_dump_json()
class TestAgentPersistenceListSessions:
def setup_mock_kvstore(self, mock_kvstore, session_keys=None, turn_keys=None, invalid_keys=None, custom_data=None):
"""Helper to setup mock kvstore with sessions, turns, and custom/invalid data
Args:
mock_kvstore: The mock KVStore object
session_keys: List of session keys or dict mapping keys to custom session data
turn_keys: List of turn keys or dict mapping keys to custom turn data
invalid_keys: Dict mapping keys to invalid/corrupt data
custom_data: Additional custom data to add to the mock responses
"""
all_keys = []
mock_data = {}
# session keys
if session_keys:
if isinstance(session_keys, dict):
all_keys.extend(session_keys.keys())
mock_data.update({k: json.dumps(v) if isinstance(v, dict) else v for k, v in session_keys.items()})
else:
all_keys.extend(session_keys)
for key in session_keys:
session_id = key.split(":")[-1]
mock_data[key] = json.dumps(
{
"session_id": session_id,
"session_name": f"Session {session_id}",
"started_at": datetime.now(UTC).isoformat(),
"turns": [],
}
)
# turn keys
if turn_keys:
if isinstance(turn_keys, dict):
all_keys.extend(turn_keys.keys())
mock_data.update({k: json.dumps(v) if isinstance(v, dict) else v for k, v in turn_keys.items()})
else:
all_keys.extend(turn_keys)
for key in turn_keys:
parts = key.split(":")
session_id = parts[-2]
turn_id = parts[-1]
mock_data[key] = json.dumps(
{
"turn_id": turn_id,
"session_id": session_id,
"input_messages": [],
"started_at": datetime.now(UTC).isoformat(),
}
)
if invalid_keys:
all_keys.extend(invalid_keys.keys())
mock_data.update(invalid_keys)
if custom_data:
mock_data.update(custom_data)
values_list = list(mock_data.values())
mock_kvstore.values_in_range.return_value = values_list
async def mock_get(key):
return mock_data.get(key)
mock_kvstore.get.side_effect = mock_get
return mock_data
@pytest.mark.parametrize(
"scenario",
[
{
# from this issue: https://github.com/meta-llama/llama-stack/issues/3048
"name": "reported_bug",
"session_keys": ["session:test-agent-123:1f08fd1c-5a9d-459d-a00b-36d4dfa49b7d"],
"turn_keys": [
"session:test-agent-123:1f08fd1c-5a9d-459d-a00b-36d4dfa49b7d:eb7e818f-41fb-49a0-bdd6-464974a2d2ad"
],
"expected_sessions": ["1f08fd1c-5a9d-459d-a00b-36d4dfa49b7d"],
},
{
"name": "basic_filtering",
"session_keys": ["session:test-agent-123:session-1", "session:test-agent-123:session-2"],
"turn_keys": ["session:test-agent-123:session-1:turn-1", "session:test-agent-123:session-1:turn-2"],
"expected_sessions": ["session-1", "session-2"],
},
{
"name": "multiple_turns_per_session",
"session_keys": ["session:test-agent-123:session-456"],
"turn_keys": [
"session:test-agent-123:session-456:turn-789",
"session:test-agent-123:session-456:turn-790",
],
"expected_sessions": ["session-456"],
},
{
"name": "multiple_sessions_with_turns",
"session_keys": ["session:test-agent-123:session-1", "session:test-agent-123:session-2"],
"turn_keys": [
"session:test-agent-123:session-1:turn-1",
"session:test-agent-123:session-1:turn-2",
"session:test-agent-123:session-2:turn-3",
],
"expected_sessions": ["session-1", "session-2"],
},
],
)
async def test_list_sessions_key_filtering(self, agent_persistence, mock_kvstore, scenario):
self.setup_mock_kvstore(mock_kvstore, session_keys=scenario["session_keys"], turn_keys=scenario["turn_keys"])
with patch("llama_stack.providers.inline.agents.meta_reference.persistence.log") as mock_log:
result = await agent_persistence.list_sessions()
assert len(result) == len(scenario["expected_sessions"])
session_ids = {s.session_id for s in result}
for expected_id in scenario["expected_sessions"]:
assert expected_id in session_ids
# no errors should be logged
mock_log.error.assert_not_called()
@pytest.mark.parametrize(
"error_scenario",
[
{
"name": "invalid_json",
"valid_keys": ["session:test-agent-123:valid-session"],
"invalid_data": {"session:test-agent-123:invalid-json": "corrupted-json-data{"},
"expected_valid_sessions": ["valid-session"],
"expected_error_count": 1,
},
{
"name": "missing_fields",
"valid_keys": ["session:test-agent-123:valid-session"],
"invalid_data": {
"session:test-agent-123:invalid-schema": json.dumps(
{
"session_id": "invalid-schema",
"session_name": "Missing Fields",
# missing `started_at` and `turns`
}
)
},
"expected_valid_sessions": ["valid-session"],
"expected_error_count": 1,
},
{
"name": "multiple_invalid",
"valid_keys": ["session:test-agent-123:valid-session-1", "session:test-agent-123:valid-session-2"],
"invalid_data": {
"session:test-agent-123:corrupted-json": "not-valid-json{",
"session:test-agent-123:incomplete-data": json.dumps({"incomplete": "data"}),
},
"expected_valid_sessions": ["valid-session-1", "valid-session-2"],
"expected_error_count": 2,
},
],
)
async def test_list_sessions_error_handling(self, agent_persistence, mock_kvstore, error_scenario):
session_keys = {}
for key in error_scenario["valid_keys"]:
session_id = key.split(":")[-1]
session_keys[key] = {
"session_id": session_id,
"session_name": f"Valid {session_id}",
"started_at": datetime.now(UTC).isoformat(),
"turns": [],
}
self.setup_mock_kvstore(mock_kvstore, session_keys=session_keys, invalid_keys=error_scenario["invalid_data"])
with patch("llama_stack.providers.inline.agents.meta_reference.persistence.log") as mock_log:
result = await agent_persistence.list_sessions()
# only valid sessions should be returned
assert len(result) == len(error_scenario["expected_valid_sessions"])
session_ids = {s.session_id for s in result}
for expected_id in error_scenario["expected_valid_sessions"]:
assert expected_id in session_ids
# error should be logged
assert mock_log.error.call_count > 0
assert mock_log.error.call_count == error_scenario["expected_error_count"]
async def test_list_sessions_empty(self, agent_persistence, mock_kvstore):
mock_kvstore.values_in_range.return_value = []
result = await agent_persistence.list_sessions()
assert result == []
mock_kvstore.values_in_range.assert_called_once_with(
start_key="session:test-agent-123:", end_key="session:test-agent-123:\xff\xff\xff\xff"
)
async def test_list_sessions_properties(self, agent_persistence, mock_kvstore):
session_data = {
"session_id": "session-123",
"session_name": "Test Session",
"started_at": datetime.now(UTC).isoformat(),
"owner": {"principal": "user-123", "attributes": None},
"turns": [],
}
self.setup_mock_kvstore(mock_kvstore, session_keys={"session:test-agent-123:session-123": session_data})
result = await agent_persistence.list_sessions()
assert len(result) == 1
assert isinstance(result[0], Session)
assert result[0].session_id == "session-123"
assert result[0].session_name == "Test Session"
assert result[0].turns == []
assert hasattr(result[0], "started_at")
async def test_list_sessions_kvstore_exception(self, agent_persistence, mock_kvstore):
mock_kvstore.values_in_range.side_effect = Exception("KVStore error")
with pytest.raises(Exception, match="KVStore error"):
await agent_persistence.list_sessions()
async def test_bug_data_loss_with_real_data(self, agent_persistence, mock_kvstore):
# tests the handling of the issue reported in: https://github.com/meta-llama/llama-stack/issues/3048
session_data = {
"session_id": "1f08fd1c-5a9d-459d-a00b-36d4dfa49b7d",
"session_name": "Test Session",
"started_at": datetime.now(UTC).isoformat(),
"turns": [],
}
turn_data = {
"turn_id": "eb7e818f-41fb-49a0-bdd6-464974a2d2ad",
"session_id": "1f08fd1c-5a9d-459d-a00b-36d4dfa49b7d",
"input_messages": [
{"role": "user", "content": "if i had a cluster i would want to call it persistence01", "context": None}
],
"steps": [
{
"turn_id": "eb7e818f-41fb-49a0-bdd6-464974a2d2ad",
"step_id": "c0f797dd-3d34-4bc5-a8f4-db6af9455132",
"started_at": "2025-08-05T14:31:50.000484Z",
"completed_at": "2025-08-05T14:31:51.303691Z",
"step_type": "inference",
"model_response": {
"role": "assistant",
"content": "OK, I can create a cluster named 'persistence01' for you.",
"stop_reason": "end_of_turn",
"tool_calls": [],
},
}
],
"output_message": {
"role": "assistant",
"content": "OK, I can create a cluster named 'persistence01' for you.",
"stop_reason": "end_of_turn",
"tool_calls": [],
},
"output_attachments": [],
"started_at": "2025-08-05T14:31:49.999950Z",
"completed_at": "2025-08-05T14:31:51.305384Z",
}
mock_data = {
"session:test-agent-123:1f08fd1c-5a9d-459d-a00b-36d4dfa49b7d": json.dumps(session_data),
"session:test-agent-123:1f08fd1c-5a9d-459d-a00b-36d4dfa49b7d:eb7e818f-41fb-49a0-bdd6-464974a2d2ad": json.dumps(
turn_data
),
}
mock_kvstore.values_in_range.return_value = list(mock_data.values())
async def mock_get(key):
return mock_data.get(key)
mock_kvstore.get.side_effect = mock_get
with patch("llama_stack.providers.inline.agents.meta_reference.persistence.log") as mock_log:
result = await agent_persistence.list_sessions()
assert len(result) == 1
assert result[0].session_id == "1f08fd1c-5a9d-459d-a00b-36d4dfa49b7d"
# confirm no errors logged
mock_log.error.assert_not_called()
async def test_list_sessions_key_range_construction(self, agent_persistence, mock_kvstore):
mock_kvstore.values_in_range.return_value = []
await agent_persistence.list_sessions()
mock_kvstore.values_in_range.assert_called_once_with(
start_key="session:test-agent-123:", end_key="session:test-agent-123:\xff\xff\xff\xff"
)

View file

@ -1,196 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import warnings
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from llama_stack.apis.agents import Document
from llama_stack.apis.common.content_types import URL, TextContentItem
from llama_stack.providers.inline.agents.meta_reference.agent_instance import get_raw_document_text
async def test_get_raw_document_text_supports_text_mime_types():
"""Test that the function accepts text/* mime types."""
document = Document(content="Sample text content", mime_type="text/plain")
result = await get_raw_document_text(document)
assert result == "Sample text content"
async def test_get_raw_document_text_supports_yaml_mime_type():
"""Test that the function accepts application/yaml mime type."""
yaml_content = """
name: test
version: 1.0
items:
- item1
- item2
"""
document = Document(content=yaml_content, mime_type="application/yaml")
result = await get_raw_document_text(document)
assert result == yaml_content
async def test_get_raw_document_text_supports_deprecated_text_yaml_with_warning():
"""Test that the function accepts text/yaml but emits a deprecation warning."""
yaml_content = """
name: test
version: 1.0
items:
- item1
- item2
"""
document = Document(content=yaml_content, mime_type="text/yaml")
with warnings.catch_warnings(record=True) as w:
warnings.simplefilter("always")
result = await get_raw_document_text(document)
# Check that result is correct
assert result == yaml_content
# Check that exactly one warning was issued
assert len(w) == 1
assert issubclass(w[0].category, DeprecationWarning)
assert "text/yaml" in str(w[0].message)
assert "application/yaml" in str(w[0].message)
assert "deprecated" in str(w[0].message).lower()
async def test_get_raw_document_text_deprecated_text_yaml_with_url():
"""Test that text/yaml works with URL content and emits warning."""
yaml_content = "name: test\nversion: 1.0"
with patch("llama_stack.providers.inline.agents.meta_reference.agent_instance.load_data_from_url") as mock_load:
mock_load.return_value = yaml_content
document = Document(content=URL(uri="https://example.com/config.yaml"), mime_type="text/yaml")
with warnings.catch_warnings(record=True) as w:
warnings.simplefilter("always")
result = await get_raw_document_text(document)
# Check that result is correct
assert result == yaml_content
mock_load.assert_called_once_with("https://example.com/config.yaml")
# Check that deprecation warning was issued
assert len(w) == 1
assert issubclass(w[0].category, DeprecationWarning)
assert "text/yaml" in str(w[0].message)
async def test_get_raw_document_text_deprecated_text_yaml_with_text_content_item():
"""Test that text/yaml works with TextContentItem and emits warning."""
yaml_content = "key: value\nlist:\n - item1\n - item2"
document = Document(content=TextContentItem(text=yaml_content), mime_type="text/yaml")
with warnings.catch_warnings(record=True) as w:
warnings.simplefilter("always")
result = await get_raw_document_text(document)
# Check that result is correct
assert result == yaml_content
# Check that deprecation warning was issued
assert len(w) == 1
assert issubclass(w[0].category, DeprecationWarning)
assert "text/yaml" in str(w[0].message)
async def test_get_raw_document_text_supports_json_mime_type():
"""Test that the function accepts application/json mime type."""
json_content = '{"name": "test", "version": "1.0", "items": ["item1", "item2"]}'
document = Document(content=json_content, mime_type="application/json")
result = await get_raw_document_text(document)
assert result == json_content
async def test_get_raw_document_text_with_json_text_content_item():
"""Test that the function handles JSON TextContentItem correctly."""
json_content = '{"key": "value", "nested": {"array": [1, 2, 3]}}'
document = Document(content=TextContentItem(text=json_content), mime_type="application/json")
result = await get_raw_document_text(document)
assert result == json_content
async def test_get_raw_document_text_rejects_unsupported_mime_types():
"""Test that the function rejects unsupported mime types."""
document = Document(
content="Some content",
mime_type="application/pdf", # Not supported
)
with pytest.raises(ValueError, match="Unexpected document mime type: application/pdf"):
await get_raw_document_text(document)
async def test_get_raw_document_text_with_url_content():
"""Test that the function handles URL content correctly."""
mock_response = AsyncMock()
mock_response.text = "Content from URL"
with patch("llama_stack.providers.inline.agents.meta_reference.agent_instance.load_data_from_url") as mock_load:
mock_load.return_value = "Content from URL"
document = Document(content=URL(uri="https://example.com/test.txt"), mime_type="text/plain")
result = await get_raw_document_text(document)
assert result == "Content from URL"
mock_load.assert_called_once_with("https://example.com/test.txt")
async def test_get_raw_document_text_with_yaml_url():
"""Test that the function handles YAML URLs correctly."""
yaml_content = "name: test\nversion: 1.0"
with patch("llama_stack.providers.inline.agents.meta_reference.agent_instance.load_data_from_url") as mock_load:
mock_load.return_value = yaml_content
document = Document(content=URL(uri="https://example.com/config.yaml"), mime_type="application/yaml")
result = await get_raw_document_text(document)
assert result == yaml_content
mock_load.assert_called_once_with("https://example.com/config.yaml")
async def test_get_raw_document_text_with_text_content_item():
"""Test that the function handles TextContentItem correctly."""
document = Document(content=TextContentItem(text="Text content item"), mime_type="text/plain")
result = await get_raw_document_text(document)
assert result == "Text content item"
async def test_get_raw_document_text_with_yaml_text_content_item():
"""Test that the function handles YAML TextContentItem correctly."""
yaml_content = "key: value\nlist:\n - item1\n - item2"
document = Document(content=TextContentItem(text=yaml_content), mime_type="application/yaml")
result = await get_raw_document_text(document)
assert result == yaml_content
async def test_get_raw_document_text_rejects_unexpected_content_type():
"""Test that the function rejects unexpected document content types."""
# Create a mock document that bypasses Pydantic validation
mock_document = MagicMock(spec=Document)
mock_document.mime_type = "text/plain"
mock_document.content = 123 # Unexpected content type (not str, URL, or TextContentItem)
with pytest.raises(ValueError, match="Unexpected document content type: <class 'int'>"):
await get_raw_document_text(mock_document)

View file

@ -1,325 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from datetime import datetime
from unittest.mock import AsyncMock
import pytest
from llama_stack.apis.agents import (
Agent,
AgentConfig,
AgentCreateResponse,
)
from llama_stack.apis.common.responses import PaginatedResponse
from llama_stack.apis.conversations import Conversations
from llama_stack.apis.inference import Inference
from llama_stack.apis.safety import Safety
from llama_stack.apis.tools import ListToolDefsResponse, ToolDef, ToolGroups, ToolRuntime
from llama_stack.apis.vector_io import VectorIO
from llama_stack.providers.inline.agents.meta_reference.agent_instance import ChatAgent
from llama_stack.providers.inline.agents.meta_reference.agents import MetaReferenceAgentsImpl
from llama_stack.providers.inline.agents.meta_reference.config import MetaReferenceAgentsImplConfig
from llama_stack.providers.inline.agents.meta_reference.persistence import AgentInfo
@pytest.fixture(autouse=True)
def setup_backends(tmp_path):
"""Register KV and SQL store backends for testing."""
from llama_stack.core.storage.datatypes import SqliteKVStoreConfig, SqliteSqlStoreConfig
from llama_stack.providers.utils.kvstore.kvstore import register_kvstore_backends
from llama_stack.providers.utils.sqlstore.sqlstore import register_sqlstore_backends
kv_path = str(tmp_path / "test_kv.db")
sql_path = str(tmp_path / "test_sql.db")
register_kvstore_backends({"kv_default": SqliteKVStoreConfig(db_path=kv_path)})
register_sqlstore_backends({"sql_default": SqliteSqlStoreConfig(db_path=sql_path)})
@pytest.fixture
def mock_apis():
return {
"inference_api": AsyncMock(spec=Inference),
"vector_io_api": AsyncMock(spec=VectorIO),
"safety_api": AsyncMock(spec=Safety),
"tool_runtime_api": AsyncMock(spec=ToolRuntime),
"tool_groups_api": AsyncMock(spec=ToolGroups),
"conversations_api": AsyncMock(spec=Conversations),
}
@pytest.fixture
def config(tmp_path):
from llama_stack.core.storage.datatypes import KVStoreReference, ResponsesStoreReference
from llama_stack.providers.inline.agents.meta_reference.config import AgentPersistenceConfig
return MetaReferenceAgentsImplConfig(
persistence=AgentPersistenceConfig(
agent_state=KVStoreReference(
backend="kv_default",
namespace="agents",
),
responses=ResponsesStoreReference(
backend="sql_default",
table_name="responses",
),
)
)
@pytest.fixture
async def agents_impl(config, mock_apis):
impl = MetaReferenceAgentsImpl(
config,
mock_apis["inference_api"],
mock_apis["vector_io_api"],
mock_apis["safety_api"],
mock_apis["tool_runtime_api"],
mock_apis["tool_groups_api"],
mock_apis["conversations_api"],
[],
)
await impl.initialize()
yield impl
await impl.shutdown()
@pytest.fixture
def sample_agent_config():
return AgentConfig(
sampling_params={
"strategy": {"type": "greedy"},
"max_tokens": 0,
"repetition_penalty": 1.0,
},
input_shields=["string"],
output_shields=["string"],
toolgroups=["mcp::my_mcp_server"],
client_tools=[
{
"name": "client_tool",
"description": "Client Tool",
"parameters": [
{
"name": "string",
"parameter_type": "string",
"description": "string",
"required": True,
"default": None,
}
],
"metadata": {
"property1": None,
"property2": None,
},
}
],
tool_choice="auto",
tool_prompt_format="json",
tool_config={
"tool_choice": "auto",
"tool_prompt_format": "json",
"system_message_behavior": "append",
},
max_infer_iters=10,
model="string",
instructions="string",
enable_session_persistence=False,
response_format={
"type": "json_schema",
"json_schema": {
"property1": None,
"property2": None,
},
},
)
async def test_create_agent(agents_impl, sample_agent_config):
response = await agents_impl.create_agent(sample_agent_config)
assert isinstance(response, AgentCreateResponse)
assert response.agent_id is not None
stored_agent = await agents_impl.persistence_store.get(f"agent:{response.agent_id}")
assert stored_agent is not None
agent_info = AgentInfo.model_validate_json(stored_agent)
assert agent_info.model == sample_agent_config.model
assert agent_info.created_at is not None
assert isinstance(agent_info.created_at, datetime)
async def test_get_agent(agents_impl, sample_agent_config):
create_response = await agents_impl.create_agent(sample_agent_config)
agent_id = create_response.agent_id
agent = await agents_impl.get_agent(agent_id)
assert isinstance(agent, Agent)
assert agent.agent_id == agent_id
assert agent.agent_config.model == sample_agent_config.model
assert agent.created_at is not None
assert isinstance(agent.created_at, datetime)
async def test_list_agents(agents_impl, sample_agent_config):
agent1_response = await agents_impl.create_agent(sample_agent_config)
agent2_response = await agents_impl.create_agent(sample_agent_config)
response = await agents_impl.list_agents()
assert isinstance(response, PaginatedResponse)
assert len(response.data) == 2
agent_ids = {agent["agent_id"] for agent in response.data}
assert agent1_response.agent_id in agent_ids
assert agent2_response.agent_id in agent_ids
@pytest.mark.parametrize("enable_session_persistence", [True, False])
async def test_create_agent_session_persistence(agents_impl, sample_agent_config, enable_session_persistence):
# Create an agent with specified persistence setting
config = sample_agent_config.model_copy()
config.enable_session_persistence = enable_session_persistence
response = await agents_impl.create_agent(config)
agent_id = response.agent_id
# Create a session
session_response = await agents_impl.create_agent_session(agent_id, "test_session")
assert session_response.session_id is not None
# Verify the session was stored
session = await agents_impl.get_agents_session(session_response.session_id, agent_id)
assert session.session_name == "test_session"
assert session.session_id == session_response.session_id
assert session.started_at is not None
assert session.turns == []
# Delete the session
await agents_impl.delete_agents_session(session_response.session_id, agent_id)
# Verify the session was deleted
with pytest.raises(ValueError):
await agents_impl.get_agents_session(session_response.session_id, agent_id)
@pytest.mark.parametrize("enable_session_persistence", [True, False])
async def test_list_agent_sessions_persistence(agents_impl, sample_agent_config, enable_session_persistence):
# Create an agent with specified persistence setting
config = sample_agent_config.model_copy()
config.enable_session_persistence = enable_session_persistence
response = await agents_impl.create_agent(config)
agent_id = response.agent_id
# Create multiple sessions
session1 = await agents_impl.create_agent_session(agent_id, "session1")
session2 = await agents_impl.create_agent_session(agent_id, "session2")
# List sessions
sessions = await agents_impl.list_agent_sessions(agent_id)
assert len(sessions.data) == 2
session_ids = {s["session_id"] for s in sessions.data}
assert session1.session_id in session_ids
assert session2.session_id in session_ids
# Delete one session
await agents_impl.delete_agents_session(session1.session_id, agent_id)
# Verify the session was deleted
with pytest.raises(ValueError):
await agents_impl.get_agents_session(session1.session_id, agent_id)
# List sessions again
sessions = await agents_impl.list_agent_sessions(agent_id)
assert len(sessions.data) == 1
assert session2.session_id in {s["session_id"] for s in sessions.data}
async def test_delete_agent(agents_impl, sample_agent_config):
# Create an agent
response = await agents_impl.create_agent(sample_agent_config)
agent_id = response.agent_id
# Delete the agent
await agents_impl.delete_agent(agent_id)
# Verify the agent was deleted
with pytest.raises(ValueError):
await agents_impl.get_agent(agent_id)
async def test__initialize_tools(agents_impl, sample_agent_config):
# Mock tool_groups_api.list_tools()
agents_impl.tool_groups_api.list_tools.return_value = ListToolDefsResponse(
data=[
ToolDef(
name="story_maker",
toolgroup_id="mcp::my_mcp_server",
description="Make a story",
input_schema={
"type": "object",
"properties": {
"story_title": {"type": "string", "description": "Title of the story", "title": "Story Title"},
"input_words": {
"type": "array",
"description": "Input words",
"items": {"type": "string"},
"title": "Input Words",
"default": [],
},
},
"required": ["story_title"],
},
)
]
)
create_response = await agents_impl.create_agent(sample_agent_config)
agent_id = create_response.agent_id
# Get an instance of ChatAgent
chat_agent = await agents_impl._get_agent_impl(agent_id)
assert chat_agent is not None
assert isinstance(chat_agent, ChatAgent)
# Initialize tool definitions
await chat_agent._initialize_tools()
assert len(chat_agent.tool_defs) == 2
# Verify the first tool, which is a client tool
first_tool = chat_agent.tool_defs[0]
assert first_tool.tool_name == "client_tool"
assert first_tool.description == "Client Tool"
# Verify the second tool, which is an MCP tool that has an array-type property
second_tool = chat_agent.tool_defs[1]
assert second_tool.tool_name == "story_maker"
assert second_tool.description == "Make a story"
# Verify the input schema
input_schema = second_tool.input_schema
assert input_schema is not None
assert input_schema["type"] == "object"
properties = input_schema["properties"]
assert len(properties) == 2
# Verify a string property
story_title = properties["story_title"]
assert story_title["type"] == "string"
assert story_title["description"] == "Title of the story"
assert story_title["title"] == "Story Title"
# Verify an array property
input_words = properties["input_words"]
assert input_words["type"] == "array"
assert input_words["description"] == "Input words"
assert input_words["items"]["type"] == "string"
assert input_words["title"] == "Input Words"
assert input_words["default"] == []
# Verify required fields
assert input_schema["required"] == ["story_title"]

View file

@ -8,7 +8,7 @@ import os
import yaml
from llama_stack.apis.inference import (
from llama_stack_api.inference import (
OpenAIChatCompletion,
)

View file

@ -15,23 +15,25 @@ from openai.types.chat.chat_completion_chunk import (
ChoiceDeltaToolCallFunction,
)
from llama_stack.apis.agents import Order
from llama_stack.apis.agents.openai_responses import (
ListOpenAIResponseInputItem,
OpenAIResponseInputMessageContentText,
OpenAIResponseInputToolFunction,
OpenAIResponseInputToolMCP,
OpenAIResponseInputToolWebSearch,
OpenAIResponseMessage,
OpenAIResponseOutputMessageContentOutputText,
OpenAIResponseOutputMessageFunctionToolCall,
OpenAIResponseOutputMessageMCPCall,
OpenAIResponseOutputMessageWebSearchToolCall,
OpenAIResponseText,
OpenAIResponseTextFormat,
WebSearchToolTypes,
from llama_stack.core.access_control.access_control import default_policy
from llama_stack.core.storage.datatypes import ResponsesStoreReference, SqliteSqlStoreConfig
from llama_stack.core.storage.sqlstore.sqlstore import register_sqlstore_backends
from llama_stack.providers.inline.agents.meta_reference.responses.openai_responses import (
OpenAIResponsesImpl,
)
from llama_stack.apis.inference import (
from llama_stack.providers.utils.responses.responses_store import (
ResponsesStore,
_OpenAIResponseObjectWithInputAndMessages,
)
from llama_stack_api import (
OpenAIChatCompletionContentPartImageParam,
OpenAIFile,
OpenAIFileObject,
OpenAISystemMessageParam,
Prompt,
)
from llama_stack_api.agents import Order
from llama_stack_api.inference import (
OpenAIAssistantMessageParam,
OpenAIChatCompletionContentPartTextParam,
OpenAIChatCompletionRequestWithExtraBody,
@ -41,17 +43,25 @@ from llama_stack.apis.inference import (
OpenAIResponseFormatJSONSchema,
OpenAIUserMessageParam,
)
from llama_stack.apis.tools.tools import ListToolDefsResponse, ToolDef, ToolGroups, ToolInvocationResult, ToolRuntime
from llama_stack.core.access_control.access_control import default_policy
from llama_stack.core.storage.datatypes import ResponsesStoreReference, SqliteSqlStoreConfig
from llama_stack.providers.inline.agents.meta_reference.responses.openai_responses import (
OpenAIResponsesImpl,
from llama_stack_api.openai_responses import (
ListOpenAIResponseInputItem,
OpenAIResponseInputMessageContentFile,
OpenAIResponseInputMessageContentImage,
OpenAIResponseInputMessageContentText,
OpenAIResponseInputToolFunction,
OpenAIResponseInputToolMCP,
OpenAIResponseInputToolWebSearch,
OpenAIResponseMessage,
OpenAIResponseOutputMessageContentOutputText,
OpenAIResponseOutputMessageFunctionToolCall,
OpenAIResponseOutputMessageMCPCall,
OpenAIResponseOutputMessageWebSearchToolCall,
OpenAIResponsePrompt,
OpenAIResponseText,
OpenAIResponseTextFormat,
WebSearchToolTypes,
)
from llama_stack.providers.utils.responses.responses_store import (
ResponsesStore,
_OpenAIResponseObjectWithInputAndMessages,
)
from llama_stack.providers.utils.sqlstore.sqlstore import register_sqlstore_backends
from llama_stack_api.tools import ListToolDefsResponse, ToolDef, ToolGroups, ToolInvocationResult, ToolRuntime
from tests.unit.providers.agents.meta_reference.fixtures import load_chat_completion_fixture
@ -98,6 +108,19 @@ def mock_safety_api():
return safety_api
@pytest.fixture
def mock_prompts_api():
prompts_api = AsyncMock()
return prompts_api
@pytest.fixture
def mock_files_api():
"""Mock files API for testing."""
files_api = AsyncMock()
return files_api
@pytest.fixture
def openai_responses_impl(
mock_inference_api,
@ -107,6 +130,8 @@ def openai_responses_impl(
mock_vector_io_api,
mock_safety_api,
mock_conversations_api,
mock_prompts_api,
mock_files_api,
):
return OpenAIResponsesImpl(
inference_api=mock_inference_api,
@ -116,6 +141,8 @@ def openai_responses_impl(
vector_io_api=mock_vector_io_api,
safety_api=mock_safety_api,
conversations_api=mock_conversations_api,
prompts_api=mock_prompts_api,
files_api=mock_files_api,
)
@ -499,7 +526,7 @@ async def test_create_openai_response_with_tool_call_function_arguments_none(ope
mock_inference_api.openai_chat_completion.return_value = fake_stream_toolcall()
async def test_create_openai_response_with_multiple_messages(openai_responses_impl, mock_inference_api):
async def test_create_openai_response_with_multiple_messages(openai_responses_impl, mock_inference_api, mock_files_api):
"""Test creating an OpenAI response with multiple messages."""
# Setup
input_messages = [
@ -710,7 +737,7 @@ async def test_create_openai_response_with_instructions(openai_responses_impl, m
async def test_create_openai_response_with_instructions_and_multiple_messages(
openai_responses_impl, mock_inference_api
openai_responses_impl, mock_inference_api, mock_files_api
):
# Setup
input_messages = [
@ -1242,3 +1269,489 @@ async def test_create_openai_response_with_output_types_as_input(
assert stored_with_outputs.input == input_with_output_types
assert len(stored_with_outputs.input) == 3
async def test_create_openai_response_with_prompt(openai_responses_impl, mock_inference_api, mock_prompts_api):
"""Test creating an OpenAI response with a prompt."""
input_text = "What is the capital of Ireland?"
model = "meta-llama/Llama-3.1-8B-Instruct"
prompt_id = "pmpt_1234567890abcdef1234567890abcdef1234567890abcdef"
prompt = Prompt(
prompt="You are a helpful {{ area_name }} assistant at {{ company_name }}. Always provide accurate information.",
prompt_id=prompt_id,
version=1,
variables=["area_name", "company_name"],
is_default=True,
)
openai_response_prompt = OpenAIResponsePrompt(
id=prompt_id,
version="1",
variables={
"area_name": OpenAIResponseInputMessageContentText(text="geography"),
"company_name": OpenAIResponseInputMessageContentText(text="Dummy Company"),
},
)
mock_prompts_api.get_prompt.return_value = prompt
mock_inference_api.openai_chat_completion.return_value = fake_stream()
result = await openai_responses_impl.create_openai_response(
input=input_text,
model=model,
prompt=openai_response_prompt,
)
mock_prompts_api.get_prompt.assert_called_with(prompt_id, 1)
mock_inference_api.openai_chat_completion.assert_called()
call_args = mock_inference_api.openai_chat_completion.call_args
sent_messages = call_args.args[0].messages
assert len(sent_messages) == 2
system_messages = [msg for msg in sent_messages if msg.role == "system"]
assert len(system_messages) == 1
assert (
system_messages[0].content
== "You are a helpful geography assistant at Dummy Company. Always provide accurate information."
)
user_messages = [msg for msg in sent_messages if msg.role == "user"]
assert len(user_messages) == 1
assert user_messages[0].content == input_text
assert result.model == model
assert result.status == "completed"
assert isinstance(result.prompt, OpenAIResponsePrompt)
assert result.prompt.id == prompt_id
assert result.prompt.variables == openai_response_prompt.variables
assert result.prompt.version == "1"
async def test_prepend_prompt_successful_without_variables(openai_responses_impl, mock_prompts_api, mock_inference_api):
"""Test prepend_prompt function without variables."""
input_text = "What is the capital of Ireland?"
model = "meta-llama/Llama-3.1-8B-Instruct"
prompt_id = "pmpt_1234567890abcdef1234567890abcdef1234567890abcdef"
prompt = Prompt(
prompt="You are a helpful assistant. Always provide accurate information.",
prompt_id=prompt_id,
version=1,
variables=[],
is_default=True,
)
openai_response_prompt = OpenAIResponsePrompt(id=prompt_id, version="1")
mock_prompts_api.get_prompt.return_value = prompt
mock_inference_api.openai_chat_completion.return_value = fake_stream()
await openai_responses_impl.create_openai_response(
input=input_text,
model=model,
prompt=openai_response_prompt,
)
mock_prompts_api.get_prompt.assert_called_with(prompt_id, 1)
mock_inference_api.openai_chat_completion.assert_called()
call_args = mock_inference_api.openai_chat_completion.call_args
sent_messages = call_args.args[0].messages
assert len(sent_messages) == 2
system_messages = [msg for msg in sent_messages if msg.role == "system"]
assert system_messages[0].content == "You are a helpful assistant. Always provide accurate information."
async def test_prepend_prompt_invalid_variable(openai_responses_impl, mock_prompts_api):
"""Test error handling in prepend_prompt function when prompt parameters contain invalid variables."""
prompt_id = "pmpt_1234567890abcdef1234567890abcdef1234567890abcdef"
prompt = Prompt(
prompt="You are a {{ role }} assistant.",
prompt_id=prompt_id,
version=1,
variables=["role"], # Only "role" is valid
is_default=True,
)
openai_response_prompt = OpenAIResponsePrompt(
id=prompt_id,
version="1",
variables={
"role": OpenAIResponseInputMessageContentText(text="helpful"),
"company": OpenAIResponseInputMessageContentText(
text="Dummy Company"
), # company is not in prompt.variables
},
)
mock_prompts_api.get_prompt.return_value = prompt
# Initial messages
messages = [OpenAIUserMessageParam(content="Test prompt")]
# Execute - should raise ValueError for invalid variable
with pytest.raises(ValueError, match="Variable company not found in prompt"):
await openai_responses_impl._prepend_prompt(messages, openai_response_prompt)
# Verify
mock_prompts_api.get_prompt.assert_called_once_with(prompt_id, 1)
async def test_prepend_prompt_not_found(openai_responses_impl, mock_prompts_api):
"""Test prepend_prompt function when prompt is not found."""
prompt_id = "pmpt_nonexistent"
openai_response_prompt = OpenAIResponsePrompt(id=prompt_id, version="1")
mock_prompts_api.get_prompt.return_value = None # Prompt not found
# Initial messages
messages = [OpenAIUserMessageParam(content="Test prompt")]
initial_length = len(messages)
# Execute
result = await openai_responses_impl._prepend_prompt(messages, openai_response_prompt)
# Verify
mock_prompts_api.get_prompt.assert_called_once_with(prompt_id, 1)
# Should return None when prompt not found
assert result is None
# Messages should not be modified
assert len(messages) == initial_length
assert messages[0].content == "Test prompt"
async def test_prepend_prompt_variable_substitution(openai_responses_impl, mock_prompts_api):
"""Test complex variable substitution with multiple occurrences and special characters in prepend_prompt function."""
prompt_id = "pmpt_1234567890abcdef1234567890abcdef1234567890abcdef"
# Support all whitespace variations: {{name}}, {{ name }}, {{ name}}, {{name }}, etc.
prompt = Prompt(
prompt="Hello {{name}}! You are working at {{ company}}. Your role is {{role}} at {{company}}. Remember, {{ name }}, to be {{ tone }}.",
prompt_id=prompt_id,
version=1,
variables=["name", "company", "role", "tone"],
is_default=True,
)
openai_response_prompt = OpenAIResponsePrompt(
id=prompt_id,
version="1",
variables={
"name": OpenAIResponseInputMessageContentText(text="Alice"),
"company": OpenAIResponseInputMessageContentText(text="Dummy Company"),
"role": OpenAIResponseInputMessageContentText(text="AI Assistant"),
"tone": OpenAIResponseInputMessageContentText(text="professional"),
},
)
mock_prompts_api.get_prompt.return_value = prompt
# Initial messages
messages = [OpenAIUserMessageParam(content="Test")]
# Execute
await openai_responses_impl._prepend_prompt(messages, openai_response_prompt)
# Verify
assert len(messages) == 2
assert isinstance(messages[0], OpenAISystemMessageParam)
expected_content = "Hello Alice! You are working at Dummy Company. Your role is AI Assistant at Dummy Company. Remember, Alice, to be professional."
assert messages[0].content == expected_content
async def test_prepend_prompt_with_image_variable(openai_responses_impl, mock_prompts_api, mock_files_api):
"""Test prepend_prompt with image variable - should create placeholder in system message and append image as separate user message."""
prompt_id = "pmpt_1234567890abcdef1234567890abcdef1234567890abcdef"
prompt = Prompt(
prompt="Analyze this {{product_image}} and describe what you see.",
prompt_id=prompt_id,
version=1,
variables=["product_image"],
is_default=True,
)
# Mock file content and file metadata
mock_file_content = b"fake_image_data"
mock_files_api.openai_retrieve_file_content.return_value = type("obj", (object,), {"body": mock_file_content})()
mock_files_api.openai_retrieve_file.return_value = OpenAIFileObject(
object="file",
id="file-abc123",
bytes=len(mock_file_content),
created_at=1234567890,
expires_at=1234567890,
filename="product.jpg",
purpose="assistants",
)
openai_response_prompt = OpenAIResponsePrompt(
id=prompt_id,
version="1",
variables={
"product_image": OpenAIResponseInputMessageContentImage(
file_id="file-abc123",
detail="high",
)
},
)
mock_prompts_api.get_prompt.return_value = prompt
# Initial messages
messages = [OpenAIUserMessageParam(content="What do you think?")]
# Execute
await openai_responses_impl._prepend_prompt(messages, openai_response_prompt)
assert len(messages) == 3
# Check system message has placeholder
assert isinstance(messages[0], OpenAISystemMessageParam)
assert messages[0].content == "Analyze this [Image: product_image] and describe what you see."
# Check original user message is still there
assert isinstance(messages[1], OpenAIUserMessageParam)
assert messages[1].content == "What do you think?"
# Check new user message with image is appended
assert isinstance(messages[2], OpenAIUserMessageParam)
assert isinstance(messages[2].content, list)
assert len(messages[2].content) == 1
# Should be image with data URL
assert isinstance(messages[2].content[0], OpenAIChatCompletionContentPartImageParam)
assert messages[2].content[0].image_url.url.startswith("data:image/")
assert messages[2].content[0].image_url.detail == "high"
async def test_prepend_prompt_with_file_variable(openai_responses_impl, mock_prompts_api, mock_files_api):
"""Test prepend_prompt with file variable - should create placeholder in system message and append file as separate user message."""
prompt_id = "pmpt_1234567890abcdef1234567890abcdef1234567890abcdef"
prompt = Prompt(
prompt="Review the document {{contract_file}} and summarize key points.",
prompt_id=prompt_id,
version=1,
variables=["contract_file"],
is_default=True,
)
# Mock file retrieval
mock_file_content = b"fake_pdf_content"
mock_files_api.openai_retrieve_file_content.return_value = type("obj", (object,), {"body": mock_file_content})()
mock_files_api.openai_retrieve_file.return_value = OpenAIFileObject(
object="file",
id="file-contract-789",
bytes=len(mock_file_content),
created_at=1234567890,
expires_at=1234567890,
filename="contract.pdf",
purpose="assistants",
)
openai_response_prompt = OpenAIResponsePrompt(
id=prompt_id,
version="1",
variables={
"contract_file": OpenAIResponseInputMessageContentFile(
file_id="file-contract-789",
filename="contract.pdf",
)
},
)
mock_prompts_api.get_prompt.return_value = prompt
# Initial messages
messages = [OpenAIUserMessageParam(content="Please review this.")]
# Execute
await openai_responses_impl._prepend_prompt(messages, openai_response_prompt)
assert len(messages) == 3
# Check system message has placeholder
assert isinstance(messages[0], OpenAISystemMessageParam)
assert messages[0].content == "Review the document [File: contract_file] and summarize key points."
# Check original user message is still there
assert isinstance(messages[1], OpenAIUserMessageParam)
assert messages[1].content == "Please review this."
# Check new user message with file is appended
assert isinstance(messages[2], OpenAIUserMessageParam)
assert isinstance(messages[2].content, list)
assert len(messages[2].content) == 1
# First part should be file with data URL
assert isinstance(messages[2].content[0], OpenAIFile)
assert messages[2].content[0].file.file_data.startswith("data:application/pdf;base64,")
assert messages[2].content[0].file.filename == "contract.pdf"
assert messages[2].content[0].file.file_id is None
async def test_prepend_prompt_with_mixed_variables(openai_responses_impl, mock_prompts_api, mock_files_api):
"""Test prepend_prompt with text, image, and file variables mixed together."""
prompt_id = "pmpt_1234567890abcdef1234567890abcdef1234567890abcdef"
prompt = Prompt(
prompt="Hello {{name}}! Analyze {{photo}} and review {{document}}. Provide insights for {{company}}.",
prompt_id=prompt_id,
version=1,
variables=["name", "photo", "document", "company"],
is_default=True,
)
# Mock file retrieval for image and file
mock_image_content = b"fake_image_data"
mock_file_content = b"fake_doc_content"
async def mock_retrieve_file_content(file_id):
if file_id == "file-photo-123":
return type("obj", (object,), {"body": mock_image_content})()
elif file_id == "file-doc-456":
return type("obj", (object,), {"body": mock_file_content})()
mock_files_api.openai_retrieve_file_content.side_effect = mock_retrieve_file_content
def mock_retrieve_file(file_id):
if file_id == "file-photo-123":
return OpenAIFileObject(
object="file",
id="file-photo-123",
bytes=len(mock_image_content),
created_at=1234567890,
expires_at=1234567890,
filename="photo.jpg",
purpose="assistants",
)
elif file_id == "file-doc-456":
return OpenAIFileObject(
object="file",
id="file-doc-456",
bytes=len(mock_file_content),
created_at=1234567890,
expires_at=1234567890,
filename="doc.pdf",
purpose="assistants",
)
mock_files_api.openai_retrieve_file.side_effect = mock_retrieve_file
openai_response_prompt = OpenAIResponsePrompt(
id=prompt_id,
version="1",
variables={
"name": OpenAIResponseInputMessageContentText(text="Alice"),
"photo": OpenAIResponseInputMessageContentImage(file_id="file-photo-123", detail="auto"),
"document": OpenAIResponseInputMessageContentFile(file_id="file-doc-456", filename="doc.pdf"),
"company": OpenAIResponseInputMessageContentText(text="Acme Corp"),
},
)
mock_prompts_api.get_prompt.return_value = prompt
# Initial messages
messages = [OpenAIUserMessageParam(content="Here's my question.")]
# Execute
await openai_responses_impl._prepend_prompt(messages, openai_response_prompt)
assert len(messages) == 3
# Check system message has text and placeholders
assert isinstance(messages[0], OpenAISystemMessageParam)
expected_system = "Hello Alice! Analyze [Image: photo] and review [File: document]. Provide insights for Acme Corp."
assert messages[0].content == expected_system
# Check original user message is still there
assert isinstance(messages[1], OpenAIUserMessageParam)
assert messages[1].content == "Here's my question."
# Check new user message with media is appended (2 media items)
assert isinstance(messages[2], OpenAIUserMessageParam)
assert isinstance(messages[2].content, list)
assert len(messages[2].content) == 2
# First part should be image with data URL
assert isinstance(messages[2].content[0], OpenAIChatCompletionContentPartImageParam)
assert messages[2].content[0].image_url.url.startswith("data:image/")
# Second part should be file with data URL
assert isinstance(messages[2].content[1], OpenAIFile)
assert messages[2].content[1].file.file_data.startswith("data:application/pdf;base64,")
assert messages[2].content[1].file.filename == "doc.pdf"
assert messages[2].content[1].file.file_id is None
async def test_prepend_prompt_with_image_using_image_url(openai_responses_impl, mock_prompts_api):
"""Test prepend_prompt with image variable using image_url instead of file_id."""
prompt_id = "pmpt_1234567890abcdef1234567890abcdef1234567890abcdef"
prompt = Prompt(
prompt="Describe {{screenshot}}.",
prompt_id=prompt_id,
version=1,
variables=["screenshot"],
is_default=True,
)
openai_response_prompt = OpenAIResponsePrompt(
id=prompt_id,
version="1",
variables={
"screenshot": OpenAIResponseInputMessageContentImage(
image_url="https://example.com/screenshot.png",
detail="low",
)
},
)
mock_prompts_api.get_prompt.return_value = prompt
# Initial messages
messages = [OpenAIUserMessageParam(content="What is this?")]
# Execute
await openai_responses_impl._prepend_prompt(messages, openai_response_prompt)
assert len(messages) == 3
# Check system message has placeholder
assert isinstance(messages[0], OpenAISystemMessageParam)
assert messages[0].content == "Describe [Image: screenshot]."
# Check original user message is still there
assert isinstance(messages[1], OpenAIUserMessageParam)
assert messages[1].content == "What is this?"
# Check new user message with image is appended
assert isinstance(messages[2], OpenAIUserMessageParam)
assert isinstance(messages[2].content, list)
# Image should use the provided URL
assert isinstance(messages[2].content[0], OpenAIChatCompletionContentPartImageParam)
assert messages[2].content[0].image_url.url == "https://example.com/screenshot.png"
assert messages[2].content[0].image_url.detail == "low"
async def test_prepend_prompt_image_variable_missing_required_fields(openai_responses_impl, mock_prompts_api):
"""Test prepend_prompt with image variable that has neither file_id nor image_url - should raise error."""
prompt_id = "pmpt_1234567890abcdef1234567890abcdef1234567890abcdef"
prompt = Prompt(
prompt="Analyze {{bad_image}}.",
prompt_id=prompt_id,
version=1,
variables=["bad_image"],
is_default=True,
)
# Create image content with neither file_id nor image_url
openai_response_prompt = OpenAIResponsePrompt(
id=prompt_id,
version="1",
variables={"bad_image": OpenAIResponseInputMessageContentImage()}, # No file_id or image_url
)
mock_prompts_api.get_prompt.return_value = prompt
messages = [OpenAIUserMessageParam(content="Test")]
# Execute - should raise ValueError
with pytest.raises(ValueError, match="Image content must have either 'image_url' or 'file_id'"):
await openai_responses_impl._prepend_prompt(messages, openai_response_prompt)

View file

@ -7,20 +7,20 @@
import pytest
from llama_stack.apis.agents.openai_responses import (
from llama_stack_api.common.errors import (
ConversationNotFoundError,
InvalidConversationIdError,
)
from llama_stack_api.conversations import (
ConversationItemList,
)
from llama_stack_api.openai_responses import (
OpenAIResponseMessage,
OpenAIResponseObject,
OpenAIResponseObjectStreamResponseCompleted,
OpenAIResponseObjectStreamResponseOutputItemDone,
OpenAIResponseOutputMessageContentOutputText,
)
from llama_stack.apis.common.errors import (
ConversationNotFoundError,
InvalidConversationIdError,
)
from llama_stack.apis.conversations.conversations import (
ConversationItemList,
)
# Import existing fixtures from the main responses test file
pytest_plugins = ["tests.unit.providers.agents.meta_reference.test_openai_responses"]
@ -39,6 +39,8 @@ def responses_impl_with_conversations(
mock_vector_io_api,
mock_conversations_api,
mock_safety_api,
mock_prompts_api,
mock_files_api,
):
"""Create OpenAIResponsesImpl instance with conversations API."""
return OpenAIResponsesImpl(
@ -49,6 +51,8 @@ def responses_impl_with_conversations(
vector_io_api=mock_vector_io_api,
conversations_api=mock_conversations_api,
safety_api=mock_safety_api,
prompts_api=mock_prompts_api,
files_api=mock_files_api,
)

View file

@ -5,22 +5,20 @@
# the root directory of this source tree.
from unittest.mock import AsyncMock
import pytest
from llama_stack.apis.agents.openai_responses import (
OpenAIResponseAnnotationFileCitation,
OpenAIResponseInputFunctionToolCallOutput,
OpenAIResponseInputMessageContentImage,
OpenAIResponseInputMessageContentText,
OpenAIResponseInputToolFunction,
OpenAIResponseInputToolWebSearch,
OpenAIResponseMessage,
OpenAIResponseOutputMessageContentOutputText,
OpenAIResponseOutputMessageFunctionToolCall,
OpenAIResponseText,
OpenAIResponseTextFormat,
from llama_stack.providers.inline.agents.meta_reference.responses.utils import (
_extract_citations_from_text,
convert_chat_choice_to_response_message,
convert_response_content_to_chat_content,
convert_response_input_to_chat_messages,
convert_response_text_to_chat_response_format,
get_message_type_by_role,
is_function_tool_call,
)
from llama_stack.apis.inference import (
from llama_stack_api.inference import (
OpenAIAssistantMessageParam,
OpenAIChatCompletionContentPartImageParam,
OpenAIChatCompletionContentPartTextParam,
@ -35,17 +33,27 @@ from llama_stack.apis.inference import (
OpenAIToolMessageParam,
OpenAIUserMessageParam,
)
from llama_stack.providers.inline.agents.meta_reference.responses.utils import (
_extract_citations_from_text,
convert_chat_choice_to_response_message,
convert_response_content_to_chat_content,
convert_response_input_to_chat_messages,
convert_response_text_to_chat_response_format,
get_message_type_by_role,
is_function_tool_call,
from llama_stack_api.openai_responses import (
OpenAIResponseAnnotationFileCitation,
OpenAIResponseInputFunctionToolCallOutput,
OpenAIResponseInputMessageContentImage,
OpenAIResponseInputMessageContentText,
OpenAIResponseInputToolFunction,
OpenAIResponseInputToolWebSearch,
OpenAIResponseMessage,
OpenAIResponseOutputMessageContentOutputText,
OpenAIResponseOutputMessageFunctionToolCall,
OpenAIResponseText,
OpenAIResponseTextFormat,
)
@pytest.fixture
def mock_files_api():
"""Mock files API for testing."""
return AsyncMock()
class TestConvertChatChoiceToResponseMessage:
async def test_convert_string_content(self):
choice = OpenAIChoice(
@ -78,17 +86,17 @@ class TestConvertChatChoiceToResponseMessage:
class TestConvertResponseContentToChatContent:
async def test_convert_string_content(self):
result = await convert_response_content_to_chat_content("Simple string")
async def test_convert_string_content(self, mock_files_api):
result = await convert_response_content_to_chat_content("Simple string", mock_files_api)
assert result == "Simple string"
async def test_convert_text_content_parts(self):
async def test_convert_text_content_parts(self, mock_files_api):
content = [
OpenAIResponseInputMessageContentText(text="First part"),
OpenAIResponseOutputMessageContentOutputText(text="Second part"),
]
result = await convert_response_content_to_chat_content(content)
result = await convert_response_content_to_chat_content(content, mock_files_api)
assert len(result) == 2
assert isinstance(result[0], OpenAIChatCompletionContentPartTextParam)
@ -96,10 +104,10 @@ class TestConvertResponseContentToChatContent:
assert isinstance(result[1], OpenAIChatCompletionContentPartTextParam)
assert result[1].text == "Second part"
async def test_convert_image_content(self):
async def test_convert_image_content(self, mock_files_api):
content = [OpenAIResponseInputMessageContentImage(image_url="https://example.com/image.jpg", detail="high")]
result = await convert_response_content_to_chat_content(content)
result = await convert_response_content_to_chat_content(content, mock_files_api)
assert len(result) == 1
assert isinstance(result[0], OpenAIChatCompletionContentPartImageParam)

View file

@ -5,7 +5,8 @@
# the root directory of this source tree.
from llama_stack.apis.agents.openai_responses import (
from llama_stack.providers.inline.agents.meta_reference.responses.types import ToolContext
from llama_stack_api.openai_responses import (
MCPListToolsTool,
OpenAIResponseInputToolFileSearch,
OpenAIResponseInputToolFunction,
@ -15,7 +16,6 @@ from llama_stack.apis.agents.openai_responses import (
OpenAIResponseOutputMessageMCPListTools,
OpenAIResponseToolMCP,
)
from llama_stack.providers.inline.agents.meta_reference.responses.types import ToolContext
class TestToolContext:

View file

@ -8,8 +8,6 @@ from unittest.mock import AsyncMock
import pytest
from llama_stack.apis.agents.agents import ResponseGuardrailSpec
from llama_stack.apis.safety import ModerationObject, ModerationObjectResults
from llama_stack.providers.inline.agents.meta_reference.responses.openai_responses import (
OpenAIResponsesImpl,
)
@ -17,6 +15,8 @@ from llama_stack.providers.inline.agents.meta_reference.responses.utils import (
extract_guardrail_ids,
run_guardrails,
)
from llama_stack_api.agents import ResponseGuardrailSpec
from llama_stack_api.safety import ModerationObject, ModerationObjectResults
@pytest.fixture
@ -30,6 +30,8 @@ def mock_apis():
"vector_io_api": AsyncMock(),
"conversations_api": AsyncMock(),
"safety_api": AsyncMock(),
"prompts_api": AsyncMock(),
"files_api": AsyncMock(),
}

View file

@ -0,0 +1,214 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
"""Tests for making Safety API optional in meta-reference agents provider.
This test suite validates the changes introduced to fix issue #4165, which
allows running the meta-reference agents provider without the Safety API.
Safety API is now an optional dependency, and errors are raised at request time
when guardrails are explicitly requested without Safety API configured.
"""
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from llama_stack.core.datatypes import Api
from llama_stack.core.storage.datatypes import KVStoreReference, ResponsesStoreReference
from llama_stack.providers.inline.agents.meta_reference import get_provider_impl
from llama_stack.providers.inline.agents.meta_reference.config import (
AgentPersistenceConfig,
MetaReferenceAgentsImplConfig,
)
from llama_stack.providers.inline.agents.meta_reference.responses.utils import (
run_guardrails,
)
@pytest.fixture
def mock_persistence_config():
"""Create a mock persistence configuration."""
return AgentPersistenceConfig(
agent_state=KVStoreReference(
backend="kv_default",
namespace="agents",
),
responses=ResponsesStoreReference(
backend="sql_default",
table_name="responses",
),
)
@pytest.fixture
def mock_deps():
"""Create mock dependencies for the agents provider."""
# Create mock APIs
inference_api = AsyncMock()
vector_io_api = AsyncMock()
tool_runtime_api = AsyncMock()
tool_groups_api = AsyncMock()
conversations_api = AsyncMock()
prompts_api = AsyncMock()
files_api = AsyncMock()
return {
Api.inference: inference_api,
Api.vector_io: vector_io_api,
Api.tool_runtime: tool_runtime_api,
Api.tool_groups: tool_groups_api,
Api.conversations: conversations_api,
Api.prompts: prompts_api,
Api.files: files_api,
}
class TestProviderInitialization:
"""Test provider initialization with different safety API configurations."""
async def test_initialization_with_safety_api_present(self, mock_persistence_config, mock_deps):
"""Test successful initialization when Safety API is configured."""
config = MetaReferenceAgentsImplConfig(persistence=mock_persistence_config)
# Add safety API to deps
safety_api = AsyncMock()
mock_deps[Api.safety] = safety_api
# Mock the initialize method to avoid actual initialization
with patch(
"llama_stack.providers.inline.agents.meta_reference.agents.MetaReferenceAgentsImpl.initialize",
new_callable=AsyncMock,
):
# Should not raise any exception
provider = await get_provider_impl(config, mock_deps, policy=[], telemetry_enabled=False)
assert provider is not None
async def test_initialization_without_safety_api(self, mock_persistence_config, mock_deps):
"""Test successful initialization when Safety API is not configured."""
config = MetaReferenceAgentsImplConfig(persistence=mock_persistence_config)
# Safety API is NOT in mock_deps - provider should still start
# Mock the initialize method to avoid actual initialization
with patch(
"llama_stack.providers.inline.agents.meta_reference.agents.MetaReferenceAgentsImpl.initialize",
new_callable=AsyncMock,
):
# Should not raise any exception
provider = await get_provider_impl(config, mock_deps, policy=[], telemetry_enabled=False)
assert provider is not None
assert provider.safety_api is None
class TestGuardrailsFunctionality:
"""Test run_guardrails function with optional safety API."""
async def test_run_guardrails_with_none_safety_api(self):
"""Test that run_guardrails returns None when safety_api is None."""
result = await run_guardrails(safety_api=None, messages="test message", guardrail_ids=["llama-guard"])
assert result is None
async def test_run_guardrails_with_empty_messages(self):
"""Test that run_guardrails returns None for empty messages."""
# Test with None safety API
result = await run_guardrails(safety_api=None, messages="", guardrail_ids=["llama-guard"])
assert result is None
# Test with mock safety API
mock_safety_api = AsyncMock()
result = await run_guardrails(safety_api=mock_safety_api, messages="", guardrail_ids=["llama-guard"])
assert result is None
async def test_run_guardrails_with_none_safety_api_ignores_guardrails(self):
"""Test that guardrails are skipped when safety_api is None, even if guardrail_ids are provided."""
# Should not raise exception, just return None
result = await run_guardrails(
safety_api=None,
messages="potentially harmful content",
guardrail_ids=["llama-guard", "content-filter"],
)
assert result is None
async def test_create_response_rejects_guardrails_without_safety_api(self, mock_persistence_config, mock_deps):
"""Test that create_openai_response raises error when guardrails requested but Safety API unavailable."""
from llama_stack.providers.inline.agents.meta_reference.responses.openai_responses import (
OpenAIResponsesImpl,
)
from llama_stack_api import ResponseGuardrailSpec
# Create OpenAIResponsesImpl with no safety API
with patch("llama_stack.providers.inline.agents.meta_reference.responses.openai_responses.ResponsesStore"):
impl = OpenAIResponsesImpl(
inference_api=mock_deps[Api.inference],
tool_groups_api=mock_deps[Api.tool_groups],
tool_runtime_api=mock_deps[Api.tool_runtime],
responses_store=MagicMock(),
vector_io_api=mock_deps[Api.vector_io],
safety_api=None, # No Safety API
conversations_api=mock_deps[Api.conversations],
prompts_api=mock_deps[Api.prompts],
files_api=mock_deps[Api.files],
)
# Test with string guardrail
with pytest.raises(ValueError) as exc_info:
await impl.create_openai_response(
input="test input",
model="test-model",
guardrails=["llama-guard"],
)
assert "Cannot process guardrails: Safety API is not configured" in str(exc_info.value)
# Test with ResponseGuardrailSpec
with pytest.raises(ValueError) as exc_info:
await impl.create_openai_response(
input="test input",
model="test-model",
guardrails=[ResponseGuardrailSpec(type="llama-guard")],
)
assert "Cannot process guardrails: Safety API is not configured" in str(exc_info.value)
async def test_create_response_succeeds_without_guardrails_and_no_safety_api(
self, mock_persistence_config, mock_deps
):
"""Test that create_openai_response works when no guardrails requested and Safety API unavailable."""
from llama_stack.providers.inline.agents.meta_reference.responses.openai_responses import (
OpenAIResponsesImpl,
)
# Create OpenAIResponsesImpl with no safety API
with (
patch("llama_stack.providers.inline.agents.meta_reference.responses.openai_responses.ResponsesStore"),
patch.object(OpenAIResponsesImpl, "_create_streaming_response", new_callable=AsyncMock) as mock_stream,
):
# Mock the streaming response to return a simple async generator
async def mock_generator():
yield MagicMock()
mock_stream.return_value = mock_generator()
impl = OpenAIResponsesImpl(
inference_api=mock_deps[Api.inference],
tool_groups_api=mock_deps[Api.tool_groups],
tool_runtime_api=mock_deps[Api.tool_runtime],
responses_store=MagicMock(),
vector_io_api=mock_deps[Api.vector_io],
safety_api=None, # No Safety API
conversations_api=mock_deps[Api.conversations],
prompts_api=mock_deps[Api.prompts],
files_api=mock_deps[Api.files],
)
# Should not raise when no guardrails requested
# Note: This will still fail later in execution due to mocking, but should pass the validation
try:
await impl.create_openai_response(
input="test input",
model="test-model",
guardrails=None, # No guardrails
)
except Exception as e:
# Ensure the error is NOT about missing Safety API
assert "Cannot process guardrails: Safety API is not configured" not in str(e)

View file

@ -1,169 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import uuid
from datetime import datetime
from unittest.mock import patch
import pytest
from llama_stack.apis.agents import Turn
from llama_stack.apis.inference import CompletionMessage, StopReason
from llama_stack.core.datatypes import User
from llama_stack.providers.inline.agents.meta_reference.persistence import AgentPersistence, AgentSessionInfo
@pytest.fixture
async def test_setup(sqlite_kvstore):
agent_persistence = AgentPersistence(agent_id="test_agent", kvstore=sqlite_kvstore, policy={})
yield agent_persistence
@patch("llama_stack.providers.inline.agents.meta_reference.persistence.get_authenticated_user")
async def test_session_creation_with_access_attributes(mock_get_authenticated_user, test_setup):
agent_persistence = test_setup
# Set creator's attributes for the session
creator_attributes = {"roles": ["researcher"], "teams": ["ai-team"]}
mock_get_authenticated_user.return_value = User("test_user", creator_attributes)
# Create a session
session_id = await agent_persistence.create_session("Test Session")
# Get the session and verify access attributes were set
session_info = await agent_persistence.get_session_info(session_id)
assert session_info is not None
assert session_info.owner is not None
assert session_info.owner.attributes is not None
assert session_info.owner.attributes["roles"] == ["researcher"]
assert session_info.owner.attributes["teams"] == ["ai-team"]
@patch("llama_stack.providers.inline.agents.meta_reference.persistence.get_authenticated_user")
async def test_session_access_control(mock_get_authenticated_user, test_setup):
agent_persistence = test_setup
# Create a session with specific access attributes
session_id = str(uuid.uuid4())
session_info = AgentSessionInfo(
session_id=session_id,
session_name="Restricted Session",
started_at=datetime.now(),
owner=User("someone", {"roles": ["admin"], "teams": ["security-team"]}),
turns=[],
identifier="Restricted Session",
)
await agent_persistence.kvstore.set(
key=f"session:{agent_persistence.agent_id}:{session_id}",
value=session_info.model_dump_json(),
)
# User with matching attributes can access
mock_get_authenticated_user.return_value = User(
"testuser", {"roles": ["admin", "user"], "teams": ["security-team", "other-team"]}
)
retrieved_session = await agent_persistence.get_session_info(session_id)
assert retrieved_session is not None
assert retrieved_session.session_id == session_id
# User without matching attributes cannot access
mock_get_authenticated_user.return_value = User("testuser", {"roles": ["user"], "teams": ["other-team"]})
retrieved_session = await agent_persistence.get_session_info(session_id)
assert retrieved_session is None
@patch("llama_stack.providers.inline.agents.meta_reference.persistence.get_authenticated_user")
async def test_turn_access_control(mock_get_authenticated_user, test_setup):
agent_persistence = test_setup
# Create a session with restricted access
session_id = str(uuid.uuid4())
session_info = AgentSessionInfo(
session_id=session_id,
session_name="Restricted Session",
started_at=datetime.now(),
owner=User("someone", {"roles": ["admin"]}),
turns=[],
identifier="Restricted Session",
)
await agent_persistence.kvstore.set(
key=f"session:{agent_persistence.agent_id}:{session_id}",
value=session_info.model_dump_json(),
)
# Create a turn for this session
turn_id = str(uuid.uuid4())
turn = Turn(
session_id=session_id,
turn_id=turn_id,
steps=[],
started_at=datetime.now(),
input_messages=[],
output_message=CompletionMessage(
content="Hello",
stop_reason=StopReason.end_of_turn,
),
)
# Admin can add turn
mock_get_authenticated_user.return_value = User("testuser", {"roles": ["admin"]})
await agent_persistence.add_turn_to_session(session_id, turn)
# Admin can get turn
retrieved_turn = await agent_persistence.get_session_turn(session_id, turn_id)
assert retrieved_turn is not None
assert retrieved_turn.turn_id == turn_id
# Regular user cannot get turn
mock_get_authenticated_user.return_value = User("testuser", {"roles": ["user"]})
with pytest.raises(ValueError):
await agent_persistence.get_session_turn(session_id, turn_id)
# Regular user cannot get turns for session
with pytest.raises(ValueError):
await agent_persistence.get_session_turns(session_id)
@patch("llama_stack.providers.inline.agents.meta_reference.persistence.get_authenticated_user")
async def test_tool_call_and_infer_iters_access_control(mock_get_authenticated_user, test_setup):
agent_persistence = test_setup
# Create a session with restricted access
session_id = str(uuid.uuid4())
session_info = AgentSessionInfo(
session_id=session_id,
session_name="Restricted Session",
started_at=datetime.now(),
owner=User("someone", {"roles": ["admin"]}),
turns=[],
identifier="Restricted Session",
)
await agent_persistence.kvstore.set(
key=f"session:{agent_persistence.agent_id}:{session_id}",
value=session_info.model_dump_json(),
)
turn_id = str(uuid.uuid4())
# Admin user can set inference iterations
mock_get_authenticated_user.return_value = User("testuser", {"roles": ["admin"]})
await agent_persistence.set_num_infer_iters_in_turn(session_id, turn_id, 5)
# Admin user can get inference iterations
infer_iters = await agent_persistence.get_num_infer_iters_in_turn(session_id, turn_id)
assert infer_iters == 5
# Regular user cannot get inference iterations
mock_get_authenticated_user.return_value = User("testuser", {"roles": ["user"]})
infer_iters = await agent_persistence.get_num_infer_iters_in_turn(session_id, turn_id)
assert infer_iters is None
# Regular user cannot set inference iterations (should raise ValueError)
with pytest.raises(ValueError):
await agent_persistence.set_num_infer_iters_in_turn(session_id, turn_id, 10)

View file

@ -13,9 +13,9 @@ from unittest.mock import AsyncMock
import pytest
from llama_stack.core.storage.datatypes import KVStoreReference, SqliteKVStoreConfig
from llama_stack.core.storage.kvstore import kvstore_impl, register_kvstore_backends
from llama_stack.providers.inline.batches.reference.batches import ReferenceBatchesImpl
from llama_stack.providers.inline.batches.reference.config import ReferenceBatchesImplConfig
from llama_stack.providers.utils.kvstore import kvstore_impl, register_kvstore_backends
@pytest.fixture

View file

@ -59,8 +59,7 @@ from unittest.mock import AsyncMock, MagicMock
import pytest
from llama_stack.apis.batches import BatchObject
from llama_stack.apis.common.errors import ConflictError, ResourceNotFoundError
from llama_stack_api import BatchObject, ConflictError, ResourceNotFoundError
class TestReferenceBatchesImpl:

View file

@ -44,7 +44,7 @@ import asyncio
import pytest
from llama_stack.apis.common.errors import ConflictError
from llama_stack_api import ConflictError
class TestReferenceBatchesIdempotency:

View file

@ -9,8 +9,8 @@ import pytest
from moto import mock_aws
from llama_stack.core.storage.datatypes import SqliteSqlStoreConfig, SqlStoreReference
from llama_stack.core.storage.sqlstore.sqlstore import register_sqlstore_backends
from llama_stack.providers.remote.files.s3 import S3FilesImplConfig, get_adapter_impl
from llama_stack.providers.utils.sqlstore.sqlstore import register_sqlstore_backends
class MockUploadFile:

View file

@ -9,8 +9,7 @@ from unittest.mock import patch
import pytest
from botocore.exceptions import ClientError
from llama_stack.apis.common.errors import ResourceNotFoundError
from llama_stack.apis.files import OpenAIFilePurpose
from llama_stack_api import OpenAIFilePurpose, ResourceNotFoundError
class TestS3FilesImpl:
@ -228,7 +227,7 @@ class TestS3FilesImpl:
mock_now.return_value = 0
from llama_stack.apis.files import ExpiresAfter
from llama_stack_api import ExpiresAfter
sample_text_file.filename = "test_expired_file"
uploaded = await s3_provider.openai_upload_file(
@ -260,7 +259,7 @@ class TestS3FilesImpl:
async def test_unsupported_expires_after_anchor(self, s3_provider, sample_text_file):
"""Unsupported anchor value should raise ValueError."""
from llama_stack.apis.files import ExpiresAfter
from llama_stack_api import ExpiresAfter
sample_text_file.filename = "test_unsupported_expires_after_anchor"
@ -273,7 +272,7 @@ class TestS3FilesImpl:
async def test_nonint_expires_after_seconds(self, s3_provider, sample_text_file):
"""Non-integer seconds in expires_after should raise ValueError."""
from llama_stack.apis.files import ExpiresAfter
from llama_stack_api import ExpiresAfter
sample_text_file.filename = "test_nonint_expires_after_seconds"
@ -286,7 +285,7 @@ class TestS3FilesImpl:
async def test_expires_after_seconds_out_of_bounds(self, s3_provider, sample_text_file):
"""Seconds outside allowed range should raise ValueError."""
from llama_stack.apis.files import ExpiresAfter
from llama_stack_api import ExpiresAfter
with pytest.raises(ValueError, match="greater than or equal to 3600"):
await s3_provider.openai_upload_file(

View file

@ -8,10 +8,9 @@ from unittest.mock import patch
import pytest
from llama_stack.apis.common.errors import ResourceNotFoundError
from llama_stack.apis.files import OpenAIFilePurpose
from llama_stack.core.datatypes import User
from llama_stack.providers.remote.files.s3.files import S3FilesImpl
from llama_stack_api import OpenAIFilePurpose, ResourceNotFoundError
async def test_listing_hides_other_users_file(s3_provider, sample_text_file):
@ -19,11 +18,11 @@ async def test_listing_hides_other_users_file(s3_provider, sample_text_file):
user_a = User("user-a", {"roles": ["team-a"]})
user_b = User("user-b", {"roles": ["team-b"]})
with patch("llama_stack.providers.utils.sqlstore.authorized_sqlstore.get_authenticated_user") as mock_get_user:
with patch("llama_stack.core.storage.sqlstore.authorized_sqlstore.get_authenticated_user") as mock_get_user:
mock_get_user.return_value = user_a
uploaded = await s3_provider.openai_upload_file(file=sample_text_file, purpose=OpenAIFilePurpose.ASSISTANTS)
with patch("llama_stack.providers.utils.sqlstore.authorized_sqlstore.get_authenticated_user") as mock_get_user:
with patch("llama_stack.core.storage.sqlstore.authorized_sqlstore.get_authenticated_user") as mock_get_user:
mock_get_user.return_value = user_b
listed = await s3_provider.openai_list_files()
assert all(f.id != uploaded.id for f in listed.data)
@ -42,11 +41,11 @@ async def test_cannot_access_other_user_file(s3_provider, sample_text_file, op):
user_a = User("user-a", {"roles": ["team-a"]})
user_b = User("user-b", {"roles": ["team-b"]})
with patch("llama_stack.providers.utils.sqlstore.authorized_sqlstore.get_authenticated_user") as mock_get_user:
with patch("llama_stack.core.storage.sqlstore.authorized_sqlstore.get_authenticated_user") as mock_get_user:
mock_get_user.return_value = user_a
uploaded = await s3_provider.openai_upload_file(file=sample_text_file, purpose=OpenAIFilePurpose.ASSISTANTS)
with patch("llama_stack.providers.utils.sqlstore.authorized_sqlstore.get_authenticated_user") as mock_get_user:
with patch("llama_stack.core.storage.sqlstore.authorized_sqlstore.get_authenticated_user") as mock_get_user:
mock_get_user.return_value = user_b
with pytest.raises(ResourceNotFoundError):
await op(s3_provider, uploaded.id)
@ -57,11 +56,11 @@ async def test_shared_role_allows_listing(s3_provider, sample_text_file):
user_a = User("user-a", {"roles": ["shared-role"]})
user_b = User("user-b", {"roles": ["shared-role"]})
with patch("llama_stack.providers.utils.sqlstore.authorized_sqlstore.get_authenticated_user") as mock_get_user:
with patch("llama_stack.core.storage.sqlstore.authorized_sqlstore.get_authenticated_user") as mock_get_user:
mock_get_user.return_value = user_a
uploaded = await s3_provider.openai_upload_file(file=sample_text_file, purpose=OpenAIFilePurpose.ASSISTANTS)
with patch("llama_stack.providers.utils.sqlstore.authorized_sqlstore.get_authenticated_user") as mock_get_user:
with patch("llama_stack.core.storage.sqlstore.authorized_sqlstore.get_authenticated_user") as mock_get_user:
mock_get_user.return_value = user_b
listed = await s3_provider.openai_list_files()
assert any(f.id == uploaded.id for f in listed.data)
@ -80,10 +79,10 @@ async def test_shared_role_allows_access(s3_provider, sample_text_file, op):
user_x = User("user-x", {"roles": ["shared-role"]})
user_y = User("user-y", {"roles": ["shared-role"]})
with patch("llama_stack.providers.utils.sqlstore.authorized_sqlstore.get_authenticated_user") as mock_get_user:
with patch("llama_stack.core.storage.sqlstore.authorized_sqlstore.get_authenticated_user") as mock_get_user:
mock_get_user.return_value = user_x
uploaded = await s3_provider.openai_upload_file(file=sample_text_file, purpose=OpenAIFilePurpose.ASSISTANTS)
with patch("llama_stack.providers.utils.sqlstore.authorized_sqlstore.get_authenticated_user") as mock_get_user:
with patch("llama_stack.core.storage.sqlstore.authorized_sqlstore.get_authenticated_user") as mock_get_user:
mock_get_user.return_value = user_y
await op(s3_provider, uploaded.id)

View file

@ -0,0 +1,78 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from types import SimpleNamespace
from unittest.mock import AsyncMock, MagicMock
import pytest
from openai import AuthenticationError
from llama_stack.providers.remote.inference.bedrock.bedrock import BedrockInferenceAdapter
from llama_stack.providers.remote.inference.bedrock.config import BedrockConfig
from llama_stack_api import OpenAIChatCompletionRequestWithExtraBody
def test_adapter_initialization():
config = BedrockConfig(api_key="test-key", region_name="us-east-1")
adapter = BedrockInferenceAdapter(config=config)
assert adapter.config.auth_credential.get_secret_value() == "test-key"
assert adapter.config.region_name == "us-east-1"
def test_client_url_construction():
config = BedrockConfig(api_key="test-key", region_name="us-west-2")
adapter = BedrockInferenceAdapter(config=config)
assert adapter.get_base_url() == "https://bedrock-runtime.us-west-2.amazonaws.com/openai/v1"
def test_api_key_from_config():
config = BedrockConfig(api_key="config-key", region_name="us-east-1")
adapter = BedrockInferenceAdapter(config=config)
assert adapter.config.auth_credential.get_secret_value() == "config-key"
def test_api_key_from_header_overrides_config():
"""Test API key from request header overrides config via client property"""
config = BedrockConfig(api_key="config-key", region_name="us-east-1")
adapter = BedrockInferenceAdapter(config=config)
adapter.provider_data_api_key_field = "aws_bearer_token_bedrock"
adapter.get_request_provider_data = MagicMock(return_value=SimpleNamespace(aws_bearer_token_bedrock="header-key"))
# The client property is where header override happens (in OpenAIMixin)
assert adapter.client.api_key == "header-key"
async def test_authentication_error_handling():
"""Test that AuthenticationError from OpenAI client is converted to ValueError with helpful message"""
config = BedrockConfig(api_key="invalid-key", region_name="us-east-1")
adapter = BedrockInferenceAdapter(config=config)
# Mock the parent class method to raise AuthenticationError
mock_response = MagicMock()
mock_response.message = "Invalid authentication credentials"
auth_error = AuthenticationError(message="Invalid authentication credentials", response=mock_response, body=None)
# Create a mock that raises the error
mock_super = AsyncMock(side_effect=auth_error)
# Patch the parent class method
original_method = BedrockInferenceAdapter.__bases__[0].openai_chat_completion
BedrockInferenceAdapter.__bases__[0].openai_chat_completion = mock_super
try:
with pytest.raises(ValueError) as exc_info:
params = OpenAIChatCompletionRequestWithExtraBody(
model="test-model", messages=[{"role": "user", "content": "test"}]
)
await adapter.openai_chat_completion(params=params)
assert "AWS Bedrock authentication failed" in str(exc_info.value)
assert "Please verify your API key" in str(exc_info.value)
finally:
# Restore original method
BedrockInferenceAdapter.__bases__[0].openai_chat_completion = original_method

View file

@ -0,0 +1,39 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from llama_stack.providers.remote.inference.bedrock.config import BedrockConfig
def test_bedrock_config_defaults_no_env(monkeypatch):
"""Test BedrockConfig defaults when env vars are not set"""
monkeypatch.delenv("AWS_BEARER_TOKEN_BEDROCK", raising=False)
monkeypatch.delenv("AWS_DEFAULT_REGION", raising=False)
config = BedrockConfig()
assert config.auth_credential is None
assert config.region_name == "us-east-2"
def test_bedrock_config_reads_from_env(monkeypatch):
"""Test BedrockConfig field initialization reads from environment variables"""
monkeypatch.setenv("AWS_DEFAULT_REGION", "eu-west-1")
config = BedrockConfig()
assert config.region_name == "eu-west-1"
def test_bedrock_config_with_values():
"""Test BedrockConfig accepts explicit values via alias"""
config = BedrockConfig(api_key="test-key", region_name="us-west-2")
assert config.auth_credential.get_secret_value() == "test-key"
assert config.region_name == "us-west-2"
def test_bedrock_config_sample():
"""Test BedrockConfig sample_run_config returns correct format"""
sample = BedrockConfig.sample_run_config()
assert "api_key" in sample
assert "region_name" in sample
assert sample["api_key"] == "${env.AWS_BEARER_TOKEN_BEDROCK:=}"
assert sample["region_name"] == "${env.AWS_DEFAULT_REGION:=us-east-2}"

View file

@ -120,7 +120,7 @@ from llama_stack.providers.remote.inference.watsonx.watsonx import WatsonXInfere
VLLMInferenceAdapter,
"llama_stack.providers.remote.inference.vllm.VLLMProviderDataValidator",
{
"url": "http://fake",
"base_url": "http://fake",
},
),
],
@ -153,7 +153,7 @@ def test_litellm_provider_data_used(config_cls, adapter_cls, provider_data_valid
"""Validate data for LiteLLM-based providers. Similar to test_openai_provider_data_used, but without the
assumption that there is an OpenAI-compatible client object."""
inference_adapter = adapter_cls(config=config_cls())
inference_adapter = adapter_cls(config=config_cls(base_url="http://fake"))
inference_adapter.__provider_spec__ = MagicMock()
inference_adapter.__provider_spec__.provider_data_validator = provider_data_validator

View file

@ -10,7 +10,13 @@ from unittest.mock import AsyncMock, MagicMock, PropertyMock, patch
import pytest
from llama_stack.apis.inference import (
from llama_stack.core.routers.inference import InferenceRouter
from llama_stack.core.routing_tables.models import ModelsRoutingTable
from llama_stack.providers.remote.inference.vllm.config import VLLMInferenceAdapterConfig
from llama_stack.providers.remote.inference.vllm.vllm import VLLMInferenceAdapter
from llama_stack_api import (
HealthStatus,
Model,
OpenAIAssistantMessageParam,
OpenAIChatCompletion,
OpenAIChatCompletionRequestWithExtraBody,
@ -20,12 +26,6 @@ from llama_stack.apis.inference import (
OpenAICompletionRequestWithExtraBody,
ToolChoice,
)
from llama_stack.apis.models import Model
from llama_stack.core.routers.inference import InferenceRouter
from llama_stack.core.routing_tables.models import ModelsRoutingTable
from llama_stack.providers.datatypes import HealthStatus
from llama_stack.providers.remote.inference.vllm.config import VLLMInferenceAdapterConfig
from llama_stack.providers.remote.inference.vllm.vllm import VLLMInferenceAdapter
# These are unit test for the remote vllm provider
# implementation. This should only contain tests which are specific to
@ -40,7 +40,7 @@ from llama_stack.providers.remote.inference.vllm.vllm import VLLMInferenceAdapte
@pytest.fixture(scope="function")
async def vllm_inference_adapter():
config = VLLMInferenceAdapterConfig(url="http://mocked.localhost:12345")
config = VLLMInferenceAdapterConfig(base_url="http://mocked.localhost:12345")
inference_adapter = VLLMInferenceAdapter(config=config)
inference_adapter.model_store = AsyncMock()
await inference_adapter.initialize()
@ -204,7 +204,7 @@ async def test_vllm_completion_extra_body():
via extra_body to the underlying OpenAI client through the InferenceRouter.
"""
# Set up the vLLM adapter
config = VLLMInferenceAdapterConfig(url="http://mocked.localhost:12345")
config = VLLMInferenceAdapterConfig(base_url="http://mocked.localhost:12345")
vllm_adapter = VLLMInferenceAdapter(config=config)
vllm_adapter.__provider_id__ = "vllm"
await vllm_adapter.initialize()
@ -277,7 +277,7 @@ async def test_vllm_chat_completion_extra_body():
via extra_body to the underlying OpenAI client through the InferenceRouter for chat completion.
"""
# Set up the vLLM adapter
config = VLLMInferenceAdapterConfig(url="http://mocked.localhost:12345")
config = VLLMInferenceAdapterConfig(base_url="http://mocked.localhost:12345")
vllm_adapter = VLLMInferenceAdapter(config=config)
vllm_adapter.__provider_id__ = "vllm"
await vllm_adapter.initialize()

View file

@ -8,11 +8,11 @@ from unittest.mock import AsyncMock
import pytest
from llama_stack.apis.tools import ToolDef
from llama_stack.providers.inline.agents.meta_reference.responses.streaming import (
convert_tooldef_to_chat_tool,
)
from llama_stack.providers.inline.agents.meta_reference.responses.types import ChatCompletionContext
from llama_stack_api import ToolDef
@pytest.fixture

View file

@ -0,0 +1,5 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.

View file

@ -0,0 +1,44 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from unittest.mock import Mock
import pytest
from llama_stack.providers.inline.inference.meta_reference.model_parallel import (
ModelRunner,
)
class TestModelRunner:
"""Test ModelRunner task dispatching for model-parallel inference."""
def test_chat_completion_task_dispatch(self):
"""Verify ModelRunner correctly dispatches chat_completion tasks."""
# Create a mock generator
mock_generator = Mock()
mock_generator.chat_completion = Mock(return_value=iter([]))
runner = ModelRunner(mock_generator)
# Create a chat_completion task
fake_params = {"model": "test"}
fake_messages = [{"role": "user", "content": "test"}]
task = ("chat_completion", [fake_params, fake_messages])
# Execute task
runner(task)
# Verify chat_completion was called with correct arguments
mock_generator.chat_completion.assert_called_once_with(fake_params, fake_messages)
def test_invalid_task_type_raises_error(self):
"""Verify ModelRunner rejects invalid task types."""
mock_generator = Mock()
runner = ModelRunner(mock_generator)
with pytest.raises(ValueError, match="Unexpected task type"):
runner(("invalid_task", []))

View file

@ -9,10 +9,9 @@ from unittest.mock import patch
import pytest
from llama_stack.apis.datasets import Dataset, DatasetPurpose, URIDataSource
from llama_stack.apis.resource import ResourceType
from llama_stack.providers.remote.datasetio.nvidia.config import NvidiaDatasetIOConfig
from llama_stack.providers.remote.datasetio.nvidia.datasetio import NvidiaDatasetIOAdapter
from llama_stack_api import Dataset, DatasetPurpose, ResourceType, URIDataSource
@pytest.fixture

View file

@ -9,14 +9,20 @@ from unittest.mock import MagicMock, patch
import pytest
from llama_stack.apis.benchmarks import Benchmark
from llama_stack.apis.common.job_types import Job, JobStatus
from llama_stack.apis.eval.eval import BenchmarkConfig, EvaluateResponse, ModelCandidate, SamplingParams
from llama_stack.apis.inference.inference import TopPSamplingStrategy
from llama_stack.apis.resource import ResourceType
from llama_stack.models.llama.sku_types import CoreModelId
from llama_stack.providers.remote.eval.nvidia.config import NVIDIAEvalConfig
from llama_stack.providers.remote.eval.nvidia.eval import NVIDIAEvalImpl
from llama_stack_api import (
Benchmark,
BenchmarkConfig,
EvaluateResponse,
Job,
JobStatus,
ModelCandidate,
ResourceType,
SamplingParams,
TopPSamplingStrategy,
)
MOCK_DATASET_ID = "default/test-dataset"
MOCK_BENCHMARK_ID = "test-benchmark"

View file

@ -10,7 +10,12 @@ from unittest.mock import patch
import pytest
from llama_stack.apis.post_training.post_training import (
from llama_stack.core.library_client import convert_pydantic_to_json_value
from llama_stack.providers.remote.post_training.nvidia.post_training import (
NvidiaPostTrainingAdapter,
NvidiaPostTrainingConfig,
)
from llama_stack_api import (
DataConfig,
DatasetFormat,
EfficiencyConfig,
@ -19,11 +24,6 @@ from llama_stack.apis.post_training.post_training import (
OptimizerType,
TrainingConfig,
)
from llama_stack.core.library_client import convert_pydantic_to_json_value
from llama_stack.providers.remote.post_training.nvidia.post_training import (
NvidiaPostTrainingAdapter,
NvidiaPostTrainingConfig,
)
class TestNvidiaParameters:

View file

@ -9,10 +9,10 @@ from unittest.mock import AsyncMock, MagicMock, patch
import aiohttp
import pytest
from llama_stack.apis.models import ModelType
from llama_stack.providers.remote.inference.nvidia.config import NVIDIAConfig
from llama_stack.providers.remote.inference.nvidia.nvidia import NVIDIAInferenceAdapter
from llama_stack.providers.utils.inference.openai_mixin import OpenAIMixin
from llama_stack_api import ModelType
class MockResponse:
@ -146,7 +146,7 @@ async def test_hosted_model_not_in_endpoint_mapping():
async def test_self_hosted_ignores_endpoint():
adapter = create_adapter(
config=NVIDIAConfig(url="http://localhost:8000", api_key=None),
config=NVIDIAConfig(base_url="http://localhost:8000", api_key=None),
rerank_endpoints={"test-model": "https://model.endpoint/rerank"}, # This should be ignored for self-hosted.
)
mock_session = MockSession(MockResponse())

View file

@ -10,13 +10,16 @@ from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from llama_stack.apis.inference import CompletionMessage, UserMessage
from llama_stack.apis.resource import ResourceType
from llama_stack.apis.safety import RunShieldResponse, ViolationLevel
from llama_stack.apis.shields import Shield
from llama_stack.models.llama.datatypes import StopReason
from llama_stack.providers.remote.safety.nvidia.config import NVIDIASafetyConfig
from llama_stack.providers.remote.safety.nvidia.nvidia import NVIDIASafetyAdapter
from llama_stack_api import (
OpenAIAssistantMessageParam,
OpenAIUserMessageParam,
ResourceType,
RunShieldResponse,
Shield,
ViolationLevel,
)
class FakeNVIDIASafetyAdapter(NVIDIASafetyAdapter):
@ -136,11 +139,9 @@ async def test_run_shield_allowed(nvidia_adapter, mock_guardrails_post):
# Run the shield
messages = [
UserMessage(role="user", content="Hello, how are you?"),
CompletionMessage(
role="assistant",
OpenAIUserMessageParam(content="Hello, how are you?"),
OpenAIAssistantMessageParam(
content="I'm doing well, thank you for asking!",
stop_reason=StopReason.end_of_message,
tool_calls=[],
),
]
@ -191,13 +192,10 @@ async def test_run_shield_blocked(nvidia_adapter, mock_guardrails_post):
# Mock Guardrails API response
mock_guardrails_post.return_value = {"status": "blocked", "rails_status": {"reason": "harmful_content"}}
# Run the shield
messages = [
UserMessage(role="user", content="Hello, how are you?"),
CompletionMessage(
role="assistant",
OpenAIUserMessageParam(content="Hello, how are you?"),
OpenAIAssistantMessageParam(
content="I'm doing well, thank you for asking!",
stop_reason=StopReason.end_of_message,
tool_calls=[],
),
]
@ -243,7 +241,7 @@ async def test_run_shield_not_found(nvidia_adapter, mock_guardrails_post):
adapter.shield_store.get_shield.return_value = None
messages = [
UserMessage(role="user", content="Hello, how are you?"),
OpenAIUserMessageParam(content="Hello, how are you?"),
]
with pytest.raises(ValueError):
@ -274,11 +272,9 @@ async def test_run_shield_http_error(nvidia_adapter, mock_guardrails_post):
# Running the shield should raise an exception
messages = [
UserMessage(role="user", content="Hello, how are you?"),
CompletionMessage(
role="assistant",
OpenAIUserMessageParam(content="Hello, how are you?"),
OpenAIAssistantMessageParam(
content="I'm doing well, thank you for asking!",
stop_reason=StopReason.end_of_message,
tool_calls=[],
),
]

View file

@ -10,15 +10,6 @@ from unittest.mock import patch
import pytest
from llama_stack.apis.post_training.post_training import (
DataConfig,
DatasetFormat,
LoraFinetuningConfig,
OptimizerConfig,
OptimizerType,
QATFinetuningConfig,
TrainingConfig,
)
from llama_stack.core.library_client import convert_pydantic_to_json_value
from llama_stack.providers.remote.post_training.nvidia.post_training import (
ListNvidiaPostTrainingJobs,
@ -27,6 +18,15 @@ from llama_stack.providers.remote.post_training.nvidia.post_training import (
NvidiaPostTrainingJob,
NvidiaPostTrainingJobStatusResponse,
)
from llama_stack_api import (
DataConfig,
DatasetFormat,
LoraFinetuningConfig,
OptimizerConfig,
OptimizerType,
QATFinetuningConfig,
TrainingConfig,
)
@pytest.fixture

View file

@ -4,50 +4,66 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from llama_stack.providers.remote.inference.bedrock.bedrock import (
_get_region_prefix,
_to_inference_profile_id,
)
from types import SimpleNamespace
from unittest.mock import AsyncMock, PropertyMock, patch
from llama_stack.providers.remote.inference.bedrock.bedrock import BedrockInferenceAdapter
from llama_stack.providers.remote.inference.bedrock.config import BedrockConfig
from llama_stack_api import OpenAIChatCompletionRequestWithExtraBody
def test_region_prefixes():
assert _get_region_prefix("us-east-1") == "us."
assert _get_region_prefix("eu-west-1") == "eu."
assert _get_region_prefix("ap-south-1") == "ap."
assert _get_region_prefix("ca-central-1") == "us."
def test_can_create_adapter():
config = BedrockConfig(api_key="test-key", region_name="us-east-1")
adapter = BedrockInferenceAdapter(config=config)
# Test case insensitive
assert _get_region_prefix("US-EAST-1") == "us."
assert _get_region_prefix("EU-WEST-1") == "eu."
assert _get_region_prefix("Ap-South-1") == "ap."
# Test None region
assert _get_region_prefix(None) == "us."
assert adapter is not None
assert adapter.config.region_name == "us-east-1"
assert adapter.get_api_key() == "test-key"
def test_model_id_conversion():
# Basic conversion
assert (
_to_inference_profile_id("meta.llama3-1-70b-instruct-v1:0", "us-east-1") == "us.meta.llama3-1-70b-instruct-v1:0"
def test_different_aws_regions():
# just check a couple regions to verify URL construction works
config = BedrockConfig(api_key="key", region_name="us-east-1")
adapter = BedrockInferenceAdapter(config=config)
assert adapter.get_base_url() == "https://bedrock-runtime.us-east-1.amazonaws.com/openai/v1"
config = BedrockConfig(api_key="key", region_name="eu-west-1")
adapter = BedrockInferenceAdapter(config=config)
assert adapter.get_base_url() == "https://bedrock-runtime.eu-west-1.amazonaws.com/openai/v1"
async def test_basic_chat_completion():
"""Test basic chat completion works with OpenAIMixin"""
config = BedrockConfig(api_key="k", region_name="us-east-1")
adapter = BedrockInferenceAdapter(config=config)
class FakeModelStore:
async def has_model(self, model_id):
return True
async def get_model(self, model_id):
return SimpleNamespace(provider_resource_id="meta.llama3-1-8b-instruct-v1:0")
adapter.model_store = FakeModelStore()
fake_response = SimpleNamespace(
id="chatcmpl-123",
choices=[SimpleNamespace(message=SimpleNamespace(content="Hello!", role="assistant"), finish_reason="stop")],
)
# Already has prefix
assert (
_to_inference_profile_id("us.meta.llama3-1-70b-instruct-v1:0", "us-east-1")
== "us.meta.llama3-1-70b-instruct-v1:0"
)
mock_create = AsyncMock(return_value=fake_response)
# ARN should be returned unchanged
arn = "arn:aws:bedrock:us-east-1:123456789012:inference-profile/us.meta.llama3-1-70b-instruct-v1:0"
assert _to_inference_profile_id(arn, "us-east-1") == arn
class FakeClient:
def __init__(self):
self.chat = SimpleNamespace(completions=SimpleNamespace(create=mock_create))
# ARN should be returned unchanged even without region
assert _to_inference_profile_id(arn) == arn
with patch.object(type(adapter), "client", new_callable=PropertyMock, return_value=FakeClient()):
params = OpenAIChatCompletionRequestWithExtraBody(
model="llama3-1-8b",
messages=[{"role": "user", "content": "hello"}],
stream=False,
)
response = await adapter.openai_chat_completion(params=params)
# Optional region parameter defaults to us-east-1
assert _to_inference_profile_id("meta.llama3-1-70b-instruct-v1:0") == "us.meta.llama3-1-70b-instruct-v1:0"
# Different regions work with optional parameter
assert (
_to_inference_profile_id("meta.llama3-1-70b-instruct-v1:0", "eu-west-1") == "eu.meta.llama3-1-70b-instruct-v1:0"
)
assert response.id == "chatcmpl-123"
assert mock_create.await_count == 1

View file

@ -4,8 +4,10 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import get_args, get_origin
import pytest
from pydantic import BaseModel
from pydantic import BaseModel, HttpUrl
from llama_stack.core.distribution import get_provider_registry, providable_apis
from llama_stack.core.utils.dynamic import instantiate_class_type
@ -41,3 +43,55 @@ class TestProviderConfigurations:
sample_config = config_type.sample_run_config(__distro_dir__="foobarbaz")
assert isinstance(sample_config, dict), f"{config_class_name}.sample_run_config() did not return a dict"
def test_remote_inference_url_standardization(self):
"""Verify all remote inference providers use standardized base_url configuration."""
provider_registry = get_provider_registry()
inference_providers = provider_registry.get("inference", {})
# Filter for remote providers only
remote_providers = {k: v for k, v in inference_providers.items() if k.startswith("remote::")}
failures = []
for provider_type, provider_spec in remote_providers.items():
try:
config_class_name = provider_spec.config_class
config_type = instantiate_class_type(config_class_name)
# Check that config has base_url field (not url)
if hasattr(config_type, "model_fields"):
fields = config_type.model_fields
# Should NOT have 'url' field (old pattern)
if "url" in fields:
failures.append(
f"{provider_type}: Uses deprecated 'url' field instead of 'base_url'. "
f"Please rename to 'base_url' for consistency."
)
# Should have 'base_url' field with HttpUrl | None type
if "base_url" in fields:
field_info = fields["base_url"]
annotation = field_info.annotation
# Check if it's HttpUrl or HttpUrl | None
# get_origin() returns Union for (X | Y), None for plain types
# get_args() returns the types inside Union, e.g. (HttpUrl, NoneType)
is_valid = False
if get_origin(annotation) is not None: # It's a Union/Optional
if HttpUrl in get_args(annotation):
is_valid = True
elif annotation == HttpUrl: # Plain HttpUrl without | None
is_valid = True
if not is_valid:
failures.append(
f"{provider_type}: base_url field has incorrect type annotation. "
f"Expected 'HttpUrl | None', got '{annotation}'"
)
except Exception as e:
failures.append(f"{provider_type}: Error checking URL standardization: {str(e)}")
if failures:
pytest.fail("URL standardization violations found:\n" + "\n".join(f" - {f}" for f in failures))

View file

@ -1,220 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import pytest
from pydantic import ValidationError
from llama_stack.apis.common.content_types import TextContentItem
from llama_stack.apis.inference import (
CompletionMessage,
OpenAIAssistantMessageParam,
OpenAIChatCompletionContentPartImageParam,
OpenAIChatCompletionContentPartTextParam,
OpenAIDeveloperMessageParam,
OpenAIImageURL,
OpenAISystemMessageParam,
OpenAIToolMessageParam,
OpenAIUserMessageParam,
SystemMessage,
UserMessage,
)
from llama_stack.models.llama.datatypes import BuiltinTool, StopReason, ToolCall
from llama_stack.providers.utils.inference.openai_compat import (
convert_message_to_openai_dict,
convert_message_to_openai_dict_new,
openai_messages_to_messages,
)
async def test_convert_message_to_openai_dict():
message = UserMessage(content=[TextContentItem(text="Hello, world!")], role="user")
assert await convert_message_to_openai_dict(message) == {
"role": "user",
"content": [{"type": "text", "text": "Hello, world!"}],
}
# Test convert_message_to_openai_dict with a tool call
async def test_convert_message_to_openai_dict_with_tool_call():
message = CompletionMessage(
content="",
tool_calls=[ToolCall(call_id="123", tool_name="test_tool", arguments='{"foo": "bar"}')],
stop_reason=StopReason.end_of_turn,
)
openai_dict = await convert_message_to_openai_dict(message)
assert openai_dict == {
"role": "assistant",
"content": [{"type": "text", "text": ""}],
"tool_calls": [
{"id": "123", "type": "function", "function": {"name": "test_tool", "arguments": '{"foo": "bar"}'}}
],
}
async def test_convert_message_to_openai_dict_with_builtin_tool_call():
message = CompletionMessage(
content="",
tool_calls=[
ToolCall(
call_id="123",
tool_name=BuiltinTool.brave_search,
arguments='{"foo": "bar"}',
)
],
stop_reason=StopReason.end_of_turn,
)
openai_dict = await convert_message_to_openai_dict(message)
assert openai_dict == {
"role": "assistant",
"content": [{"type": "text", "text": ""}],
"tool_calls": [
{"id": "123", "type": "function", "function": {"name": "brave_search", "arguments": '{"foo": "bar"}'}}
],
}
async def test_openai_messages_to_messages_with_content_str():
openai_messages = [
OpenAISystemMessageParam(content="system message"),
OpenAIUserMessageParam(content="user message"),
OpenAIAssistantMessageParam(content="assistant message"),
]
llama_messages = openai_messages_to_messages(openai_messages)
assert len(llama_messages) == 3
assert isinstance(llama_messages[0], SystemMessage)
assert isinstance(llama_messages[1], UserMessage)
assert isinstance(llama_messages[2], CompletionMessage)
assert llama_messages[0].content == "system message"
assert llama_messages[1].content == "user message"
assert llama_messages[2].content == "assistant message"
async def test_openai_messages_to_messages_with_content_list():
openai_messages = [
OpenAISystemMessageParam(content=[OpenAIChatCompletionContentPartTextParam(text="system message")]),
OpenAIUserMessageParam(content=[OpenAIChatCompletionContentPartTextParam(text="user message")]),
OpenAIAssistantMessageParam(content=[OpenAIChatCompletionContentPartTextParam(text="assistant message")]),
]
llama_messages = openai_messages_to_messages(openai_messages)
assert len(llama_messages) == 3
assert isinstance(llama_messages[0], SystemMessage)
assert isinstance(llama_messages[1], UserMessage)
assert isinstance(llama_messages[2], CompletionMessage)
assert llama_messages[0].content[0].text == "system message"
assert llama_messages[1].content[0].text == "user message"
assert llama_messages[2].content[0].text == "assistant message"
@pytest.mark.parametrize(
"message_class,kwargs",
[
(OpenAISystemMessageParam, {}),
(OpenAIAssistantMessageParam, {}),
(OpenAIDeveloperMessageParam, {}),
(OpenAIUserMessageParam, {}),
(OpenAIToolMessageParam, {"tool_call_id": "call_123"}),
],
)
def test_message_accepts_text_string(message_class, kwargs):
"""Test that messages accept string text content."""
msg = message_class(content="Test message", **kwargs)
assert msg.content == "Test message"
@pytest.mark.parametrize(
"message_class,kwargs",
[
(OpenAISystemMessageParam, {}),
(OpenAIAssistantMessageParam, {}),
(OpenAIDeveloperMessageParam, {}),
(OpenAIUserMessageParam, {}),
(OpenAIToolMessageParam, {"tool_call_id": "call_123"}),
],
)
def test_message_accepts_text_list(message_class, kwargs):
"""Test that messages accept list of text content parts."""
content_list = [OpenAIChatCompletionContentPartTextParam(text="Test message")]
msg = message_class(content=content_list, **kwargs)
assert len(msg.content) == 1
assert msg.content[0].text == "Test message"
@pytest.mark.parametrize(
"message_class,kwargs",
[
(OpenAISystemMessageParam, {}),
(OpenAIAssistantMessageParam, {}),
(OpenAIDeveloperMessageParam, {}),
(OpenAIToolMessageParam, {"tool_call_id": "call_123"}),
],
)
def test_message_rejects_images(message_class, kwargs):
"""Test that system, assistant, developer, and tool messages reject image content."""
with pytest.raises(ValidationError):
message_class(
content=[
OpenAIChatCompletionContentPartImageParam(image_url=OpenAIImageURL(url="http://example.com/image.jpg"))
],
**kwargs,
)
def test_user_message_accepts_images():
"""Test that user messages accept image content (unlike other message types)."""
# List with images should work
msg = OpenAIUserMessageParam(
content=[
OpenAIChatCompletionContentPartTextParam(text="Describe this image:"),
OpenAIChatCompletionContentPartImageParam(image_url=OpenAIImageURL(url="http://example.com/image.jpg")),
]
)
assert len(msg.content) == 2
assert msg.content[0].text == "Describe this image:"
assert msg.content[1].image_url.url == "http://example.com/image.jpg"
async def test_convert_message_to_openai_dict_new_user_message():
"""Test convert_message_to_openai_dict_new with UserMessage."""
message = UserMessage(content="Hello, world!", role="user")
result = await convert_message_to_openai_dict_new(message)
assert result["role"] == "user"
assert result["content"] == "Hello, world!"
async def test_convert_message_to_openai_dict_new_completion_message_with_tool_calls():
"""Test convert_message_to_openai_dict_new with CompletionMessage containing tool calls."""
message = CompletionMessage(
content="I'll help you find the weather.",
tool_calls=[
ToolCall(
call_id="call_123",
tool_name="get_weather",
arguments='{"city": "Sligo"}',
)
],
stop_reason=StopReason.end_of_turn,
)
result = await convert_message_to_openai_dict_new(message)
# This would have failed with "Cannot instantiate typing.Union" before the fix
assert result["role"] == "assistant"
assert result["content"] == "I'll help you find the weather."
assert "tool_calls" in result
assert result["tool_calls"] is not None
assert len(result["tool_calls"]) == 1
tool_call = result["tool_calls"][0]
assert tool_call.id == "call_123"
assert tool_call.type == "function"
assert tool_call.function.name == "get_weather"
assert tool_call.function.arguments == '{"city": "Sligo"}'

View file

@ -12,11 +12,17 @@ from unittest.mock import AsyncMock, MagicMock, Mock, PropertyMock, patch
import pytest
from pydantic import BaseModel, Field
from llama_stack.apis.inference import Model, OpenAIChatCompletionRequestWithExtraBody, OpenAIUserMessageParam
from llama_stack.apis.models import ModelType
from llama_stack.core.request_headers import request_provider_data_context
from llama_stack.providers.utils.inference.model_registry import RemoteInferenceProviderConfig
from llama_stack.providers.utils.inference.openai_mixin import OpenAIMixin
from llama_stack_api import (
Model,
ModelType,
OpenAIChatCompletionRequestWithExtraBody,
OpenAICompletionRequestWithExtraBody,
OpenAIEmbeddingsRequestWithExtraBody,
OpenAIUserMessageParam,
)
class OpenAIMixinImpl(OpenAIMixin):
@ -835,3 +841,96 @@ class TestOpenAIMixinProviderDataApiKey:
error_message = str(exc_info.value)
assert "test_api_key" in error_message
assert "x-llamastack-provider-data" in error_message
class TestOpenAIMixinAllowedModelsInference:
"""Test cases for allowed_models enforcement during inference requests"""
async def test_inference_with_allowed_models(self, mixin, mock_client_context):
"""Test that all inference methods succeed with allowed models"""
mixin.config.allowed_models = ["gpt-4", "text-davinci-003", "text-embedding-ada-002"]
mock_client = MagicMock()
mock_client.chat.completions.create = AsyncMock(return_value=MagicMock())
mock_client.completions.create = AsyncMock(return_value=MagicMock())
mock_embedding_response = MagicMock()
mock_embedding_response.data = [MagicMock(embedding=[0.1, 0.2, 0.3])]
mock_embedding_response.usage = MagicMock(prompt_tokens=5, total_tokens=5)
mock_client.embeddings.create = AsyncMock(return_value=mock_embedding_response)
with mock_client_context(mixin, mock_client):
# Test chat completion
await mixin.openai_chat_completion(
OpenAIChatCompletionRequestWithExtraBody(
model="gpt-4", messages=[OpenAIUserMessageParam(role="user", content="Hello")]
)
)
mock_client.chat.completions.create.assert_called_once()
# Test completion
await mixin.openai_completion(
OpenAICompletionRequestWithExtraBody(model="text-davinci-003", prompt="Hello")
)
mock_client.completions.create.assert_called_once()
# Test embeddings
await mixin.openai_embeddings(
OpenAIEmbeddingsRequestWithExtraBody(model="text-embedding-ada-002", input="test text")
)
mock_client.embeddings.create.assert_called_once()
async def test_inference_with_disallowed_models(self, mixin, mock_client_context):
"""Test that all inference methods fail with disallowed models"""
mixin.config.allowed_models = ["gpt-4"]
mock_client = MagicMock()
with mock_client_context(mixin, mock_client):
# Test chat completion with disallowed model
with pytest.raises(ValueError, match="Model 'gpt-4-turbo' is not in the allowed models list"):
await mixin.openai_chat_completion(
OpenAIChatCompletionRequestWithExtraBody(
model="gpt-4-turbo", messages=[OpenAIUserMessageParam(role="user", content="Hello")]
)
)
# Test completion with disallowed model
with pytest.raises(ValueError, match="Model 'text-davinci-002' is not in the allowed models list"):
await mixin.openai_completion(
OpenAICompletionRequestWithExtraBody(model="text-davinci-002", prompt="Hello")
)
# Test embeddings with disallowed model
with pytest.raises(ValueError, match="Model 'text-embedding-3-large' is not in the allowed models list"):
await mixin.openai_embeddings(
OpenAIEmbeddingsRequestWithExtraBody(model="text-embedding-3-large", input="test text")
)
mock_client.chat.completions.create.assert_not_called()
mock_client.completions.create.assert_not_called()
mock_client.embeddings.create.assert_not_called()
async def test_inference_with_no_restrictions(self, mixin, mock_client_context):
"""Test that inference succeeds when allowed_models is None or empty list blocks all"""
# Test with None (no restrictions)
assert mixin.config.allowed_models is None
mock_client = MagicMock()
mock_client.chat.completions.create = AsyncMock(return_value=MagicMock())
with mock_client_context(mixin, mock_client):
await mixin.openai_chat_completion(
OpenAIChatCompletionRequestWithExtraBody(
model="any-model", messages=[OpenAIUserMessageParam(role="user", content="Hello")]
)
)
mock_client.chat.completions.create.assert_called_once()
# Test with empty list (blocks all models)
mixin.config.allowed_models = []
with mock_client_context(mixin, mock_client):
with pytest.raises(ValueError, match="Model 'gpt-4' is not in the allowed models list"):
await mixin.openai_chat_completion(
OpenAIChatCompletionRequestWithExtraBody(
model="gpt-4", messages=[OpenAIUserMessageParam(role="user", content="Hello")]
)
)

View file

@ -0,0 +1,32 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from llama_stack.models.llama.datatypes import RawTextItem
from llama_stack.providers.utils.inference.prompt_adapter import (
convert_openai_message_to_raw_message,
)
from llama_stack_api import OpenAIAssistantMessageParam, OpenAIUserMessageParam
class TestConvertOpenAIMessageToRawMessage:
"""Test conversion of OpenAI message types to RawMessage format."""
async def test_user_message_conversion(self):
msg = OpenAIUserMessageParam(role="user", content="Hello world")
raw_msg = await convert_openai_message_to_raw_message(msg)
assert raw_msg.role == "user"
assert isinstance(raw_msg.content, RawTextItem)
assert raw_msg.content.text == "Hello world"
async def test_assistant_message_conversion(self):
msg = OpenAIAssistantMessageParam(role="assistant", content="Hi there!")
raw_msg = await convert_openai_message_to_raw_message(msg)
assert raw_msg.role == "assistant"
assert isinstance(raw_msg.content, RawTextItem)
assert raw_msg.content.text == "Hi there!"
assert raw_msg.tool_calls == []

View file

@ -8,9 +8,8 @@ from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from llama_stack.apis.common.content_types import URL, TextContentItem
from llama_stack.apis.tools import RAGDocument
from llama_stack.providers.utils.memory.vector_store import content_from_data_and_mime_type, content_from_doc
from llama_stack_api import URL, RAGDocument, TextContentItem
async def test_content_from_doc_with_url():

View file

@ -35,8 +35,8 @@
import pytest
from llama_stack.apis.models import Model
from llama_stack.providers.utils.inference.model_registry import ModelRegistryHelper, ProviderModelEntry
from llama_stack_api import Model
@pytest.fixture

View file

@ -10,16 +10,15 @@ from unittest.mock import AsyncMock, MagicMock, patch
import numpy as np
import pytest
from llama_stack.apis.vector_io import Chunk, ChunkMetadata, QueryChunksResponse
from llama_stack.apis.vector_stores import VectorStore
from llama_stack.core.storage.datatypes import KVStoreReference, SqliteKVStoreConfig
from llama_stack.core.storage.kvstore import register_kvstore_backends
from llama_stack.providers.inline.vector_io.faiss.config import FaissVectorIOConfig
from llama_stack.providers.inline.vector_io.faiss.faiss import FaissIndex, FaissVectorIOAdapter
from llama_stack.providers.inline.vector_io.sqlite_vec import SQLiteVectorIOConfig
from llama_stack.providers.inline.vector_io.sqlite_vec.sqlite_vec import SQLiteVecIndex, SQLiteVecVectorIOAdapter
from llama_stack.providers.remote.vector_io.pgvector.config import PGVectorVectorIOConfig
from llama_stack.providers.remote.vector_io.pgvector.pgvector import PGVectorIndex, PGVectorVectorIOAdapter
from llama_stack.providers.utils.kvstore import register_kvstore_backends
from llama_stack_api import Chunk, ChunkMetadata, QueryChunksResponse, VectorStore
EMBEDDING_DIMENSION = 768
COLLECTION_PREFIX = "test_collection"
@ -280,7 +279,7 @@ async def pgvector_vec_adapter(unique_kvstore_config, mock_inference_api, embedd
) as mock_check_version:
mock_check_version.return_value = "0.5.1"
with patch("llama_stack.providers.utils.kvstore.kvstore_impl") as mock_kvstore_impl:
with patch("llama_stack.core.storage.kvstore.kvstore_impl") as mock_kvstore_impl:
mock_kvstore = AsyncMock()
mock_kvstore_impl.return_value = mock_kvstore

View file

@ -10,15 +10,12 @@ from unittest.mock import MagicMock, patch
import numpy as np
import pytest
from llama_stack.apis.files import Files
from llama_stack.apis.vector_io import Chunk, QueryChunksResponse
from llama_stack.apis.vector_stores import VectorStore
from llama_stack.providers.datatypes import HealthStatus
from llama_stack.providers.inline.vector_io.faiss.config import FaissVectorIOConfig
from llama_stack.providers.inline.vector_io.faiss.faiss import (
FaissIndex,
FaissVectorIOAdapter,
)
from llama_stack_api import Chunk, Files, HealthStatus, QueryChunksResponse, VectorStore
# This test is a unit test for the FaissVectorIOAdapter class. This should only contain
# tests which are specific to this class. More general (API-level) tests should be placed in

View file

@ -9,12 +9,12 @@ import asyncio
import numpy as np
import pytest
from llama_stack.apis.vector_io import Chunk, QueryChunksResponse
from llama_stack.providers.inline.vector_io.sqlite_vec.sqlite_vec import (
SQLiteVecIndex,
SQLiteVecVectorIOAdapter,
_create_sqlite_connection,
)
from llama_stack_api import Chunk, QueryChunksResponse
# This test is a unit test for the SQLiteVecVectorIOAdapter class. This should only contain
# tests which are specific to this class. More general (API-level) tests should be placed in

View file

@ -11,17 +11,17 @@ from unittest.mock import AsyncMock, patch
import numpy as np
import pytest
from llama_stack.apis.common.errors import VectorStoreNotFoundError
from llama_stack.apis.vector_io import (
from llama_stack.providers.inline.vector_io.sqlite_vec.sqlite_vec import VECTOR_DBS_PREFIX
from llama_stack_api import (
Chunk,
OpenAICreateVectorStoreFileBatchRequestWithExtraBody,
OpenAICreateVectorStoreRequestWithExtraBody,
QueryChunksResponse,
VectorStore,
VectorStoreChunkingStrategyAuto,
VectorStoreFileObject,
VectorStoreNotFoundError,
)
from llama_stack.apis.vector_stores import VectorStore
from llama_stack.providers.inline.vector_io.sqlite_vec.sqlite_vec import VECTOR_DBS_PREFIX
# This test is a unit test for the inline VectorIO providers. This should only contain
# tests which are specific to this class. More general (API-level) tests should be placed in
@ -92,6 +92,99 @@ async def test_persistence_across_adapter_restarts(vector_io_adapter):
await vector_io_adapter.shutdown()
async def test_vector_store_lazy_loading_from_kvstore(vector_io_adapter):
"""
Test that vector stores can be lazy-loaded from KV store when not in cache.
Verifies that clearing the cache doesn't break vector store access - they
can be loaded on-demand from persistent storage.
"""
await vector_io_adapter.initialize()
vector_store_id = f"lazy_load_test_{np.random.randint(1e6)}"
vector_store = VectorStore(
identifier=vector_store_id,
provider_id="test_provider",
embedding_model="test_model",
embedding_dimension=128,
)
await vector_io_adapter.register_vector_store(vector_store)
assert vector_store_id in vector_io_adapter.cache
vector_io_adapter.cache.clear()
assert vector_store_id not in vector_io_adapter.cache
loaded_index = await vector_io_adapter._get_and_cache_vector_store_index(vector_store_id)
assert loaded_index is not None
assert loaded_index.vector_store.identifier == vector_store_id
assert vector_store_id in vector_io_adapter.cache
cached_index = await vector_io_adapter._get_and_cache_vector_store_index(vector_store_id)
assert cached_index is loaded_index
await vector_io_adapter.shutdown()
async def test_vector_store_preloading_on_initialization(vector_io_adapter):
"""
Test that vector stores are preloaded from KV store during initialization.
Verifies that after restart, all vector stores are automatically loaded into
cache and immediately accessible without requiring lazy loading.
"""
await vector_io_adapter.initialize()
vector_store_ids = [f"preload_test_{i}_{np.random.randint(1e6)}" for i in range(3)]
for vs_id in vector_store_ids:
vector_store = VectorStore(
identifier=vs_id,
provider_id="test_provider",
embedding_model="test_model",
embedding_dimension=128,
)
await vector_io_adapter.register_vector_store(vector_store)
for vs_id in vector_store_ids:
assert vs_id in vector_io_adapter.cache
await vector_io_adapter.shutdown()
await vector_io_adapter.initialize()
for vs_id in vector_store_ids:
assert vs_id in vector_io_adapter.cache
for vs_id in vector_store_ids:
loaded_index = await vector_io_adapter._get_and_cache_vector_store_index(vs_id)
assert loaded_index is not None
assert loaded_index.vector_store.identifier == vs_id
await vector_io_adapter.shutdown()
async def test_kvstore_none_raises_runtime_error(vector_io_adapter):
"""
Test that accessing vector stores with uninitialized kvstore raises RuntimeError.
Verifies proper RuntimeError is raised instead of assertions when kvstore is None.
"""
await vector_io_adapter.initialize()
vector_store_id = f"kvstore_none_test_{np.random.randint(1e6)}"
vector_store = VectorStore(
identifier=vector_store_id,
provider_id="test_provider",
embedding_model="test_model",
embedding_dimension=128,
)
await vector_io_adapter.register_vector_store(vector_store)
vector_io_adapter.cache.clear()
vector_io_adapter.kvstore = None
with pytest.raises(RuntimeError, match="KVStore not initialized"):
await vector_io_adapter._get_and_cache_vector_store_index(vector_store_id)
async def test_register_and_unregister_vector_store(vector_io_adapter):
unique_id = f"foo_db_{np.random.randint(1e6)}"
dummy = VectorStore(
@ -129,7 +222,7 @@ async def test_insert_chunks_missing_db_raises(vector_io_adapter):
async def test_insert_chunks_with_missing_document_id(vector_io_adapter):
"""Ensure no KeyError when document_id is missing or in different places."""
from llama_stack.apis.vector_io import Chunk, ChunkMetadata
from llama_stack_api import Chunk, ChunkMetadata
fake_index = AsyncMock()
vector_io_adapter.cache["db1"] = fake_index
@ -162,10 +255,9 @@ async def test_insert_chunks_with_missing_document_id(vector_io_adapter):
async def test_document_id_with_invalid_type_raises_error():
"""Ensure TypeError is raised when document_id is not a string."""
from llama_stack.apis.vector_io import Chunk
# Integer document_id should raise TypeError
from llama_stack.providers.utils.vector_io.vector_utils import generate_chunk_id
from llama_stack_api import Chunk
chunk = Chunk(content="test", chunk_id=generate_chunk_id("test", "test"), metadata={"document_id": 12345})
with pytest.raises(TypeError) as exc_info:

View file

@ -4,8 +4,8 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from llama_stack.apis.vector_io import Chunk, ChunkMetadata
from llama_stack.providers.utils.vector_io.vector_utils import generate_chunk_id
from llama_stack_api import Chunk, ChunkMetadata, VectorStoreFileObject
# This test is a unit test for the chunk_utils.py helpers. This should only contain
# tests which are specific to this file. More general (API-level) tests should be placed in
@ -78,3 +78,77 @@ def test_chunk_serialization():
serialized_chunk = chunk.model_dump()
assert serialized_chunk["chunk_id"] == "test-chunk-id"
assert "chunk_id" in serialized_chunk
def test_vector_store_file_object_attributes_validation():
"""Test VectorStoreFileObject validates and sanitizes attributes at input boundary."""
# Test with metadata containing lists, nested dicts, and primitives
from llama_stack_api.vector_io import VectorStoreChunkingStrategyAuto
file_obj = VectorStoreFileObject(
id="file-123",
attributes={
"tags": ["transformers", "h100-compatible", "region:us"], # List -> string
"model_name": "granite-3.3-8b", # String preserved
"score": 0.95, # Float preserved
"active": True, # Bool preserved
"count": 42, # Int -> float
"nested": {"key": "value"}, # Dict filtered out
},
chunking_strategy=VectorStoreChunkingStrategyAuto(),
created_at=1234567890,
status="completed",
vector_store_id="vs-123",
)
# Lists converted to comma-separated strings
assert file_obj.attributes["tags"] == "transformers, h100-compatible, region:us"
# Primitives preserved
assert file_obj.attributes["model_name"] == "granite-3.3-8b"
assert file_obj.attributes["score"] == 0.95
assert file_obj.attributes["active"] is True
assert file_obj.attributes["count"] == 42.0 # int -> float
# Complex types filtered out
assert "nested" not in file_obj.attributes
def test_vector_store_file_object_attributes_constraints():
"""Test VectorStoreFileObject enforces OpenAPI constraints on attributes."""
from llama_stack_api.vector_io import VectorStoreChunkingStrategyAuto
# Test max 16 properties
many_attrs = {f"key{i}": f"value{i}" for i in range(20)}
file_obj = VectorStoreFileObject(
id="file-123",
attributes=many_attrs,
chunking_strategy=VectorStoreChunkingStrategyAuto(),
created_at=1234567890,
status="completed",
vector_store_id="vs-123",
)
assert len(file_obj.attributes) == 16 # Max 16 properties
# Test max 64 char keys are filtered
long_key_attrs = {"a" * 65: "value", "valid_key": "value"}
file_obj = VectorStoreFileObject(
id="file-124",
attributes=long_key_attrs,
chunking_strategy=VectorStoreChunkingStrategyAuto(),
created_at=1234567890,
status="completed",
vector_store_id="vs-123",
)
assert "a" * 65 not in file_obj.attributes
assert "valid_key" in file_obj.attributes
# Test max 512 char string values are truncated
long_value_attrs = {"key": "x" * 600}
file_obj = VectorStoreFileObject(
id="file-125",
attributes=long_value_attrs,
chunking_strategy=VectorStoreChunkingStrategyAuto(),
created_at=1234567890,
status="completed",
vector_store_id="vs-123",
)
assert len(file_obj.attributes["key"]) == 512

View file

@ -8,13 +8,8 @@ from unittest.mock import AsyncMock, MagicMock
import pytest
from llama_stack.apis.tools.rag_tool import RAGQueryConfig
from llama_stack.apis.vector_io import (
Chunk,
ChunkMetadata,
QueryChunksResponse,
)
from llama_stack.providers.inline.tool_runtime.rag.memory import MemoryToolRuntimeImpl
from llama_stack_api import Chunk, ChunkMetadata, QueryChunksResponse, RAGQueryConfig
class TestRagQuery:

View file

@ -13,12 +13,6 @@ from unittest.mock import AsyncMock, MagicMock
import numpy as np
import pytest
from llama_stack.apis.inference.inference import (
OpenAIEmbeddingData,
OpenAIEmbeddingsRequestWithExtraBody,
)
from llama_stack.apis.tools import RAGDocument
from llama_stack.apis.vector_io import Chunk
from llama_stack.providers.utils.memory.vector_store import (
URL,
VectorStoreWithIndex,
@ -27,6 +21,7 @@ from llama_stack.providers.utils.memory.vector_store import (
make_overlapped_chunks,
)
from llama_stack.providers.utils.vector_io.vector_utils import generate_chunk_id
from llama_stack_api import Chunk, OpenAIEmbeddingData, OpenAIEmbeddingsRequestWithExtraBody, RAGDocument
DUMMY_PDF_PATH = Path(os.path.abspath(__file__)).parent / "fixtures" / "dummy.pdf"
# Depending on the machine, this can get parsed a couple of ways

View file

@ -7,16 +7,15 @@
import pytest
from llama_stack.apis.inference import Model
from llama_stack.apis.vector_stores import VectorStore
from llama_stack.core.datatypes import VectorStoreWithOwner
from llama_stack.core.storage.datatypes import KVStoreReference, SqliteKVStoreConfig
from llama_stack.core.storage.kvstore import kvstore_impl, register_kvstore_backends
from llama_stack.core.store.registry import (
KEY_FORMAT,
CachedDiskDistributionRegistry,
DiskDistributionRegistry,
)
from llama_stack.providers.utils.kvstore import kvstore_impl, register_kvstore_backends
from llama_stack_api import Model, VectorStore
@pytest.fixture
@ -304,8 +303,8 @@ async def test_double_registration_different_objects(disk_dist_registry):
async def test_double_registration_with_cache(cached_disk_dist_registry):
"""Test double registration behavior with caching enabled."""
from llama_stack.apis.models import ModelType
from llama_stack.core.datatypes import ModelWithOwner
from llama_stack_api import ModelType
model1 = ModelWithOwner(
identifier="test_model",

View file

@ -5,9 +5,9 @@
# the root directory of this source tree.
from llama_stack.apis.models import ModelType
from llama_stack.core.datatypes import ModelWithOwner, User
from llama_stack.core.store.registry import CachedDiskDistributionRegistry
from llama_stack_api import ModelType
async def test_registry_cache_with_acl(cached_disk_dist_registry):

View file

@ -10,11 +10,10 @@ import pytest
import yaml
from pydantic import TypeAdapter, ValidationError
from llama_stack.apis.datatypes import Api
from llama_stack.apis.models import ModelType
from llama_stack.core.access_control.access_control import AccessDeniedError, is_action_allowed
from llama_stack.core.datatypes import AccessRule, ModelWithOwner, User
from llama_stack.core.routing_tables.models import ModelsRoutingTable
from llama_stack_api import Api, ModelType
class AsyncMock(MagicMock):

View file

@ -144,7 +144,7 @@ def middleware_with_mocks(mock_auth_endpoint):
middleware = AuthenticationMiddleware(mock_app, auth_config, {})
# Mock the route_impls to simulate finding routes with required scopes
from llama_stack.schema_utils import WebMethod
from llama_stack_api import WebMethod
routes = {
("POST", "/test/scoped"): WebMethod(route="/test/scoped", method="POST", required_scope="test.read"),

View file

@ -15,7 +15,7 @@ from starlette.middleware.base import BaseHTTPMiddleware
from llama_stack.core.datatypes import QuotaConfig, QuotaPeriod
from llama_stack.core.server.quota import QuotaMiddleware
from llama_stack.core.storage.datatypes import KVStoreReference, SqliteKVStoreConfig
from llama_stack.providers.utils.kvstore import register_kvstore_backends
from llama_stack.core.storage.kvstore import register_kvstore_backends
@pytest.fixture

View file

@ -11,7 +11,6 @@ from unittest.mock import AsyncMock, MagicMock
from pydantic import BaseModel, Field
from llama_stack.apis.inference import Inference
from llama_stack.core.datatypes import Api, Provider, StackRunConfig
from llama_stack.core.resolver import resolve_impls
from llama_stack.core.routers.inference import InferenceRouter
@ -25,9 +24,9 @@ from llama_stack.core.storage.datatypes import (
SqlStoreReference,
StorageConfig,
)
from llama_stack.providers.datatypes import InlineProviderSpec, ProviderSpec
from llama_stack.providers.utils.kvstore import register_kvstore_backends
from llama_stack.providers.utils.sqlstore.sqlstore import register_sqlstore_backends
from llama_stack.core.storage.kvstore import register_kvstore_backends
from llama_stack.core.storage.sqlstore.sqlstore import register_sqlstore_backends
from llama_stack_api import Inference, InlineProviderSpec, ProviderSpec
def add_protocol_methods(cls: type, protocol: type[Protocol]) -> None:

View file

@ -0,0 +1,48 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from pydantic import BaseModel
from llama_stack_api import Conversation, SamplingStrategy
from llama_stack_api.schema_utils import (
clear_dynamic_schema_types,
get_registered_schema_info,
iter_dynamic_schema_types,
iter_json_schema_types,
iter_registered_schema_types,
register_dynamic_schema_type,
)
def test_json_schema_registry_contains_known_model() -> None:
assert Conversation in iter_json_schema_types()
def test_registered_schema_registry_contains_sampling_strategy() -> None:
registered_names = {info.name for info in iter_registered_schema_types()}
assert "SamplingStrategy" in registered_names
schema_info = get_registered_schema_info(SamplingStrategy)
assert schema_info is not None
assert schema_info.name == "SamplingStrategy"
def test_dynamic_schema_registration_round_trip() -> None:
existing_models = tuple(iter_dynamic_schema_types())
clear_dynamic_schema_types()
try:
class TemporaryModel(BaseModel):
foo: str
register_dynamic_schema_type(TemporaryModel)
assert TemporaryModel in iter_dynamic_schema_types()
clear_dynamic_schema_types()
assert TemporaryModel not in iter_dynamic_schema_types()
finally:
for model in existing_models:
register_dynamic_schema_type(model)

View file

@ -12,7 +12,7 @@ from pydantic import ValidationError
from llama_stack.core.access_control.access_control import AccessDeniedError
from llama_stack.core.datatypes import AuthenticationRequiredError
from llama_stack.core.server.server import translate_exception
from llama_stack.core.server.server import remove_disabled_providers, translate_exception
class TestTranslateException:
@ -194,3 +194,70 @@ class TestTranslateException:
assert isinstance(result3, HTTPException)
assert result3.status_code == 403
assert result3.detail == "Permission denied: Access denied"
class TestRemoveDisabledProviders:
"""Test cases for the remove_disabled_providers function."""
def test_remove_explicitly_disabled_provider(self):
"""Test that providers with provider_id='__disabled__' are removed."""
config = {
"providers": {
"inference": [
{"provider_id": "openai", "provider_type": "remote::openai", "config": {}},
{"provider_id": "__disabled__", "provider_type": "remote::vllm", "config": {}},
]
}
}
result = remove_disabled_providers(config)
assert len(result["providers"]["inference"]) == 1
assert result["providers"]["inference"][0]["provider_id"] == "openai"
def test_remove_empty_provider_id(self):
"""Test that providers with empty provider_id are removed."""
config = {
"providers": {
"inference": [
{"provider_id": "openai", "provider_type": "remote::openai", "config": {}},
{"provider_id": "", "provider_type": "remote::vllm", "config": {}},
]
}
}
result = remove_disabled_providers(config)
assert len(result["providers"]["inference"]) == 1
assert result["providers"]["inference"][0]["provider_id"] == "openai"
def test_keep_models_with_none_provider_model_id(self):
"""Test that models with None provider_model_id are NOT removed."""
config = {
"registered_resources": {
"models": [
{
"model_id": "llama-3-2-3b",
"provider_id": "vllm-inference",
"model_type": "llm",
"provider_model_id": None,
"metadata": {},
},
{
"model_id": "gpt-4o-mini",
"provider_id": "openai",
"model_type": "llm",
"provider_model_id": None,
"metadata": {},
},
{
"model_id": "granite-embedding-125m",
"provider_id": "sentence-transformers",
"model_type": "embedding",
"provider_model_id": "ibm-granite/granite-embedding-125m-english",
"metadata": {"embedding_dimension": 768},
},
]
}
}
result = remove_disabled_providers(config)
assert len(result["registered_resources"]["models"]) == 3
assert result["registered_resources"]["models"][0]["model_id"] == "llama-3-2-3b"
assert result["registered_resources"]["models"][1]["model_id"] == "gpt-4o-mini"
assert result["registered_resources"]["models"][2]["model_id"] == "granite-embedding-125m"

View file

@ -10,8 +10,8 @@ from unittest.mock import AsyncMock, MagicMock
import pytest
from llama_stack.apis.common.responses import PaginatedResponse
from llama_stack.core.server.server import create_dynamic_typed_route, create_sse_event, sse_generator
from llama_stack_api import PaginatedResponse
@pytest.fixture
@ -104,12 +104,18 @@ async def test_paginated_response_url_setting():
route_handler = create_dynamic_typed_route(mock_api_method, "get", "/test/route")
# Mock minimal request
# Mock minimal request with proper state object
request = MagicMock()
request.scope = {"user_attributes": {}, "principal": ""}
request.headers = {}
request.body = AsyncMock(return_value=b"")
# Create a simple state object without auto-generating attributes
class MockState:
pass
request.state = MockState()
result = await route_handler(request)
assert isinstance(result, PaginatedResponse)

View file

@ -11,8 +11,8 @@ Tests the new input_schema and output_schema fields.
from pydantic import ValidationError
from llama_stack.apis.tools import ToolDef
from llama_stack.models.llama.datatypes import BuiltinTool, ToolDefinition
from llama_stack_api import ToolDef
class TestToolDefValidation:

View file

@ -8,16 +8,16 @@ import time
import pytest
from llama_stack.apis.inference import (
from llama_stack.core.storage.datatypes import InferenceStoreReference, SqliteSqlStoreConfig
from llama_stack.core.storage.sqlstore.sqlstore import register_sqlstore_backends
from llama_stack.providers.utils.inference.inference_store import InferenceStore
from llama_stack_api import (
OpenAIAssistantMessageParam,
OpenAIChatCompletion,
OpenAIChoice,
OpenAIUserMessageParam,
Order,
)
from llama_stack.core.storage.datatypes import InferenceStoreReference, SqliteSqlStoreConfig
from llama_stack.providers.utils.inference.inference_store import InferenceStore
from llama_stack.providers.utils.sqlstore.sqlstore import register_sqlstore_backends
@pytest.fixture(autouse=True)

View file

@ -5,8 +5,8 @@
# the root directory of this source tree.
from llama_stack.providers.utils.kvstore.config import SqliteKVStoreConfig
from llama_stack.providers.utils.kvstore.sqlite.sqlite import SqliteKVStoreImpl
from llama_stack.core.storage.kvstore.config import SqliteKVStoreConfig
from llama_stack.core.storage.kvstore.sqlite.sqlite import SqliteKVStoreImpl
async def test_memory_kvstore_persistence_behavior():

View file

@ -10,15 +10,10 @@ from uuid import uuid4
import pytest
from llama_stack.apis.agents import Order
from llama_stack.apis.agents.openai_responses import (
OpenAIResponseInput,
OpenAIResponseObject,
)
from llama_stack.apis.inference import OpenAIMessageParam, OpenAIUserMessageParam
from llama_stack.core.storage.datatypes import ResponsesStoreReference, SqliteSqlStoreConfig
from llama_stack.core.storage.sqlstore.sqlstore import register_sqlstore_backends
from llama_stack.providers.utils.responses.responses_store import ResponsesStore
from llama_stack.providers.utils.sqlstore.sqlstore import register_sqlstore_backends
from llama_stack_api import OpenAIMessageParam, OpenAIResponseInput, OpenAIResponseObject, OpenAIUserMessageParam, Order
def build_store(db_path: str, policy: list | None = None) -> ResponsesStore:
@ -46,7 +41,7 @@ def create_test_response_object(
def create_test_response_input(content: str, input_id: str) -> OpenAIResponseInput:
"""Helper to create a test response input."""
from llama_stack.apis.agents.openai_responses import OpenAIResponseMessage
from llama_stack_api import OpenAIResponseMessage
return OpenAIResponseMessage(
id=input_id,

View file

@ -9,9 +9,9 @@ from tempfile import TemporaryDirectory
import pytest
from llama_stack.providers.utils.sqlstore.api import ColumnType
from llama_stack.providers.utils.sqlstore.sqlalchemy_sqlstore import SqlAlchemySqlStoreImpl
from llama_stack.providers.utils.sqlstore.sqlstore import SqliteSqlStoreConfig
from llama_stack.core.storage.sqlstore.sqlalchemy_sqlstore import SqlAlchemySqlStoreImpl
from llama_stack.core.storage.sqlstore.sqlstore import SqliteSqlStoreConfig
from llama_stack_api.internal.sqlstore import ColumnDefinition, ColumnType
async def test_sqlite_sqlstore():
@ -65,6 +65,38 @@ async def test_sqlite_sqlstore():
assert result.has_more is False
async def test_sqlstore_upsert_support():
with TemporaryDirectory() as tmp_dir:
db_path = tmp_dir + "/upsert.db"
store = SqlAlchemySqlStoreImpl(SqliteSqlStoreConfig(db_path=db_path))
await store.create_table(
"items",
{
"id": ColumnDefinition(type=ColumnType.STRING, primary_key=True),
"value": ColumnType.STRING,
"updated_at": ColumnType.INTEGER,
},
)
await store.upsert(
table="items",
data={"id": "item_1", "value": "first", "updated_at": 1},
conflict_columns=["id"],
)
row = await store.fetch_one("items", {"id": "item_1"})
assert row == {"id": "item_1", "value": "first", "updated_at": 1}
await store.upsert(
table="items",
data={"id": "item_1", "value": "second", "updated_at": 2},
conflict_columns=["id"],
update_columns=["value", "updated_at"],
)
row = await store.fetch_one("items", {"id": "item_1"})
assert row == {"id": "item_1", "value": "second", "updated_at": 2}
async def test_sqlstore_pagination_basic():
"""Test basic pagination functionality at the SQL store level."""
with TemporaryDirectory() as tmp_dir:

View file

@ -10,13 +10,13 @@ from unittest.mock import patch
from llama_stack.core.access_control.access_control import default_policy, is_action_allowed
from llama_stack.core.access_control.datatypes import Action
from llama_stack.core.datatypes import User
from llama_stack.providers.utils.sqlstore.api import ColumnType
from llama_stack.providers.utils.sqlstore.authorized_sqlstore import AuthorizedSqlStore, SqlRecord
from llama_stack.providers.utils.sqlstore.sqlalchemy_sqlstore import SqlAlchemySqlStoreImpl
from llama_stack.providers.utils.sqlstore.sqlstore import SqliteSqlStoreConfig
from llama_stack.core.storage.sqlstore.authorized_sqlstore import AuthorizedSqlStore, SqlRecord
from llama_stack.core.storage.sqlstore.sqlalchemy_sqlstore import SqlAlchemySqlStoreImpl
from llama_stack.core.storage.sqlstore.sqlstore import SqliteSqlStoreConfig
from llama_stack_api.internal.sqlstore import ColumnType
@patch("llama_stack.providers.utils.sqlstore.authorized_sqlstore.get_authenticated_user")
@patch("llama_stack.core.storage.sqlstore.authorized_sqlstore.get_authenticated_user")
async def test_authorized_fetch_with_where_sql_access_control(mock_get_authenticated_user):
"""Test that fetch_all works correctly with where_sql for access control"""
with TemporaryDirectory() as tmp_dir:
@ -78,7 +78,7 @@ async def test_authorized_fetch_with_where_sql_access_control(mock_get_authentic
assert row["title"] == "User Document"
@patch("llama_stack.providers.utils.sqlstore.authorized_sqlstore.get_authenticated_user")
@patch("llama_stack.core.storage.sqlstore.authorized_sqlstore.get_authenticated_user")
async def test_sql_policy_consistency(mock_get_authenticated_user):
"""Test that SQL WHERE clause logic exactly matches is_action_allowed policy logic"""
with TemporaryDirectory() as tmp_dir:
@ -164,7 +164,7 @@ async def test_sql_policy_consistency(mock_get_authenticated_user):
)
@patch("llama_stack.providers.utils.sqlstore.authorized_sqlstore.get_authenticated_user")
@patch("llama_stack.core.storage.sqlstore.authorized_sqlstore.get_authenticated_user")
async def test_authorized_store_user_attribute_capture(mock_get_authenticated_user):
"""Test that user attributes are properly captured during insert"""
with TemporaryDirectory() as tmp_dir: