mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-12-03 09:53:45 +00:00
Merge branch 'main' into feat/gunicorn-production-server
This commit is contained in:
commit
893d49c59e
2086 changed files with 133277 additions and 643859 deletions
|
|
@ -5,18 +5,7 @@
|
|||
# the root directory of this source tree.
|
||||
|
||||
|
||||
from llama_stack.apis.conversations.conversations import (
|
||||
Conversation,
|
||||
ConversationCreateRequest,
|
||||
ConversationItem,
|
||||
ConversationItemList,
|
||||
)
|
||||
|
||||
|
||||
def test_conversation_create_request_defaults():
|
||||
request = ConversationCreateRequest()
|
||||
assert request.items == []
|
||||
assert request.metadata == {}
|
||||
from llama_stack_api import Conversation, ConversationItem, ConversationItemList
|
||||
|
||||
|
||||
def test_conversation_model_defaults():
|
||||
|
|
|
|||
|
|
@ -12,10 +12,6 @@ from openai.types.conversations.conversation import Conversation as OpenAIConver
|
|||
from openai.types.conversations.conversation_item import ConversationItem as OpenAIConversationItem
|
||||
from pydantic import TypeAdapter
|
||||
|
||||
from llama_stack.apis.agents.openai_responses import (
|
||||
OpenAIResponseInputMessageContentText,
|
||||
OpenAIResponseMessage,
|
||||
)
|
||||
from llama_stack.core.conversations.conversations import (
|
||||
ConversationServiceConfig,
|
||||
ConversationServiceImpl,
|
||||
|
|
@ -27,7 +23,8 @@ from llama_stack.core.storage.datatypes import (
|
|||
SqlStoreReference,
|
||||
StorageConfig,
|
||||
)
|
||||
from llama_stack.providers.utils.sqlstore.sqlstore import register_sqlstore_backends
|
||||
from llama_stack.core.storage.sqlstore.sqlstore import register_sqlstore_backends
|
||||
from llama_stack_api import OpenAIResponseInputMessageContentText, OpenAIResponseMessage
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
|
|
@ -41,6 +38,9 @@ async def service():
|
|||
},
|
||||
stores=ServerStoresConfig(
|
||||
conversations=SqlStoreReference(backend="sql_test", table_name="openai_conversations"),
|
||||
metadata=None,
|
||||
inference=None,
|
||||
prompts=None,
|
||||
),
|
||||
)
|
||||
register_sqlstore_backends({"sql_test": storage.backends["sql_test"]})
|
||||
|
|
@ -145,6 +145,9 @@ async def test_policy_configuration():
|
|||
},
|
||||
stores=ServerStoresConfig(
|
||||
conversations=SqlStoreReference(backend="sql_test", table_name="openai_conversations"),
|
||||
metadata=None,
|
||||
inference=None,
|
||||
prompts=None,
|
||||
),
|
||||
)
|
||||
register_sqlstore_backends({"sql_test": storage.backends["sql_test"]})
|
||||
|
|
|
|||
|
|
@ -6,10 +6,9 @@
|
|||
|
||||
from unittest.mock import AsyncMock
|
||||
|
||||
from llama_stack.apis.safety.safety import ModerationObject, ModerationObjectResults
|
||||
from llama_stack.apis.shields import ListShieldsResponse, Shield
|
||||
from llama_stack.core.datatypes import SafetyConfig
|
||||
from llama_stack.core.routers.safety import SafetyRouter
|
||||
from llama_stack_api import ListShieldsResponse, ModerationObject, ModerationObjectResults, Shield
|
||||
|
||||
|
||||
async def test_run_moderation_uses_default_shield_when_model_missing():
|
||||
|
|
|
|||
|
|
@ -8,8 +8,13 @@ from unittest.mock import AsyncMock, Mock
|
|||
|
||||
import pytest
|
||||
|
||||
from llama_stack.apis.vector_io import OpenAICreateVectorStoreRequestWithExtraBody
|
||||
from llama_stack.core.routers.vector_io import VectorIORouter
|
||||
from llama_stack_api import (
|
||||
ModelNotFoundError,
|
||||
ModelType,
|
||||
ModelTypeError,
|
||||
OpenAICreateVectorStoreRequestWithExtraBody,
|
||||
)
|
||||
|
||||
|
||||
async def test_single_provider_auto_selection():
|
||||
|
|
@ -21,6 +26,7 @@ async def test_single_provider_auto_selection():
|
|||
Mock(identifier="all-MiniLM-L6-v2", model_type="embedding", metadata={"embedding_dimension": 384})
|
||||
]
|
||||
)
|
||||
mock_routing_table.get_object_by_identifier = AsyncMock(return_value=Mock(model_type=ModelType.embedding))
|
||||
mock_routing_table.register_vector_store = AsyncMock(
|
||||
return_value=Mock(identifier="vs_123", provider_id="inline::faiss", provider_resource_id="vs_123")
|
||||
)
|
||||
|
|
@ -48,6 +54,7 @@ async def test_create_vector_stores_multiple_providers_missing_provider_id_error
|
|||
Mock(identifier="all-MiniLM-L6-v2", model_type="embedding", metadata={"embedding_dimension": 384})
|
||||
]
|
||||
)
|
||||
mock_routing_table.get_object_by_identifier = AsyncMock(return_value=Mock(model_type=ModelType.embedding))
|
||||
router = VectorIORouter(mock_routing_table)
|
||||
request = OpenAICreateVectorStoreRequestWithExtraBody.model_validate(
|
||||
{"name": "test_store", "embedding_model": "all-MiniLM-L6-v2"}
|
||||
|
|
@ -55,3 +62,94 @@ async def test_create_vector_stores_multiple_providers_missing_provider_id_error
|
|||
|
||||
with pytest.raises(ValueError, match="Multiple vector_io providers available"):
|
||||
await router.openai_create_vector_store(request)
|
||||
|
||||
|
||||
async def test_update_vector_store_provider_id_change_fails():
|
||||
"""Test that updating a vector store with a different provider_id fails with clear error."""
|
||||
mock_routing_table = Mock()
|
||||
|
||||
# Mock an existing vector store with provider_id "faiss"
|
||||
mock_existing_store = Mock()
|
||||
mock_existing_store.provider_id = "inline::faiss"
|
||||
mock_existing_store.identifier = "vs_123"
|
||||
|
||||
mock_routing_table.get_object_by_identifier = AsyncMock(return_value=mock_existing_store)
|
||||
mock_routing_table.get_provider_impl = AsyncMock(
|
||||
return_value=Mock(openai_update_vector_store=AsyncMock(return_value=Mock(id="vs_123")))
|
||||
)
|
||||
|
||||
router = VectorIORouter(mock_routing_table)
|
||||
|
||||
# Try to update with different provider_id in metadata - this should fail
|
||||
with pytest.raises(ValueError, match="provider_id cannot be changed after vector store creation"):
|
||||
await router.openai_update_vector_store(
|
||||
vector_store_id="vs_123",
|
||||
name="updated_name",
|
||||
metadata={"provider_id": "inline::sqlite"}, # Different provider_id
|
||||
)
|
||||
|
||||
# Verify the existing store was looked up to check provider_id
|
||||
mock_routing_table.get_object_by_identifier.assert_called_once_with("vector_store", "vs_123")
|
||||
|
||||
# Provider should not be called since validation failed
|
||||
mock_routing_table.get_provider_impl.assert_not_called()
|
||||
|
||||
|
||||
async def test_update_vector_store_same_provider_id_succeeds():
|
||||
"""Test that updating a vector store with the same provider_id succeeds."""
|
||||
mock_routing_table = Mock()
|
||||
|
||||
# Mock an existing vector store with provider_id "faiss"
|
||||
mock_existing_store = Mock()
|
||||
mock_existing_store.provider_id = "inline::faiss"
|
||||
mock_existing_store.identifier = "vs_123"
|
||||
|
||||
mock_routing_table.get_object_by_identifier = AsyncMock(return_value=mock_existing_store)
|
||||
mock_routing_table.get_provider_impl = AsyncMock(
|
||||
return_value=Mock(openai_update_vector_store=AsyncMock(return_value=Mock(id="vs_123")))
|
||||
)
|
||||
|
||||
router = VectorIORouter(mock_routing_table)
|
||||
|
||||
# Update with same provider_id should succeed
|
||||
await router.openai_update_vector_store(
|
||||
vector_store_id="vs_123",
|
||||
name="updated_name",
|
||||
metadata={"provider_id": "inline::faiss"}, # Same provider_id
|
||||
)
|
||||
|
||||
# Verify the provider update method was called
|
||||
mock_routing_table.get_provider_impl.assert_called_once_with("vs_123")
|
||||
provider = await mock_routing_table.get_provider_impl("vs_123")
|
||||
provider.openai_update_vector_store.assert_called_once_with(
|
||||
vector_store_id="vs_123", name="updated_name", expires_after=None, metadata={"provider_id": "inline::faiss"}
|
||||
)
|
||||
|
||||
|
||||
async def test_create_vector_store_with_unknown_embedding_model_raises_error():
|
||||
"""Test that creating a vector store with an unknown embedding model raises
|
||||
FoundError."""
|
||||
mock_routing_table = Mock(impls_by_provider_id={"provider": "mock"})
|
||||
mock_routing_table.get_object_by_identifier = AsyncMock(return_value=None)
|
||||
|
||||
router = VectorIORouter(mock_routing_table)
|
||||
request = OpenAICreateVectorStoreRequestWithExtraBody.model_validate(
|
||||
{"embedding_model": "unknown-model", "embedding_dimension": 384}
|
||||
)
|
||||
|
||||
with pytest.raises(ModelNotFoundError, match="Model 'unknown-model' not found"):
|
||||
await router.openai_create_vector_store(request)
|
||||
|
||||
|
||||
async def test_create_vector_store_with_wrong_model_type_raises_error():
|
||||
"""Test that creating a vector store with a non-embedding model raises ModelTypeError."""
|
||||
mock_routing_table = Mock(impls_by_provider_id={"provider": "mock"})
|
||||
mock_routing_table.get_object_by_identifier = AsyncMock(return_value=Mock(model_type=ModelType.llm))
|
||||
|
||||
router = VectorIORouter(mock_routing_table)
|
||||
request = OpenAICreateVectorStoreRequestWithExtraBody.model_validate(
|
||||
{"embedding_model": "text-model", "embedding_dimension": 384}
|
||||
)
|
||||
|
||||
with pytest.raises(ModelTypeError, match="Model 'text-model' is of type"):
|
||||
await router.openai_create_vector_store(request)
|
||||
|
|
|
|||
|
|
@ -10,11 +10,10 @@ from unittest.mock import AsyncMock
|
|||
|
||||
import pytest
|
||||
|
||||
from llama_stack.apis.models import ListModelsResponse, Model, ModelType
|
||||
from llama_stack.apis.shields import ListShieldsResponse, Shield
|
||||
from llama_stack.core.datatypes import QualifiedModel, SafetyConfig, StackRunConfig, StorageConfig, VectorStoresConfig
|
||||
from llama_stack.core.datatypes import QualifiedModel, SafetyConfig, StackRunConfig, VectorStoresConfig
|
||||
from llama_stack.core.stack import validate_safety_config, validate_vector_stores_config
|
||||
from llama_stack.providers.datatypes import Api
|
||||
from llama_stack.core.storage.datatypes import ServerStoresConfig, StorageConfig
|
||||
from llama_stack_api import Api, ListModelsResponse, ListShieldsResponse, Model, ModelType, Shield
|
||||
|
||||
|
||||
class TestVectorStoresValidation:
|
||||
|
|
@ -23,7 +22,15 @@ class TestVectorStoresValidation:
|
|||
run_config = StackRunConfig(
|
||||
image_name="test",
|
||||
providers={},
|
||||
storage=StorageConfig(backends={}, stores={}),
|
||||
storage=StorageConfig(
|
||||
backends={},
|
||||
stores=ServerStoresConfig(
|
||||
metadata=None,
|
||||
inference=None,
|
||||
conversations=None,
|
||||
prompts=None,
|
||||
),
|
||||
),
|
||||
vector_stores=VectorStoresConfig(
|
||||
default_provider_id="faiss",
|
||||
default_embedding_model=QualifiedModel(
|
||||
|
|
@ -43,7 +50,15 @@ class TestVectorStoresValidation:
|
|||
run_config = StackRunConfig(
|
||||
image_name="test",
|
||||
providers={},
|
||||
storage=StorageConfig(backends={}, stores={}),
|
||||
storage=StorageConfig(
|
||||
backends={},
|
||||
stores=ServerStoresConfig(
|
||||
metadata=None,
|
||||
inference=None,
|
||||
conversations=None,
|
||||
prompts=None,
|
||||
),
|
||||
),
|
||||
vector_stores=VectorStoresConfig(
|
||||
default_provider_id="faiss",
|
||||
default_embedding_model=QualifiedModel(
|
||||
|
|
|
|||
|
|
@ -10,14 +10,6 @@ from unittest.mock import AsyncMock
|
|||
|
||||
import pytest
|
||||
|
||||
from llama_stack.apis.common.content_types import URL
|
||||
from llama_stack.apis.common.errors import ModelNotFoundError
|
||||
from llama_stack.apis.common.type_system import NumberType
|
||||
from llama_stack.apis.datasets.datasets import Dataset, DatasetPurpose, URIDataSource
|
||||
from llama_stack.apis.datatypes import Api
|
||||
from llama_stack.apis.models import Model, ModelType
|
||||
from llama_stack.apis.shields.shields import Shield
|
||||
from llama_stack.apis.tools import ListToolDefsResponse, ToolDef, ToolGroup
|
||||
from llama_stack.core.datatypes import RegistryEntrySource
|
||||
from llama_stack.core.routing_tables.benchmarks import BenchmarksRoutingTable
|
||||
from llama_stack.core.routing_tables.datasets import DatasetsRoutingTable
|
||||
|
|
@ -25,6 +17,21 @@ from llama_stack.core.routing_tables.models import ModelsRoutingTable
|
|||
from llama_stack.core.routing_tables.scoring_functions import ScoringFunctionsRoutingTable
|
||||
from llama_stack.core.routing_tables.shields import ShieldsRoutingTable
|
||||
from llama_stack.core.routing_tables.toolgroups import ToolGroupsRoutingTable
|
||||
from llama_stack_api import (
|
||||
URL,
|
||||
Api,
|
||||
Dataset,
|
||||
DatasetPurpose,
|
||||
ListToolDefsResponse,
|
||||
Model,
|
||||
ModelNotFoundError,
|
||||
ModelType,
|
||||
NumberType,
|
||||
Shield,
|
||||
ToolDef,
|
||||
ToolGroup,
|
||||
URIDataSource,
|
||||
)
|
||||
|
||||
|
||||
class Impl:
|
||||
|
|
@ -130,7 +137,7 @@ class ToolGroupsImpl(Impl):
|
|||
async def unregister_toolgroup(self, toolgroup_id: str):
|
||||
return toolgroup_id
|
||||
|
||||
async def list_runtime_tools(self, toolgroup_id, mcp_endpoint):
|
||||
async def list_runtime_tools(self, toolgroup_id, mcp_endpoint, authorization=None):
|
||||
return ListToolDefsResponse(
|
||||
data=[
|
||||
ToolDef(
|
||||
|
|
|
|||
|
|
@ -11,8 +11,15 @@ from unittest.mock import patch
|
|||
import pytest
|
||||
from openai import AsyncOpenAI
|
||||
|
||||
from llama_stack.testing.api_recorder import (
|
||||
APIRecordingMode,
|
||||
ResponseStorage,
|
||||
api_recording,
|
||||
normalize_inference_request,
|
||||
)
|
||||
|
||||
# Import the real Pydantic response types instead of using Mocks
|
||||
from llama_stack.apis.inference import (
|
||||
from llama_stack_api import (
|
||||
OpenAIAssistantMessageParam,
|
||||
OpenAIChatCompletion,
|
||||
OpenAIChoice,
|
||||
|
|
@ -20,12 +27,6 @@ from llama_stack.apis.inference import (
|
|||
OpenAIEmbeddingsResponse,
|
||||
OpenAIEmbeddingUsage,
|
||||
)
|
||||
from llama_stack.testing.api_recorder import (
|
||||
APIRecordingMode,
|
||||
ResponseStorage,
|
||||
api_recording,
|
||||
normalize_inference_request,
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
|
|
|
|||
|
|
@ -22,7 +22,7 @@ from llama_stack.core.storage.datatypes import (
|
|||
SqlStoreReference,
|
||||
StorageConfig,
|
||||
)
|
||||
from llama_stack.providers.datatypes import ProviderSpec
|
||||
from llama_stack_api import ProviderSpec
|
||||
|
||||
|
||||
class SampleConfig(BaseModel):
|
||||
|
|
@ -312,7 +312,7 @@ pip_packages:
|
|||
"""Test loading an external provider from a module (success path)."""
|
||||
from types import SimpleNamespace
|
||||
|
||||
from llama_stack.providers.datatypes import Api, ProviderSpec
|
||||
from llama_stack_api import Api, ProviderSpec
|
||||
|
||||
# Simulate a provider module with get_provider_spec
|
||||
fake_spec = ProviderSpec(
|
||||
|
|
@ -396,7 +396,7 @@ pip_packages:
|
|||
def test_external_provider_from_module_building(self, mock_providers):
|
||||
"""Test loading an external provider from a module during build (building=True, partial spec)."""
|
||||
from llama_stack.core.datatypes import BuildConfig, BuildProvider, DistributionSpec
|
||||
from llama_stack.providers.datatypes import Api
|
||||
from llama_stack_api import Api
|
||||
|
||||
# No importlib patch needed, should not import module when type of `config` is BuildConfig or DistributionSpec
|
||||
build_config = BuildConfig(
|
||||
|
|
@ -457,7 +457,7 @@ class TestGetExternalProvidersFromModule:
|
|||
from types import SimpleNamespace
|
||||
|
||||
from llama_stack.core.distribution import get_external_providers_from_module
|
||||
from llama_stack.providers.datatypes import ProviderSpec
|
||||
from llama_stack_api import ProviderSpec
|
||||
|
||||
fake_spec = ProviderSpec(
|
||||
api=Api.inference,
|
||||
|
|
@ -594,7 +594,7 @@ class TestGetExternalProvidersFromModule:
|
|||
from types import SimpleNamespace
|
||||
|
||||
from llama_stack.core.distribution import get_external_providers_from_module
|
||||
from llama_stack.providers.datatypes import ProviderSpec
|
||||
from llama_stack_api import ProviderSpec
|
||||
|
||||
spec1 = ProviderSpec(
|
||||
api=Api.inference,
|
||||
|
|
@ -642,7 +642,7 @@ class TestGetExternalProvidersFromModule:
|
|||
from types import SimpleNamespace
|
||||
|
||||
from llama_stack.core.distribution import get_external_providers_from_module
|
||||
from llama_stack.providers.datatypes import ProviderSpec
|
||||
from llama_stack_api import ProviderSpec
|
||||
|
||||
spec1 = ProviderSpec(
|
||||
api=Api.inference,
|
||||
|
|
@ -690,7 +690,7 @@ class TestGetExternalProvidersFromModule:
|
|||
from types import SimpleNamespace
|
||||
|
||||
from llama_stack.core.distribution import get_external_providers_from_module
|
||||
from llama_stack.providers.datatypes import ProviderSpec
|
||||
from llama_stack_api import ProviderSpec
|
||||
|
||||
# Module returns both inline and remote variants
|
||||
spec1 = ProviderSpec(
|
||||
|
|
@ -829,7 +829,7 @@ class TestGetExternalProvidersFromModule:
|
|||
from types import SimpleNamespace
|
||||
|
||||
from llama_stack.core.distribution import get_external_providers_from_module
|
||||
from llama_stack.providers.datatypes import ProviderSpec
|
||||
from llama_stack_api import ProviderSpec
|
||||
|
||||
inference_spec = ProviderSpec(
|
||||
api=Api.inference,
|
||||
|
|
|
|||
130
tests/unit/distribution/test_stack_list.py
Normal file
130
tests/unit/distribution/test_stack_list.py
Normal file
|
|
@ -0,0 +1,130 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
"""Tests for the llama stack list command."""
|
||||
|
||||
import argparse
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from llama_stack.cli.stack.list_stacks import StackListBuilds
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def list_stacks_command():
|
||||
"""Create a StackListBuilds instance for testing."""
|
||||
parser = argparse.ArgumentParser()
|
||||
subparsers = parser.add_subparsers()
|
||||
return StackListBuilds(subparsers)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_distribs_base_dir(tmp_path):
|
||||
"""Create a mock DISTRIBS_BASE_DIR with some custom distributions."""
|
||||
custom_dir = tmp_path / "distributions"
|
||||
custom_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Create a custom distribution
|
||||
starter_custom = custom_dir / "starter"
|
||||
starter_custom.mkdir()
|
||||
(starter_custom / "starter-build.yaml").write_text("# build config")
|
||||
(starter_custom / "starter-run.yaml").write_text("# run config")
|
||||
|
||||
return custom_dir
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_distro_dir(tmp_path):
|
||||
"""Create a mock distributions directory with built-in distributions."""
|
||||
distro_dir = tmp_path / "src" / "llama_stack" / "distributions"
|
||||
distro_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Create some built-in distributions
|
||||
for distro_name in ["starter", "nvidia", "dell"]:
|
||||
distro_path = distro_dir / distro_name
|
||||
distro_path.mkdir()
|
||||
(distro_path / "build.yaml").write_text("# build config")
|
||||
(distro_path / "run.yaml").write_text("# run config")
|
||||
|
||||
return distro_dir
|
||||
|
||||
|
||||
def create_path_mock(builtin_dist_dir):
|
||||
"""Create a properly mocked Path object that returns builtin_dist_dir for the distributions path."""
|
||||
mock_parent_parent_parent = MagicMock()
|
||||
mock_parent_parent_parent.__truediv__ = (
|
||||
lambda self, other: builtin_dist_dir if other == "distributions" else MagicMock()
|
||||
)
|
||||
|
||||
mock_path = MagicMock()
|
||||
mock_path.parent.parent.parent = mock_parent_parent_parent
|
||||
|
||||
return mock_path
|
||||
|
||||
|
||||
class TestStackList:
|
||||
"""Test suite for llama stack list command."""
|
||||
|
||||
def test_builtin_distros_shown_without_running(self, list_stacks_command, mock_distro_dir, tmp_path):
|
||||
"""Test that built-in distributions are shown even before running them."""
|
||||
mock_path = create_path_mock(mock_distro_dir)
|
||||
|
||||
# Mock DISTRIBS_BASE_DIR to be a non-existent directory (no custom distributions)
|
||||
with patch("llama_stack.cli.stack.list_stacks.DISTRIBS_BASE_DIR", tmp_path / "nonexistent"):
|
||||
with patch("llama_stack.cli.stack.list_stacks.Path") as mock_path_class:
|
||||
mock_path_class.return_value = mock_path
|
||||
|
||||
distributions = list_stacks_command._get_distribution_dirs()
|
||||
|
||||
# Verify built-in distributions are found
|
||||
assert len(distributions) > 0, "Should find built-in distributions"
|
||||
assert all(source_type == "built-in" for _, source_type in distributions.values()), (
|
||||
"All should be built-in"
|
||||
)
|
||||
|
||||
# Check specific distributions we created
|
||||
assert "starter" in distributions
|
||||
assert "nvidia" in distributions
|
||||
assert "dell" in distributions
|
||||
|
||||
def test_custom_distribution_overrides_builtin(self, list_stacks_command, mock_distro_dir, mock_distribs_base_dir):
|
||||
"""Test that custom distributions override built-in ones with the same name."""
|
||||
mock_path = create_path_mock(mock_distro_dir)
|
||||
|
||||
with patch("llama_stack.cli.stack.list_stacks.DISTRIBS_BASE_DIR", mock_distribs_base_dir):
|
||||
with patch("llama_stack.cli.stack.list_stacks.Path") as mock_path_class:
|
||||
mock_path_class.return_value = mock_path
|
||||
|
||||
distributions = list_stacks_command._get_distribution_dirs()
|
||||
|
||||
# "starter" should exist and be marked as "custom" (not "built-in")
|
||||
# because the custom version overrides the built-in one
|
||||
assert "starter" in distributions
|
||||
_, source_type = distributions["starter"]
|
||||
assert source_type == "custom", "Custom distribution should override built-in"
|
||||
|
||||
def test_hidden_directories_ignored(self, list_stacks_command, mock_distro_dir, tmp_path):
|
||||
"""Test that hidden directories (starting with .) are ignored."""
|
||||
# Add a hidden directory
|
||||
hidden_dir = mock_distro_dir / ".hidden"
|
||||
hidden_dir.mkdir()
|
||||
(hidden_dir / "build.yaml").write_text("# build")
|
||||
|
||||
# Add a __pycache__ directory
|
||||
pycache_dir = mock_distro_dir / "__pycache__"
|
||||
pycache_dir.mkdir()
|
||||
|
||||
mock_path = create_path_mock(mock_distro_dir)
|
||||
|
||||
with patch("llama_stack.cli.stack.list_stacks.DISTRIBS_BASE_DIR", tmp_path / "nonexistent"):
|
||||
with patch("llama_stack.cli.stack.list_stacks.Path") as mock_path_class:
|
||||
mock_path_class.return_value = mock_path
|
||||
|
||||
distributions = list_stacks_command._get_distribution_dirs()
|
||||
|
||||
assert ".hidden" not in distributions
|
||||
assert "__pycache__" not in distributions
|
||||
|
|
@ -7,16 +7,14 @@
|
|||
|
||||
import pytest
|
||||
|
||||
from llama_stack.apis.common.errors import ResourceNotFoundError
|
||||
from llama_stack.apis.common.responses import Order
|
||||
from llama_stack.apis.files import OpenAIFilePurpose
|
||||
from llama_stack.core.access_control.access_control import default_policy
|
||||
from llama_stack.core.storage.datatypes import SqliteSqlStoreConfig, SqlStoreReference
|
||||
from llama_stack.core.storage.sqlstore.sqlstore import register_sqlstore_backends
|
||||
from llama_stack.providers.inline.files.localfs import (
|
||||
LocalfsFilesImpl,
|
||||
LocalfsFilesImplConfig,
|
||||
)
|
||||
from llama_stack.providers.utils.sqlstore.sqlstore import register_sqlstore_backends
|
||||
from llama_stack_api import OpenAIFilePurpose, Order, ResourceNotFoundError
|
||||
|
||||
|
||||
class MockUploadFile:
|
||||
|
|
|
|||
|
|
@ -6,9 +6,9 @@
|
|||
|
||||
import pytest
|
||||
|
||||
from llama_stack.core.storage.kvstore.config import SqliteKVStoreConfig
|
||||
from llama_stack.core.storage.kvstore.sqlite import SqliteKVStoreImpl
|
||||
from llama_stack.core.store.registry import CachedDiskDistributionRegistry, DiskDistributionRegistry
|
||||
from llama_stack.providers.utils.kvstore.config import SqliteKVStoreConfig
|
||||
from llama_stack.providers.utils.kvstore.sqlite import SqliteKVStoreImpl
|
||||
|
||||
|
||||
@pytest.fixture(scope="function")
|
||||
|
|
|
|||
|
|
@ -1,303 +0,0 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
|
||||
from llama_stack.apis.inference import (
|
||||
ChatCompletionRequest,
|
||||
CompletionMessage,
|
||||
StopReason,
|
||||
SystemMessage,
|
||||
SystemMessageBehavior,
|
||||
ToolCall,
|
||||
ToolConfig,
|
||||
UserMessage,
|
||||
)
|
||||
from llama_stack.models.llama.datatypes import (
|
||||
BuiltinTool,
|
||||
ToolDefinition,
|
||||
ToolPromptFormat,
|
||||
)
|
||||
from llama_stack.providers.utils.inference.prompt_adapter import (
|
||||
chat_completion_request_to_messages,
|
||||
chat_completion_request_to_prompt,
|
||||
interleaved_content_as_str,
|
||||
)
|
||||
|
||||
MODEL = "Llama3.1-8B-Instruct"
|
||||
MODEL3_2 = "Llama3.2-3B-Instruct"
|
||||
|
||||
|
||||
async def test_system_default():
|
||||
content = "Hello !"
|
||||
request = ChatCompletionRequest(
|
||||
model=MODEL,
|
||||
messages=[
|
||||
UserMessage(content=content),
|
||||
],
|
||||
)
|
||||
messages = chat_completion_request_to_messages(request, MODEL)
|
||||
assert len(messages) == 2
|
||||
assert messages[-1].content == content
|
||||
assert "Cutting Knowledge Date: December 2023" in interleaved_content_as_str(messages[0].content)
|
||||
|
||||
|
||||
async def test_system_builtin_only():
|
||||
content = "Hello !"
|
||||
request = ChatCompletionRequest(
|
||||
model=MODEL,
|
||||
messages=[
|
||||
UserMessage(content=content),
|
||||
],
|
||||
tools=[
|
||||
ToolDefinition(tool_name=BuiltinTool.code_interpreter),
|
||||
ToolDefinition(tool_name=BuiltinTool.brave_search),
|
||||
],
|
||||
)
|
||||
messages = chat_completion_request_to_messages(request, MODEL)
|
||||
assert len(messages) == 2
|
||||
assert messages[-1].content == content
|
||||
assert "Cutting Knowledge Date: December 2023" in interleaved_content_as_str(messages[0].content)
|
||||
assert "Tools: brave_search" in interleaved_content_as_str(messages[0].content)
|
||||
|
||||
|
||||
async def test_system_custom_only():
|
||||
content = "Hello !"
|
||||
request = ChatCompletionRequest(
|
||||
model=MODEL,
|
||||
messages=[
|
||||
UserMessage(content=content),
|
||||
],
|
||||
tools=[
|
||||
ToolDefinition(
|
||||
tool_name="custom1",
|
||||
description="custom1 tool",
|
||||
input_schema={
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"param1": {
|
||||
"type": "str",
|
||||
"description": "param1 description",
|
||||
},
|
||||
},
|
||||
"required": ["param1"],
|
||||
},
|
||||
)
|
||||
],
|
||||
tool_config=ToolConfig(tool_prompt_format=ToolPromptFormat.json),
|
||||
)
|
||||
messages = chat_completion_request_to_messages(request, MODEL)
|
||||
assert len(messages) == 3
|
||||
assert "Environment: ipython" in interleaved_content_as_str(messages[0].content)
|
||||
|
||||
assert "Return function calls in JSON format" in interleaved_content_as_str(messages[1].content)
|
||||
assert messages[-1].content == content
|
||||
|
||||
|
||||
async def test_system_custom_and_builtin():
|
||||
content = "Hello !"
|
||||
request = ChatCompletionRequest(
|
||||
model=MODEL,
|
||||
messages=[
|
||||
UserMessage(content=content),
|
||||
],
|
||||
tools=[
|
||||
ToolDefinition(tool_name=BuiltinTool.code_interpreter),
|
||||
ToolDefinition(tool_name=BuiltinTool.brave_search),
|
||||
ToolDefinition(
|
||||
tool_name="custom1",
|
||||
description="custom1 tool",
|
||||
input_schema={
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"param1": {
|
||||
"type": "str",
|
||||
"description": "param1 description",
|
||||
},
|
||||
},
|
||||
"required": ["param1"],
|
||||
},
|
||||
),
|
||||
],
|
||||
)
|
||||
messages = chat_completion_request_to_messages(request, MODEL)
|
||||
assert len(messages) == 3
|
||||
|
||||
assert "Environment: ipython" in interleaved_content_as_str(messages[0].content)
|
||||
assert "Tools: brave_search" in interleaved_content_as_str(messages[0].content)
|
||||
|
||||
assert "Return function calls in JSON format" in interleaved_content_as_str(messages[1].content)
|
||||
assert messages[-1].content == content
|
||||
|
||||
|
||||
async def test_completion_message_encoding():
|
||||
request = ChatCompletionRequest(
|
||||
model=MODEL3_2,
|
||||
messages=[
|
||||
UserMessage(content="hello"),
|
||||
CompletionMessage(
|
||||
content="",
|
||||
stop_reason=StopReason.end_of_turn,
|
||||
tool_calls=[
|
||||
ToolCall(
|
||||
tool_name="custom1",
|
||||
arguments='{"param1": "value1"}', # arguments must be a JSON string
|
||||
call_id="123",
|
||||
)
|
||||
],
|
||||
),
|
||||
],
|
||||
tools=[
|
||||
ToolDefinition(
|
||||
tool_name="custom1",
|
||||
description="custom1 tool",
|
||||
input_schema={
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"param1": {
|
||||
"type": "str",
|
||||
"description": "param1 description",
|
||||
},
|
||||
},
|
||||
"required": ["param1"],
|
||||
},
|
||||
),
|
||||
],
|
||||
tool_config=ToolConfig(tool_prompt_format=ToolPromptFormat.python_list),
|
||||
)
|
||||
prompt = await chat_completion_request_to_prompt(request, request.model)
|
||||
assert '[custom1(param1="value1")]' in prompt
|
||||
|
||||
request.model = MODEL
|
||||
request.tool_config = ToolConfig(tool_prompt_format=ToolPromptFormat.json)
|
||||
prompt = await chat_completion_request_to_prompt(request, request.model)
|
||||
assert '{"type": "function", "name": "custom1", "parameters": {"param1": "value1"}}' in prompt
|
||||
|
||||
|
||||
async def test_user_provided_system_message():
|
||||
content = "Hello !"
|
||||
system_prompt = "You are a pirate"
|
||||
request = ChatCompletionRequest(
|
||||
model=MODEL,
|
||||
messages=[
|
||||
SystemMessage(content=system_prompt),
|
||||
UserMessage(content=content),
|
||||
],
|
||||
tools=[
|
||||
ToolDefinition(tool_name=BuiltinTool.code_interpreter),
|
||||
],
|
||||
)
|
||||
messages = chat_completion_request_to_messages(request, MODEL)
|
||||
assert len(messages) == 2
|
||||
assert interleaved_content_as_str(messages[0].content).endswith(system_prompt)
|
||||
|
||||
assert messages[-1].content == content
|
||||
|
||||
|
||||
async def test_replace_system_message_behavior_builtin_tools():
|
||||
content = "Hello !"
|
||||
system_prompt = "You are a pirate"
|
||||
request = ChatCompletionRequest(
|
||||
model=MODEL,
|
||||
messages=[
|
||||
SystemMessage(content=system_prompt),
|
||||
UserMessage(content=content),
|
||||
],
|
||||
tools=[
|
||||
ToolDefinition(tool_name=BuiltinTool.code_interpreter),
|
||||
],
|
||||
tool_config=ToolConfig(
|
||||
tool_choice="auto",
|
||||
tool_prompt_format=ToolPromptFormat.python_list,
|
||||
system_message_behavior=SystemMessageBehavior.replace,
|
||||
),
|
||||
)
|
||||
messages = chat_completion_request_to_messages(request, MODEL3_2)
|
||||
assert len(messages) == 2
|
||||
assert interleaved_content_as_str(messages[0].content).endswith(system_prompt)
|
||||
assert "Environment: ipython" in interleaved_content_as_str(messages[0].content)
|
||||
assert messages[-1].content == content
|
||||
|
||||
|
||||
async def test_replace_system_message_behavior_custom_tools():
|
||||
content = "Hello !"
|
||||
system_prompt = "You are a pirate"
|
||||
request = ChatCompletionRequest(
|
||||
model=MODEL,
|
||||
messages=[
|
||||
SystemMessage(content=system_prompt),
|
||||
UserMessage(content=content),
|
||||
],
|
||||
tools=[
|
||||
ToolDefinition(tool_name=BuiltinTool.code_interpreter),
|
||||
ToolDefinition(
|
||||
tool_name="custom1",
|
||||
description="custom1 tool",
|
||||
input_schema={
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"param1": {
|
||||
"type": "str",
|
||||
"description": "param1 description",
|
||||
},
|
||||
},
|
||||
"required": ["param1"],
|
||||
},
|
||||
),
|
||||
],
|
||||
tool_config=ToolConfig(
|
||||
tool_choice="auto",
|
||||
tool_prompt_format=ToolPromptFormat.python_list,
|
||||
system_message_behavior=SystemMessageBehavior.replace,
|
||||
),
|
||||
)
|
||||
messages = chat_completion_request_to_messages(request, MODEL3_2)
|
||||
|
||||
assert len(messages) == 2
|
||||
assert interleaved_content_as_str(messages[0].content).endswith(system_prompt)
|
||||
assert "Environment: ipython" in interleaved_content_as_str(messages[0].content)
|
||||
assert messages[-1].content == content
|
||||
|
||||
|
||||
async def test_replace_system_message_behavior_custom_tools_with_template():
|
||||
content = "Hello !"
|
||||
system_prompt = "You are a pirate {{ function_description }}"
|
||||
request = ChatCompletionRequest(
|
||||
model=MODEL,
|
||||
messages=[
|
||||
SystemMessage(content=system_prompt),
|
||||
UserMessage(content=content),
|
||||
],
|
||||
tools=[
|
||||
ToolDefinition(tool_name=BuiltinTool.code_interpreter),
|
||||
ToolDefinition(
|
||||
tool_name="custom1",
|
||||
description="custom1 tool",
|
||||
input_schema={
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"param1": {
|
||||
"type": "str",
|
||||
"description": "param1 description",
|
||||
},
|
||||
},
|
||||
"required": ["param1"],
|
||||
},
|
||||
),
|
||||
],
|
||||
tool_config=ToolConfig(
|
||||
tool_choice="auto",
|
||||
tool_prompt_format=ToolPromptFormat.python_list,
|
||||
system_message_behavior=SystemMessageBehavior.replace,
|
||||
),
|
||||
)
|
||||
messages = chat_completion_request_to_messages(request, MODEL3_2)
|
||||
|
||||
assert len(messages) == 2
|
||||
assert "Environment: ipython" in interleaved_content_as_str(messages[0].content)
|
||||
assert "You are a pirate" in interleaved_content_as_str(messages[0].content)
|
||||
# function description is present in the system prompt
|
||||
assert '"name": "custom1"' in interleaved_content_as_str(messages[0].content)
|
||||
assert messages[-1].content == content
|
||||
|
|
@ -18,7 +18,7 @@ from llama_stack.core.storage.datatypes import (
|
|||
SqlStoreReference,
|
||||
StorageConfig,
|
||||
)
|
||||
from llama_stack.providers.utils.kvstore import register_kvstore_backends
|
||||
from llama_stack.core.storage.kvstore import register_kvstore_backends
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
|
|
|
|||
|
|
@ -1,347 +0,0 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import json
|
||||
from datetime import UTC, datetime
|
||||
from unittest.mock import AsyncMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from llama_stack.apis.agents import Session
|
||||
from llama_stack.core.datatypes import User
|
||||
from llama_stack.providers.inline.agents.meta_reference.persistence import (
|
||||
AgentPersistence,
|
||||
AgentSessionInfo,
|
||||
)
|
||||
from llama_stack.providers.utils.kvstore import KVStore
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_kvstore():
|
||||
return AsyncMock(spec=KVStore)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_policy():
|
||||
return []
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def agent_persistence(mock_kvstore, mock_policy):
|
||||
return AgentPersistence(agent_id="test-agent-123", kvstore=mock_kvstore, policy=mock_policy)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_session():
|
||||
return AgentSessionInfo(
|
||||
session_id="session-123",
|
||||
session_name="Test Session",
|
||||
started_at=datetime.now(UTC),
|
||||
owner=User(principal="user-123", attributes=None),
|
||||
turns=[],
|
||||
identifier="test-session",
|
||||
type="session",
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_session_json(sample_session):
|
||||
return sample_session.model_dump_json()
|
||||
|
||||
|
||||
class TestAgentPersistenceListSessions:
|
||||
def setup_mock_kvstore(self, mock_kvstore, session_keys=None, turn_keys=None, invalid_keys=None, custom_data=None):
|
||||
"""Helper to setup mock kvstore with sessions, turns, and custom/invalid data
|
||||
|
||||
Args:
|
||||
mock_kvstore: The mock KVStore object
|
||||
session_keys: List of session keys or dict mapping keys to custom session data
|
||||
turn_keys: List of turn keys or dict mapping keys to custom turn data
|
||||
invalid_keys: Dict mapping keys to invalid/corrupt data
|
||||
custom_data: Additional custom data to add to the mock responses
|
||||
"""
|
||||
all_keys = []
|
||||
mock_data = {}
|
||||
|
||||
# session keys
|
||||
if session_keys:
|
||||
if isinstance(session_keys, dict):
|
||||
all_keys.extend(session_keys.keys())
|
||||
mock_data.update({k: json.dumps(v) if isinstance(v, dict) else v for k, v in session_keys.items()})
|
||||
else:
|
||||
all_keys.extend(session_keys)
|
||||
for key in session_keys:
|
||||
session_id = key.split(":")[-1]
|
||||
mock_data[key] = json.dumps(
|
||||
{
|
||||
"session_id": session_id,
|
||||
"session_name": f"Session {session_id}",
|
||||
"started_at": datetime.now(UTC).isoformat(),
|
||||
"turns": [],
|
||||
}
|
||||
)
|
||||
|
||||
# turn keys
|
||||
if turn_keys:
|
||||
if isinstance(turn_keys, dict):
|
||||
all_keys.extend(turn_keys.keys())
|
||||
mock_data.update({k: json.dumps(v) if isinstance(v, dict) else v for k, v in turn_keys.items()})
|
||||
else:
|
||||
all_keys.extend(turn_keys)
|
||||
for key in turn_keys:
|
||||
parts = key.split(":")
|
||||
session_id = parts[-2]
|
||||
turn_id = parts[-1]
|
||||
mock_data[key] = json.dumps(
|
||||
{
|
||||
"turn_id": turn_id,
|
||||
"session_id": session_id,
|
||||
"input_messages": [],
|
||||
"started_at": datetime.now(UTC).isoformat(),
|
||||
}
|
||||
)
|
||||
|
||||
if invalid_keys:
|
||||
all_keys.extend(invalid_keys.keys())
|
||||
mock_data.update(invalid_keys)
|
||||
|
||||
if custom_data:
|
||||
mock_data.update(custom_data)
|
||||
|
||||
values_list = list(mock_data.values())
|
||||
mock_kvstore.values_in_range.return_value = values_list
|
||||
|
||||
async def mock_get(key):
|
||||
return mock_data.get(key)
|
||||
|
||||
mock_kvstore.get.side_effect = mock_get
|
||||
|
||||
return mock_data
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"scenario",
|
||||
[
|
||||
{
|
||||
# from this issue: https://github.com/meta-llama/llama-stack/issues/3048
|
||||
"name": "reported_bug",
|
||||
"session_keys": ["session:test-agent-123:1f08fd1c-5a9d-459d-a00b-36d4dfa49b7d"],
|
||||
"turn_keys": [
|
||||
"session:test-agent-123:1f08fd1c-5a9d-459d-a00b-36d4dfa49b7d:eb7e818f-41fb-49a0-bdd6-464974a2d2ad"
|
||||
],
|
||||
"expected_sessions": ["1f08fd1c-5a9d-459d-a00b-36d4dfa49b7d"],
|
||||
},
|
||||
{
|
||||
"name": "basic_filtering",
|
||||
"session_keys": ["session:test-agent-123:session-1", "session:test-agent-123:session-2"],
|
||||
"turn_keys": ["session:test-agent-123:session-1:turn-1", "session:test-agent-123:session-1:turn-2"],
|
||||
"expected_sessions": ["session-1", "session-2"],
|
||||
},
|
||||
{
|
||||
"name": "multiple_turns_per_session",
|
||||
"session_keys": ["session:test-agent-123:session-456"],
|
||||
"turn_keys": [
|
||||
"session:test-agent-123:session-456:turn-789",
|
||||
"session:test-agent-123:session-456:turn-790",
|
||||
],
|
||||
"expected_sessions": ["session-456"],
|
||||
},
|
||||
{
|
||||
"name": "multiple_sessions_with_turns",
|
||||
"session_keys": ["session:test-agent-123:session-1", "session:test-agent-123:session-2"],
|
||||
"turn_keys": [
|
||||
"session:test-agent-123:session-1:turn-1",
|
||||
"session:test-agent-123:session-1:turn-2",
|
||||
"session:test-agent-123:session-2:turn-3",
|
||||
],
|
||||
"expected_sessions": ["session-1", "session-2"],
|
||||
},
|
||||
],
|
||||
)
|
||||
async def test_list_sessions_key_filtering(self, agent_persistence, mock_kvstore, scenario):
|
||||
self.setup_mock_kvstore(mock_kvstore, session_keys=scenario["session_keys"], turn_keys=scenario["turn_keys"])
|
||||
|
||||
with patch("llama_stack.providers.inline.agents.meta_reference.persistence.log") as mock_log:
|
||||
result = await agent_persistence.list_sessions()
|
||||
|
||||
assert len(result) == len(scenario["expected_sessions"])
|
||||
session_ids = {s.session_id for s in result}
|
||||
for expected_id in scenario["expected_sessions"]:
|
||||
assert expected_id in session_ids
|
||||
|
||||
# no errors should be logged
|
||||
mock_log.error.assert_not_called()
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"error_scenario",
|
||||
[
|
||||
{
|
||||
"name": "invalid_json",
|
||||
"valid_keys": ["session:test-agent-123:valid-session"],
|
||||
"invalid_data": {"session:test-agent-123:invalid-json": "corrupted-json-data{"},
|
||||
"expected_valid_sessions": ["valid-session"],
|
||||
"expected_error_count": 1,
|
||||
},
|
||||
{
|
||||
"name": "missing_fields",
|
||||
"valid_keys": ["session:test-agent-123:valid-session"],
|
||||
"invalid_data": {
|
||||
"session:test-agent-123:invalid-schema": json.dumps(
|
||||
{
|
||||
"session_id": "invalid-schema",
|
||||
"session_name": "Missing Fields",
|
||||
# missing `started_at` and `turns`
|
||||
}
|
||||
)
|
||||
},
|
||||
"expected_valid_sessions": ["valid-session"],
|
||||
"expected_error_count": 1,
|
||||
},
|
||||
{
|
||||
"name": "multiple_invalid",
|
||||
"valid_keys": ["session:test-agent-123:valid-session-1", "session:test-agent-123:valid-session-2"],
|
||||
"invalid_data": {
|
||||
"session:test-agent-123:corrupted-json": "not-valid-json{",
|
||||
"session:test-agent-123:incomplete-data": json.dumps({"incomplete": "data"}),
|
||||
},
|
||||
"expected_valid_sessions": ["valid-session-1", "valid-session-2"],
|
||||
"expected_error_count": 2,
|
||||
},
|
||||
],
|
||||
)
|
||||
async def test_list_sessions_error_handling(self, agent_persistence, mock_kvstore, error_scenario):
|
||||
session_keys = {}
|
||||
for key in error_scenario["valid_keys"]:
|
||||
session_id = key.split(":")[-1]
|
||||
session_keys[key] = {
|
||||
"session_id": session_id,
|
||||
"session_name": f"Valid {session_id}",
|
||||
"started_at": datetime.now(UTC).isoformat(),
|
||||
"turns": [],
|
||||
}
|
||||
|
||||
self.setup_mock_kvstore(mock_kvstore, session_keys=session_keys, invalid_keys=error_scenario["invalid_data"])
|
||||
|
||||
with patch("llama_stack.providers.inline.agents.meta_reference.persistence.log") as mock_log:
|
||||
result = await agent_persistence.list_sessions()
|
||||
|
||||
# only valid sessions should be returned
|
||||
assert len(result) == len(error_scenario["expected_valid_sessions"])
|
||||
session_ids = {s.session_id for s in result}
|
||||
for expected_id in error_scenario["expected_valid_sessions"]:
|
||||
assert expected_id in session_ids
|
||||
|
||||
# error should be logged
|
||||
assert mock_log.error.call_count > 0
|
||||
assert mock_log.error.call_count == error_scenario["expected_error_count"]
|
||||
|
||||
async def test_list_sessions_empty(self, agent_persistence, mock_kvstore):
|
||||
mock_kvstore.values_in_range.return_value = []
|
||||
|
||||
result = await agent_persistence.list_sessions()
|
||||
|
||||
assert result == []
|
||||
mock_kvstore.values_in_range.assert_called_once_with(
|
||||
start_key="session:test-agent-123:", end_key="session:test-agent-123:\xff\xff\xff\xff"
|
||||
)
|
||||
|
||||
async def test_list_sessions_properties(self, agent_persistence, mock_kvstore):
|
||||
session_data = {
|
||||
"session_id": "session-123",
|
||||
"session_name": "Test Session",
|
||||
"started_at": datetime.now(UTC).isoformat(),
|
||||
"owner": {"principal": "user-123", "attributes": None},
|
||||
"turns": [],
|
||||
}
|
||||
|
||||
self.setup_mock_kvstore(mock_kvstore, session_keys={"session:test-agent-123:session-123": session_data})
|
||||
|
||||
result = await agent_persistence.list_sessions()
|
||||
|
||||
assert len(result) == 1
|
||||
assert isinstance(result[0], Session)
|
||||
assert result[0].session_id == "session-123"
|
||||
assert result[0].session_name == "Test Session"
|
||||
assert result[0].turns == []
|
||||
assert hasattr(result[0], "started_at")
|
||||
|
||||
async def test_list_sessions_kvstore_exception(self, agent_persistence, mock_kvstore):
|
||||
mock_kvstore.values_in_range.side_effect = Exception("KVStore error")
|
||||
|
||||
with pytest.raises(Exception, match="KVStore error"):
|
||||
await agent_persistence.list_sessions()
|
||||
|
||||
async def test_bug_data_loss_with_real_data(self, agent_persistence, mock_kvstore):
|
||||
# tests the handling of the issue reported in: https://github.com/meta-llama/llama-stack/issues/3048
|
||||
session_data = {
|
||||
"session_id": "1f08fd1c-5a9d-459d-a00b-36d4dfa49b7d",
|
||||
"session_name": "Test Session",
|
||||
"started_at": datetime.now(UTC).isoformat(),
|
||||
"turns": [],
|
||||
}
|
||||
|
||||
turn_data = {
|
||||
"turn_id": "eb7e818f-41fb-49a0-bdd6-464974a2d2ad",
|
||||
"session_id": "1f08fd1c-5a9d-459d-a00b-36d4dfa49b7d",
|
||||
"input_messages": [
|
||||
{"role": "user", "content": "if i had a cluster i would want to call it persistence01", "context": None}
|
||||
],
|
||||
"steps": [
|
||||
{
|
||||
"turn_id": "eb7e818f-41fb-49a0-bdd6-464974a2d2ad",
|
||||
"step_id": "c0f797dd-3d34-4bc5-a8f4-db6af9455132",
|
||||
"started_at": "2025-08-05T14:31:50.000484Z",
|
||||
"completed_at": "2025-08-05T14:31:51.303691Z",
|
||||
"step_type": "inference",
|
||||
"model_response": {
|
||||
"role": "assistant",
|
||||
"content": "OK, I can create a cluster named 'persistence01' for you.",
|
||||
"stop_reason": "end_of_turn",
|
||||
"tool_calls": [],
|
||||
},
|
||||
}
|
||||
],
|
||||
"output_message": {
|
||||
"role": "assistant",
|
||||
"content": "OK, I can create a cluster named 'persistence01' for you.",
|
||||
"stop_reason": "end_of_turn",
|
||||
"tool_calls": [],
|
||||
},
|
||||
"output_attachments": [],
|
||||
"started_at": "2025-08-05T14:31:49.999950Z",
|
||||
"completed_at": "2025-08-05T14:31:51.305384Z",
|
||||
}
|
||||
|
||||
mock_data = {
|
||||
"session:test-agent-123:1f08fd1c-5a9d-459d-a00b-36d4dfa49b7d": json.dumps(session_data),
|
||||
"session:test-agent-123:1f08fd1c-5a9d-459d-a00b-36d4dfa49b7d:eb7e818f-41fb-49a0-bdd6-464974a2d2ad": json.dumps(
|
||||
turn_data
|
||||
),
|
||||
}
|
||||
|
||||
mock_kvstore.values_in_range.return_value = list(mock_data.values())
|
||||
|
||||
async def mock_get(key):
|
||||
return mock_data.get(key)
|
||||
|
||||
mock_kvstore.get.side_effect = mock_get
|
||||
|
||||
with patch("llama_stack.providers.inline.agents.meta_reference.persistence.log") as mock_log:
|
||||
result = await agent_persistence.list_sessions()
|
||||
|
||||
assert len(result) == 1
|
||||
assert result[0].session_id == "1f08fd1c-5a9d-459d-a00b-36d4dfa49b7d"
|
||||
|
||||
# confirm no errors logged
|
||||
mock_log.error.assert_not_called()
|
||||
|
||||
async def test_list_sessions_key_range_construction(self, agent_persistence, mock_kvstore):
|
||||
mock_kvstore.values_in_range.return_value = []
|
||||
|
||||
await agent_persistence.list_sessions()
|
||||
|
||||
mock_kvstore.values_in_range.assert_called_once_with(
|
||||
start_key="session:test-agent-123:", end_key="session:test-agent-123:\xff\xff\xff\xff"
|
||||
)
|
||||
|
|
@ -1,196 +0,0 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import warnings
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from llama_stack.apis.agents import Document
|
||||
from llama_stack.apis.common.content_types import URL, TextContentItem
|
||||
from llama_stack.providers.inline.agents.meta_reference.agent_instance import get_raw_document_text
|
||||
|
||||
|
||||
async def test_get_raw_document_text_supports_text_mime_types():
|
||||
"""Test that the function accepts text/* mime types."""
|
||||
document = Document(content="Sample text content", mime_type="text/plain")
|
||||
|
||||
result = await get_raw_document_text(document)
|
||||
assert result == "Sample text content"
|
||||
|
||||
|
||||
async def test_get_raw_document_text_supports_yaml_mime_type():
|
||||
"""Test that the function accepts application/yaml mime type."""
|
||||
yaml_content = """
|
||||
name: test
|
||||
version: 1.0
|
||||
items:
|
||||
- item1
|
||||
- item2
|
||||
"""
|
||||
|
||||
document = Document(content=yaml_content, mime_type="application/yaml")
|
||||
|
||||
result = await get_raw_document_text(document)
|
||||
assert result == yaml_content
|
||||
|
||||
|
||||
async def test_get_raw_document_text_supports_deprecated_text_yaml_with_warning():
|
||||
"""Test that the function accepts text/yaml but emits a deprecation warning."""
|
||||
yaml_content = """
|
||||
name: test
|
||||
version: 1.0
|
||||
items:
|
||||
- item1
|
||||
- item2
|
||||
"""
|
||||
|
||||
document = Document(content=yaml_content, mime_type="text/yaml")
|
||||
|
||||
with warnings.catch_warnings(record=True) as w:
|
||||
warnings.simplefilter("always")
|
||||
result = await get_raw_document_text(document)
|
||||
|
||||
# Check that result is correct
|
||||
assert result == yaml_content
|
||||
|
||||
# Check that exactly one warning was issued
|
||||
assert len(w) == 1
|
||||
assert issubclass(w[0].category, DeprecationWarning)
|
||||
assert "text/yaml" in str(w[0].message)
|
||||
assert "application/yaml" in str(w[0].message)
|
||||
assert "deprecated" in str(w[0].message).lower()
|
||||
|
||||
|
||||
async def test_get_raw_document_text_deprecated_text_yaml_with_url():
|
||||
"""Test that text/yaml works with URL content and emits warning."""
|
||||
yaml_content = "name: test\nversion: 1.0"
|
||||
|
||||
with patch("llama_stack.providers.inline.agents.meta_reference.agent_instance.load_data_from_url") as mock_load:
|
||||
mock_load.return_value = yaml_content
|
||||
|
||||
document = Document(content=URL(uri="https://example.com/config.yaml"), mime_type="text/yaml")
|
||||
|
||||
with warnings.catch_warnings(record=True) as w:
|
||||
warnings.simplefilter("always")
|
||||
result = await get_raw_document_text(document)
|
||||
|
||||
# Check that result is correct
|
||||
assert result == yaml_content
|
||||
mock_load.assert_called_once_with("https://example.com/config.yaml")
|
||||
|
||||
# Check that deprecation warning was issued
|
||||
assert len(w) == 1
|
||||
assert issubclass(w[0].category, DeprecationWarning)
|
||||
assert "text/yaml" in str(w[0].message)
|
||||
|
||||
|
||||
async def test_get_raw_document_text_deprecated_text_yaml_with_text_content_item():
|
||||
"""Test that text/yaml works with TextContentItem and emits warning."""
|
||||
yaml_content = "key: value\nlist:\n - item1\n - item2"
|
||||
|
||||
document = Document(content=TextContentItem(text=yaml_content), mime_type="text/yaml")
|
||||
|
||||
with warnings.catch_warnings(record=True) as w:
|
||||
warnings.simplefilter("always")
|
||||
result = await get_raw_document_text(document)
|
||||
|
||||
# Check that result is correct
|
||||
assert result == yaml_content
|
||||
|
||||
# Check that deprecation warning was issued
|
||||
assert len(w) == 1
|
||||
assert issubclass(w[0].category, DeprecationWarning)
|
||||
assert "text/yaml" in str(w[0].message)
|
||||
|
||||
|
||||
async def test_get_raw_document_text_supports_json_mime_type():
|
||||
"""Test that the function accepts application/json mime type."""
|
||||
json_content = '{"name": "test", "version": "1.0", "items": ["item1", "item2"]}'
|
||||
|
||||
document = Document(content=json_content, mime_type="application/json")
|
||||
|
||||
result = await get_raw_document_text(document)
|
||||
assert result == json_content
|
||||
|
||||
|
||||
async def test_get_raw_document_text_with_json_text_content_item():
|
||||
"""Test that the function handles JSON TextContentItem correctly."""
|
||||
json_content = '{"key": "value", "nested": {"array": [1, 2, 3]}}'
|
||||
|
||||
document = Document(content=TextContentItem(text=json_content), mime_type="application/json")
|
||||
|
||||
result = await get_raw_document_text(document)
|
||||
assert result == json_content
|
||||
|
||||
|
||||
async def test_get_raw_document_text_rejects_unsupported_mime_types():
|
||||
"""Test that the function rejects unsupported mime types."""
|
||||
document = Document(
|
||||
content="Some content",
|
||||
mime_type="application/pdf", # Not supported
|
||||
)
|
||||
|
||||
with pytest.raises(ValueError, match="Unexpected document mime type: application/pdf"):
|
||||
await get_raw_document_text(document)
|
||||
|
||||
|
||||
async def test_get_raw_document_text_with_url_content():
|
||||
"""Test that the function handles URL content correctly."""
|
||||
mock_response = AsyncMock()
|
||||
mock_response.text = "Content from URL"
|
||||
|
||||
with patch("llama_stack.providers.inline.agents.meta_reference.agent_instance.load_data_from_url") as mock_load:
|
||||
mock_load.return_value = "Content from URL"
|
||||
|
||||
document = Document(content=URL(uri="https://example.com/test.txt"), mime_type="text/plain")
|
||||
|
||||
result = await get_raw_document_text(document)
|
||||
assert result == "Content from URL"
|
||||
mock_load.assert_called_once_with("https://example.com/test.txt")
|
||||
|
||||
|
||||
async def test_get_raw_document_text_with_yaml_url():
|
||||
"""Test that the function handles YAML URLs correctly."""
|
||||
yaml_content = "name: test\nversion: 1.0"
|
||||
|
||||
with patch("llama_stack.providers.inline.agents.meta_reference.agent_instance.load_data_from_url") as mock_load:
|
||||
mock_load.return_value = yaml_content
|
||||
|
||||
document = Document(content=URL(uri="https://example.com/config.yaml"), mime_type="application/yaml")
|
||||
|
||||
result = await get_raw_document_text(document)
|
||||
assert result == yaml_content
|
||||
mock_load.assert_called_once_with("https://example.com/config.yaml")
|
||||
|
||||
|
||||
async def test_get_raw_document_text_with_text_content_item():
|
||||
"""Test that the function handles TextContentItem correctly."""
|
||||
document = Document(content=TextContentItem(text="Text content item"), mime_type="text/plain")
|
||||
|
||||
result = await get_raw_document_text(document)
|
||||
assert result == "Text content item"
|
||||
|
||||
|
||||
async def test_get_raw_document_text_with_yaml_text_content_item():
|
||||
"""Test that the function handles YAML TextContentItem correctly."""
|
||||
yaml_content = "key: value\nlist:\n - item1\n - item2"
|
||||
|
||||
document = Document(content=TextContentItem(text=yaml_content), mime_type="application/yaml")
|
||||
|
||||
result = await get_raw_document_text(document)
|
||||
assert result == yaml_content
|
||||
|
||||
|
||||
async def test_get_raw_document_text_rejects_unexpected_content_type():
|
||||
"""Test that the function rejects unexpected document content types."""
|
||||
# Create a mock document that bypasses Pydantic validation
|
||||
mock_document = MagicMock(spec=Document)
|
||||
mock_document.mime_type = "text/plain"
|
||||
mock_document.content = 123 # Unexpected content type (not str, URL, or TextContentItem)
|
||||
|
||||
with pytest.raises(ValueError, match="Unexpected document content type: <class 'int'>"):
|
||||
await get_raw_document_text(mock_document)
|
||||
|
|
@ -1,325 +0,0 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from datetime import datetime
|
||||
from unittest.mock import AsyncMock
|
||||
|
||||
import pytest
|
||||
|
||||
from llama_stack.apis.agents import (
|
||||
Agent,
|
||||
AgentConfig,
|
||||
AgentCreateResponse,
|
||||
)
|
||||
from llama_stack.apis.common.responses import PaginatedResponse
|
||||
from llama_stack.apis.conversations import Conversations
|
||||
from llama_stack.apis.inference import Inference
|
||||
from llama_stack.apis.safety import Safety
|
||||
from llama_stack.apis.tools import ListToolDefsResponse, ToolDef, ToolGroups, ToolRuntime
|
||||
from llama_stack.apis.vector_io import VectorIO
|
||||
from llama_stack.providers.inline.agents.meta_reference.agent_instance import ChatAgent
|
||||
from llama_stack.providers.inline.agents.meta_reference.agents import MetaReferenceAgentsImpl
|
||||
from llama_stack.providers.inline.agents.meta_reference.config import MetaReferenceAgentsImplConfig
|
||||
from llama_stack.providers.inline.agents.meta_reference.persistence import AgentInfo
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def setup_backends(tmp_path):
|
||||
"""Register KV and SQL store backends for testing."""
|
||||
from llama_stack.core.storage.datatypes import SqliteKVStoreConfig, SqliteSqlStoreConfig
|
||||
from llama_stack.providers.utils.kvstore.kvstore import register_kvstore_backends
|
||||
from llama_stack.providers.utils.sqlstore.sqlstore import register_sqlstore_backends
|
||||
|
||||
kv_path = str(tmp_path / "test_kv.db")
|
||||
sql_path = str(tmp_path / "test_sql.db")
|
||||
|
||||
register_kvstore_backends({"kv_default": SqliteKVStoreConfig(db_path=kv_path)})
|
||||
register_sqlstore_backends({"sql_default": SqliteSqlStoreConfig(db_path=sql_path)})
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_apis():
|
||||
return {
|
||||
"inference_api": AsyncMock(spec=Inference),
|
||||
"vector_io_api": AsyncMock(spec=VectorIO),
|
||||
"safety_api": AsyncMock(spec=Safety),
|
||||
"tool_runtime_api": AsyncMock(spec=ToolRuntime),
|
||||
"tool_groups_api": AsyncMock(spec=ToolGroups),
|
||||
"conversations_api": AsyncMock(spec=Conversations),
|
||||
}
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def config(tmp_path):
|
||||
from llama_stack.core.storage.datatypes import KVStoreReference, ResponsesStoreReference
|
||||
from llama_stack.providers.inline.agents.meta_reference.config import AgentPersistenceConfig
|
||||
|
||||
return MetaReferenceAgentsImplConfig(
|
||||
persistence=AgentPersistenceConfig(
|
||||
agent_state=KVStoreReference(
|
||||
backend="kv_default",
|
||||
namespace="agents",
|
||||
),
|
||||
responses=ResponsesStoreReference(
|
||||
backend="sql_default",
|
||||
table_name="responses",
|
||||
),
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def agents_impl(config, mock_apis):
|
||||
impl = MetaReferenceAgentsImpl(
|
||||
config,
|
||||
mock_apis["inference_api"],
|
||||
mock_apis["vector_io_api"],
|
||||
mock_apis["safety_api"],
|
||||
mock_apis["tool_runtime_api"],
|
||||
mock_apis["tool_groups_api"],
|
||||
mock_apis["conversations_api"],
|
||||
[],
|
||||
)
|
||||
await impl.initialize()
|
||||
yield impl
|
||||
await impl.shutdown()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_agent_config():
|
||||
return AgentConfig(
|
||||
sampling_params={
|
||||
"strategy": {"type": "greedy"},
|
||||
"max_tokens": 0,
|
||||
"repetition_penalty": 1.0,
|
||||
},
|
||||
input_shields=["string"],
|
||||
output_shields=["string"],
|
||||
toolgroups=["mcp::my_mcp_server"],
|
||||
client_tools=[
|
||||
{
|
||||
"name": "client_tool",
|
||||
"description": "Client Tool",
|
||||
"parameters": [
|
||||
{
|
||||
"name": "string",
|
||||
"parameter_type": "string",
|
||||
"description": "string",
|
||||
"required": True,
|
||||
"default": None,
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"property1": None,
|
||||
"property2": None,
|
||||
},
|
||||
}
|
||||
],
|
||||
tool_choice="auto",
|
||||
tool_prompt_format="json",
|
||||
tool_config={
|
||||
"tool_choice": "auto",
|
||||
"tool_prompt_format": "json",
|
||||
"system_message_behavior": "append",
|
||||
},
|
||||
max_infer_iters=10,
|
||||
model="string",
|
||||
instructions="string",
|
||||
enable_session_persistence=False,
|
||||
response_format={
|
||||
"type": "json_schema",
|
||||
"json_schema": {
|
||||
"property1": None,
|
||||
"property2": None,
|
||||
},
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
async def test_create_agent(agents_impl, sample_agent_config):
|
||||
response = await agents_impl.create_agent(sample_agent_config)
|
||||
|
||||
assert isinstance(response, AgentCreateResponse)
|
||||
assert response.agent_id is not None
|
||||
|
||||
stored_agent = await agents_impl.persistence_store.get(f"agent:{response.agent_id}")
|
||||
assert stored_agent is not None
|
||||
agent_info = AgentInfo.model_validate_json(stored_agent)
|
||||
assert agent_info.model == sample_agent_config.model
|
||||
assert agent_info.created_at is not None
|
||||
assert isinstance(agent_info.created_at, datetime)
|
||||
|
||||
|
||||
async def test_get_agent(agents_impl, sample_agent_config):
|
||||
create_response = await agents_impl.create_agent(sample_agent_config)
|
||||
agent_id = create_response.agent_id
|
||||
|
||||
agent = await agents_impl.get_agent(agent_id)
|
||||
|
||||
assert isinstance(agent, Agent)
|
||||
assert agent.agent_id == agent_id
|
||||
assert agent.agent_config.model == sample_agent_config.model
|
||||
assert agent.created_at is not None
|
||||
assert isinstance(agent.created_at, datetime)
|
||||
|
||||
|
||||
async def test_list_agents(agents_impl, sample_agent_config):
|
||||
agent1_response = await agents_impl.create_agent(sample_agent_config)
|
||||
agent2_response = await agents_impl.create_agent(sample_agent_config)
|
||||
|
||||
response = await agents_impl.list_agents()
|
||||
|
||||
assert isinstance(response, PaginatedResponse)
|
||||
assert len(response.data) == 2
|
||||
agent_ids = {agent["agent_id"] for agent in response.data}
|
||||
assert agent1_response.agent_id in agent_ids
|
||||
assert agent2_response.agent_id in agent_ids
|
||||
|
||||
|
||||
@pytest.mark.parametrize("enable_session_persistence", [True, False])
|
||||
async def test_create_agent_session_persistence(agents_impl, sample_agent_config, enable_session_persistence):
|
||||
# Create an agent with specified persistence setting
|
||||
config = sample_agent_config.model_copy()
|
||||
config.enable_session_persistence = enable_session_persistence
|
||||
response = await agents_impl.create_agent(config)
|
||||
agent_id = response.agent_id
|
||||
|
||||
# Create a session
|
||||
session_response = await agents_impl.create_agent_session(agent_id, "test_session")
|
||||
assert session_response.session_id is not None
|
||||
|
||||
# Verify the session was stored
|
||||
session = await agents_impl.get_agents_session(session_response.session_id, agent_id)
|
||||
assert session.session_name == "test_session"
|
||||
assert session.session_id == session_response.session_id
|
||||
assert session.started_at is not None
|
||||
assert session.turns == []
|
||||
|
||||
# Delete the session
|
||||
await agents_impl.delete_agents_session(session_response.session_id, agent_id)
|
||||
|
||||
# Verify the session was deleted
|
||||
with pytest.raises(ValueError):
|
||||
await agents_impl.get_agents_session(session_response.session_id, agent_id)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("enable_session_persistence", [True, False])
|
||||
async def test_list_agent_sessions_persistence(agents_impl, sample_agent_config, enable_session_persistence):
|
||||
# Create an agent with specified persistence setting
|
||||
config = sample_agent_config.model_copy()
|
||||
config.enable_session_persistence = enable_session_persistence
|
||||
response = await agents_impl.create_agent(config)
|
||||
agent_id = response.agent_id
|
||||
|
||||
# Create multiple sessions
|
||||
session1 = await agents_impl.create_agent_session(agent_id, "session1")
|
||||
session2 = await agents_impl.create_agent_session(agent_id, "session2")
|
||||
|
||||
# List sessions
|
||||
sessions = await agents_impl.list_agent_sessions(agent_id)
|
||||
assert len(sessions.data) == 2
|
||||
session_ids = {s["session_id"] for s in sessions.data}
|
||||
assert session1.session_id in session_ids
|
||||
assert session2.session_id in session_ids
|
||||
|
||||
# Delete one session
|
||||
await agents_impl.delete_agents_session(session1.session_id, agent_id)
|
||||
|
||||
# Verify the session was deleted
|
||||
with pytest.raises(ValueError):
|
||||
await agents_impl.get_agents_session(session1.session_id, agent_id)
|
||||
|
||||
# List sessions again
|
||||
sessions = await agents_impl.list_agent_sessions(agent_id)
|
||||
assert len(sessions.data) == 1
|
||||
assert session2.session_id in {s["session_id"] for s in sessions.data}
|
||||
|
||||
|
||||
async def test_delete_agent(agents_impl, sample_agent_config):
|
||||
# Create an agent
|
||||
response = await agents_impl.create_agent(sample_agent_config)
|
||||
agent_id = response.agent_id
|
||||
|
||||
# Delete the agent
|
||||
await agents_impl.delete_agent(agent_id)
|
||||
|
||||
# Verify the agent was deleted
|
||||
with pytest.raises(ValueError):
|
||||
await agents_impl.get_agent(agent_id)
|
||||
|
||||
|
||||
async def test__initialize_tools(agents_impl, sample_agent_config):
|
||||
# Mock tool_groups_api.list_tools()
|
||||
agents_impl.tool_groups_api.list_tools.return_value = ListToolDefsResponse(
|
||||
data=[
|
||||
ToolDef(
|
||||
name="story_maker",
|
||||
toolgroup_id="mcp::my_mcp_server",
|
||||
description="Make a story",
|
||||
input_schema={
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"story_title": {"type": "string", "description": "Title of the story", "title": "Story Title"},
|
||||
"input_words": {
|
||||
"type": "array",
|
||||
"description": "Input words",
|
||||
"items": {"type": "string"},
|
||||
"title": "Input Words",
|
||||
"default": [],
|
||||
},
|
||||
},
|
||||
"required": ["story_title"],
|
||||
},
|
||||
)
|
||||
]
|
||||
)
|
||||
|
||||
create_response = await agents_impl.create_agent(sample_agent_config)
|
||||
agent_id = create_response.agent_id
|
||||
|
||||
# Get an instance of ChatAgent
|
||||
chat_agent = await agents_impl._get_agent_impl(agent_id)
|
||||
assert chat_agent is not None
|
||||
assert isinstance(chat_agent, ChatAgent)
|
||||
|
||||
# Initialize tool definitions
|
||||
await chat_agent._initialize_tools()
|
||||
assert len(chat_agent.tool_defs) == 2
|
||||
|
||||
# Verify the first tool, which is a client tool
|
||||
first_tool = chat_agent.tool_defs[0]
|
||||
assert first_tool.tool_name == "client_tool"
|
||||
assert first_tool.description == "Client Tool"
|
||||
|
||||
# Verify the second tool, which is an MCP tool that has an array-type property
|
||||
second_tool = chat_agent.tool_defs[1]
|
||||
assert second_tool.tool_name == "story_maker"
|
||||
assert second_tool.description == "Make a story"
|
||||
|
||||
# Verify the input schema
|
||||
input_schema = second_tool.input_schema
|
||||
assert input_schema is not None
|
||||
assert input_schema["type"] == "object"
|
||||
|
||||
properties = input_schema["properties"]
|
||||
assert len(properties) == 2
|
||||
|
||||
# Verify a string property
|
||||
story_title = properties["story_title"]
|
||||
assert story_title["type"] == "string"
|
||||
assert story_title["description"] == "Title of the story"
|
||||
assert story_title["title"] == "Story Title"
|
||||
|
||||
# Verify an array property
|
||||
input_words = properties["input_words"]
|
||||
assert input_words["type"] == "array"
|
||||
assert input_words["description"] == "Input words"
|
||||
assert input_words["items"]["type"] == "string"
|
||||
assert input_words["title"] == "Input Words"
|
||||
assert input_words["default"] == []
|
||||
|
||||
# Verify required fields
|
||||
assert input_schema["required"] == ["story_title"]
|
||||
|
|
@ -8,7 +8,7 @@ import os
|
|||
|
||||
import yaml
|
||||
|
||||
from llama_stack.apis.inference import (
|
||||
from llama_stack_api.inference import (
|
||||
OpenAIChatCompletion,
|
||||
)
|
||||
|
||||
|
|
|
|||
|
|
@ -15,23 +15,25 @@ from openai.types.chat.chat_completion_chunk import (
|
|||
ChoiceDeltaToolCallFunction,
|
||||
)
|
||||
|
||||
from llama_stack.apis.agents import Order
|
||||
from llama_stack.apis.agents.openai_responses import (
|
||||
ListOpenAIResponseInputItem,
|
||||
OpenAIResponseInputMessageContentText,
|
||||
OpenAIResponseInputToolFunction,
|
||||
OpenAIResponseInputToolMCP,
|
||||
OpenAIResponseInputToolWebSearch,
|
||||
OpenAIResponseMessage,
|
||||
OpenAIResponseOutputMessageContentOutputText,
|
||||
OpenAIResponseOutputMessageFunctionToolCall,
|
||||
OpenAIResponseOutputMessageMCPCall,
|
||||
OpenAIResponseOutputMessageWebSearchToolCall,
|
||||
OpenAIResponseText,
|
||||
OpenAIResponseTextFormat,
|
||||
WebSearchToolTypes,
|
||||
from llama_stack.core.access_control.access_control import default_policy
|
||||
from llama_stack.core.storage.datatypes import ResponsesStoreReference, SqliteSqlStoreConfig
|
||||
from llama_stack.core.storage.sqlstore.sqlstore import register_sqlstore_backends
|
||||
from llama_stack.providers.inline.agents.meta_reference.responses.openai_responses import (
|
||||
OpenAIResponsesImpl,
|
||||
)
|
||||
from llama_stack.apis.inference import (
|
||||
from llama_stack.providers.utils.responses.responses_store import (
|
||||
ResponsesStore,
|
||||
_OpenAIResponseObjectWithInputAndMessages,
|
||||
)
|
||||
from llama_stack_api import (
|
||||
OpenAIChatCompletionContentPartImageParam,
|
||||
OpenAIFile,
|
||||
OpenAIFileObject,
|
||||
OpenAISystemMessageParam,
|
||||
Prompt,
|
||||
)
|
||||
from llama_stack_api.agents import Order
|
||||
from llama_stack_api.inference import (
|
||||
OpenAIAssistantMessageParam,
|
||||
OpenAIChatCompletionContentPartTextParam,
|
||||
OpenAIChatCompletionRequestWithExtraBody,
|
||||
|
|
@ -41,17 +43,25 @@ from llama_stack.apis.inference import (
|
|||
OpenAIResponseFormatJSONSchema,
|
||||
OpenAIUserMessageParam,
|
||||
)
|
||||
from llama_stack.apis.tools.tools import ListToolDefsResponse, ToolDef, ToolGroups, ToolInvocationResult, ToolRuntime
|
||||
from llama_stack.core.access_control.access_control import default_policy
|
||||
from llama_stack.core.storage.datatypes import ResponsesStoreReference, SqliteSqlStoreConfig
|
||||
from llama_stack.providers.inline.agents.meta_reference.responses.openai_responses import (
|
||||
OpenAIResponsesImpl,
|
||||
from llama_stack_api.openai_responses import (
|
||||
ListOpenAIResponseInputItem,
|
||||
OpenAIResponseInputMessageContentFile,
|
||||
OpenAIResponseInputMessageContentImage,
|
||||
OpenAIResponseInputMessageContentText,
|
||||
OpenAIResponseInputToolFunction,
|
||||
OpenAIResponseInputToolMCP,
|
||||
OpenAIResponseInputToolWebSearch,
|
||||
OpenAIResponseMessage,
|
||||
OpenAIResponseOutputMessageContentOutputText,
|
||||
OpenAIResponseOutputMessageFunctionToolCall,
|
||||
OpenAIResponseOutputMessageMCPCall,
|
||||
OpenAIResponseOutputMessageWebSearchToolCall,
|
||||
OpenAIResponsePrompt,
|
||||
OpenAIResponseText,
|
||||
OpenAIResponseTextFormat,
|
||||
WebSearchToolTypes,
|
||||
)
|
||||
from llama_stack.providers.utils.responses.responses_store import (
|
||||
ResponsesStore,
|
||||
_OpenAIResponseObjectWithInputAndMessages,
|
||||
)
|
||||
from llama_stack.providers.utils.sqlstore.sqlstore import register_sqlstore_backends
|
||||
from llama_stack_api.tools import ListToolDefsResponse, ToolDef, ToolGroups, ToolInvocationResult, ToolRuntime
|
||||
from tests.unit.providers.agents.meta_reference.fixtures import load_chat_completion_fixture
|
||||
|
||||
|
||||
|
|
@ -98,6 +108,19 @@ def mock_safety_api():
|
|||
return safety_api
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_prompts_api():
|
||||
prompts_api = AsyncMock()
|
||||
return prompts_api
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_files_api():
|
||||
"""Mock files API for testing."""
|
||||
files_api = AsyncMock()
|
||||
return files_api
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def openai_responses_impl(
|
||||
mock_inference_api,
|
||||
|
|
@ -107,6 +130,8 @@ def openai_responses_impl(
|
|||
mock_vector_io_api,
|
||||
mock_safety_api,
|
||||
mock_conversations_api,
|
||||
mock_prompts_api,
|
||||
mock_files_api,
|
||||
):
|
||||
return OpenAIResponsesImpl(
|
||||
inference_api=mock_inference_api,
|
||||
|
|
@ -116,6 +141,8 @@ def openai_responses_impl(
|
|||
vector_io_api=mock_vector_io_api,
|
||||
safety_api=mock_safety_api,
|
||||
conversations_api=mock_conversations_api,
|
||||
prompts_api=mock_prompts_api,
|
||||
files_api=mock_files_api,
|
||||
)
|
||||
|
||||
|
||||
|
|
@ -499,7 +526,7 @@ async def test_create_openai_response_with_tool_call_function_arguments_none(ope
|
|||
mock_inference_api.openai_chat_completion.return_value = fake_stream_toolcall()
|
||||
|
||||
|
||||
async def test_create_openai_response_with_multiple_messages(openai_responses_impl, mock_inference_api):
|
||||
async def test_create_openai_response_with_multiple_messages(openai_responses_impl, mock_inference_api, mock_files_api):
|
||||
"""Test creating an OpenAI response with multiple messages."""
|
||||
# Setup
|
||||
input_messages = [
|
||||
|
|
@ -710,7 +737,7 @@ async def test_create_openai_response_with_instructions(openai_responses_impl, m
|
|||
|
||||
|
||||
async def test_create_openai_response_with_instructions_and_multiple_messages(
|
||||
openai_responses_impl, mock_inference_api
|
||||
openai_responses_impl, mock_inference_api, mock_files_api
|
||||
):
|
||||
# Setup
|
||||
input_messages = [
|
||||
|
|
@ -1242,3 +1269,489 @@ async def test_create_openai_response_with_output_types_as_input(
|
|||
|
||||
assert stored_with_outputs.input == input_with_output_types
|
||||
assert len(stored_with_outputs.input) == 3
|
||||
|
||||
|
||||
async def test_create_openai_response_with_prompt(openai_responses_impl, mock_inference_api, mock_prompts_api):
|
||||
"""Test creating an OpenAI response with a prompt."""
|
||||
input_text = "What is the capital of Ireland?"
|
||||
model = "meta-llama/Llama-3.1-8B-Instruct"
|
||||
prompt_id = "pmpt_1234567890abcdef1234567890abcdef1234567890abcdef"
|
||||
prompt = Prompt(
|
||||
prompt="You are a helpful {{ area_name }} assistant at {{ company_name }}. Always provide accurate information.",
|
||||
prompt_id=prompt_id,
|
||||
version=1,
|
||||
variables=["area_name", "company_name"],
|
||||
is_default=True,
|
||||
)
|
||||
|
||||
openai_response_prompt = OpenAIResponsePrompt(
|
||||
id=prompt_id,
|
||||
version="1",
|
||||
variables={
|
||||
"area_name": OpenAIResponseInputMessageContentText(text="geography"),
|
||||
"company_name": OpenAIResponseInputMessageContentText(text="Dummy Company"),
|
||||
},
|
||||
)
|
||||
|
||||
mock_prompts_api.get_prompt.return_value = prompt
|
||||
mock_inference_api.openai_chat_completion.return_value = fake_stream()
|
||||
|
||||
result = await openai_responses_impl.create_openai_response(
|
||||
input=input_text,
|
||||
model=model,
|
||||
prompt=openai_response_prompt,
|
||||
)
|
||||
|
||||
mock_prompts_api.get_prompt.assert_called_with(prompt_id, 1)
|
||||
mock_inference_api.openai_chat_completion.assert_called()
|
||||
call_args = mock_inference_api.openai_chat_completion.call_args
|
||||
sent_messages = call_args.args[0].messages
|
||||
assert len(sent_messages) == 2
|
||||
|
||||
system_messages = [msg for msg in sent_messages if msg.role == "system"]
|
||||
assert len(system_messages) == 1
|
||||
assert (
|
||||
system_messages[0].content
|
||||
== "You are a helpful geography assistant at Dummy Company. Always provide accurate information."
|
||||
)
|
||||
|
||||
user_messages = [msg for msg in sent_messages if msg.role == "user"]
|
||||
assert len(user_messages) == 1
|
||||
assert user_messages[0].content == input_text
|
||||
|
||||
assert result.model == model
|
||||
assert result.status == "completed"
|
||||
assert isinstance(result.prompt, OpenAIResponsePrompt)
|
||||
assert result.prompt.id == prompt_id
|
||||
assert result.prompt.variables == openai_response_prompt.variables
|
||||
assert result.prompt.version == "1"
|
||||
|
||||
|
||||
async def test_prepend_prompt_successful_without_variables(openai_responses_impl, mock_prompts_api, mock_inference_api):
|
||||
"""Test prepend_prompt function without variables."""
|
||||
input_text = "What is the capital of Ireland?"
|
||||
model = "meta-llama/Llama-3.1-8B-Instruct"
|
||||
prompt_id = "pmpt_1234567890abcdef1234567890abcdef1234567890abcdef"
|
||||
prompt = Prompt(
|
||||
prompt="You are a helpful assistant. Always provide accurate information.",
|
||||
prompt_id=prompt_id,
|
||||
version=1,
|
||||
variables=[],
|
||||
is_default=True,
|
||||
)
|
||||
|
||||
openai_response_prompt = OpenAIResponsePrompt(id=prompt_id, version="1")
|
||||
|
||||
mock_prompts_api.get_prompt.return_value = prompt
|
||||
mock_inference_api.openai_chat_completion.return_value = fake_stream()
|
||||
|
||||
await openai_responses_impl.create_openai_response(
|
||||
input=input_text,
|
||||
model=model,
|
||||
prompt=openai_response_prompt,
|
||||
)
|
||||
|
||||
mock_prompts_api.get_prompt.assert_called_with(prompt_id, 1)
|
||||
mock_inference_api.openai_chat_completion.assert_called()
|
||||
call_args = mock_inference_api.openai_chat_completion.call_args
|
||||
sent_messages = call_args.args[0].messages
|
||||
assert len(sent_messages) == 2
|
||||
system_messages = [msg for msg in sent_messages if msg.role == "system"]
|
||||
assert system_messages[0].content == "You are a helpful assistant. Always provide accurate information."
|
||||
|
||||
|
||||
async def test_prepend_prompt_invalid_variable(openai_responses_impl, mock_prompts_api):
|
||||
"""Test error handling in prepend_prompt function when prompt parameters contain invalid variables."""
|
||||
prompt_id = "pmpt_1234567890abcdef1234567890abcdef1234567890abcdef"
|
||||
prompt = Prompt(
|
||||
prompt="You are a {{ role }} assistant.",
|
||||
prompt_id=prompt_id,
|
||||
version=1,
|
||||
variables=["role"], # Only "role" is valid
|
||||
is_default=True,
|
||||
)
|
||||
|
||||
openai_response_prompt = OpenAIResponsePrompt(
|
||||
id=prompt_id,
|
||||
version="1",
|
||||
variables={
|
||||
"role": OpenAIResponseInputMessageContentText(text="helpful"),
|
||||
"company": OpenAIResponseInputMessageContentText(
|
||||
text="Dummy Company"
|
||||
), # company is not in prompt.variables
|
||||
},
|
||||
)
|
||||
|
||||
mock_prompts_api.get_prompt.return_value = prompt
|
||||
|
||||
# Initial messages
|
||||
messages = [OpenAIUserMessageParam(content="Test prompt")]
|
||||
|
||||
# Execute - should raise ValueError for invalid variable
|
||||
with pytest.raises(ValueError, match="Variable company not found in prompt"):
|
||||
await openai_responses_impl._prepend_prompt(messages, openai_response_prompt)
|
||||
|
||||
# Verify
|
||||
mock_prompts_api.get_prompt.assert_called_once_with(prompt_id, 1)
|
||||
|
||||
|
||||
async def test_prepend_prompt_not_found(openai_responses_impl, mock_prompts_api):
|
||||
"""Test prepend_prompt function when prompt is not found."""
|
||||
prompt_id = "pmpt_nonexistent"
|
||||
openai_response_prompt = OpenAIResponsePrompt(id=prompt_id, version="1")
|
||||
|
||||
mock_prompts_api.get_prompt.return_value = None # Prompt not found
|
||||
|
||||
# Initial messages
|
||||
messages = [OpenAIUserMessageParam(content="Test prompt")]
|
||||
initial_length = len(messages)
|
||||
|
||||
# Execute
|
||||
result = await openai_responses_impl._prepend_prompt(messages, openai_response_prompt)
|
||||
|
||||
# Verify
|
||||
mock_prompts_api.get_prompt.assert_called_once_with(prompt_id, 1)
|
||||
|
||||
# Should return None when prompt not found
|
||||
assert result is None
|
||||
|
||||
# Messages should not be modified
|
||||
assert len(messages) == initial_length
|
||||
assert messages[0].content == "Test prompt"
|
||||
|
||||
|
||||
async def test_prepend_prompt_variable_substitution(openai_responses_impl, mock_prompts_api):
|
||||
"""Test complex variable substitution with multiple occurrences and special characters in prepend_prompt function."""
|
||||
prompt_id = "pmpt_1234567890abcdef1234567890abcdef1234567890abcdef"
|
||||
|
||||
# Support all whitespace variations: {{name}}, {{ name }}, {{ name}}, {{name }}, etc.
|
||||
prompt = Prompt(
|
||||
prompt="Hello {{name}}! You are working at {{ company}}. Your role is {{role}} at {{company}}. Remember, {{ name }}, to be {{ tone }}.",
|
||||
prompt_id=prompt_id,
|
||||
version=1,
|
||||
variables=["name", "company", "role", "tone"],
|
||||
is_default=True,
|
||||
)
|
||||
|
||||
openai_response_prompt = OpenAIResponsePrompt(
|
||||
id=prompt_id,
|
||||
version="1",
|
||||
variables={
|
||||
"name": OpenAIResponseInputMessageContentText(text="Alice"),
|
||||
"company": OpenAIResponseInputMessageContentText(text="Dummy Company"),
|
||||
"role": OpenAIResponseInputMessageContentText(text="AI Assistant"),
|
||||
"tone": OpenAIResponseInputMessageContentText(text="professional"),
|
||||
},
|
||||
)
|
||||
|
||||
mock_prompts_api.get_prompt.return_value = prompt
|
||||
|
||||
# Initial messages
|
||||
messages = [OpenAIUserMessageParam(content="Test")]
|
||||
|
||||
# Execute
|
||||
await openai_responses_impl._prepend_prompt(messages, openai_response_prompt)
|
||||
|
||||
# Verify
|
||||
assert len(messages) == 2
|
||||
assert isinstance(messages[0], OpenAISystemMessageParam)
|
||||
expected_content = "Hello Alice! You are working at Dummy Company. Your role is AI Assistant at Dummy Company. Remember, Alice, to be professional."
|
||||
assert messages[0].content == expected_content
|
||||
|
||||
|
||||
async def test_prepend_prompt_with_image_variable(openai_responses_impl, mock_prompts_api, mock_files_api):
|
||||
"""Test prepend_prompt with image variable - should create placeholder in system message and append image as separate user message."""
|
||||
prompt_id = "pmpt_1234567890abcdef1234567890abcdef1234567890abcdef"
|
||||
prompt = Prompt(
|
||||
prompt="Analyze this {{product_image}} and describe what you see.",
|
||||
prompt_id=prompt_id,
|
||||
version=1,
|
||||
variables=["product_image"],
|
||||
is_default=True,
|
||||
)
|
||||
|
||||
# Mock file content and file metadata
|
||||
mock_file_content = b"fake_image_data"
|
||||
mock_files_api.openai_retrieve_file_content.return_value = type("obj", (object,), {"body": mock_file_content})()
|
||||
mock_files_api.openai_retrieve_file.return_value = OpenAIFileObject(
|
||||
object="file",
|
||||
id="file-abc123",
|
||||
bytes=len(mock_file_content),
|
||||
created_at=1234567890,
|
||||
expires_at=1234567890,
|
||||
filename="product.jpg",
|
||||
purpose="assistants",
|
||||
)
|
||||
|
||||
openai_response_prompt = OpenAIResponsePrompt(
|
||||
id=prompt_id,
|
||||
version="1",
|
||||
variables={
|
||||
"product_image": OpenAIResponseInputMessageContentImage(
|
||||
file_id="file-abc123",
|
||||
detail="high",
|
||||
)
|
||||
},
|
||||
)
|
||||
|
||||
mock_prompts_api.get_prompt.return_value = prompt
|
||||
|
||||
# Initial messages
|
||||
messages = [OpenAIUserMessageParam(content="What do you think?")]
|
||||
|
||||
# Execute
|
||||
await openai_responses_impl._prepend_prompt(messages, openai_response_prompt)
|
||||
|
||||
assert len(messages) == 3
|
||||
|
||||
# Check system message has placeholder
|
||||
assert isinstance(messages[0], OpenAISystemMessageParam)
|
||||
assert messages[0].content == "Analyze this [Image: product_image] and describe what you see."
|
||||
|
||||
# Check original user message is still there
|
||||
assert isinstance(messages[1], OpenAIUserMessageParam)
|
||||
assert messages[1].content == "What do you think?"
|
||||
|
||||
# Check new user message with image is appended
|
||||
assert isinstance(messages[2], OpenAIUserMessageParam)
|
||||
assert isinstance(messages[2].content, list)
|
||||
assert len(messages[2].content) == 1
|
||||
|
||||
# Should be image with data URL
|
||||
assert isinstance(messages[2].content[0], OpenAIChatCompletionContentPartImageParam)
|
||||
assert messages[2].content[0].image_url.url.startswith("data:image/")
|
||||
assert messages[2].content[0].image_url.detail == "high"
|
||||
|
||||
|
||||
async def test_prepend_prompt_with_file_variable(openai_responses_impl, mock_prompts_api, mock_files_api):
|
||||
"""Test prepend_prompt with file variable - should create placeholder in system message and append file as separate user message."""
|
||||
prompt_id = "pmpt_1234567890abcdef1234567890abcdef1234567890abcdef"
|
||||
prompt = Prompt(
|
||||
prompt="Review the document {{contract_file}} and summarize key points.",
|
||||
prompt_id=prompt_id,
|
||||
version=1,
|
||||
variables=["contract_file"],
|
||||
is_default=True,
|
||||
)
|
||||
|
||||
# Mock file retrieval
|
||||
mock_file_content = b"fake_pdf_content"
|
||||
mock_files_api.openai_retrieve_file_content.return_value = type("obj", (object,), {"body": mock_file_content})()
|
||||
mock_files_api.openai_retrieve_file.return_value = OpenAIFileObject(
|
||||
object="file",
|
||||
id="file-contract-789",
|
||||
bytes=len(mock_file_content),
|
||||
created_at=1234567890,
|
||||
expires_at=1234567890,
|
||||
filename="contract.pdf",
|
||||
purpose="assistants",
|
||||
)
|
||||
|
||||
openai_response_prompt = OpenAIResponsePrompt(
|
||||
id=prompt_id,
|
||||
version="1",
|
||||
variables={
|
||||
"contract_file": OpenAIResponseInputMessageContentFile(
|
||||
file_id="file-contract-789",
|
||||
filename="contract.pdf",
|
||||
)
|
||||
},
|
||||
)
|
||||
|
||||
mock_prompts_api.get_prompt.return_value = prompt
|
||||
|
||||
# Initial messages
|
||||
messages = [OpenAIUserMessageParam(content="Please review this.")]
|
||||
|
||||
# Execute
|
||||
await openai_responses_impl._prepend_prompt(messages, openai_response_prompt)
|
||||
|
||||
assert len(messages) == 3
|
||||
|
||||
# Check system message has placeholder
|
||||
assert isinstance(messages[0], OpenAISystemMessageParam)
|
||||
assert messages[0].content == "Review the document [File: contract_file] and summarize key points."
|
||||
|
||||
# Check original user message is still there
|
||||
assert isinstance(messages[1], OpenAIUserMessageParam)
|
||||
assert messages[1].content == "Please review this."
|
||||
|
||||
# Check new user message with file is appended
|
||||
assert isinstance(messages[2], OpenAIUserMessageParam)
|
||||
assert isinstance(messages[2].content, list)
|
||||
assert len(messages[2].content) == 1
|
||||
|
||||
# First part should be file with data URL
|
||||
assert isinstance(messages[2].content[0], OpenAIFile)
|
||||
assert messages[2].content[0].file.file_data.startswith("data:application/pdf;base64,")
|
||||
assert messages[2].content[0].file.filename == "contract.pdf"
|
||||
assert messages[2].content[0].file.file_id is None
|
||||
|
||||
|
||||
async def test_prepend_prompt_with_mixed_variables(openai_responses_impl, mock_prompts_api, mock_files_api):
|
||||
"""Test prepend_prompt with text, image, and file variables mixed together."""
|
||||
prompt_id = "pmpt_1234567890abcdef1234567890abcdef1234567890abcdef"
|
||||
prompt = Prompt(
|
||||
prompt="Hello {{name}}! Analyze {{photo}} and review {{document}}. Provide insights for {{company}}.",
|
||||
prompt_id=prompt_id,
|
||||
version=1,
|
||||
variables=["name", "photo", "document", "company"],
|
||||
is_default=True,
|
||||
)
|
||||
|
||||
# Mock file retrieval for image and file
|
||||
mock_image_content = b"fake_image_data"
|
||||
mock_file_content = b"fake_doc_content"
|
||||
|
||||
async def mock_retrieve_file_content(file_id):
|
||||
if file_id == "file-photo-123":
|
||||
return type("obj", (object,), {"body": mock_image_content})()
|
||||
elif file_id == "file-doc-456":
|
||||
return type("obj", (object,), {"body": mock_file_content})()
|
||||
|
||||
mock_files_api.openai_retrieve_file_content.side_effect = mock_retrieve_file_content
|
||||
|
||||
def mock_retrieve_file(file_id):
|
||||
if file_id == "file-photo-123":
|
||||
return OpenAIFileObject(
|
||||
object="file",
|
||||
id="file-photo-123",
|
||||
bytes=len(mock_image_content),
|
||||
created_at=1234567890,
|
||||
expires_at=1234567890,
|
||||
filename="photo.jpg",
|
||||
purpose="assistants",
|
||||
)
|
||||
elif file_id == "file-doc-456":
|
||||
return OpenAIFileObject(
|
||||
object="file",
|
||||
id="file-doc-456",
|
||||
bytes=len(mock_file_content),
|
||||
created_at=1234567890,
|
||||
expires_at=1234567890,
|
||||
filename="doc.pdf",
|
||||
purpose="assistants",
|
||||
)
|
||||
|
||||
mock_files_api.openai_retrieve_file.side_effect = mock_retrieve_file
|
||||
|
||||
openai_response_prompt = OpenAIResponsePrompt(
|
||||
id=prompt_id,
|
||||
version="1",
|
||||
variables={
|
||||
"name": OpenAIResponseInputMessageContentText(text="Alice"),
|
||||
"photo": OpenAIResponseInputMessageContentImage(file_id="file-photo-123", detail="auto"),
|
||||
"document": OpenAIResponseInputMessageContentFile(file_id="file-doc-456", filename="doc.pdf"),
|
||||
"company": OpenAIResponseInputMessageContentText(text="Acme Corp"),
|
||||
},
|
||||
)
|
||||
|
||||
mock_prompts_api.get_prompt.return_value = prompt
|
||||
|
||||
# Initial messages
|
||||
messages = [OpenAIUserMessageParam(content="Here's my question.")]
|
||||
|
||||
# Execute
|
||||
await openai_responses_impl._prepend_prompt(messages, openai_response_prompt)
|
||||
|
||||
assert len(messages) == 3
|
||||
|
||||
# Check system message has text and placeholders
|
||||
assert isinstance(messages[0], OpenAISystemMessageParam)
|
||||
expected_system = "Hello Alice! Analyze [Image: photo] and review [File: document]. Provide insights for Acme Corp."
|
||||
assert messages[0].content == expected_system
|
||||
|
||||
# Check original user message is still there
|
||||
assert isinstance(messages[1], OpenAIUserMessageParam)
|
||||
assert messages[1].content == "Here's my question."
|
||||
|
||||
# Check new user message with media is appended (2 media items)
|
||||
assert isinstance(messages[2], OpenAIUserMessageParam)
|
||||
assert isinstance(messages[2].content, list)
|
||||
assert len(messages[2].content) == 2
|
||||
|
||||
# First part should be image with data URL
|
||||
assert isinstance(messages[2].content[0], OpenAIChatCompletionContentPartImageParam)
|
||||
assert messages[2].content[0].image_url.url.startswith("data:image/")
|
||||
|
||||
# Second part should be file with data URL
|
||||
assert isinstance(messages[2].content[1], OpenAIFile)
|
||||
assert messages[2].content[1].file.file_data.startswith("data:application/pdf;base64,")
|
||||
assert messages[2].content[1].file.filename == "doc.pdf"
|
||||
assert messages[2].content[1].file.file_id is None
|
||||
|
||||
|
||||
async def test_prepend_prompt_with_image_using_image_url(openai_responses_impl, mock_prompts_api):
|
||||
"""Test prepend_prompt with image variable using image_url instead of file_id."""
|
||||
prompt_id = "pmpt_1234567890abcdef1234567890abcdef1234567890abcdef"
|
||||
prompt = Prompt(
|
||||
prompt="Describe {{screenshot}}.",
|
||||
prompt_id=prompt_id,
|
||||
version=1,
|
||||
variables=["screenshot"],
|
||||
is_default=True,
|
||||
)
|
||||
|
||||
openai_response_prompt = OpenAIResponsePrompt(
|
||||
id=prompt_id,
|
||||
version="1",
|
||||
variables={
|
||||
"screenshot": OpenAIResponseInputMessageContentImage(
|
||||
image_url="https://example.com/screenshot.png",
|
||||
detail="low",
|
||||
)
|
||||
},
|
||||
)
|
||||
|
||||
mock_prompts_api.get_prompt.return_value = prompt
|
||||
|
||||
# Initial messages
|
||||
messages = [OpenAIUserMessageParam(content="What is this?")]
|
||||
|
||||
# Execute
|
||||
await openai_responses_impl._prepend_prompt(messages, openai_response_prompt)
|
||||
|
||||
assert len(messages) == 3
|
||||
|
||||
# Check system message has placeholder
|
||||
assert isinstance(messages[0], OpenAISystemMessageParam)
|
||||
assert messages[0].content == "Describe [Image: screenshot]."
|
||||
|
||||
# Check original user message is still there
|
||||
assert isinstance(messages[1], OpenAIUserMessageParam)
|
||||
assert messages[1].content == "What is this?"
|
||||
|
||||
# Check new user message with image is appended
|
||||
assert isinstance(messages[2], OpenAIUserMessageParam)
|
||||
assert isinstance(messages[2].content, list)
|
||||
|
||||
# Image should use the provided URL
|
||||
assert isinstance(messages[2].content[0], OpenAIChatCompletionContentPartImageParam)
|
||||
assert messages[2].content[0].image_url.url == "https://example.com/screenshot.png"
|
||||
assert messages[2].content[0].image_url.detail == "low"
|
||||
|
||||
|
||||
async def test_prepend_prompt_image_variable_missing_required_fields(openai_responses_impl, mock_prompts_api):
|
||||
"""Test prepend_prompt with image variable that has neither file_id nor image_url - should raise error."""
|
||||
prompt_id = "pmpt_1234567890abcdef1234567890abcdef1234567890abcdef"
|
||||
prompt = Prompt(
|
||||
prompt="Analyze {{bad_image}}.",
|
||||
prompt_id=prompt_id,
|
||||
version=1,
|
||||
variables=["bad_image"],
|
||||
is_default=True,
|
||||
)
|
||||
|
||||
# Create image content with neither file_id nor image_url
|
||||
openai_response_prompt = OpenAIResponsePrompt(
|
||||
id=prompt_id,
|
||||
version="1",
|
||||
variables={"bad_image": OpenAIResponseInputMessageContentImage()}, # No file_id or image_url
|
||||
)
|
||||
|
||||
mock_prompts_api.get_prompt.return_value = prompt
|
||||
messages = [OpenAIUserMessageParam(content="Test")]
|
||||
|
||||
# Execute - should raise ValueError
|
||||
with pytest.raises(ValueError, match="Image content must have either 'image_url' or 'file_id'"):
|
||||
await openai_responses_impl._prepend_prompt(messages, openai_response_prompt)
|
||||
|
|
|
|||
|
|
@ -7,20 +7,20 @@
|
|||
|
||||
import pytest
|
||||
|
||||
from llama_stack.apis.agents.openai_responses import (
|
||||
from llama_stack_api.common.errors import (
|
||||
ConversationNotFoundError,
|
||||
InvalidConversationIdError,
|
||||
)
|
||||
from llama_stack_api.conversations import (
|
||||
ConversationItemList,
|
||||
)
|
||||
from llama_stack_api.openai_responses import (
|
||||
OpenAIResponseMessage,
|
||||
OpenAIResponseObject,
|
||||
OpenAIResponseObjectStreamResponseCompleted,
|
||||
OpenAIResponseObjectStreamResponseOutputItemDone,
|
||||
OpenAIResponseOutputMessageContentOutputText,
|
||||
)
|
||||
from llama_stack.apis.common.errors import (
|
||||
ConversationNotFoundError,
|
||||
InvalidConversationIdError,
|
||||
)
|
||||
from llama_stack.apis.conversations.conversations import (
|
||||
ConversationItemList,
|
||||
)
|
||||
|
||||
# Import existing fixtures from the main responses test file
|
||||
pytest_plugins = ["tests.unit.providers.agents.meta_reference.test_openai_responses"]
|
||||
|
|
@ -39,6 +39,8 @@ def responses_impl_with_conversations(
|
|||
mock_vector_io_api,
|
||||
mock_conversations_api,
|
||||
mock_safety_api,
|
||||
mock_prompts_api,
|
||||
mock_files_api,
|
||||
):
|
||||
"""Create OpenAIResponsesImpl instance with conversations API."""
|
||||
return OpenAIResponsesImpl(
|
||||
|
|
@ -49,6 +51,8 @@ def responses_impl_with_conversations(
|
|||
vector_io_api=mock_vector_io_api,
|
||||
conversations_api=mock_conversations_api,
|
||||
safety_api=mock_safety_api,
|
||||
prompts_api=mock_prompts_api,
|
||||
files_api=mock_files_api,
|
||||
)
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -5,22 +5,20 @@
|
|||
# the root directory of this source tree.
|
||||
|
||||
|
||||
from unittest.mock import AsyncMock
|
||||
|
||||
import pytest
|
||||
|
||||
from llama_stack.apis.agents.openai_responses import (
|
||||
OpenAIResponseAnnotationFileCitation,
|
||||
OpenAIResponseInputFunctionToolCallOutput,
|
||||
OpenAIResponseInputMessageContentImage,
|
||||
OpenAIResponseInputMessageContentText,
|
||||
OpenAIResponseInputToolFunction,
|
||||
OpenAIResponseInputToolWebSearch,
|
||||
OpenAIResponseMessage,
|
||||
OpenAIResponseOutputMessageContentOutputText,
|
||||
OpenAIResponseOutputMessageFunctionToolCall,
|
||||
OpenAIResponseText,
|
||||
OpenAIResponseTextFormat,
|
||||
from llama_stack.providers.inline.agents.meta_reference.responses.utils import (
|
||||
_extract_citations_from_text,
|
||||
convert_chat_choice_to_response_message,
|
||||
convert_response_content_to_chat_content,
|
||||
convert_response_input_to_chat_messages,
|
||||
convert_response_text_to_chat_response_format,
|
||||
get_message_type_by_role,
|
||||
is_function_tool_call,
|
||||
)
|
||||
from llama_stack.apis.inference import (
|
||||
from llama_stack_api.inference import (
|
||||
OpenAIAssistantMessageParam,
|
||||
OpenAIChatCompletionContentPartImageParam,
|
||||
OpenAIChatCompletionContentPartTextParam,
|
||||
|
|
@ -35,17 +33,27 @@ from llama_stack.apis.inference import (
|
|||
OpenAIToolMessageParam,
|
||||
OpenAIUserMessageParam,
|
||||
)
|
||||
from llama_stack.providers.inline.agents.meta_reference.responses.utils import (
|
||||
_extract_citations_from_text,
|
||||
convert_chat_choice_to_response_message,
|
||||
convert_response_content_to_chat_content,
|
||||
convert_response_input_to_chat_messages,
|
||||
convert_response_text_to_chat_response_format,
|
||||
get_message_type_by_role,
|
||||
is_function_tool_call,
|
||||
from llama_stack_api.openai_responses import (
|
||||
OpenAIResponseAnnotationFileCitation,
|
||||
OpenAIResponseInputFunctionToolCallOutput,
|
||||
OpenAIResponseInputMessageContentImage,
|
||||
OpenAIResponseInputMessageContentText,
|
||||
OpenAIResponseInputToolFunction,
|
||||
OpenAIResponseInputToolWebSearch,
|
||||
OpenAIResponseMessage,
|
||||
OpenAIResponseOutputMessageContentOutputText,
|
||||
OpenAIResponseOutputMessageFunctionToolCall,
|
||||
OpenAIResponseText,
|
||||
OpenAIResponseTextFormat,
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_files_api():
|
||||
"""Mock files API for testing."""
|
||||
return AsyncMock()
|
||||
|
||||
|
||||
class TestConvertChatChoiceToResponseMessage:
|
||||
async def test_convert_string_content(self):
|
||||
choice = OpenAIChoice(
|
||||
|
|
@ -78,17 +86,17 @@ class TestConvertChatChoiceToResponseMessage:
|
|||
|
||||
|
||||
class TestConvertResponseContentToChatContent:
|
||||
async def test_convert_string_content(self):
|
||||
result = await convert_response_content_to_chat_content("Simple string")
|
||||
async def test_convert_string_content(self, mock_files_api):
|
||||
result = await convert_response_content_to_chat_content("Simple string", mock_files_api)
|
||||
assert result == "Simple string"
|
||||
|
||||
async def test_convert_text_content_parts(self):
|
||||
async def test_convert_text_content_parts(self, mock_files_api):
|
||||
content = [
|
||||
OpenAIResponseInputMessageContentText(text="First part"),
|
||||
OpenAIResponseOutputMessageContentOutputText(text="Second part"),
|
||||
]
|
||||
|
||||
result = await convert_response_content_to_chat_content(content)
|
||||
result = await convert_response_content_to_chat_content(content, mock_files_api)
|
||||
|
||||
assert len(result) == 2
|
||||
assert isinstance(result[0], OpenAIChatCompletionContentPartTextParam)
|
||||
|
|
@ -96,10 +104,10 @@ class TestConvertResponseContentToChatContent:
|
|||
assert isinstance(result[1], OpenAIChatCompletionContentPartTextParam)
|
||||
assert result[1].text == "Second part"
|
||||
|
||||
async def test_convert_image_content(self):
|
||||
async def test_convert_image_content(self, mock_files_api):
|
||||
content = [OpenAIResponseInputMessageContentImage(image_url="https://example.com/image.jpg", detail="high")]
|
||||
|
||||
result = await convert_response_content_to_chat_content(content)
|
||||
result = await convert_response_content_to_chat_content(content, mock_files_api)
|
||||
|
||||
assert len(result) == 1
|
||||
assert isinstance(result[0], OpenAIChatCompletionContentPartImageParam)
|
||||
|
|
|
|||
|
|
@ -5,7 +5,8 @@
|
|||
# the root directory of this source tree.
|
||||
|
||||
|
||||
from llama_stack.apis.agents.openai_responses import (
|
||||
from llama_stack.providers.inline.agents.meta_reference.responses.types import ToolContext
|
||||
from llama_stack_api.openai_responses import (
|
||||
MCPListToolsTool,
|
||||
OpenAIResponseInputToolFileSearch,
|
||||
OpenAIResponseInputToolFunction,
|
||||
|
|
@ -15,7 +16,6 @@ from llama_stack.apis.agents.openai_responses import (
|
|||
OpenAIResponseOutputMessageMCPListTools,
|
||||
OpenAIResponseToolMCP,
|
||||
)
|
||||
from llama_stack.providers.inline.agents.meta_reference.responses.types import ToolContext
|
||||
|
||||
|
||||
class TestToolContext:
|
||||
|
|
|
|||
|
|
@ -8,8 +8,6 @@ from unittest.mock import AsyncMock
|
|||
|
||||
import pytest
|
||||
|
||||
from llama_stack.apis.agents.agents import ResponseGuardrailSpec
|
||||
from llama_stack.apis.safety import ModerationObject, ModerationObjectResults
|
||||
from llama_stack.providers.inline.agents.meta_reference.responses.openai_responses import (
|
||||
OpenAIResponsesImpl,
|
||||
)
|
||||
|
|
@ -17,6 +15,8 @@ from llama_stack.providers.inline.agents.meta_reference.responses.utils import (
|
|||
extract_guardrail_ids,
|
||||
run_guardrails,
|
||||
)
|
||||
from llama_stack_api.agents import ResponseGuardrailSpec
|
||||
from llama_stack_api.safety import ModerationObject, ModerationObjectResults
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
|
|
@ -30,6 +30,8 @@ def mock_apis():
|
|||
"vector_io_api": AsyncMock(),
|
||||
"conversations_api": AsyncMock(),
|
||||
"safety_api": AsyncMock(),
|
||||
"prompts_api": AsyncMock(),
|
||||
"files_api": AsyncMock(),
|
||||
}
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -0,0 +1,214 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
"""Tests for making Safety API optional in meta-reference agents provider.
|
||||
|
||||
This test suite validates the changes introduced to fix issue #4165, which
|
||||
allows running the meta-reference agents provider without the Safety API.
|
||||
Safety API is now an optional dependency, and errors are raised at request time
|
||||
when guardrails are explicitly requested without Safety API configured.
|
||||
"""
|
||||
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from llama_stack.core.datatypes import Api
|
||||
from llama_stack.core.storage.datatypes import KVStoreReference, ResponsesStoreReference
|
||||
from llama_stack.providers.inline.agents.meta_reference import get_provider_impl
|
||||
from llama_stack.providers.inline.agents.meta_reference.config import (
|
||||
AgentPersistenceConfig,
|
||||
MetaReferenceAgentsImplConfig,
|
||||
)
|
||||
from llama_stack.providers.inline.agents.meta_reference.responses.utils import (
|
||||
run_guardrails,
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_persistence_config():
|
||||
"""Create a mock persistence configuration."""
|
||||
return AgentPersistenceConfig(
|
||||
agent_state=KVStoreReference(
|
||||
backend="kv_default",
|
||||
namespace="agents",
|
||||
),
|
||||
responses=ResponsesStoreReference(
|
||||
backend="sql_default",
|
||||
table_name="responses",
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_deps():
|
||||
"""Create mock dependencies for the agents provider."""
|
||||
# Create mock APIs
|
||||
inference_api = AsyncMock()
|
||||
vector_io_api = AsyncMock()
|
||||
tool_runtime_api = AsyncMock()
|
||||
tool_groups_api = AsyncMock()
|
||||
conversations_api = AsyncMock()
|
||||
prompts_api = AsyncMock()
|
||||
files_api = AsyncMock()
|
||||
|
||||
return {
|
||||
Api.inference: inference_api,
|
||||
Api.vector_io: vector_io_api,
|
||||
Api.tool_runtime: tool_runtime_api,
|
||||
Api.tool_groups: tool_groups_api,
|
||||
Api.conversations: conversations_api,
|
||||
Api.prompts: prompts_api,
|
||||
Api.files: files_api,
|
||||
}
|
||||
|
||||
|
||||
class TestProviderInitialization:
|
||||
"""Test provider initialization with different safety API configurations."""
|
||||
|
||||
async def test_initialization_with_safety_api_present(self, mock_persistence_config, mock_deps):
|
||||
"""Test successful initialization when Safety API is configured."""
|
||||
config = MetaReferenceAgentsImplConfig(persistence=mock_persistence_config)
|
||||
|
||||
# Add safety API to deps
|
||||
safety_api = AsyncMock()
|
||||
mock_deps[Api.safety] = safety_api
|
||||
|
||||
# Mock the initialize method to avoid actual initialization
|
||||
with patch(
|
||||
"llama_stack.providers.inline.agents.meta_reference.agents.MetaReferenceAgentsImpl.initialize",
|
||||
new_callable=AsyncMock,
|
||||
):
|
||||
# Should not raise any exception
|
||||
provider = await get_provider_impl(config, mock_deps, policy=[], telemetry_enabled=False)
|
||||
assert provider is not None
|
||||
|
||||
async def test_initialization_without_safety_api(self, mock_persistence_config, mock_deps):
|
||||
"""Test successful initialization when Safety API is not configured."""
|
||||
config = MetaReferenceAgentsImplConfig(persistence=mock_persistence_config)
|
||||
|
||||
# Safety API is NOT in mock_deps - provider should still start
|
||||
# Mock the initialize method to avoid actual initialization
|
||||
with patch(
|
||||
"llama_stack.providers.inline.agents.meta_reference.agents.MetaReferenceAgentsImpl.initialize",
|
||||
new_callable=AsyncMock,
|
||||
):
|
||||
# Should not raise any exception
|
||||
provider = await get_provider_impl(config, mock_deps, policy=[], telemetry_enabled=False)
|
||||
assert provider is not None
|
||||
assert provider.safety_api is None
|
||||
|
||||
|
||||
class TestGuardrailsFunctionality:
|
||||
"""Test run_guardrails function with optional safety API."""
|
||||
|
||||
async def test_run_guardrails_with_none_safety_api(self):
|
||||
"""Test that run_guardrails returns None when safety_api is None."""
|
||||
result = await run_guardrails(safety_api=None, messages="test message", guardrail_ids=["llama-guard"])
|
||||
assert result is None
|
||||
|
||||
async def test_run_guardrails_with_empty_messages(self):
|
||||
"""Test that run_guardrails returns None for empty messages."""
|
||||
# Test with None safety API
|
||||
result = await run_guardrails(safety_api=None, messages="", guardrail_ids=["llama-guard"])
|
||||
assert result is None
|
||||
|
||||
# Test with mock safety API
|
||||
mock_safety_api = AsyncMock()
|
||||
result = await run_guardrails(safety_api=mock_safety_api, messages="", guardrail_ids=["llama-guard"])
|
||||
assert result is None
|
||||
|
||||
async def test_run_guardrails_with_none_safety_api_ignores_guardrails(self):
|
||||
"""Test that guardrails are skipped when safety_api is None, even if guardrail_ids are provided."""
|
||||
# Should not raise exception, just return None
|
||||
result = await run_guardrails(
|
||||
safety_api=None,
|
||||
messages="potentially harmful content",
|
||||
guardrail_ids=["llama-guard", "content-filter"],
|
||||
)
|
||||
assert result is None
|
||||
|
||||
async def test_create_response_rejects_guardrails_without_safety_api(self, mock_persistence_config, mock_deps):
|
||||
"""Test that create_openai_response raises error when guardrails requested but Safety API unavailable."""
|
||||
from llama_stack.providers.inline.agents.meta_reference.responses.openai_responses import (
|
||||
OpenAIResponsesImpl,
|
||||
)
|
||||
from llama_stack_api import ResponseGuardrailSpec
|
||||
|
||||
# Create OpenAIResponsesImpl with no safety API
|
||||
with patch("llama_stack.providers.inline.agents.meta_reference.responses.openai_responses.ResponsesStore"):
|
||||
impl = OpenAIResponsesImpl(
|
||||
inference_api=mock_deps[Api.inference],
|
||||
tool_groups_api=mock_deps[Api.tool_groups],
|
||||
tool_runtime_api=mock_deps[Api.tool_runtime],
|
||||
responses_store=MagicMock(),
|
||||
vector_io_api=mock_deps[Api.vector_io],
|
||||
safety_api=None, # No Safety API
|
||||
conversations_api=mock_deps[Api.conversations],
|
||||
prompts_api=mock_deps[Api.prompts],
|
||||
files_api=mock_deps[Api.files],
|
||||
)
|
||||
|
||||
# Test with string guardrail
|
||||
with pytest.raises(ValueError) as exc_info:
|
||||
await impl.create_openai_response(
|
||||
input="test input",
|
||||
model="test-model",
|
||||
guardrails=["llama-guard"],
|
||||
)
|
||||
assert "Cannot process guardrails: Safety API is not configured" in str(exc_info.value)
|
||||
|
||||
# Test with ResponseGuardrailSpec
|
||||
with pytest.raises(ValueError) as exc_info:
|
||||
await impl.create_openai_response(
|
||||
input="test input",
|
||||
model="test-model",
|
||||
guardrails=[ResponseGuardrailSpec(type="llama-guard")],
|
||||
)
|
||||
assert "Cannot process guardrails: Safety API is not configured" in str(exc_info.value)
|
||||
|
||||
async def test_create_response_succeeds_without_guardrails_and_no_safety_api(
|
||||
self, mock_persistence_config, mock_deps
|
||||
):
|
||||
"""Test that create_openai_response works when no guardrails requested and Safety API unavailable."""
|
||||
from llama_stack.providers.inline.agents.meta_reference.responses.openai_responses import (
|
||||
OpenAIResponsesImpl,
|
||||
)
|
||||
|
||||
# Create OpenAIResponsesImpl with no safety API
|
||||
with (
|
||||
patch("llama_stack.providers.inline.agents.meta_reference.responses.openai_responses.ResponsesStore"),
|
||||
patch.object(OpenAIResponsesImpl, "_create_streaming_response", new_callable=AsyncMock) as mock_stream,
|
||||
):
|
||||
# Mock the streaming response to return a simple async generator
|
||||
async def mock_generator():
|
||||
yield MagicMock()
|
||||
|
||||
mock_stream.return_value = mock_generator()
|
||||
|
||||
impl = OpenAIResponsesImpl(
|
||||
inference_api=mock_deps[Api.inference],
|
||||
tool_groups_api=mock_deps[Api.tool_groups],
|
||||
tool_runtime_api=mock_deps[Api.tool_runtime],
|
||||
responses_store=MagicMock(),
|
||||
vector_io_api=mock_deps[Api.vector_io],
|
||||
safety_api=None, # No Safety API
|
||||
conversations_api=mock_deps[Api.conversations],
|
||||
prompts_api=mock_deps[Api.prompts],
|
||||
files_api=mock_deps[Api.files],
|
||||
)
|
||||
|
||||
# Should not raise when no guardrails requested
|
||||
# Note: This will still fail later in execution due to mocking, but should pass the validation
|
||||
try:
|
||||
await impl.create_openai_response(
|
||||
input="test input",
|
||||
model="test-model",
|
||||
guardrails=None, # No guardrails
|
||||
)
|
||||
except Exception as e:
|
||||
# Ensure the error is NOT about missing Safety API
|
||||
assert "Cannot process guardrails: Safety API is not configured" not in str(e)
|
||||
|
|
@ -1,169 +0,0 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import uuid
|
||||
from datetime import datetime
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
|
||||
from llama_stack.apis.agents import Turn
|
||||
from llama_stack.apis.inference import CompletionMessage, StopReason
|
||||
from llama_stack.core.datatypes import User
|
||||
from llama_stack.providers.inline.agents.meta_reference.persistence import AgentPersistence, AgentSessionInfo
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def test_setup(sqlite_kvstore):
|
||||
agent_persistence = AgentPersistence(agent_id="test_agent", kvstore=sqlite_kvstore, policy={})
|
||||
yield agent_persistence
|
||||
|
||||
|
||||
@patch("llama_stack.providers.inline.agents.meta_reference.persistence.get_authenticated_user")
|
||||
async def test_session_creation_with_access_attributes(mock_get_authenticated_user, test_setup):
|
||||
agent_persistence = test_setup
|
||||
|
||||
# Set creator's attributes for the session
|
||||
creator_attributes = {"roles": ["researcher"], "teams": ["ai-team"]}
|
||||
mock_get_authenticated_user.return_value = User("test_user", creator_attributes)
|
||||
|
||||
# Create a session
|
||||
session_id = await agent_persistence.create_session("Test Session")
|
||||
|
||||
# Get the session and verify access attributes were set
|
||||
session_info = await agent_persistence.get_session_info(session_id)
|
||||
assert session_info is not None
|
||||
assert session_info.owner is not None
|
||||
assert session_info.owner.attributes is not None
|
||||
assert session_info.owner.attributes["roles"] == ["researcher"]
|
||||
assert session_info.owner.attributes["teams"] == ["ai-team"]
|
||||
|
||||
|
||||
@patch("llama_stack.providers.inline.agents.meta_reference.persistence.get_authenticated_user")
|
||||
async def test_session_access_control(mock_get_authenticated_user, test_setup):
|
||||
agent_persistence = test_setup
|
||||
|
||||
# Create a session with specific access attributes
|
||||
session_id = str(uuid.uuid4())
|
||||
session_info = AgentSessionInfo(
|
||||
session_id=session_id,
|
||||
session_name="Restricted Session",
|
||||
started_at=datetime.now(),
|
||||
owner=User("someone", {"roles": ["admin"], "teams": ["security-team"]}),
|
||||
turns=[],
|
||||
identifier="Restricted Session",
|
||||
)
|
||||
|
||||
await agent_persistence.kvstore.set(
|
||||
key=f"session:{agent_persistence.agent_id}:{session_id}",
|
||||
value=session_info.model_dump_json(),
|
||||
)
|
||||
|
||||
# User with matching attributes can access
|
||||
mock_get_authenticated_user.return_value = User(
|
||||
"testuser", {"roles": ["admin", "user"], "teams": ["security-team", "other-team"]}
|
||||
)
|
||||
retrieved_session = await agent_persistence.get_session_info(session_id)
|
||||
assert retrieved_session is not None
|
||||
assert retrieved_session.session_id == session_id
|
||||
|
||||
# User without matching attributes cannot access
|
||||
mock_get_authenticated_user.return_value = User("testuser", {"roles": ["user"], "teams": ["other-team"]})
|
||||
retrieved_session = await agent_persistence.get_session_info(session_id)
|
||||
assert retrieved_session is None
|
||||
|
||||
|
||||
@patch("llama_stack.providers.inline.agents.meta_reference.persistence.get_authenticated_user")
|
||||
async def test_turn_access_control(mock_get_authenticated_user, test_setup):
|
||||
agent_persistence = test_setup
|
||||
|
||||
# Create a session with restricted access
|
||||
session_id = str(uuid.uuid4())
|
||||
session_info = AgentSessionInfo(
|
||||
session_id=session_id,
|
||||
session_name="Restricted Session",
|
||||
started_at=datetime.now(),
|
||||
owner=User("someone", {"roles": ["admin"]}),
|
||||
turns=[],
|
||||
identifier="Restricted Session",
|
||||
)
|
||||
|
||||
await agent_persistence.kvstore.set(
|
||||
key=f"session:{agent_persistence.agent_id}:{session_id}",
|
||||
value=session_info.model_dump_json(),
|
||||
)
|
||||
|
||||
# Create a turn for this session
|
||||
turn_id = str(uuid.uuid4())
|
||||
turn = Turn(
|
||||
session_id=session_id,
|
||||
turn_id=turn_id,
|
||||
steps=[],
|
||||
started_at=datetime.now(),
|
||||
input_messages=[],
|
||||
output_message=CompletionMessage(
|
||||
content="Hello",
|
||||
stop_reason=StopReason.end_of_turn,
|
||||
),
|
||||
)
|
||||
|
||||
# Admin can add turn
|
||||
mock_get_authenticated_user.return_value = User("testuser", {"roles": ["admin"]})
|
||||
await agent_persistence.add_turn_to_session(session_id, turn)
|
||||
|
||||
# Admin can get turn
|
||||
retrieved_turn = await agent_persistence.get_session_turn(session_id, turn_id)
|
||||
assert retrieved_turn is not None
|
||||
assert retrieved_turn.turn_id == turn_id
|
||||
|
||||
# Regular user cannot get turn
|
||||
mock_get_authenticated_user.return_value = User("testuser", {"roles": ["user"]})
|
||||
with pytest.raises(ValueError):
|
||||
await agent_persistence.get_session_turn(session_id, turn_id)
|
||||
|
||||
# Regular user cannot get turns for session
|
||||
with pytest.raises(ValueError):
|
||||
await agent_persistence.get_session_turns(session_id)
|
||||
|
||||
|
||||
@patch("llama_stack.providers.inline.agents.meta_reference.persistence.get_authenticated_user")
|
||||
async def test_tool_call_and_infer_iters_access_control(mock_get_authenticated_user, test_setup):
|
||||
agent_persistence = test_setup
|
||||
|
||||
# Create a session with restricted access
|
||||
session_id = str(uuid.uuid4())
|
||||
session_info = AgentSessionInfo(
|
||||
session_id=session_id,
|
||||
session_name="Restricted Session",
|
||||
started_at=datetime.now(),
|
||||
owner=User("someone", {"roles": ["admin"]}),
|
||||
turns=[],
|
||||
identifier="Restricted Session",
|
||||
)
|
||||
|
||||
await agent_persistence.kvstore.set(
|
||||
key=f"session:{agent_persistence.agent_id}:{session_id}",
|
||||
value=session_info.model_dump_json(),
|
||||
)
|
||||
|
||||
turn_id = str(uuid.uuid4())
|
||||
|
||||
# Admin user can set inference iterations
|
||||
mock_get_authenticated_user.return_value = User("testuser", {"roles": ["admin"]})
|
||||
await agent_persistence.set_num_infer_iters_in_turn(session_id, turn_id, 5)
|
||||
|
||||
# Admin user can get inference iterations
|
||||
infer_iters = await agent_persistence.get_num_infer_iters_in_turn(session_id, turn_id)
|
||||
assert infer_iters == 5
|
||||
|
||||
# Regular user cannot get inference iterations
|
||||
mock_get_authenticated_user.return_value = User("testuser", {"roles": ["user"]})
|
||||
infer_iters = await agent_persistence.get_num_infer_iters_in_turn(session_id, turn_id)
|
||||
assert infer_iters is None
|
||||
|
||||
# Regular user cannot set inference iterations (should raise ValueError)
|
||||
with pytest.raises(ValueError):
|
||||
await agent_persistence.set_num_infer_iters_in_turn(session_id, turn_id, 10)
|
||||
|
|
@ -13,9 +13,9 @@ from unittest.mock import AsyncMock
|
|||
import pytest
|
||||
|
||||
from llama_stack.core.storage.datatypes import KVStoreReference, SqliteKVStoreConfig
|
||||
from llama_stack.core.storage.kvstore import kvstore_impl, register_kvstore_backends
|
||||
from llama_stack.providers.inline.batches.reference.batches import ReferenceBatchesImpl
|
||||
from llama_stack.providers.inline.batches.reference.config import ReferenceBatchesImplConfig
|
||||
from llama_stack.providers.utils.kvstore import kvstore_impl, register_kvstore_backends
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
|
|
|
|||
|
|
@ -59,8 +59,7 @@ from unittest.mock import AsyncMock, MagicMock
|
|||
|
||||
import pytest
|
||||
|
||||
from llama_stack.apis.batches import BatchObject
|
||||
from llama_stack.apis.common.errors import ConflictError, ResourceNotFoundError
|
||||
from llama_stack_api import BatchObject, ConflictError, ResourceNotFoundError
|
||||
|
||||
|
||||
class TestReferenceBatchesImpl:
|
||||
|
|
|
|||
|
|
@ -44,7 +44,7 @@ import asyncio
|
|||
|
||||
import pytest
|
||||
|
||||
from llama_stack.apis.common.errors import ConflictError
|
||||
from llama_stack_api import ConflictError
|
||||
|
||||
|
||||
class TestReferenceBatchesIdempotency:
|
||||
|
|
|
|||
|
|
@ -9,8 +9,8 @@ import pytest
|
|||
from moto import mock_aws
|
||||
|
||||
from llama_stack.core.storage.datatypes import SqliteSqlStoreConfig, SqlStoreReference
|
||||
from llama_stack.core.storage.sqlstore.sqlstore import register_sqlstore_backends
|
||||
from llama_stack.providers.remote.files.s3 import S3FilesImplConfig, get_adapter_impl
|
||||
from llama_stack.providers.utils.sqlstore.sqlstore import register_sqlstore_backends
|
||||
|
||||
|
||||
class MockUploadFile:
|
||||
|
|
|
|||
|
|
@ -9,8 +9,7 @@ from unittest.mock import patch
|
|||
import pytest
|
||||
from botocore.exceptions import ClientError
|
||||
|
||||
from llama_stack.apis.common.errors import ResourceNotFoundError
|
||||
from llama_stack.apis.files import OpenAIFilePurpose
|
||||
from llama_stack_api import OpenAIFilePurpose, ResourceNotFoundError
|
||||
|
||||
|
||||
class TestS3FilesImpl:
|
||||
|
|
@ -228,7 +227,7 @@ class TestS3FilesImpl:
|
|||
|
||||
mock_now.return_value = 0
|
||||
|
||||
from llama_stack.apis.files import ExpiresAfter
|
||||
from llama_stack_api import ExpiresAfter
|
||||
|
||||
sample_text_file.filename = "test_expired_file"
|
||||
uploaded = await s3_provider.openai_upload_file(
|
||||
|
|
@ -260,7 +259,7 @@ class TestS3FilesImpl:
|
|||
|
||||
async def test_unsupported_expires_after_anchor(self, s3_provider, sample_text_file):
|
||||
"""Unsupported anchor value should raise ValueError."""
|
||||
from llama_stack.apis.files import ExpiresAfter
|
||||
from llama_stack_api import ExpiresAfter
|
||||
|
||||
sample_text_file.filename = "test_unsupported_expires_after_anchor"
|
||||
|
||||
|
|
@ -273,7 +272,7 @@ class TestS3FilesImpl:
|
|||
|
||||
async def test_nonint_expires_after_seconds(self, s3_provider, sample_text_file):
|
||||
"""Non-integer seconds in expires_after should raise ValueError."""
|
||||
from llama_stack.apis.files import ExpiresAfter
|
||||
from llama_stack_api import ExpiresAfter
|
||||
|
||||
sample_text_file.filename = "test_nonint_expires_after_seconds"
|
||||
|
||||
|
|
@ -286,7 +285,7 @@ class TestS3FilesImpl:
|
|||
|
||||
async def test_expires_after_seconds_out_of_bounds(self, s3_provider, sample_text_file):
|
||||
"""Seconds outside allowed range should raise ValueError."""
|
||||
from llama_stack.apis.files import ExpiresAfter
|
||||
from llama_stack_api import ExpiresAfter
|
||||
|
||||
with pytest.raises(ValueError, match="greater than or equal to 3600"):
|
||||
await s3_provider.openai_upload_file(
|
||||
|
|
|
|||
|
|
@ -8,10 +8,9 @@ from unittest.mock import patch
|
|||
|
||||
import pytest
|
||||
|
||||
from llama_stack.apis.common.errors import ResourceNotFoundError
|
||||
from llama_stack.apis.files import OpenAIFilePurpose
|
||||
from llama_stack.core.datatypes import User
|
||||
from llama_stack.providers.remote.files.s3.files import S3FilesImpl
|
||||
from llama_stack_api import OpenAIFilePurpose, ResourceNotFoundError
|
||||
|
||||
|
||||
async def test_listing_hides_other_users_file(s3_provider, sample_text_file):
|
||||
|
|
@ -19,11 +18,11 @@ async def test_listing_hides_other_users_file(s3_provider, sample_text_file):
|
|||
user_a = User("user-a", {"roles": ["team-a"]})
|
||||
user_b = User("user-b", {"roles": ["team-b"]})
|
||||
|
||||
with patch("llama_stack.providers.utils.sqlstore.authorized_sqlstore.get_authenticated_user") as mock_get_user:
|
||||
with patch("llama_stack.core.storage.sqlstore.authorized_sqlstore.get_authenticated_user") as mock_get_user:
|
||||
mock_get_user.return_value = user_a
|
||||
uploaded = await s3_provider.openai_upload_file(file=sample_text_file, purpose=OpenAIFilePurpose.ASSISTANTS)
|
||||
|
||||
with patch("llama_stack.providers.utils.sqlstore.authorized_sqlstore.get_authenticated_user") as mock_get_user:
|
||||
with patch("llama_stack.core.storage.sqlstore.authorized_sqlstore.get_authenticated_user") as mock_get_user:
|
||||
mock_get_user.return_value = user_b
|
||||
listed = await s3_provider.openai_list_files()
|
||||
assert all(f.id != uploaded.id for f in listed.data)
|
||||
|
|
@ -42,11 +41,11 @@ async def test_cannot_access_other_user_file(s3_provider, sample_text_file, op):
|
|||
user_a = User("user-a", {"roles": ["team-a"]})
|
||||
user_b = User("user-b", {"roles": ["team-b"]})
|
||||
|
||||
with patch("llama_stack.providers.utils.sqlstore.authorized_sqlstore.get_authenticated_user") as mock_get_user:
|
||||
with patch("llama_stack.core.storage.sqlstore.authorized_sqlstore.get_authenticated_user") as mock_get_user:
|
||||
mock_get_user.return_value = user_a
|
||||
uploaded = await s3_provider.openai_upload_file(file=sample_text_file, purpose=OpenAIFilePurpose.ASSISTANTS)
|
||||
|
||||
with patch("llama_stack.providers.utils.sqlstore.authorized_sqlstore.get_authenticated_user") as mock_get_user:
|
||||
with patch("llama_stack.core.storage.sqlstore.authorized_sqlstore.get_authenticated_user") as mock_get_user:
|
||||
mock_get_user.return_value = user_b
|
||||
with pytest.raises(ResourceNotFoundError):
|
||||
await op(s3_provider, uploaded.id)
|
||||
|
|
@ -57,11 +56,11 @@ async def test_shared_role_allows_listing(s3_provider, sample_text_file):
|
|||
user_a = User("user-a", {"roles": ["shared-role"]})
|
||||
user_b = User("user-b", {"roles": ["shared-role"]})
|
||||
|
||||
with patch("llama_stack.providers.utils.sqlstore.authorized_sqlstore.get_authenticated_user") as mock_get_user:
|
||||
with patch("llama_stack.core.storage.sqlstore.authorized_sqlstore.get_authenticated_user") as mock_get_user:
|
||||
mock_get_user.return_value = user_a
|
||||
uploaded = await s3_provider.openai_upload_file(file=sample_text_file, purpose=OpenAIFilePurpose.ASSISTANTS)
|
||||
|
||||
with patch("llama_stack.providers.utils.sqlstore.authorized_sqlstore.get_authenticated_user") as mock_get_user:
|
||||
with patch("llama_stack.core.storage.sqlstore.authorized_sqlstore.get_authenticated_user") as mock_get_user:
|
||||
mock_get_user.return_value = user_b
|
||||
listed = await s3_provider.openai_list_files()
|
||||
assert any(f.id == uploaded.id for f in listed.data)
|
||||
|
|
@ -80,10 +79,10 @@ async def test_shared_role_allows_access(s3_provider, sample_text_file, op):
|
|||
user_x = User("user-x", {"roles": ["shared-role"]})
|
||||
user_y = User("user-y", {"roles": ["shared-role"]})
|
||||
|
||||
with patch("llama_stack.providers.utils.sqlstore.authorized_sqlstore.get_authenticated_user") as mock_get_user:
|
||||
with patch("llama_stack.core.storage.sqlstore.authorized_sqlstore.get_authenticated_user") as mock_get_user:
|
||||
mock_get_user.return_value = user_x
|
||||
uploaded = await s3_provider.openai_upload_file(file=sample_text_file, purpose=OpenAIFilePurpose.ASSISTANTS)
|
||||
|
||||
with patch("llama_stack.providers.utils.sqlstore.authorized_sqlstore.get_authenticated_user") as mock_get_user:
|
||||
with patch("llama_stack.core.storage.sqlstore.authorized_sqlstore.get_authenticated_user") as mock_get_user:
|
||||
mock_get_user.return_value = user_y
|
||||
await op(s3_provider, uploaded.id)
|
||||
|
|
|
|||
78
tests/unit/providers/inference/test_bedrock_adapter.py
Normal file
78
tests/unit/providers/inference/test_bedrock_adapter.py
Normal file
|
|
@ -0,0 +1,78 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import AsyncMock, MagicMock
|
||||
|
||||
import pytest
|
||||
from openai import AuthenticationError
|
||||
|
||||
from llama_stack.providers.remote.inference.bedrock.bedrock import BedrockInferenceAdapter
|
||||
from llama_stack.providers.remote.inference.bedrock.config import BedrockConfig
|
||||
from llama_stack_api import OpenAIChatCompletionRequestWithExtraBody
|
||||
|
||||
|
||||
def test_adapter_initialization():
|
||||
config = BedrockConfig(api_key="test-key", region_name="us-east-1")
|
||||
adapter = BedrockInferenceAdapter(config=config)
|
||||
|
||||
assert adapter.config.auth_credential.get_secret_value() == "test-key"
|
||||
assert adapter.config.region_name == "us-east-1"
|
||||
|
||||
|
||||
def test_client_url_construction():
|
||||
config = BedrockConfig(api_key="test-key", region_name="us-west-2")
|
||||
adapter = BedrockInferenceAdapter(config=config)
|
||||
|
||||
assert adapter.get_base_url() == "https://bedrock-runtime.us-west-2.amazonaws.com/openai/v1"
|
||||
|
||||
|
||||
def test_api_key_from_config():
|
||||
config = BedrockConfig(api_key="config-key", region_name="us-east-1")
|
||||
adapter = BedrockInferenceAdapter(config=config)
|
||||
assert adapter.config.auth_credential.get_secret_value() == "config-key"
|
||||
|
||||
|
||||
def test_api_key_from_header_overrides_config():
|
||||
"""Test API key from request header overrides config via client property"""
|
||||
config = BedrockConfig(api_key="config-key", region_name="us-east-1")
|
||||
adapter = BedrockInferenceAdapter(config=config)
|
||||
adapter.provider_data_api_key_field = "aws_bearer_token_bedrock"
|
||||
adapter.get_request_provider_data = MagicMock(return_value=SimpleNamespace(aws_bearer_token_bedrock="header-key"))
|
||||
|
||||
# The client property is where header override happens (in OpenAIMixin)
|
||||
assert adapter.client.api_key == "header-key"
|
||||
|
||||
|
||||
async def test_authentication_error_handling():
|
||||
"""Test that AuthenticationError from OpenAI client is converted to ValueError with helpful message"""
|
||||
config = BedrockConfig(api_key="invalid-key", region_name="us-east-1")
|
||||
adapter = BedrockInferenceAdapter(config=config)
|
||||
|
||||
# Mock the parent class method to raise AuthenticationError
|
||||
mock_response = MagicMock()
|
||||
mock_response.message = "Invalid authentication credentials"
|
||||
auth_error = AuthenticationError(message="Invalid authentication credentials", response=mock_response, body=None)
|
||||
|
||||
# Create a mock that raises the error
|
||||
mock_super = AsyncMock(side_effect=auth_error)
|
||||
|
||||
# Patch the parent class method
|
||||
original_method = BedrockInferenceAdapter.__bases__[0].openai_chat_completion
|
||||
BedrockInferenceAdapter.__bases__[0].openai_chat_completion = mock_super
|
||||
|
||||
try:
|
||||
with pytest.raises(ValueError) as exc_info:
|
||||
params = OpenAIChatCompletionRequestWithExtraBody(
|
||||
model="test-model", messages=[{"role": "user", "content": "test"}]
|
||||
)
|
||||
await adapter.openai_chat_completion(params=params)
|
||||
|
||||
assert "AWS Bedrock authentication failed" in str(exc_info.value)
|
||||
assert "Please verify your API key" in str(exc_info.value)
|
||||
finally:
|
||||
# Restore original method
|
||||
BedrockInferenceAdapter.__bases__[0].openai_chat_completion = original_method
|
||||
39
tests/unit/providers/inference/test_bedrock_config.py
Normal file
39
tests/unit/providers/inference/test_bedrock_config.py
Normal file
|
|
@ -0,0 +1,39 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from llama_stack.providers.remote.inference.bedrock.config import BedrockConfig
|
||||
|
||||
|
||||
def test_bedrock_config_defaults_no_env(monkeypatch):
|
||||
"""Test BedrockConfig defaults when env vars are not set"""
|
||||
monkeypatch.delenv("AWS_BEARER_TOKEN_BEDROCK", raising=False)
|
||||
monkeypatch.delenv("AWS_DEFAULT_REGION", raising=False)
|
||||
config = BedrockConfig()
|
||||
assert config.auth_credential is None
|
||||
assert config.region_name == "us-east-2"
|
||||
|
||||
|
||||
def test_bedrock_config_reads_from_env(monkeypatch):
|
||||
"""Test BedrockConfig field initialization reads from environment variables"""
|
||||
monkeypatch.setenv("AWS_DEFAULT_REGION", "eu-west-1")
|
||||
config = BedrockConfig()
|
||||
assert config.region_name == "eu-west-1"
|
||||
|
||||
|
||||
def test_bedrock_config_with_values():
|
||||
"""Test BedrockConfig accepts explicit values via alias"""
|
||||
config = BedrockConfig(api_key="test-key", region_name="us-west-2")
|
||||
assert config.auth_credential.get_secret_value() == "test-key"
|
||||
assert config.region_name == "us-west-2"
|
||||
|
||||
|
||||
def test_bedrock_config_sample():
|
||||
"""Test BedrockConfig sample_run_config returns correct format"""
|
||||
sample = BedrockConfig.sample_run_config()
|
||||
assert "api_key" in sample
|
||||
assert "region_name" in sample
|
||||
assert sample["api_key"] == "${env.AWS_BEARER_TOKEN_BEDROCK:=}"
|
||||
assert sample["region_name"] == "${env.AWS_DEFAULT_REGION:=us-east-2}"
|
||||
|
|
@ -120,7 +120,7 @@ from llama_stack.providers.remote.inference.watsonx.watsonx import WatsonXInfere
|
|||
VLLMInferenceAdapter,
|
||||
"llama_stack.providers.remote.inference.vllm.VLLMProviderDataValidator",
|
||||
{
|
||||
"url": "http://fake",
|
||||
"base_url": "http://fake",
|
||||
},
|
||||
),
|
||||
],
|
||||
|
|
@ -153,7 +153,7 @@ def test_litellm_provider_data_used(config_cls, adapter_cls, provider_data_valid
|
|||
"""Validate data for LiteLLM-based providers. Similar to test_openai_provider_data_used, but without the
|
||||
assumption that there is an OpenAI-compatible client object."""
|
||||
|
||||
inference_adapter = adapter_cls(config=config_cls())
|
||||
inference_adapter = adapter_cls(config=config_cls(base_url="http://fake"))
|
||||
|
||||
inference_adapter.__provider_spec__ = MagicMock()
|
||||
inference_adapter.__provider_spec__.provider_data_validator = provider_data_validator
|
||||
|
|
|
|||
|
|
@ -10,7 +10,13 @@ from unittest.mock import AsyncMock, MagicMock, PropertyMock, patch
|
|||
|
||||
import pytest
|
||||
|
||||
from llama_stack.apis.inference import (
|
||||
from llama_stack.core.routers.inference import InferenceRouter
|
||||
from llama_stack.core.routing_tables.models import ModelsRoutingTable
|
||||
from llama_stack.providers.remote.inference.vllm.config import VLLMInferenceAdapterConfig
|
||||
from llama_stack.providers.remote.inference.vllm.vllm import VLLMInferenceAdapter
|
||||
from llama_stack_api import (
|
||||
HealthStatus,
|
||||
Model,
|
||||
OpenAIAssistantMessageParam,
|
||||
OpenAIChatCompletion,
|
||||
OpenAIChatCompletionRequestWithExtraBody,
|
||||
|
|
@ -20,12 +26,6 @@ from llama_stack.apis.inference import (
|
|||
OpenAICompletionRequestWithExtraBody,
|
||||
ToolChoice,
|
||||
)
|
||||
from llama_stack.apis.models import Model
|
||||
from llama_stack.core.routers.inference import InferenceRouter
|
||||
from llama_stack.core.routing_tables.models import ModelsRoutingTable
|
||||
from llama_stack.providers.datatypes import HealthStatus
|
||||
from llama_stack.providers.remote.inference.vllm.config import VLLMInferenceAdapterConfig
|
||||
from llama_stack.providers.remote.inference.vllm.vllm import VLLMInferenceAdapter
|
||||
|
||||
# These are unit test for the remote vllm provider
|
||||
# implementation. This should only contain tests which are specific to
|
||||
|
|
@ -40,7 +40,7 @@ from llama_stack.providers.remote.inference.vllm.vllm import VLLMInferenceAdapte
|
|||
|
||||
@pytest.fixture(scope="function")
|
||||
async def vllm_inference_adapter():
|
||||
config = VLLMInferenceAdapterConfig(url="http://mocked.localhost:12345")
|
||||
config = VLLMInferenceAdapterConfig(base_url="http://mocked.localhost:12345")
|
||||
inference_adapter = VLLMInferenceAdapter(config=config)
|
||||
inference_adapter.model_store = AsyncMock()
|
||||
await inference_adapter.initialize()
|
||||
|
|
@ -204,7 +204,7 @@ async def test_vllm_completion_extra_body():
|
|||
via extra_body to the underlying OpenAI client through the InferenceRouter.
|
||||
"""
|
||||
# Set up the vLLM adapter
|
||||
config = VLLMInferenceAdapterConfig(url="http://mocked.localhost:12345")
|
||||
config = VLLMInferenceAdapterConfig(base_url="http://mocked.localhost:12345")
|
||||
vllm_adapter = VLLMInferenceAdapter(config=config)
|
||||
vllm_adapter.__provider_id__ = "vllm"
|
||||
await vllm_adapter.initialize()
|
||||
|
|
@ -277,7 +277,7 @@ async def test_vllm_chat_completion_extra_body():
|
|||
via extra_body to the underlying OpenAI client through the InferenceRouter for chat completion.
|
||||
"""
|
||||
# Set up the vLLM adapter
|
||||
config = VLLMInferenceAdapterConfig(url="http://mocked.localhost:12345")
|
||||
config = VLLMInferenceAdapterConfig(base_url="http://mocked.localhost:12345")
|
||||
vllm_adapter = VLLMInferenceAdapter(config=config)
|
||||
vllm_adapter.__provider_id__ = "vllm"
|
||||
await vllm_adapter.initialize()
|
||||
|
|
|
|||
|
|
@ -8,11 +8,11 @@ from unittest.mock import AsyncMock
|
|||
|
||||
import pytest
|
||||
|
||||
from llama_stack.apis.tools import ToolDef
|
||||
from llama_stack.providers.inline.agents.meta_reference.responses.streaming import (
|
||||
convert_tooldef_to_chat_tool,
|
||||
)
|
||||
from llama_stack.providers.inline.agents.meta_reference.responses.types import ChatCompletionContext
|
||||
from llama_stack_api import ToolDef
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
|
|
|
|||
5
tests/unit/providers/inline/inference/__init__.py
Normal file
5
tests/unit/providers/inline/inference/__init__.py
Normal file
|
|
@ -0,0 +1,5 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
44
tests/unit/providers/inline/inference/test_meta_reference.py
Normal file
44
tests/unit/providers/inline/inference/test_meta_reference.py
Normal file
|
|
@ -0,0 +1,44 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from unittest.mock import Mock
|
||||
|
||||
import pytest
|
||||
|
||||
from llama_stack.providers.inline.inference.meta_reference.model_parallel import (
|
||||
ModelRunner,
|
||||
)
|
||||
|
||||
|
||||
class TestModelRunner:
|
||||
"""Test ModelRunner task dispatching for model-parallel inference."""
|
||||
|
||||
def test_chat_completion_task_dispatch(self):
|
||||
"""Verify ModelRunner correctly dispatches chat_completion tasks."""
|
||||
# Create a mock generator
|
||||
mock_generator = Mock()
|
||||
mock_generator.chat_completion = Mock(return_value=iter([]))
|
||||
|
||||
runner = ModelRunner(mock_generator)
|
||||
|
||||
# Create a chat_completion task
|
||||
fake_params = {"model": "test"}
|
||||
fake_messages = [{"role": "user", "content": "test"}]
|
||||
task = ("chat_completion", [fake_params, fake_messages])
|
||||
|
||||
# Execute task
|
||||
runner(task)
|
||||
|
||||
# Verify chat_completion was called with correct arguments
|
||||
mock_generator.chat_completion.assert_called_once_with(fake_params, fake_messages)
|
||||
|
||||
def test_invalid_task_type_raises_error(self):
|
||||
"""Verify ModelRunner rejects invalid task types."""
|
||||
mock_generator = Mock()
|
||||
runner = ModelRunner(mock_generator)
|
||||
|
||||
with pytest.raises(ValueError, match="Unexpected task type"):
|
||||
runner(("invalid_task", []))
|
||||
|
|
@ -9,10 +9,9 @@ from unittest.mock import patch
|
|||
|
||||
import pytest
|
||||
|
||||
from llama_stack.apis.datasets import Dataset, DatasetPurpose, URIDataSource
|
||||
from llama_stack.apis.resource import ResourceType
|
||||
from llama_stack.providers.remote.datasetio.nvidia.config import NvidiaDatasetIOConfig
|
||||
from llama_stack.providers.remote.datasetio.nvidia.datasetio import NvidiaDatasetIOAdapter
|
||||
from llama_stack_api import Dataset, DatasetPurpose, ResourceType, URIDataSource
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
|
|
|
|||
|
|
@ -9,14 +9,20 @@ from unittest.mock import MagicMock, patch
|
|||
|
||||
import pytest
|
||||
|
||||
from llama_stack.apis.benchmarks import Benchmark
|
||||
from llama_stack.apis.common.job_types import Job, JobStatus
|
||||
from llama_stack.apis.eval.eval import BenchmarkConfig, EvaluateResponse, ModelCandidate, SamplingParams
|
||||
from llama_stack.apis.inference.inference import TopPSamplingStrategy
|
||||
from llama_stack.apis.resource import ResourceType
|
||||
from llama_stack.models.llama.sku_types import CoreModelId
|
||||
from llama_stack.providers.remote.eval.nvidia.config import NVIDIAEvalConfig
|
||||
from llama_stack.providers.remote.eval.nvidia.eval import NVIDIAEvalImpl
|
||||
from llama_stack_api import (
|
||||
Benchmark,
|
||||
BenchmarkConfig,
|
||||
EvaluateResponse,
|
||||
Job,
|
||||
JobStatus,
|
||||
ModelCandidate,
|
||||
ResourceType,
|
||||
SamplingParams,
|
||||
TopPSamplingStrategy,
|
||||
)
|
||||
|
||||
MOCK_DATASET_ID = "default/test-dataset"
|
||||
MOCK_BENCHMARK_ID = "test-benchmark"
|
||||
|
|
|
|||
|
|
@ -10,7 +10,12 @@ from unittest.mock import patch
|
|||
|
||||
import pytest
|
||||
|
||||
from llama_stack.apis.post_training.post_training import (
|
||||
from llama_stack.core.library_client import convert_pydantic_to_json_value
|
||||
from llama_stack.providers.remote.post_training.nvidia.post_training import (
|
||||
NvidiaPostTrainingAdapter,
|
||||
NvidiaPostTrainingConfig,
|
||||
)
|
||||
from llama_stack_api import (
|
||||
DataConfig,
|
||||
DatasetFormat,
|
||||
EfficiencyConfig,
|
||||
|
|
@ -19,11 +24,6 @@ from llama_stack.apis.post_training.post_training import (
|
|||
OptimizerType,
|
||||
TrainingConfig,
|
||||
)
|
||||
from llama_stack.core.library_client import convert_pydantic_to_json_value
|
||||
from llama_stack.providers.remote.post_training.nvidia.post_training import (
|
||||
NvidiaPostTrainingAdapter,
|
||||
NvidiaPostTrainingConfig,
|
||||
)
|
||||
|
||||
|
||||
class TestNvidiaParameters:
|
||||
|
|
|
|||
|
|
@ -9,10 +9,10 @@ from unittest.mock import AsyncMock, MagicMock, patch
|
|||
import aiohttp
|
||||
import pytest
|
||||
|
||||
from llama_stack.apis.models import ModelType
|
||||
from llama_stack.providers.remote.inference.nvidia.config import NVIDIAConfig
|
||||
from llama_stack.providers.remote.inference.nvidia.nvidia import NVIDIAInferenceAdapter
|
||||
from llama_stack.providers.utils.inference.openai_mixin import OpenAIMixin
|
||||
from llama_stack_api import ModelType
|
||||
|
||||
|
||||
class MockResponse:
|
||||
|
|
@ -146,7 +146,7 @@ async def test_hosted_model_not_in_endpoint_mapping():
|
|||
|
||||
async def test_self_hosted_ignores_endpoint():
|
||||
adapter = create_adapter(
|
||||
config=NVIDIAConfig(url="http://localhost:8000", api_key=None),
|
||||
config=NVIDIAConfig(base_url="http://localhost:8000", api_key=None),
|
||||
rerank_endpoints={"test-model": "https://model.endpoint/rerank"}, # This should be ignored for self-hosted.
|
||||
)
|
||||
mock_session = MockSession(MockResponse())
|
||||
|
|
|
|||
|
|
@ -10,13 +10,16 @@ from unittest.mock import AsyncMock, MagicMock, patch
|
|||
|
||||
import pytest
|
||||
|
||||
from llama_stack.apis.inference import CompletionMessage, UserMessage
|
||||
from llama_stack.apis.resource import ResourceType
|
||||
from llama_stack.apis.safety import RunShieldResponse, ViolationLevel
|
||||
from llama_stack.apis.shields import Shield
|
||||
from llama_stack.models.llama.datatypes import StopReason
|
||||
from llama_stack.providers.remote.safety.nvidia.config import NVIDIASafetyConfig
|
||||
from llama_stack.providers.remote.safety.nvidia.nvidia import NVIDIASafetyAdapter
|
||||
from llama_stack_api import (
|
||||
OpenAIAssistantMessageParam,
|
||||
OpenAIUserMessageParam,
|
||||
ResourceType,
|
||||
RunShieldResponse,
|
||||
Shield,
|
||||
ViolationLevel,
|
||||
)
|
||||
|
||||
|
||||
class FakeNVIDIASafetyAdapter(NVIDIASafetyAdapter):
|
||||
|
|
@ -136,11 +139,9 @@ async def test_run_shield_allowed(nvidia_adapter, mock_guardrails_post):
|
|||
|
||||
# Run the shield
|
||||
messages = [
|
||||
UserMessage(role="user", content="Hello, how are you?"),
|
||||
CompletionMessage(
|
||||
role="assistant",
|
||||
OpenAIUserMessageParam(content="Hello, how are you?"),
|
||||
OpenAIAssistantMessageParam(
|
||||
content="I'm doing well, thank you for asking!",
|
||||
stop_reason=StopReason.end_of_message,
|
||||
tool_calls=[],
|
||||
),
|
||||
]
|
||||
|
|
@ -191,13 +192,10 @@ async def test_run_shield_blocked(nvidia_adapter, mock_guardrails_post):
|
|||
# Mock Guardrails API response
|
||||
mock_guardrails_post.return_value = {"status": "blocked", "rails_status": {"reason": "harmful_content"}}
|
||||
|
||||
# Run the shield
|
||||
messages = [
|
||||
UserMessage(role="user", content="Hello, how are you?"),
|
||||
CompletionMessage(
|
||||
role="assistant",
|
||||
OpenAIUserMessageParam(content="Hello, how are you?"),
|
||||
OpenAIAssistantMessageParam(
|
||||
content="I'm doing well, thank you for asking!",
|
||||
stop_reason=StopReason.end_of_message,
|
||||
tool_calls=[],
|
||||
),
|
||||
]
|
||||
|
|
@ -243,7 +241,7 @@ async def test_run_shield_not_found(nvidia_adapter, mock_guardrails_post):
|
|||
adapter.shield_store.get_shield.return_value = None
|
||||
|
||||
messages = [
|
||||
UserMessage(role="user", content="Hello, how are you?"),
|
||||
OpenAIUserMessageParam(content="Hello, how are you?"),
|
||||
]
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
|
|
@ -274,11 +272,9 @@ async def test_run_shield_http_error(nvidia_adapter, mock_guardrails_post):
|
|||
|
||||
# Running the shield should raise an exception
|
||||
messages = [
|
||||
UserMessage(role="user", content="Hello, how are you?"),
|
||||
CompletionMessage(
|
||||
role="assistant",
|
||||
OpenAIUserMessageParam(content="Hello, how are you?"),
|
||||
OpenAIAssistantMessageParam(
|
||||
content="I'm doing well, thank you for asking!",
|
||||
stop_reason=StopReason.end_of_message,
|
||||
tool_calls=[],
|
||||
),
|
||||
]
|
||||
|
|
|
|||
|
|
@ -10,15 +10,6 @@ from unittest.mock import patch
|
|||
|
||||
import pytest
|
||||
|
||||
from llama_stack.apis.post_training.post_training import (
|
||||
DataConfig,
|
||||
DatasetFormat,
|
||||
LoraFinetuningConfig,
|
||||
OptimizerConfig,
|
||||
OptimizerType,
|
||||
QATFinetuningConfig,
|
||||
TrainingConfig,
|
||||
)
|
||||
from llama_stack.core.library_client import convert_pydantic_to_json_value
|
||||
from llama_stack.providers.remote.post_training.nvidia.post_training import (
|
||||
ListNvidiaPostTrainingJobs,
|
||||
|
|
@ -27,6 +18,15 @@ from llama_stack.providers.remote.post_training.nvidia.post_training import (
|
|||
NvidiaPostTrainingJob,
|
||||
NvidiaPostTrainingJobStatusResponse,
|
||||
)
|
||||
from llama_stack_api import (
|
||||
DataConfig,
|
||||
DatasetFormat,
|
||||
LoraFinetuningConfig,
|
||||
OptimizerConfig,
|
||||
OptimizerType,
|
||||
QATFinetuningConfig,
|
||||
TrainingConfig,
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
|
|
|
|||
|
|
@ -4,50 +4,66 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from llama_stack.providers.remote.inference.bedrock.bedrock import (
|
||||
_get_region_prefix,
|
||||
_to_inference_profile_id,
|
||||
)
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import AsyncMock, PropertyMock, patch
|
||||
|
||||
from llama_stack.providers.remote.inference.bedrock.bedrock import BedrockInferenceAdapter
|
||||
from llama_stack.providers.remote.inference.bedrock.config import BedrockConfig
|
||||
from llama_stack_api import OpenAIChatCompletionRequestWithExtraBody
|
||||
|
||||
|
||||
def test_region_prefixes():
|
||||
assert _get_region_prefix("us-east-1") == "us."
|
||||
assert _get_region_prefix("eu-west-1") == "eu."
|
||||
assert _get_region_prefix("ap-south-1") == "ap."
|
||||
assert _get_region_prefix("ca-central-1") == "us."
|
||||
def test_can_create_adapter():
|
||||
config = BedrockConfig(api_key="test-key", region_name="us-east-1")
|
||||
adapter = BedrockInferenceAdapter(config=config)
|
||||
|
||||
# Test case insensitive
|
||||
assert _get_region_prefix("US-EAST-1") == "us."
|
||||
assert _get_region_prefix("EU-WEST-1") == "eu."
|
||||
assert _get_region_prefix("Ap-South-1") == "ap."
|
||||
|
||||
# Test None region
|
||||
assert _get_region_prefix(None) == "us."
|
||||
assert adapter is not None
|
||||
assert adapter.config.region_name == "us-east-1"
|
||||
assert adapter.get_api_key() == "test-key"
|
||||
|
||||
|
||||
def test_model_id_conversion():
|
||||
# Basic conversion
|
||||
assert (
|
||||
_to_inference_profile_id("meta.llama3-1-70b-instruct-v1:0", "us-east-1") == "us.meta.llama3-1-70b-instruct-v1:0"
|
||||
def test_different_aws_regions():
|
||||
# just check a couple regions to verify URL construction works
|
||||
config = BedrockConfig(api_key="key", region_name="us-east-1")
|
||||
adapter = BedrockInferenceAdapter(config=config)
|
||||
assert adapter.get_base_url() == "https://bedrock-runtime.us-east-1.amazonaws.com/openai/v1"
|
||||
|
||||
config = BedrockConfig(api_key="key", region_name="eu-west-1")
|
||||
adapter = BedrockInferenceAdapter(config=config)
|
||||
assert adapter.get_base_url() == "https://bedrock-runtime.eu-west-1.amazonaws.com/openai/v1"
|
||||
|
||||
|
||||
async def test_basic_chat_completion():
|
||||
"""Test basic chat completion works with OpenAIMixin"""
|
||||
config = BedrockConfig(api_key="k", region_name="us-east-1")
|
||||
adapter = BedrockInferenceAdapter(config=config)
|
||||
|
||||
class FakeModelStore:
|
||||
async def has_model(self, model_id):
|
||||
return True
|
||||
|
||||
async def get_model(self, model_id):
|
||||
return SimpleNamespace(provider_resource_id="meta.llama3-1-8b-instruct-v1:0")
|
||||
|
||||
adapter.model_store = FakeModelStore()
|
||||
|
||||
fake_response = SimpleNamespace(
|
||||
id="chatcmpl-123",
|
||||
choices=[SimpleNamespace(message=SimpleNamespace(content="Hello!", role="assistant"), finish_reason="stop")],
|
||||
)
|
||||
|
||||
# Already has prefix
|
||||
assert (
|
||||
_to_inference_profile_id("us.meta.llama3-1-70b-instruct-v1:0", "us-east-1")
|
||||
== "us.meta.llama3-1-70b-instruct-v1:0"
|
||||
)
|
||||
mock_create = AsyncMock(return_value=fake_response)
|
||||
|
||||
# ARN should be returned unchanged
|
||||
arn = "arn:aws:bedrock:us-east-1:123456789012:inference-profile/us.meta.llama3-1-70b-instruct-v1:0"
|
||||
assert _to_inference_profile_id(arn, "us-east-1") == arn
|
||||
class FakeClient:
|
||||
def __init__(self):
|
||||
self.chat = SimpleNamespace(completions=SimpleNamespace(create=mock_create))
|
||||
|
||||
# ARN should be returned unchanged even without region
|
||||
assert _to_inference_profile_id(arn) == arn
|
||||
with patch.object(type(adapter), "client", new_callable=PropertyMock, return_value=FakeClient()):
|
||||
params = OpenAIChatCompletionRequestWithExtraBody(
|
||||
model="llama3-1-8b",
|
||||
messages=[{"role": "user", "content": "hello"}],
|
||||
stream=False,
|
||||
)
|
||||
response = await adapter.openai_chat_completion(params=params)
|
||||
|
||||
# Optional region parameter defaults to us-east-1
|
||||
assert _to_inference_profile_id("meta.llama3-1-70b-instruct-v1:0") == "us.meta.llama3-1-70b-instruct-v1:0"
|
||||
|
||||
# Different regions work with optional parameter
|
||||
assert (
|
||||
_to_inference_profile_id("meta.llama3-1-70b-instruct-v1:0", "eu-west-1") == "eu.meta.llama3-1-70b-instruct-v1:0"
|
||||
)
|
||||
assert response.id == "chatcmpl-123"
|
||||
assert mock_create.await_count == 1
|
||||
|
|
|
|||
|
|
@ -4,8 +4,10 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import get_args, get_origin
|
||||
|
||||
import pytest
|
||||
from pydantic import BaseModel
|
||||
from pydantic import BaseModel, HttpUrl
|
||||
|
||||
from llama_stack.core.distribution import get_provider_registry, providable_apis
|
||||
from llama_stack.core.utils.dynamic import instantiate_class_type
|
||||
|
|
@ -41,3 +43,55 @@ class TestProviderConfigurations:
|
|||
|
||||
sample_config = config_type.sample_run_config(__distro_dir__="foobarbaz")
|
||||
assert isinstance(sample_config, dict), f"{config_class_name}.sample_run_config() did not return a dict"
|
||||
|
||||
def test_remote_inference_url_standardization(self):
|
||||
"""Verify all remote inference providers use standardized base_url configuration."""
|
||||
provider_registry = get_provider_registry()
|
||||
inference_providers = provider_registry.get("inference", {})
|
||||
|
||||
# Filter for remote providers only
|
||||
remote_providers = {k: v for k, v in inference_providers.items() if k.startswith("remote::")}
|
||||
|
||||
failures = []
|
||||
for provider_type, provider_spec in remote_providers.items():
|
||||
try:
|
||||
config_class_name = provider_spec.config_class
|
||||
config_type = instantiate_class_type(config_class_name)
|
||||
|
||||
# Check that config has base_url field (not url)
|
||||
if hasattr(config_type, "model_fields"):
|
||||
fields = config_type.model_fields
|
||||
|
||||
# Should NOT have 'url' field (old pattern)
|
||||
if "url" in fields:
|
||||
failures.append(
|
||||
f"{provider_type}: Uses deprecated 'url' field instead of 'base_url'. "
|
||||
f"Please rename to 'base_url' for consistency."
|
||||
)
|
||||
|
||||
# Should have 'base_url' field with HttpUrl | None type
|
||||
if "base_url" in fields:
|
||||
field_info = fields["base_url"]
|
||||
annotation = field_info.annotation
|
||||
|
||||
# Check if it's HttpUrl or HttpUrl | None
|
||||
# get_origin() returns Union for (X | Y), None for plain types
|
||||
# get_args() returns the types inside Union, e.g. (HttpUrl, NoneType)
|
||||
is_valid = False
|
||||
if get_origin(annotation) is not None: # It's a Union/Optional
|
||||
if HttpUrl in get_args(annotation):
|
||||
is_valid = True
|
||||
elif annotation == HttpUrl: # Plain HttpUrl without | None
|
||||
is_valid = True
|
||||
|
||||
if not is_valid:
|
||||
failures.append(
|
||||
f"{provider_type}: base_url field has incorrect type annotation. "
|
||||
f"Expected 'HttpUrl | None', got '{annotation}'"
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
failures.append(f"{provider_type}: Error checking URL standardization: {str(e)}")
|
||||
|
||||
if failures:
|
||||
pytest.fail("URL standardization violations found:\n" + "\n".join(f" - {f}" for f in failures))
|
||||
|
|
|
|||
|
|
@ -1,220 +0,0 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import pytest
|
||||
from pydantic import ValidationError
|
||||
|
||||
from llama_stack.apis.common.content_types import TextContentItem
|
||||
from llama_stack.apis.inference import (
|
||||
CompletionMessage,
|
||||
OpenAIAssistantMessageParam,
|
||||
OpenAIChatCompletionContentPartImageParam,
|
||||
OpenAIChatCompletionContentPartTextParam,
|
||||
OpenAIDeveloperMessageParam,
|
||||
OpenAIImageURL,
|
||||
OpenAISystemMessageParam,
|
||||
OpenAIToolMessageParam,
|
||||
OpenAIUserMessageParam,
|
||||
SystemMessage,
|
||||
UserMessage,
|
||||
)
|
||||
from llama_stack.models.llama.datatypes import BuiltinTool, StopReason, ToolCall
|
||||
from llama_stack.providers.utils.inference.openai_compat import (
|
||||
convert_message_to_openai_dict,
|
||||
convert_message_to_openai_dict_new,
|
||||
openai_messages_to_messages,
|
||||
)
|
||||
|
||||
|
||||
async def test_convert_message_to_openai_dict():
|
||||
message = UserMessage(content=[TextContentItem(text="Hello, world!")], role="user")
|
||||
assert await convert_message_to_openai_dict(message) == {
|
||||
"role": "user",
|
||||
"content": [{"type": "text", "text": "Hello, world!"}],
|
||||
}
|
||||
|
||||
|
||||
# Test convert_message_to_openai_dict with a tool call
|
||||
async def test_convert_message_to_openai_dict_with_tool_call():
|
||||
message = CompletionMessage(
|
||||
content="",
|
||||
tool_calls=[ToolCall(call_id="123", tool_name="test_tool", arguments='{"foo": "bar"}')],
|
||||
stop_reason=StopReason.end_of_turn,
|
||||
)
|
||||
|
||||
openai_dict = await convert_message_to_openai_dict(message)
|
||||
|
||||
assert openai_dict == {
|
||||
"role": "assistant",
|
||||
"content": [{"type": "text", "text": ""}],
|
||||
"tool_calls": [
|
||||
{"id": "123", "type": "function", "function": {"name": "test_tool", "arguments": '{"foo": "bar"}'}}
|
||||
],
|
||||
}
|
||||
|
||||
|
||||
async def test_convert_message_to_openai_dict_with_builtin_tool_call():
|
||||
message = CompletionMessage(
|
||||
content="",
|
||||
tool_calls=[
|
||||
ToolCall(
|
||||
call_id="123",
|
||||
tool_name=BuiltinTool.brave_search,
|
||||
arguments='{"foo": "bar"}',
|
||||
)
|
||||
],
|
||||
stop_reason=StopReason.end_of_turn,
|
||||
)
|
||||
|
||||
openai_dict = await convert_message_to_openai_dict(message)
|
||||
|
||||
assert openai_dict == {
|
||||
"role": "assistant",
|
||||
"content": [{"type": "text", "text": ""}],
|
||||
"tool_calls": [
|
||||
{"id": "123", "type": "function", "function": {"name": "brave_search", "arguments": '{"foo": "bar"}'}}
|
||||
],
|
||||
}
|
||||
|
||||
|
||||
async def test_openai_messages_to_messages_with_content_str():
|
||||
openai_messages = [
|
||||
OpenAISystemMessageParam(content="system message"),
|
||||
OpenAIUserMessageParam(content="user message"),
|
||||
OpenAIAssistantMessageParam(content="assistant message"),
|
||||
]
|
||||
|
||||
llama_messages = openai_messages_to_messages(openai_messages)
|
||||
assert len(llama_messages) == 3
|
||||
assert isinstance(llama_messages[0], SystemMessage)
|
||||
assert isinstance(llama_messages[1], UserMessage)
|
||||
assert isinstance(llama_messages[2], CompletionMessage)
|
||||
assert llama_messages[0].content == "system message"
|
||||
assert llama_messages[1].content == "user message"
|
||||
assert llama_messages[2].content == "assistant message"
|
||||
|
||||
|
||||
async def test_openai_messages_to_messages_with_content_list():
|
||||
openai_messages = [
|
||||
OpenAISystemMessageParam(content=[OpenAIChatCompletionContentPartTextParam(text="system message")]),
|
||||
OpenAIUserMessageParam(content=[OpenAIChatCompletionContentPartTextParam(text="user message")]),
|
||||
OpenAIAssistantMessageParam(content=[OpenAIChatCompletionContentPartTextParam(text="assistant message")]),
|
||||
]
|
||||
|
||||
llama_messages = openai_messages_to_messages(openai_messages)
|
||||
assert len(llama_messages) == 3
|
||||
assert isinstance(llama_messages[0], SystemMessage)
|
||||
assert isinstance(llama_messages[1], UserMessage)
|
||||
assert isinstance(llama_messages[2], CompletionMessage)
|
||||
assert llama_messages[0].content[0].text == "system message"
|
||||
assert llama_messages[1].content[0].text == "user message"
|
||||
assert llama_messages[2].content[0].text == "assistant message"
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"message_class,kwargs",
|
||||
[
|
||||
(OpenAISystemMessageParam, {}),
|
||||
(OpenAIAssistantMessageParam, {}),
|
||||
(OpenAIDeveloperMessageParam, {}),
|
||||
(OpenAIUserMessageParam, {}),
|
||||
(OpenAIToolMessageParam, {"tool_call_id": "call_123"}),
|
||||
],
|
||||
)
|
||||
def test_message_accepts_text_string(message_class, kwargs):
|
||||
"""Test that messages accept string text content."""
|
||||
msg = message_class(content="Test message", **kwargs)
|
||||
assert msg.content == "Test message"
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"message_class,kwargs",
|
||||
[
|
||||
(OpenAISystemMessageParam, {}),
|
||||
(OpenAIAssistantMessageParam, {}),
|
||||
(OpenAIDeveloperMessageParam, {}),
|
||||
(OpenAIUserMessageParam, {}),
|
||||
(OpenAIToolMessageParam, {"tool_call_id": "call_123"}),
|
||||
],
|
||||
)
|
||||
def test_message_accepts_text_list(message_class, kwargs):
|
||||
"""Test that messages accept list of text content parts."""
|
||||
content_list = [OpenAIChatCompletionContentPartTextParam(text="Test message")]
|
||||
msg = message_class(content=content_list, **kwargs)
|
||||
assert len(msg.content) == 1
|
||||
assert msg.content[0].text == "Test message"
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"message_class,kwargs",
|
||||
[
|
||||
(OpenAISystemMessageParam, {}),
|
||||
(OpenAIAssistantMessageParam, {}),
|
||||
(OpenAIDeveloperMessageParam, {}),
|
||||
(OpenAIToolMessageParam, {"tool_call_id": "call_123"}),
|
||||
],
|
||||
)
|
||||
def test_message_rejects_images(message_class, kwargs):
|
||||
"""Test that system, assistant, developer, and tool messages reject image content."""
|
||||
with pytest.raises(ValidationError):
|
||||
message_class(
|
||||
content=[
|
||||
OpenAIChatCompletionContentPartImageParam(image_url=OpenAIImageURL(url="http://example.com/image.jpg"))
|
||||
],
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
|
||||
def test_user_message_accepts_images():
|
||||
"""Test that user messages accept image content (unlike other message types)."""
|
||||
# List with images should work
|
||||
msg = OpenAIUserMessageParam(
|
||||
content=[
|
||||
OpenAIChatCompletionContentPartTextParam(text="Describe this image:"),
|
||||
OpenAIChatCompletionContentPartImageParam(image_url=OpenAIImageURL(url="http://example.com/image.jpg")),
|
||||
]
|
||||
)
|
||||
assert len(msg.content) == 2
|
||||
assert msg.content[0].text == "Describe this image:"
|
||||
assert msg.content[1].image_url.url == "http://example.com/image.jpg"
|
||||
|
||||
|
||||
async def test_convert_message_to_openai_dict_new_user_message():
|
||||
"""Test convert_message_to_openai_dict_new with UserMessage."""
|
||||
message = UserMessage(content="Hello, world!", role="user")
|
||||
result = await convert_message_to_openai_dict_new(message)
|
||||
|
||||
assert result["role"] == "user"
|
||||
assert result["content"] == "Hello, world!"
|
||||
|
||||
|
||||
async def test_convert_message_to_openai_dict_new_completion_message_with_tool_calls():
|
||||
"""Test convert_message_to_openai_dict_new with CompletionMessage containing tool calls."""
|
||||
message = CompletionMessage(
|
||||
content="I'll help you find the weather.",
|
||||
tool_calls=[
|
||||
ToolCall(
|
||||
call_id="call_123",
|
||||
tool_name="get_weather",
|
||||
arguments='{"city": "Sligo"}',
|
||||
)
|
||||
],
|
||||
stop_reason=StopReason.end_of_turn,
|
||||
)
|
||||
result = await convert_message_to_openai_dict_new(message)
|
||||
|
||||
# This would have failed with "Cannot instantiate typing.Union" before the fix
|
||||
assert result["role"] == "assistant"
|
||||
assert result["content"] == "I'll help you find the weather."
|
||||
assert "tool_calls" in result
|
||||
assert result["tool_calls"] is not None
|
||||
assert len(result["tool_calls"]) == 1
|
||||
|
||||
tool_call = result["tool_calls"][0]
|
||||
assert tool_call.id == "call_123"
|
||||
assert tool_call.type == "function"
|
||||
assert tool_call.function.name == "get_weather"
|
||||
assert tool_call.function.arguments == '{"city": "Sligo"}'
|
||||
|
|
@ -12,11 +12,17 @@ from unittest.mock import AsyncMock, MagicMock, Mock, PropertyMock, patch
|
|||
import pytest
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from llama_stack.apis.inference import Model, OpenAIChatCompletionRequestWithExtraBody, OpenAIUserMessageParam
|
||||
from llama_stack.apis.models import ModelType
|
||||
from llama_stack.core.request_headers import request_provider_data_context
|
||||
from llama_stack.providers.utils.inference.model_registry import RemoteInferenceProviderConfig
|
||||
from llama_stack.providers.utils.inference.openai_mixin import OpenAIMixin
|
||||
from llama_stack_api import (
|
||||
Model,
|
||||
ModelType,
|
||||
OpenAIChatCompletionRequestWithExtraBody,
|
||||
OpenAICompletionRequestWithExtraBody,
|
||||
OpenAIEmbeddingsRequestWithExtraBody,
|
||||
OpenAIUserMessageParam,
|
||||
)
|
||||
|
||||
|
||||
class OpenAIMixinImpl(OpenAIMixin):
|
||||
|
|
@ -835,3 +841,96 @@ class TestOpenAIMixinProviderDataApiKey:
|
|||
error_message = str(exc_info.value)
|
||||
assert "test_api_key" in error_message
|
||||
assert "x-llamastack-provider-data" in error_message
|
||||
|
||||
|
||||
class TestOpenAIMixinAllowedModelsInference:
|
||||
"""Test cases for allowed_models enforcement during inference requests"""
|
||||
|
||||
async def test_inference_with_allowed_models(self, mixin, mock_client_context):
|
||||
"""Test that all inference methods succeed with allowed models"""
|
||||
mixin.config.allowed_models = ["gpt-4", "text-davinci-003", "text-embedding-ada-002"]
|
||||
|
||||
mock_client = MagicMock()
|
||||
mock_client.chat.completions.create = AsyncMock(return_value=MagicMock())
|
||||
mock_client.completions.create = AsyncMock(return_value=MagicMock())
|
||||
mock_embedding_response = MagicMock()
|
||||
mock_embedding_response.data = [MagicMock(embedding=[0.1, 0.2, 0.3])]
|
||||
mock_embedding_response.usage = MagicMock(prompt_tokens=5, total_tokens=5)
|
||||
mock_client.embeddings.create = AsyncMock(return_value=mock_embedding_response)
|
||||
|
||||
with mock_client_context(mixin, mock_client):
|
||||
# Test chat completion
|
||||
await mixin.openai_chat_completion(
|
||||
OpenAIChatCompletionRequestWithExtraBody(
|
||||
model="gpt-4", messages=[OpenAIUserMessageParam(role="user", content="Hello")]
|
||||
)
|
||||
)
|
||||
mock_client.chat.completions.create.assert_called_once()
|
||||
|
||||
# Test completion
|
||||
await mixin.openai_completion(
|
||||
OpenAICompletionRequestWithExtraBody(model="text-davinci-003", prompt="Hello")
|
||||
)
|
||||
mock_client.completions.create.assert_called_once()
|
||||
|
||||
# Test embeddings
|
||||
await mixin.openai_embeddings(
|
||||
OpenAIEmbeddingsRequestWithExtraBody(model="text-embedding-ada-002", input="test text")
|
||||
)
|
||||
mock_client.embeddings.create.assert_called_once()
|
||||
|
||||
async def test_inference_with_disallowed_models(self, mixin, mock_client_context):
|
||||
"""Test that all inference methods fail with disallowed models"""
|
||||
mixin.config.allowed_models = ["gpt-4"]
|
||||
|
||||
mock_client = MagicMock()
|
||||
|
||||
with mock_client_context(mixin, mock_client):
|
||||
# Test chat completion with disallowed model
|
||||
with pytest.raises(ValueError, match="Model 'gpt-4-turbo' is not in the allowed models list"):
|
||||
await mixin.openai_chat_completion(
|
||||
OpenAIChatCompletionRequestWithExtraBody(
|
||||
model="gpt-4-turbo", messages=[OpenAIUserMessageParam(role="user", content="Hello")]
|
||||
)
|
||||
)
|
||||
|
||||
# Test completion with disallowed model
|
||||
with pytest.raises(ValueError, match="Model 'text-davinci-002' is not in the allowed models list"):
|
||||
await mixin.openai_completion(
|
||||
OpenAICompletionRequestWithExtraBody(model="text-davinci-002", prompt="Hello")
|
||||
)
|
||||
|
||||
# Test embeddings with disallowed model
|
||||
with pytest.raises(ValueError, match="Model 'text-embedding-3-large' is not in the allowed models list"):
|
||||
await mixin.openai_embeddings(
|
||||
OpenAIEmbeddingsRequestWithExtraBody(model="text-embedding-3-large", input="test text")
|
||||
)
|
||||
|
||||
mock_client.chat.completions.create.assert_not_called()
|
||||
mock_client.completions.create.assert_not_called()
|
||||
mock_client.embeddings.create.assert_not_called()
|
||||
|
||||
async def test_inference_with_no_restrictions(self, mixin, mock_client_context):
|
||||
"""Test that inference succeeds when allowed_models is None or empty list blocks all"""
|
||||
# Test with None (no restrictions)
|
||||
assert mixin.config.allowed_models is None
|
||||
mock_client = MagicMock()
|
||||
mock_client.chat.completions.create = AsyncMock(return_value=MagicMock())
|
||||
|
||||
with mock_client_context(mixin, mock_client):
|
||||
await mixin.openai_chat_completion(
|
||||
OpenAIChatCompletionRequestWithExtraBody(
|
||||
model="any-model", messages=[OpenAIUserMessageParam(role="user", content="Hello")]
|
||||
)
|
||||
)
|
||||
mock_client.chat.completions.create.assert_called_once()
|
||||
|
||||
# Test with empty list (blocks all models)
|
||||
mixin.config.allowed_models = []
|
||||
with mock_client_context(mixin, mock_client):
|
||||
with pytest.raises(ValueError, match="Model 'gpt-4' is not in the allowed models list"):
|
||||
await mixin.openai_chat_completion(
|
||||
OpenAIChatCompletionRequestWithExtraBody(
|
||||
model="gpt-4", messages=[OpenAIUserMessageParam(role="user", content="Hello")]
|
||||
)
|
||||
)
|
||||
|
|
|
|||
32
tests/unit/providers/utils/inference/test_prompt_adapter.py
Normal file
32
tests/unit/providers/utils/inference/test_prompt_adapter.py
Normal file
|
|
@ -0,0 +1,32 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from llama_stack.models.llama.datatypes import RawTextItem
|
||||
from llama_stack.providers.utils.inference.prompt_adapter import (
|
||||
convert_openai_message_to_raw_message,
|
||||
)
|
||||
from llama_stack_api import OpenAIAssistantMessageParam, OpenAIUserMessageParam
|
||||
|
||||
|
||||
class TestConvertOpenAIMessageToRawMessage:
|
||||
"""Test conversion of OpenAI message types to RawMessage format."""
|
||||
|
||||
async def test_user_message_conversion(self):
|
||||
msg = OpenAIUserMessageParam(role="user", content="Hello world")
|
||||
raw_msg = await convert_openai_message_to_raw_message(msg)
|
||||
|
||||
assert raw_msg.role == "user"
|
||||
assert isinstance(raw_msg.content, RawTextItem)
|
||||
assert raw_msg.content.text == "Hello world"
|
||||
|
||||
async def test_assistant_message_conversion(self):
|
||||
msg = OpenAIAssistantMessageParam(role="assistant", content="Hi there!")
|
||||
raw_msg = await convert_openai_message_to_raw_message(msg)
|
||||
|
||||
assert raw_msg.role == "assistant"
|
||||
assert isinstance(raw_msg.content, RawTextItem)
|
||||
assert raw_msg.content.text == "Hi there!"
|
||||
assert raw_msg.tool_calls == []
|
||||
|
|
@ -8,9 +8,8 @@ from unittest.mock import AsyncMock, MagicMock, patch
|
|||
|
||||
import pytest
|
||||
|
||||
from llama_stack.apis.common.content_types import URL, TextContentItem
|
||||
from llama_stack.apis.tools import RAGDocument
|
||||
from llama_stack.providers.utils.memory.vector_store import content_from_data_and_mime_type, content_from_doc
|
||||
from llama_stack_api import URL, RAGDocument, TextContentItem
|
||||
|
||||
|
||||
async def test_content_from_doc_with_url():
|
||||
|
|
|
|||
|
|
@ -35,8 +35,8 @@
|
|||
|
||||
import pytest
|
||||
|
||||
from llama_stack.apis.models import Model
|
||||
from llama_stack.providers.utils.inference.model_registry import ModelRegistryHelper, ProviderModelEntry
|
||||
from llama_stack_api import Model
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
|
|
|
|||
|
|
@ -10,16 +10,15 @@ from unittest.mock import AsyncMock, MagicMock, patch
|
|||
import numpy as np
|
||||
import pytest
|
||||
|
||||
from llama_stack.apis.vector_io import Chunk, ChunkMetadata, QueryChunksResponse
|
||||
from llama_stack.apis.vector_stores import VectorStore
|
||||
from llama_stack.core.storage.datatypes import KVStoreReference, SqliteKVStoreConfig
|
||||
from llama_stack.core.storage.kvstore import register_kvstore_backends
|
||||
from llama_stack.providers.inline.vector_io.faiss.config import FaissVectorIOConfig
|
||||
from llama_stack.providers.inline.vector_io.faiss.faiss import FaissIndex, FaissVectorIOAdapter
|
||||
from llama_stack.providers.inline.vector_io.sqlite_vec import SQLiteVectorIOConfig
|
||||
from llama_stack.providers.inline.vector_io.sqlite_vec.sqlite_vec import SQLiteVecIndex, SQLiteVecVectorIOAdapter
|
||||
from llama_stack.providers.remote.vector_io.pgvector.config import PGVectorVectorIOConfig
|
||||
from llama_stack.providers.remote.vector_io.pgvector.pgvector import PGVectorIndex, PGVectorVectorIOAdapter
|
||||
from llama_stack.providers.utils.kvstore import register_kvstore_backends
|
||||
from llama_stack_api import Chunk, ChunkMetadata, QueryChunksResponse, VectorStore
|
||||
|
||||
EMBEDDING_DIMENSION = 768
|
||||
COLLECTION_PREFIX = "test_collection"
|
||||
|
|
@ -280,7 +279,7 @@ async def pgvector_vec_adapter(unique_kvstore_config, mock_inference_api, embedd
|
|||
) as mock_check_version:
|
||||
mock_check_version.return_value = "0.5.1"
|
||||
|
||||
with patch("llama_stack.providers.utils.kvstore.kvstore_impl") as mock_kvstore_impl:
|
||||
with patch("llama_stack.core.storage.kvstore.kvstore_impl") as mock_kvstore_impl:
|
||||
mock_kvstore = AsyncMock()
|
||||
mock_kvstore_impl.return_value = mock_kvstore
|
||||
|
||||
|
|
|
|||
|
|
@ -10,15 +10,12 @@ from unittest.mock import MagicMock, patch
|
|||
import numpy as np
|
||||
import pytest
|
||||
|
||||
from llama_stack.apis.files import Files
|
||||
from llama_stack.apis.vector_io import Chunk, QueryChunksResponse
|
||||
from llama_stack.apis.vector_stores import VectorStore
|
||||
from llama_stack.providers.datatypes import HealthStatus
|
||||
from llama_stack.providers.inline.vector_io.faiss.config import FaissVectorIOConfig
|
||||
from llama_stack.providers.inline.vector_io.faiss.faiss import (
|
||||
FaissIndex,
|
||||
FaissVectorIOAdapter,
|
||||
)
|
||||
from llama_stack_api import Chunk, Files, HealthStatus, QueryChunksResponse, VectorStore
|
||||
|
||||
# This test is a unit test for the FaissVectorIOAdapter class. This should only contain
|
||||
# tests which are specific to this class. More general (API-level) tests should be placed in
|
||||
|
|
|
|||
|
|
@ -9,12 +9,12 @@ import asyncio
|
|||
import numpy as np
|
||||
import pytest
|
||||
|
||||
from llama_stack.apis.vector_io import Chunk, QueryChunksResponse
|
||||
from llama_stack.providers.inline.vector_io.sqlite_vec.sqlite_vec import (
|
||||
SQLiteVecIndex,
|
||||
SQLiteVecVectorIOAdapter,
|
||||
_create_sqlite_connection,
|
||||
)
|
||||
from llama_stack_api import Chunk, QueryChunksResponse
|
||||
|
||||
# This test is a unit test for the SQLiteVecVectorIOAdapter class. This should only contain
|
||||
# tests which are specific to this class. More general (API-level) tests should be placed in
|
||||
|
|
|
|||
|
|
@ -11,17 +11,17 @@ from unittest.mock import AsyncMock, patch
|
|||
import numpy as np
|
||||
import pytest
|
||||
|
||||
from llama_stack.apis.common.errors import VectorStoreNotFoundError
|
||||
from llama_stack.apis.vector_io import (
|
||||
from llama_stack.providers.inline.vector_io.sqlite_vec.sqlite_vec import VECTOR_DBS_PREFIX
|
||||
from llama_stack_api import (
|
||||
Chunk,
|
||||
OpenAICreateVectorStoreFileBatchRequestWithExtraBody,
|
||||
OpenAICreateVectorStoreRequestWithExtraBody,
|
||||
QueryChunksResponse,
|
||||
VectorStore,
|
||||
VectorStoreChunkingStrategyAuto,
|
||||
VectorStoreFileObject,
|
||||
VectorStoreNotFoundError,
|
||||
)
|
||||
from llama_stack.apis.vector_stores import VectorStore
|
||||
from llama_stack.providers.inline.vector_io.sqlite_vec.sqlite_vec import VECTOR_DBS_PREFIX
|
||||
|
||||
# This test is a unit test for the inline VectorIO providers. This should only contain
|
||||
# tests which are specific to this class. More general (API-level) tests should be placed in
|
||||
|
|
@ -92,6 +92,99 @@ async def test_persistence_across_adapter_restarts(vector_io_adapter):
|
|||
await vector_io_adapter.shutdown()
|
||||
|
||||
|
||||
async def test_vector_store_lazy_loading_from_kvstore(vector_io_adapter):
|
||||
"""
|
||||
Test that vector stores can be lazy-loaded from KV store when not in cache.
|
||||
|
||||
Verifies that clearing the cache doesn't break vector store access - they
|
||||
can be loaded on-demand from persistent storage.
|
||||
"""
|
||||
await vector_io_adapter.initialize()
|
||||
|
||||
vector_store_id = f"lazy_load_test_{np.random.randint(1e6)}"
|
||||
vector_store = VectorStore(
|
||||
identifier=vector_store_id,
|
||||
provider_id="test_provider",
|
||||
embedding_model="test_model",
|
||||
embedding_dimension=128,
|
||||
)
|
||||
await vector_io_adapter.register_vector_store(vector_store)
|
||||
assert vector_store_id in vector_io_adapter.cache
|
||||
|
||||
vector_io_adapter.cache.clear()
|
||||
assert vector_store_id not in vector_io_adapter.cache
|
||||
|
||||
loaded_index = await vector_io_adapter._get_and_cache_vector_store_index(vector_store_id)
|
||||
assert loaded_index is not None
|
||||
assert loaded_index.vector_store.identifier == vector_store_id
|
||||
assert vector_store_id in vector_io_adapter.cache
|
||||
|
||||
cached_index = await vector_io_adapter._get_and_cache_vector_store_index(vector_store_id)
|
||||
assert cached_index is loaded_index
|
||||
|
||||
await vector_io_adapter.shutdown()
|
||||
|
||||
|
||||
async def test_vector_store_preloading_on_initialization(vector_io_adapter):
|
||||
"""
|
||||
Test that vector stores are preloaded from KV store during initialization.
|
||||
|
||||
Verifies that after restart, all vector stores are automatically loaded into
|
||||
cache and immediately accessible without requiring lazy loading.
|
||||
"""
|
||||
await vector_io_adapter.initialize()
|
||||
|
||||
vector_store_ids = [f"preload_test_{i}_{np.random.randint(1e6)}" for i in range(3)]
|
||||
for vs_id in vector_store_ids:
|
||||
vector_store = VectorStore(
|
||||
identifier=vs_id,
|
||||
provider_id="test_provider",
|
||||
embedding_model="test_model",
|
||||
embedding_dimension=128,
|
||||
)
|
||||
await vector_io_adapter.register_vector_store(vector_store)
|
||||
|
||||
for vs_id in vector_store_ids:
|
||||
assert vs_id in vector_io_adapter.cache
|
||||
|
||||
await vector_io_adapter.shutdown()
|
||||
await vector_io_adapter.initialize()
|
||||
|
||||
for vs_id in vector_store_ids:
|
||||
assert vs_id in vector_io_adapter.cache
|
||||
|
||||
for vs_id in vector_store_ids:
|
||||
loaded_index = await vector_io_adapter._get_and_cache_vector_store_index(vs_id)
|
||||
assert loaded_index is not None
|
||||
assert loaded_index.vector_store.identifier == vs_id
|
||||
|
||||
await vector_io_adapter.shutdown()
|
||||
|
||||
|
||||
async def test_kvstore_none_raises_runtime_error(vector_io_adapter):
|
||||
"""
|
||||
Test that accessing vector stores with uninitialized kvstore raises RuntimeError.
|
||||
|
||||
Verifies proper RuntimeError is raised instead of assertions when kvstore is None.
|
||||
"""
|
||||
await vector_io_adapter.initialize()
|
||||
|
||||
vector_store_id = f"kvstore_none_test_{np.random.randint(1e6)}"
|
||||
vector_store = VectorStore(
|
||||
identifier=vector_store_id,
|
||||
provider_id="test_provider",
|
||||
embedding_model="test_model",
|
||||
embedding_dimension=128,
|
||||
)
|
||||
await vector_io_adapter.register_vector_store(vector_store)
|
||||
|
||||
vector_io_adapter.cache.clear()
|
||||
vector_io_adapter.kvstore = None
|
||||
|
||||
with pytest.raises(RuntimeError, match="KVStore not initialized"):
|
||||
await vector_io_adapter._get_and_cache_vector_store_index(vector_store_id)
|
||||
|
||||
|
||||
async def test_register_and_unregister_vector_store(vector_io_adapter):
|
||||
unique_id = f"foo_db_{np.random.randint(1e6)}"
|
||||
dummy = VectorStore(
|
||||
|
|
@ -129,7 +222,7 @@ async def test_insert_chunks_missing_db_raises(vector_io_adapter):
|
|||
|
||||
async def test_insert_chunks_with_missing_document_id(vector_io_adapter):
|
||||
"""Ensure no KeyError when document_id is missing or in different places."""
|
||||
from llama_stack.apis.vector_io import Chunk, ChunkMetadata
|
||||
from llama_stack_api import Chunk, ChunkMetadata
|
||||
|
||||
fake_index = AsyncMock()
|
||||
vector_io_adapter.cache["db1"] = fake_index
|
||||
|
|
@ -162,10 +255,9 @@ async def test_insert_chunks_with_missing_document_id(vector_io_adapter):
|
|||
|
||||
async def test_document_id_with_invalid_type_raises_error():
|
||||
"""Ensure TypeError is raised when document_id is not a string."""
|
||||
from llama_stack.apis.vector_io import Chunk
|
||||
|
||||
# Integer document_id should raise TypeError
|
||||
from llama_stack.providers.utils.vector_io.vector_utils import generate_chunk_id
|
||||
from llama_stack_api import Chunk
|
||||
|
||||
chunk = Chunk(content="test", chunk_id=generate_chunk_id("test", "test"), metadata={"document_id": 12345})
|
||||
with pytest.raises(TypeError) as exc_info:
|
||||
|
|
|
|||
|
|
@ -4,8 +4,8 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from llama_stack.apis.vector_io import Chunk, ChunkMetadata
|
||||
from llama_stack.providers.utils.vector_io.vector_utils import generate_chunk_id
|
||||
from llama_stack_api import Chunk, ChunkMetadata, VectorStoreFileObject
|
||||
|
||||
# This test is a unit test for the chunk_utils.py helpers. This should only contain
|
||||
# tests which are specific to this file. More general (API-level) tests should be placed in
|
||||
|
|
@ -78,3 +78,77 @@ def test_chunk_serialization():
|
|||
serialized_chunk = chunk.model_dump()
|
||||
assert serialized_chunk["chunk_id"] == "test-chunk-id"
|
||||
assert "chunk_id" in serialized_chunk
|
||||
|
||||
|
||||
def test_vector_store_file_object_attributes_validation():
|
||||
"""Test VectorStoreFileObject validates and sanitizes attributes at input boundary."""
|
||||
# Test with metadata containing lists, nested dicts, and primitives
|
||||
from llama_stack_api.vector_io import VectorStoreChunkingStrategyAuto
|
||||
|
||||
file_obj = VectorStoreFileObject(
|
||||
id="file-123",
|
||||
attributes={
|
||||
"tags": ["transformers", "h100-compatible", "region:us"], # List -> string
|
||||
"model_name": "granite-3.3-8b", # String preserved
|
||||
"score": 0.95, # Float preserved
|
||||
"active": True, # Bool preserved
|
||||
"count": 42, # Int -> float
|
||||
"nested": {"key": "value"}, # Dict filtered out
|
||||
},
|
||||
chunking_strategy=VectorStoreChunkingStrategyAuto(),
|
||||
created_at=1234567890,
|
||||
status="completed",
|
||||
vector_store_id="vs-123",
|
||||
)
|
||||
|
||||
# Lists converted to comma-separated strings
|
||||
assert file_obj.attributes["tags"] == "transformers, h100-compatible, region:us"
|
||||
# Primitives preserved
|
||||
assert file_obj.attributes["model_name"] == "granite-3.3-8b"
|
||||
assert file_obj.attributes["score"] == 0.95
|
||||
assert file_obj.attributes["active"] is True
|
||||
assert file_obj.attributes["count"] == 42.0 # int -> float
|
||||
# Complex types filtered out
|
||||
assert "nested" not in file_obj.attributes
|
||||
|
||||
|
||||
def test_vector_store_file_object_attributes_constraints():
|
||||
"""Test VectorStoreFileObject enforces OpenAPI constraints on attributes."""
|
||||
from llama_stack_api.vector_io import VectorStoreChunkingStrategyAuto
|
||||
|
||||
# Test max 16 properties
|
||||
many_attrs = {f"key{i}": f"value{i}" for i in range(20)}
|
||||
file_obj = VectorStoreFileObject(
|
||||
id="file-123",
|
||||
attributes=many_attrs,
|
||||
chunking_strategy=VectorStoreChunkingStrategyAuto(),
|
||||
created_at=1234567890,
|
||||
status="completed",
|
||||
vector_store_id="vs-123",
|
||||
)
|
||||
assert len(file_obj.attributes) == 16 # Max 16 properties
|
||||
|
||||
# Test max 64 char keys are filtered
|
||||
long_key_attrs = {"a" * 65: "value", "valid_key": "value"}
|
||||
file_obj = VectorStoreFileObject(
|
||||
id="file-124",
|
||||
attributes=long_key_attrs,
|
||||
chunking_strategy=VectorStoreChunkingStrategyAuto(),
|
||||
created_at=1234567890,
|
||||
status="completed",
|
||||
vector_store_id="vs-123",
|
||||
)
|
||||
assert "a" * 65 not in file_obj.attributes
|
||||
assert "valid_key" in file_obj.attributes
|
||||
|
||||
# Test max 512 char string values are truncated
|
||||
long_value_attrs = {"key": "x" * 600}
|
||||
file_obj = VectorStoreFileObject(
|
||||
id="file-125",
|
||||
attributes=long_value_attrs,
|
||||
chunking_strategy=VectorStoreChunkingStrategyAuto(),
|
||||
created_at=1234567890,
|
||||
status="completed",
|
||||
vector_store_id="vs-123",
|
||||
)
|
||||
assert len(file_obj.attributes["key"]) == 512
|
||||
|
|
|
|||
|
|
@ -8,13 +8,8 @@ from unittest.mock import AsyncMock, MagicMock
|
|||
|
||||
import pytest
|
||||
|
||||
from llama_stack.apis.tools.rag_tool import RAGQueryConfig
|
||||
from llama_stack.apis.vector_io import (
|
||||
Chunk,
|
||||
ChunkMetadata,
|
||||
QueryChunksResponse,
|
||||
)
|
||||
from llama_stack.providers.inline.tool_runtime.rag.memory import MemoryToolRuntimeImpl
|
||||
from llama_stack_api import Chunk, ChunkMetadata, QueryChunksResponse, RAGQueryConfig
|
||||
|
||||
|
||||
class TestRagQuery:
|
||||
|
|
|
|||
|
|
@ -13,12 +13,6 @@ from unittest.mock import AsyncMock, MagicMock
|
|||
import numpy as np
|
||||
import pytest
|
||||
|
||||
from llama_stack.apis.inference.inference import (
|
||||
OpenAIEmbeddingData,
|
||||
OpenAIEmbeddingsRequestWithExtraBody,
|
||||
)
|
||||
from llama_stack.apis.tools import RAGDocument
|
||||
from llama_stack.apis.vector_io import Chunk
|
||||
from llama_stack.providers.utils.memory.vector_store import (
|
||||
URL,
|
||||
VectorStoreWithIndex,
|
||||
|
|
@ -27,6 +21,7 @@ from llama_stack.providers.utils.memory.vector_store import (
|
|||
make_overlapped_chunks,
|
||||
)
|
||||
from llama_stack.providers.utils.vector_io.vector_utils import generate_chunk_id
|
||||
from llama_stack_api import Chunk, OpenAIEmbeddingData, OpenAIEmbeddingsRequestWithExtraBody, RAGDocument
|
||||
|
||||
DUMMY_PDF_PATH = Path(os.path.abspath(__file__)).parent / "fixtures" / "dummy.pdf"
|
||||
# Depending on the machine, this can get parsed a couple of ways
|
||||
|
|
|
|||
|
|
@ -7,16 +7,15 @@
|
|||
|
||||
import pytest
|
||||
|
||||
from llama_stack.apis.inference import Model
|
||||
from llama_stack.apis.vector_stores import VectorStore
|
||||
from llama_stack.core.datatypes import VectorStoreWithOwner
|
||||
from llama_stack.core.storage.datatypes import KVStoreReference, SqliteKVStoreConfig
|
||||
from llama_stack.core.storage.kvstore import kvstore_impl, register_kvstore_backends
|
||||
from llama_stack.core.store.registry import (
|
||||
KEY_FORMAT,
|
||||
CachedDiskDistributionRegistry,
|
||||
DiskDistributionRegistry,
|
||||
)
|
||||
from llama_stack.providers.utils.kvstore import kvstore_impl, register_kvstore_backends
|
||||
from llama_stack_api import Model, VectorStore
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
|
|
@ -304,8 +303,8 @@ async def test_double_registration_different_objects(disk_dist_registry):
|
|||
|
||||
async def test_double_registration_with_cache(cached_disk_dist_registry):
|
||||
"""Test double registration behavior with caching enabled."""
|
||||
from llama_stack.apis.models import ModelType
|
||||
from llama_stack.core.datatypes import ModelWithOwner
|
||||
from llama_stack_api import ModelType
|
||||
|
||||
model1 = ModelWithOwner(
|
||||
identifier="test_model",
|
||||
|
|
|
|||
|
|
@ -5,9 +5,9 @@
|
|||
# the root directory of this source tree.
|
||||
|
||||
|
||||
from llama_stack.apis.models import ModelType
|
||||
from llama_stack.core.datatypes import ModelWithOwner, User
|
||||
from llama_stack.core.store.registry import CachedDiskDistributionRegistry
|
||||
from llama_stack_api import ModelType
|
||||
|
||||
|
||||
async def test_registry_cache_with_acl(cached_disk_dist_registry):
|
||||
|
|
|
|||
|
|
@ -10,11 +10,10 @@ import pytest
|
|||
import yaml
|
||||
from pydantic import TypeAdapter, ValidationError
|
||||
|
||||
from llama_stack.apis.datatypes import Api
|
||||
from llama_stack.apis.models import ModelType
|
||||
from llama_stack.core.access_control.access_control import AccessDeniedError, is_action_allowed
|
||||
from llama_stack.core.datatypes import AccessRule, ModelWithOwner, User
|
||||
from llama_stack.core.routing_tables.models import ModelsRoutingTable
|
||||
from llama_stack_api import Api, ModelType
|
||||
|
||||
|
||||
class AsyncMock(MagicMock):
|
||||
|
|
|
|||
|
|
@ -144,7 +144,7 @@ def middleware_with_mocks(mock_auth_endpoint):
|
|||
middleware = AuthenticationMiddleware(mock_app, auth_config, {})
|
||||
|
||||
# Mock the route_impls to simulate finding routes with required scopes
|
||||
from llama_stack.schema_utils import WebMethod
|
||||
from llama_stack_api import WebMethod
|
||||
|
||||
routes = {
|
||||
("POST", "/test/scoped"): WebMethod(route="/test/scoped", method="POST", required_scope="test.read"),
|
||||
|
|
|
|||
|
|
@ -15,7 +15,7 @@ from starlette.middleware.base import BaseHTTPMiddleware
|
|||
from llama_stack.core.datatypes import QuotaConfig, QuotaPeriod
|
||||
from llama_stack.core.server.quota import QuotaMiddleware
|
||||
from llama_stack.core.storage.datatypes import KVStoreReference, SqliteKVStoreConfig
|
||||
from llama_stack.providers.utils.kvstore import register_kvstore_backends
|
||||
from llama_stack.core.storage.kvstore import register_kvstore_backends
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
|
|
|
|||
|
|
@ -11,7 +11,6 @@ from unittest.mock import AsyncMock, MagicMock
|
|||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from llama_stack.apis.inference import Inference
|
||||
from llama_stack.core.datatypes import Api, Provider, StackRunConfig
|
||||
from llama_stack.core.resolver import resolve_impls
|
||||
from llama_stack.core.routers.inference import InferenceRouter
|
||||
|
|
@ -25,9 +24,9 @@ from llama_stack.core.storage.datatypes import (
|
|||
SqlStoreReference,
|
||||
StorageConfig,
|
||||
)
|
||||
from llama_stack.providers.datatypes import InlineProviderSpec, ProviderSpec
|
||||
from llama_stack.providers.utils.kvstore import register_kvstore_backends
|
||||
from llama_stack.providers.utils.sqlstore.sqlstore import register_sqlstore_backends
|
||||
from llama_stack.core.storage.kvstore import register_kvstore_backends
|
||||
from llama_stack.core.storage.sqlstore.sqlstore import register_sqlstore_backends
|
||||
from llama_stack_api import Inference, InlineProviderSpec, ProviderSpec
|
||||
|
||||
|
||||
def add_protocol_methods(cls: type, protocol: type[Protocol]) -> None:
|
||||
|
|
|
|||
48
tests/unit/server/test_schema_registry.py
Normal file
48
tests/unit/server/test_schema_registry.py
Normal file
|
|
@ -0,0 +1,48 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from llama_stack_api import Conversation, SamplingStrategy
|
||||
from llama_stack_api.schema_utils import (
|
||||
clear_dynamic_schema_types,
|
||||
get_registered_schema_info,
|
||||
iter_dynamic_schema_types,
|
||||
iter_json_schema_types,
|
||||
iter_registered_schema_types,
|
||||
register_dynamic_schema_type,
|
||||
)
|
||||
|
||||
|
||||
def test_json_schema_registry_contains_known_model() -> None:
|
||||
assert Conversation in iter_json_schema_types()
|
||||
|
||||
|
||||
def test_registered_schema_registry_contains_sampling_strategy() -> None:
|
||||
registered_names = {info.name for info in iter_registered_schema_types()}
|
||||
assert "SamplingStrategy" in registered_names
|
||||
|
||||
schema_info = get_registered_schema_info(SamplingStrategy)
|
||||
assert schema_info is not None
|
||||
assert schema_info.name == "SamplingStrategy"
|
||||
|
||||
|
||||
def test_dynamic_schema_registration_round_trip() -> None:
|
||||
existing_models = tuple(iter_dynamic_schema_types())
|
||||
clear_dynamic_schema_types()
|
||||
try:
|
||||
|
||||
class TemporaryModel(BaseModel):
|
||||
foo: str
|
||||
|
||||
register_dynamic_schema_type(TemporaryModel)
|
||||
assert TemporaryModel in iter_dynamic_schema_types()
|
||||
|
||||
clear_dynamic_schema_types()
|
||||
assert TemporaryModel not in iter_dynamic_schema_types()
|
||||
finally:
|
||||
for model in existing_models:
|
||||
register_dynamic_schema_type(model)
|
||||
|
|
@ -12,7 +12,7 @@ from pydantic import ValidationError
|
|||
|
||||
from llama_stack.core.access_control.access_control import AccessDeniedError
|
||||
from llama_stack.core.datatypes import AuthenticationRequiredError
|
||||
from llama_stack.core.server.server import translate_exception
|
||||
from llama_stack.core.server.server import remove_disabled_providers, translate_exception
|
||||
|
||||
|
||||
class TestTranslateException:
|
||||
|
|
@ -194,3 +194,70 @@ class TestTranslateException:
|
|||
assert isinstance(result3, HTTPException)
|
||||
assert result3.status_code == 403
|
||||
assert result3.detail == "Permission denied: Access denied"
|
||||
|
||||
|
||||
class TestRemoveDisabledProviders:
|
||||
"""Test cases for the remove_disabled_providers function."""
|
||||
|
||||
def test_remove_explicitly_disabled_provider(self):
|
||||
"""Test that providers with provider_id='__disabled__' are removed."""
|
||||
config = {
|
||||
"providers": {
|
||||
"inference": [
|
||||
{"provider_id": "openai", "provider_type": "remote::openai", "config": {}},
|
||||
{"provider_id": "__disabled__", "provider_type": "remote::vllm", "config": {}},
|
||||
]
|
||||
}
|
||||
}
|
||||
result = remove_disabled_providers(config)
|
||||
assert len(result["providers"]["inference"]) == 1
|
||||
assert result["providers"]["inference"][0]["provider_id"] == "openai"
|
||||
|
||||
def test_remove_empty_provider_id(self):
|
||||
"""Test that providers with empty provider_id are removed."""
|
||||
config = {
|
||||
"providers": {
|
||||
"inference": [
|
||||
{"provider_id": "openai", "provider_type": "remote::openai", "config": {}},
|
||||
{"provider_id": "", "provider_type": "remote::vllm", "config": {}},
|
||||
]
|
||||
}
|
||||
}
|
||||
result = remove_disabled_providers(config)
|
||||
assert len(result["providers"]["inference"]) == 1
|
||||
assert result["providers"]["inference"][0]["provider_id"] == "openai"
|
||||
|
||||
def test_keep_models_with_none_provider_model_id(self):
|
||||
"""Test that models with None provider_model_id are NOT removed."""
|
||||
config = {
|
||||
"registered_resources": {
|
||||
"models": [
|
||||
{
|
||||
"model_id": "llama-3-2-3b",
|
||||
"provider_id": "vllm-inference",
|
||||
"model_type": "llm",
|
||||
"provider_model_id": None,
|
||||
"metadata": {},
|
||||
},
|
||||
{
|
||||
"model_id": "gpt-4o-mini",
|
||||
"provider_id": "openai",
|
||||
"model_type": "llm",
|
||||
"provider_model_id": None,
|
||||
"metadata": {},
|
||||
},
|
||||
{
|
||||
"model_id": "granite-embedding-125m",
|
||||
"provider_id": "sentence-transformers",
|
||||
"model_type": "embedding",
|
||||
"provider_model_id": "ibm-granite/granite-embedding-125m-english",
|
||||
"metadata": {"embedding_dimension": 768},
|
||||
},
|
||||
]
|
||||
}
|
||||
}
|
||||
result = remove_disabled_providers(config)
|
||||
assert len(result["registered_resources"]["models"]) == 3
|
||||
assert result["registered_resources"]["models"][0]["model_id"] == "llama-3-2-3b"
|
||||
assert result["registered_resources"]["models"][1]["model_id"] == "gpt-4o-mini"
|
||||
assert result["registered_resources"]["models"][2]["model_id"] == "granite-embedding-125m"
|
||||
|
|
|
|||
|
|
@ -10,8 +10,8 @@ from unittest.mock import AsyncMock, MagicMock
|
|||
|
||||
import pytest
|
||||
|
||||
from llama_stack.apis.common.responses import PaginatedResponse
|
||||
from llama_stack.core.server.server import create_dynamic_typed_route, create_sse_event, sse_generator
|
||||
from llama_stack_api import PaginatedResponse
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
|
|
@ -104,12 +104,18 @@ async def test_paginated_response_url_setting():
|
|||
|
||||
route_handler = create_dynamic_typed_route(mock_api_method, "get", "/test/route")
|
||||
|
||||
# Mock minimal request
|
||||
# Mock minimal request with proper state object
|
||||
request = MagicMock()
|
||||
request.scope = {"user_attributes": {}, "principal": ""}
|
||||
request.headers = {}
|
||||
request.body = AsyncMock(return_value=b"")
|
||||
|
||||
# Create a simple state object without auto-generating attributes
|
||||
class MockState:
|
||||
pass
|
||||
|
||||
request.state = MockState()
|
||||
|
||||
result = await route_handler(request)
|
||||
|
||||
assert isinstance(result, PaginatedResponse)
|
||||
|
|
|
|||
|
|
@ -11,8 +11,8 @@ Tests the new input_schema and output_schema fields.
|
|||
|
||||
from pydantic import ValidationError
|
||||
|
||||
from llama_stack.apis.tools import ToolDef
|
||||
from llama_stack.models.llama.datatypes import BuiltinTool, ToolDefinition
|
||||
from llama_stack_api import ToolDef
|
||||
|
||||
|
||||
class TestToolDefValidation:
|
||||
|
|
|
|||
|
|
@ -8,16 +8,16 @@ import time
|
|||
|
||||
import pytest
|
||||
|
||||
from llama_stack.apis.inference import (
|
||||
from llama_stack.core.storage.datatypes import InferenceStoreReference, SqliteSqlStoreConfig
|
||||
from llama_stack.core.storage.sqlstore.sqlstore import register_sqlstore_backends
|
||||
from llama_stack.providers.utils.inference.inference_store import InferenceStore
|
||||
from llama_stack_api import (
|
||||
OpenAIAssistantMessageParam,
|
||||
OpenAIChatCompletion,
|
||||
OpenAIChoice,
|
||||
OpenAIUserMessageParam,
|
||||
Order,
|
||||
)
|
||||
from llama_stack.core.storage.datatypes import InferenceStoreReference, SqliteSqlStoreConfig
|
||||
from llama_stack.providers.utils.inference.inference_store import InferenceStore
|
||||
from llama_stack.providers.utils.sqlstore.sqlstore import register_sqlstore_backends
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
|
|
|
|||
|
|
@ -5,8 +5,8 @@
|
|||
# the root directory of this source tree.
|
||||
|
||||
|
||||
from llama_stack.providers.utils.kvstore.config import SqliteKVStoreConfig
|
||||
from llama_stack.providers.utils.kvstore.sqlite.sqlite import SqliteKVStoreImpl
|
||||
from llama_stack.core.storage.kvstore.config import SqliteKVStoreConfig
|
||||
from llama_stack.core.storage.kvstore.sqlite.sqlite import SqliteKVStoreImpl
|
||||
|
||||
|
||||
async def test_memory_kvstore_persistence_behavior():
|
||||
|
|
|
|||
|
|
@ -10,15 +10,10 @@ from uuid import uuid4
|
|||
|
||||
import pytest
|
||||
|
||||
from llama_stack.apis.agents import Order
|
||||
from llama_stack.apis.agents.openai_responses import (
|
||||
OpenAIResponseInput,
|
||||
OpenAIResponseObject,
|
||||
)
|
||||
from llama_stack.apis.inference import OpenAIMessageParam, OpenAIUserMessageParam
|
||||
from llama_stack.core.storage.datatypes import ResponsesStoreReference, SqliteSqlStoreConfig
|
||||
from llama_stack.core.storage.sqlstore.sqlstore import register_sqlstore_backends
|
||||
from llama_stack.providers.utils.responses.responses_store import ResponsesStore
|
||||
from llama_stack.providers.utils.sqlstore.sqlstore import register_sqlstore_backends
|
||||
from llama_stack_api import OpenAIMessageParam, OpenAIResponseInput, OpenAIResponseObject, OpenAIUserMessageParam, Order
|
||||
|
||||
|
||||
def build_store(db_path: str, policy: list | None = None) -> ResponsesStore:
|
||||
|
|
@ -46,7 +41,7 @@ def create_test_response_object(
|
|||
|
||||
def create_test_response_input(content: str, input_id: str) -> OpenAIResponseInput:
|
||||
"""Helper to create a test response input."""
|
||||
from llama_stack.apis.agents.openai_responses import OpenAIResponseMessage
|
||||
from llama_stack_api import OpenAIResponseMessage
|
||||
|
||||
return OpenAIResponseMessage(
|
||||
id=input_id,
|
||||
|
|
|
|||
|
|
@ -9,9 +9,9 @@ from tempfile import TemporaryDirectory
|
|||
|
||||
import pytest
|
||||
|
||||
from llama_stack.providers.utils.sqlstore.api import ColumnType
|
||||
from llama_stack.providers.utils.sqlstore.sqlalchemy_sqlstore import SqlAlchemySqlStoreImpl
|
||||
from llama_stack.providers.utils.sqlstore.sqlstore import SqliteSqlStoreConfig
|
||||
from llama_stack.core.storage.sqlstore.sqlalchemy_sqlstore import SqlAlchemySqlStoreImpl
|
||||
from llama_stack.core.storage.sqlstore.sqlstore import SqliteSqlStoreConfig
|
||||
from llama_stack_api.internal.sqlstore import ColumnDefinition, ColumnType
|
||||
|
||||
|
||||
async def test_sqlite_sqlstore():
|
||||
|
|
@ -65,6 +65,38 @@ async def test_sqlite_sqlstore():
|
|||
assert result.has_more is False
|
||||
|
||||
|
||||
async def test_sqlstore_upsert_support():
|
||||
with TemporaryDirectory() as tmp_dir:
|
||||
db_path = tmp_dir + "/upsert.db"
|
||||
store = SqlAlchemySqlStoreImpl(SqliteSqlStoreConfig(db_path=db_path))
|
||||
|
||||
await store.create_table(
|
||||
"items",
|
||||
{
|
||||
"id": ColumnDefinition(type=ColumnType.STRING, primary_key=True),
|
||||
"value": ColumnType.STRING,
|
||||
"updated_at": ColumnType.INTEGER,
|
||||
},
|
||||
)
|
||||
|
||||
await store.upsert(
|
||||
table="items",
|
||||
data={"id": "item_1", "value": "first", "updated_at": 1},
|
||||
conflict_columns=["id"],
|
||||
)
|
||||
row = await store.fetch_one("items", {"id": "item_1"})
|
||||
assert row == {"id": "item_1", "value": "first", "updated_at": 1}
|
||||
|
||||
await store.upsert(
|
||||
table="items",
|
||||
data={"id": "item_1", "value": "second", "updated_at": 2},
|
||||
conflict_columns=["id"],
|
||||
update_columns=["value", "updated_at"],
|
||||
)
|
||||
row = await store.fetch_one("items", {"id": "item_1"})
|
||||
assert row == {"id": "item_1", "value": "second", "updated_at": 2}
|
||||
|
||||
|
||||
async def test_sqlstore_pagination_basic():
|
||||
"""Test basic pagination functionality at the SQL store level."""
|
||||
with TemporaryDirectory() as tmp_dir:
|
||||
|
|
|
|||
|
|
@ -10,13 +10,13 @@ from unittest.mock import patch
|
|||
from llama_stack.core.access_control.access_control import default_policy, is_action_allowed
|
||||
from llama_stack.core.access_control.datatypes import Action
|
||||
from llama_stack.core.datatypes import User
|
||||
from llama_stack.providers.utils.sqlstore.api import ColumnType
|
||||
from llama_stack.providers.utils.sqlstore.authorized_sqlstore import AuthorizedSqlStore, SqlRecord
|
||||
from llama_stack.providers.utils.sqlstore.sqlalchemy_sqlstore import SqlAlchemySqlStoreImpl
|
||||
from llama_stack.providers.utils.sqlstore.sqlstore import SqliteSqlStoreConfig
|
||||
from llama_stack.core.storage.sqlstore.authorized_sqlstore import AuthorizedSqlStore, SqlRecord
|
||||
from llama_stack.core.storage.sqlstore.sqlalchemy_sqlstore import SqlAlchemySqlStoreImpl
|
||||
from llama_stack.core.storage.sqlstore.sqlstore import SqliteSqlStoreConfig
|
||||
from llama_stack_api.internal.sqlstore import ColumnType
|
||||
|
||||
|
||||
@patch("llama_stack.providers.utils.sqlstore.authorized_sqlstore.get_authenticated_user")
|
||||
@patch("llama_stack.core.storage.sqlstore.authorized_sqlstore.get_authenticated_user")
|
||||
async def test_authorized_fetch_with_where_sql_access_control(mock_get_authenticated_user):
|
||||
"""Test that fetch_all works correctly with where_sql for access control"""
|
||||
with TemporaryDirectory() as tmp_dir:
|
||||
|
|
@ -78,7 +78,7 @@ async def test_authorized_fetch_with_where_sql_access_control(mock_get_authentic
|
|||
assert row["title"] == "User Document"
|
||||
|
||||
|
||||
@patch("llama_stack.providers.utils.sqlstore.authorized_sqlstore.get_authenticated_user")
|
||||
@patch("llama_stack.core.storage.sqlstore.authorized_sqlstore.get_authenticated_user")
|
||||
async def test_sql_policy_consistency(mock_get_authenticated_user):
|
||||
"""Test that SQL WHERE clause logic exactly matches is_action_allowed policy logic"""
|
||||
with TemporaryDirectory() as tmp_dir:
|
||||
|
|
@ -164,7 +164,7 @@ async def test_sql_policy_consistency(mock_get_authenticated_user):
|
|||
)
|
||||
|
||||
|
||||
@patch("llama_stack.providers.utils.sqlstore.authorized_sqlstore.get_authenticated_user")
|
||||
@patch("llama_stack.core.storage.sqlstore.authorized_sqlstore.get_authenticated_user")
|
||||
async def test_authorized_store_user_attribute_capture(mock_get_authenticated_user):
|
||||
"""Test that user attributes are properly captured during insert"""
|
||||
with TemporaryDirectory() as tmp_dir:
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue