mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-12-11 19:56:03 +00:00
Merge origin/main into add-missing-provider-data-impls
Resolved conflicts in: - benchmarking/k8s-benchmark/stack_run_config.yaml (accepted new storage schema) - llama_stack/providers/remote/inference/cerebras/cerebras.py (kept provider data support) - llama_stack/providers/remote/inference/cerebras/config.py (kept provider data support) - llama_stack/providers/remote/inference/nvidia/config.py (kept provider data support) - llama_stack/providers/remote/inference/runpod/config.py (merged imports) - pyproject.toml (kept databricks-sdk dependency)
This commit is contained in:
commit
9eb9a37ee4
1880 changed files with 804868 additions and 70533 deletions
|
|
@ -23,6 +23,30 @@ def config_with_image_name_int():
|
|||
image_name: 1234
|
||||
apis_to_serve: []
|
||||
built_at: {datetime.now().isoformat()}
|
||||
storage:
|
||||
backends:
|
||||
kv_default:
|
||||
type: kv_sqlite
|
||||
db_path: /tmp/test_kv.db
|
||||
sql_default:
|
||||
type: sql_sqlite
|
||||
db_path: /tmp/test_sql.db
|
||||
stores:
|
||||
metadata:
|
||||
backend: kv_default
|
||||
namespace: metadata
|
||||
inference:
|
||||
backend: sql_default
|
||||
table_name: inference
|
||||
conversations:
|
||||
backend: sql_default
|
||||
table_name: conversations
|
||||
responses:
|
||||
backend: sql_default
|
||||
table_name: responses
|
||||
prompts:
|
||||
backend: kv_default
|
||||
namespace: prompts
|
||||
providers:
|
||||
inference:
|
||||
- provider_id: provider1
|
||||
|
|
@ -54,6 +78,27 @@ def up_to_date_config():
|
|||
image_name: foo
|
||||
apis_to_serve: []
|
||||
built_at: {datetime.now().isoformat()}
|
||||
storage:
|
||||
backends:
|
||||
kv_default:
|
||||
type: kv_sqlite
|
||||
db_path: /tmp/test_kv.db
|
||||
sql_default:
|
||||
type: sql_sqlite
|
||||
db_path: /tmp/test_sql.db
|
||||
stores:
|
||||
metadata:
|
||||
backend: kv_default
|
||||
namespace: metadata
|
||||
inference:
|
||||
backend: sql_default
|
||||
table_name: inference
|
||||
conversations:
|
||||
backend: sql_default
|
||||
table_name: conversations
|
||||
responses:
|
||||
backend: sql_default
|
||||
table_name: responses
|
||||
providers:
|
||||
inference:
|
||||
- provider_id: provider1
|
||||
|
|
|
|||
|
|
@ -4,17 +4,26 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import pytest_socket
|
||||
import logging # allow-direct-logging
|
||||
import os
|
||||
import warnings
|
||||
|
||||
# We need to import the fixtures here so that pytest can find them
|
||||
# but ruff doesn't think they are used and removes the import. "noqa: F401" prevents them from being removed
|
||||
from .fixtures import cached_disk_dist_registry, disk_dist_registry, sqlite_kvstore # noqa: F401
|
||||
import pytest
|
||||
|
||||
|
||||
def pytest_runtest_setup(item):
|
||||
"""Setup for each test - check if network access should be allowed."""
|
||||
if "allow_network" in item.keywords:
|
||||
pytest_socket.enable_socket()
|
||||
else:
|
||||
# Allowing Unix sockets is necessary for some tests that use local servers and mocks
|
||||
pytest_socket.disable_socket(allow_unix_socket=True)
|
||||
def pytest_sessionstart(session) -> None:
|
||||
if "LLAMA_STACK_LOGGING" not in os.environ:
|
||||
os.environ["LLAMA_STACK_LOGGING"] = "all=WARNING"
|
||||
|
||||
# Silence common deprecation spam during unit tests.
|
||||
warnings.filterwarnings("ignore", category=DeprecationWarning)
|
||||
warnings.filterwarnings("ignore", category=PendingDeprecationWarning)
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def suppress_httpx_logs(caplog):
|
||||
"""Suppress httpx INFO logs for all unit tests"""
|
||||
caplog.set_level(logging.WARNING, logger="httpx")
|
||||
|
||||
|
||||
pytest_plugins = ["tests.unit.fixtures"]
|
||||
|
|
|
|||
|
|
@ -20,7 +20,14 @@ from llama_stack.core.conversations.conversations import (
|
|||
ConversationServiceConfig,
|
||||
ConversationServiceImpl,
|
||||
)
|
||||
from llama_stack.providers.utils.sqlstore.sqlstore import SqliteSqlStoreConfig
|
||||
from llama_stack.core.datatypes import StackRunConfig
|
||||
from llama_stack.core.storage.datatypes import (
|
||||
ServerStoresConfig,
|
||||
SqliteSqlStoreConfig,
|
||||
SqlStoreReference,
|
||||
StorageConfig,
|
||||
)
|
||||
from llama_stack.providers.utils.sqlstore.sqlstore import register_sqlstore_backends
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
|
|
@ -28,7 +35,18 @@ async def service():
|
|||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
db_path = Path(tmpdir) / "test_conversations.db"
|
||||
|
||||
config = ConversationServiceConfig(conversations_store=SqliteSqlStoreConfig(db_path=str(db_path)), policy=[])
|
||||
storage = StorageConfig(
|
||||
backends={
|
||||
"sql_test": SqliteSqlStoreConfig(db_path=str(db_path)),
|
||||
},
|
||||
stores=ServerStoresConfig(
|
||||
conversations=SqlStoreReference(backend="sql_test", table_name="openai_conversations"),
|
||||
),
|
||||
)
|
||||
register_sqlstore_backends({"sql_test": storage.backends["sql_test"]})
|
||||
run_config = StackRunConfig(image_name="test", apis=[], providers={}, storage=storage)
|
||||
|
||||
config = ConversationServiceConfig(run_config=run_config, policy=[])
|
||||
service = ConversationServiceImpl(config, {})
|
||||
await service.initialize()
|
||||
yield service
|
||||
|
|
@ -64,7 +82,7 @@ async def test_conversation_items(service):
|
|||
assert len(item_list.data) == 1
|
||||
assert item_list.data[0].id == "msg_test123"
|
||||
|
||||
items = await service.list(conversation.id)
|
||||
items = await service.list_items(conversation.id)
|
||||
assert len(items.data) == 1
|
||||
|
||||
|
||||
|
|
@ -102,7 +120,7 @@ async def test_openai_type_compatibility(service):
|
|||
assert hasattr(item_list, attr)
|
||||
assert item_list.object == "list"
|
||||
|
||||
items = await service.list(conversation.id)
|
||||
items = await service.list_items(conversation.id)
|
||||
item = await service.retrieve(conversation.id, items.data[0].id)
|
||||
item_dict = item.model_dump()
|
||||
|
||||
|
|
@ -121,9 +139,18 @@ async def test_policy_configuration():
|
|||
AccessRule(forbid=Scope(principal="test_user", actions=[Action.CREATE, Action.READ], resource="*"))
|
||||
]
|
||||
|
||||
config = ConversationServiceConfig(
|
||||
conversations_store=SqliteSqlStoreConfig(db_path=str(db_path)), policy=restrictive_policy
|
||||
storage = StorageConfig(
|
||||
backends={
|
||||
"sql_test": SqliteSqlStoreConfig(db_path=str(db_path)),
|
||||
},
|
||||
stores=ServerStoresConfig(
|
||||
conversations=SqlStoreReference(backend="sql_test", table_name="openai_conversations"),
|
||||
),
|
||||
)
|
||||
register_sqlstore_backends({"sql_test": storage.backends["sql_test"]})
|
||||
run_config = StackRunConfig(image_name="test", apis=[], providers={}, storage=storage)
|
||||
|
||||
config = ConversationServiceConfig(run_config=run_config, policy=restrictive_policy)
|
||||
service = ConversationServiceImpl(config, {})
|
||||
await service.initialize()
|
||||
|
||||
|
|
|
|||
43
tests/unit/core/routers/test_safety_router.py
Normal file
43
tests/unit/core/routers/test_safety_router.py
Normal file
|
|
@ -0,0 +1,43 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from unittest.mock import AsyncMock
|
||||
|
||||
from llama_stack.apis.safety.safety import ModerationObject, ModerationObjectResults
|
||||
from llama_stack.apis.shields import ListShieldsResponse, Shield
|
||||
from llama_stack.core.datatypes import SafetyConfig
|
||||
from llama_stack.core.routers.safety import SafetyRouter
|
||||
|
||||
|
||||
async def test_run_moderation_uses_default_shield_when_model_missing():
|
||||
routing_table = AsyncMock()
|
||||
shield = Shield(
|
||||
identifier="shield-1",
|
||||
provider_resource_id="provider/shield-model",
|
||||
provider_id="provider-id",
|
||||
params={},
|
||||
)
|
||||
routing_table.list_shields.return_value = ListShieldsResponse(data=[shield])
|
||||
|
||||
moderation_response = ModerationObject(
|
||||
id="mid",
|
||||
model="shield-1",
|
||||
results=[ModerationObjectResults(flagged=False)],
|
||||
)
|
||||
provider = AsyncMock()
|
||||
provider.run_moderation.return_value = moderation_response
|
||||
routing_table.get_provider_impl.return_value = provider
|
||||
|
||||
router = SafetyRouter(routing_table=routing_table, safety_config=SafetyConfig(default_shield_id="shield-1"))
|
||||
|
||||
result = await router.run_moderation("hello world")
|
||||
|
||||
assert result is moderation_response
|
||||
routing_table.get_provider_impl.assert_awaited_once_with("shield-1")
|
||||
provider.run_moderation.assert_awaited_once()
|
||||
_, kwargs = provider.run_moderation.call_args
|
||||
assert kwargs["model"] == "provider/shield-model"
|
||||
assert kwargs["input"] == "hello world"
|
||||
57
tests/unit/core/routers/test_vector_io.py
Normal file
57
tests/unit/core/routers/test_vector_io.py
Normal file
|
|
@ -0,0 +1,57 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from unittest.mock import AsyncMock, Mock
|
||||
|
||||
import pytest
|
||||
|
||||
from llama_stack.apis.vector_io import OpenAICreateVectorStoreRequestWithExtraBody
|
||||
from llama_stack.core.routers.vector_io import VectorIORouter
|
||||
|
||||
|
||||
async def test_single_provider_auto_selection():
|
||||
# provider_id automatically selected during vector store create() when only one provider available
|
||||
mock_routing_table = Mock()
|
||||
mock_routing_table.impls_by_provider_id = {"inline::faiss": "mock_provider"}
|
||||
mock_routing_table.get_all_with_type = AsyncMock(
|
||||
return_value=[
|
||||
Mock(identifier="all-MiniLM-L6-v2", model_type="embedding", metadata={"embedding_dimension": 384})
|
||||
]
|
||||
)
|
||||
mock_routing_table.register_vector_store = AsyncMock(
|
||||
return_value=Mock(identifier="vs_123", provider_id="inline::faiss", provider_resource_id="vs_123")
|
||||
)
|
||||
mock_routing_table.get_provider_impl = AsyncMock(
|
||||
return_value=Mock(openai_create_vector_store=AsyncMock(return_value=Mock(id="vs_123")))
|
||||
)
|
||||
router = VectorIORouter(mock_routing_table)
|
||||
request = OpenAICreateVectorStoreRequestWithExtraBody.model_validate(
|
||||
{"name": "test_store", "embedding_model": "all-MiniLM-L6-v2"}
|
||||
)
|
||||
|
||||
result = await router.openai_create_vector_store(request)
|
||||
assert result.id == "vs_123"
|
||||
|
||||
|
||||
async def test_create_vector_stores_multiple_providers_missing_provider_id_error():
|
||||
# if multiple providers are available, vector store create will error without provider_id
|
||||
mock_routing_table = Mock()
|
||||
mock_routing_table.impls_by_provider_id = {
|
||||
"inline::faiss": "mock_provider_1",
|
||||
"inline::sqlite-vec": "mock_provider_2",
|
||||
}
|
||||
mock_routing_table.get_all_with_type = AsyncMock(
|
||||
return_value=[
|
||||
Mock(identifier="all-MiniLM-L6-v2", model_type="embedding", metadata={"embedding_dimension": 384})
|
||||
]
|
||||
)
|
||||
router = VectorIORouter(mock_routing_table)
|
||||
request = OpenAICreateVectorStoreRequestWithExtraBody.model_validate(
|
||||
{"name": "test_store", "embedding_model": "all-MiniLM-L6-v2"}
|
||||
)
|
||||
|
||||
with pytest.raises(ValueError, match="Multiple vector_io providers available"):
|
||||
await router.openai_create_vector_store(request)
|
||||
102
tests/unit/core/test_stack_validation.py
Normal file
102
tests/unit/core/test_stack_validation.py
Normal file
|
|
@ -0,0 +1,102 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
"""Unit tests for Stack validation functions."""
|
||||
|
||||
from unittest.mock import AsyncMock
|
||||
|
||||
import pytest
|
||||
|
||||
from llama_stack.apis.models import ListModelsResponse, Model, ModelType
|
||||
from llama_stack.apis.shields import ListShieldsResponse, Shield
|
||||
from llama_stack.core.datatypes import QualifiedModel, SafetyConfig, StackRunConfig, StorageConfig, VectorStoresConfig
|
||||
from llama_stack.core.stack import validate_safety_config, validate_vector_stores_config
|
||||
from llama_stack.providers.datatypes import Api
|
||||
|
||||
|
||||
class TestVectorStoresValidation:
|
||||
async def test_validate_missing_model(self):
|
||||
"""Test validation fails when model not found."""
|
||||
run_config = StackRunConfig(
|
||||
image_name="test",
|
||||
providers={},
|
||||
storage=StorageConfig(backends={}, stores={}),
|
||||
vector_stores=VectorStoresConfig(
|
||||
default_provider_id="faiss",
|
||||
default_embedding_model=QualifiedModel(
|
||||
provider_id="p",
|
||||
model_id="missing",
|
||||
),
|
||||
),
|
||||
)
|
||||
mock_models = AsyncMock()
|
||||
mock_models.list_models.return_value = ListModelsResponse(data=[])
|
||||
|
||||
with pytest.raises(ValueError, match="not found"):
|
||||
await validate_vector_stores_config(run_config.vector_stores, {Api.models: mock_models})
|
||||
|
||||
async def test_validate_success(self):
|
||||
"""Test validation passes with valid model."""
|
||||
run_config = StackRunConfig(
|
||||
image_name="test",
|
||||
providers={},
|
||||
storage=StorageConfig(backends={}, stores={}),
|
||||
vector_stores=VectorStoresConfig(
|
||||
default_provider_id="faiss",
|
||||
default_embedding_model=QualifiedModel(
|
||||
provider_id="p",
|
||||
model_id="valid",
|
||||
),
|
||||
),
|
||||
)
|
||||
mock_models = AsyncMock()
|
||||
mock_models.list_models.return_value = ListModelsResponse(
|
||||
data=[
|
||||
Model(
|
||||
identifier="p/valid", # Must match provider_id/model_id format
|
||||
model_type=ModelType.embedding,
|
||||
metadata={"embedding_dimension": 768},
|
||||
provider_id="p",
|
||||
provider_resource_id="valid",
|
||||
)
|
||||
]
|
||||
)
|
||||
|
||||
await validate_vector_stores_config(run_config.vector_stores, {Api.models: mock_models})
|
||||
|
||||
|
||||
class TestSafetyConfigValidation:
|
||||
async def test_validate_success(self):
|
||||
safety_config = SafetyConfig(default_shield_id="shield-1")
|
||||
|
||||
shield = Shield(
|
||||
identifier="shield-1",
|
||||
provider_id="provider-x",
|
||||
provider_resource_id="model-x",
|
||||
params={},
|
||||
)
|
||||
|
||||
shields_impl = AsyncMock()
|
||||
shields_impl.list_shields.return_value = ListShieldsResponse(data=[shield])
|
||||
|
||||
await validate_safety_config(safety_config, {Api.shields: shields_impl, Api.safety: AsyncMock()})
|
||||
|
||||
async def test_validate_wrong_shield_id(self):
|
||||
safety_config = SafetyConfig(default_shield_id="wrong-shield-id")
|
||||
|
||||
shields_impl = AsyncMock()
|
||||
shields_impl.list_shields.return_value = ListShieldsResponse(
|
||||
data=[
|
||||
Shield(
|
||||
identifier="shield-1",
|
||||
provider_resource_id="model-x",
|
||||
provider_id="provider-x",
|
||||
params={},
|
||||
)
|
||||
]
|
||||
)
|
||||
with pytest.raises(ValueError, match="wrong-shield-id"):
|
||||
await validate_safety_config(safety_config, {Api.shields: shields_impl, Api.safety: AsyncMock()})
|
||||
84
tests/unit/core/test_storage_references.py
Normal file
84
tests/unit/core/test_storage_references.py
Normal file
|
|
@ -0,0 +1,84 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
"""Unit tests for storage backend/reference validation."""
|
||||
|
||||
import pytest
|
||||
from pydantic import ValidationError
|
||||
|
||||
from llama_stack.core.datatypes import (
|
||||
LLAMA_STACK_RUN_CONFIG_VERSION,
|
||||
StackRunConfig,
|
||||
)
|
||||
from llama_stack.core.storage.datatypes import (
|
||||
InferenceStoreReference,
|
||||
KVStoreReference,
|
||||
ServerStoresConfig,
|
||||
SqliteKVStoreConfig,
|
||||
SqliteSqlStoreConfig,
|
||||
SqlStoreReference,
|
||||
StorageConfig,
|
||||
)
|
||||
|
||||
|
||||
def _base_run_config(**overrides):
|
||||
metadata_reference = overrides.pop(
|
||||
"metadata_reference",
|
||||
KVStoreReference(backend="kv_default", namespace="registry"),
|
||||
)
|
||||
inference_reference = overrides.pop(
|
||||
"inference_reference",
|
||||
InferenceStoreReference(backend="sql_default", table_name="inference"),
|
||||
)
|
||||
conversations_reference = overrides.pop(
|
||||
"conversations_reference",
|
||||
SqlStoreReference(backend="sql_default", table_name="conversations"),
|
||||
)
|
||||
storage = overrides.pop(
|
||||
"storage",
|
||||
StorageConfig(
|
||||
backends={
|
||||
"kv_default": SqliteKVStoreConfig(db_path="/tmp/kv.db"),
|
||||
"sql_default": SqliteSqlStoreConfig(db_path="/tmp/sql.db"),
|
||||
},
|
||||
stores=ServerStoresConfig(
|
||||
metadata=metadata_reference,
|
||||
inference=inference_reference,
|
||||
conversations=conversations_reference,
|
||||
),
|
||||
),
|
||||
)
|
||||
return StackRunConfig(
|
||||
version=LLAMA_STACK_RUN_CONFIG_VERSION,
|
||||
image_name="test-distro",
|
||||
apis=[],
|
||||
providers={},
|
||||
storage=storage,
|
||||
**overrides,
|
||||
)
|
||||
|
||||
|
||||
def test_references_require_known_backend():
|
||||
with pytest.raises(ValidationError, match="unknown backend 'missing'"):
|
||||
_base_run_config(metadata_reference=KVStoreReference(backend="missing", namespace="registry"))
|
||||
|
||||
|
||||
def test_references_must_match_backend_family():
|
||||
with pytest.raises(ValidationError, match="kv_.* is required"):
|
||||
_base_run_config(metadata_reference=KVStoreReference(backend="sql_default", namespace="registry"))
|
||||
|
||||
with pytest.raises(ValidationError, match="sql_.* is required"):
|
||||
_base_run_config(
|
||||
inference_reference=InferenceStoreReference(backend="kv_default", table_name="inference"),
|
||||
)
|
||||
|
||||
|
||||
def test_valid_configuration_passes_validation():
|
||||
config = _base_run_config()
|
||||
stores = config.storage.stores
|
||||
assert stores.metadata is not None and stores.metadata.backend == "kv_default"
|
||||
assert stores.inference is not None and stores.inference.backend == "sql_default"
|
||||
assert stores.conversations is not None and stores.conversations.backend == "sql_default"
|
||||
|
|
@ -11,13 +11,13 @@ from unittest.mock import AsyncMock
|
|||
import pytest
|
||||
|
||||
from llama_stack.apis.common.content_types import URL
|
||||
from llama_stack.apis.common.errors import ModelNotFoundError
|
||||
from llama_stack.apis.common.type_system import NumberType
|
||||
from llama_stack.apis.datasets.datasets import Dataset, DatasetPurpose, URIDataSource
|
||||
from llama_stack.apis.datatypes import Api
|
||||
from llama_stack.apis.models import Model, ModelType
|
||||
from llama_stack.apis.shields.shields import Shield
|
||||
from llama_stack.apis.tools import ListToolDefsResponse, ToolDef, ToolGroup
|
||||
from llama_stack.apis.vector_dbs import VectorDB
|
||||
from llama_stack.core.datatypes import RegistryEntrySource
|
||||
from llama_stack.core.routing_tables.benchmarks import BenchmarksRoutingTable
|
||||
from llama_stack.core.routing_tables.datasets import DatasetsRoutingTable
|
||||
|
|
@ -25,7 +25,6 @@ from llama_stack.core.routing_tables.models import ModelsRoutingTable
|
|||
from llama_stack.core.routing_tables.scoring_functions import ScoringFunctionsRoutingTable
|
||||
from llama_stack.core.routing_tables.shields import ShieldsRoutingTable
|
||||
from llama_stack.core.routing_tables.toolgroups import ToolGroupsRoutingTable
|
||||
from llama_stack.core.routing_tables.vector_dbs import VectorDBsRoutingTable
|
||||
|
||||
|
||||
class Impl:
|
||||
|
|
@ -146,31 +145,6 @@ class ToolGroupsImpl(Impl):
|
|||
)
|
||||
|
||||
|
||||
class VectorDBImpl(Impl):
|
||||
def __init__(self):
|
||||
super().__init__(Api.vector_io)
|
||||
|
||||
async def register_vector_db(self, vector_db: VectorDB):
|
||||
return vector_db
|
||||
|
||||
async def unregister_vector_db(self, vector_db_id: str):
|
||||
return vector_db_id
|
||||
|
||||
async def openai_create_vector_store(self, **kwargs):
|
||||
import time
|
||||
import uuid
|
||||
|
||||
from llama_stack.apis.vector_io.vector_io import VectorStoreFileCounts, VectorStoreObject
|
||||
|
||||
vector_store_id = kwargs.get("provider_vector_db_id") or f"vs_{uuid.uuid4()}"
|
||||
return VectorStoreObject(
|
||||
id=vector_store_id,
|
||||
name=kwargs.get("name", vector_store_id),
|
||||
created_at=int(time.time()),
|
||||
file_counts=VectorStoreFileCounts(completed=0, cancelled=0, failed=0, in_progress=0, total=0),
|
||||
)
|
||||
|
||||
|
||||
async def test_models_routing_table(cached_disk_dist_registry):
|
||||
table = ModelsRoutingTable({"test_provider": InferenceImpl()}, cached_disk_dist_registry, {})
|
||||
await table.initialize()
|
||||
|
|
@ -263,40 +237,6 @@ async def test_shields_routing_table(cached_disk_dist_registry):
|
|||
await table.unregister_shield(identifier="non-existent")
|
||||
|
||||
|
||||
async def test_vectordbs_routing_table(cached_disk_dist_registry):
|
||||
table = VectorDBsRoutingTable({"test_provider": VectorDBImpl()}, cached_disk_dist_registry, {})
|
||||
await table.initialize()
|
||||
|
||||
m_table = ModelsRoutingTable({"test_provider": InferenceImpl()}, cached_disk_dist_registry, {})
|
||||
await m_table.initialize()
|
||||
await m_table.register_model(
|
||||
model_id="test-model",
|
||||
provider_id="test_provider",
|
||||
metadata={"embedding_dimension": 128},
|
||||
model_type=ModelType.embedding,
|
||||
)
|
||||
|
||||
# Register multiple vector databases and verify listing
|
||||
vdb1 = await table.register_vector_db(vector_db_id="test-vectordb", embedding_model="test_provider/test-model")
|
||||
vdb2 = await table.register_vector_db(vector_db_id="test-vectordb-2", embedding_model="test_provider/test-model")
|
||||
vector_dbs = await table.list_vector_dbs()
|
||||
|
||||
assert len(vector_dbs.data) == 2
|
||||
vector_db_ids = {v.identifier for v in vector_dbs.data}
|
||||
assert vdb1.identifier in vector_db_ids
|
||||
assert vdb2.identifier in vector_db_ids
|
||||
|
||||
# Verify they have UUID-based identifiers
|
||||
assert vdb1.identifier.startswith("vs_")
|
||||
assert vdb2.identifier.startswith("vs_")
|
||||
|
||||
await table.unregister_vector_db(vector_db_id=vdb1.identifier)
|
||||
await table.unregister_vector_db(vector_db_id=vdb2.identifier)
|
||||
|
||||
vector_dbs = await table.list_vector_dbs()
|
||||
assert len(vector_dbs.data) == 0
|
||||
|
||||
|
||||
async def test_datasets_routing_table(cached_disk_dist_registry):
|
||||
table = DatasetsRoutingTable({"localfs": DatasetsImpl()}, cached_disk_dist_registry, {})
|
||||
await table.initialize()
|
||||
|
|
@ -354,6 +294,111 @@ async def test_scoring_functions_routing_table(cached_disk_dist_registry):
|
|||
assert len(scoring_functions_list_after_deletion.data) == 0
|
||||
|
||||
|
||||
async def test_double_registration_models_positive(cached_disk_dist_registry):
|
||||
"""Test that registering the same model twice with identical data succeeds."""
|
||||
table = ModelsRoutingTable({"test_provider": InferenceImpl()}, cached_disk_dist_registry, {})
|
||||
await table.initialize()
|
||||
|
||||
# Register a model
|
||||
await table.register_model(model_id="test-model", provider_id="test_provider", metadata={"param1": "value1"})
|
||||
|
||||
# Register the exact same model again - should succeed (idempotent)
|
||||
await table.register_model(model_id="test-model", provider_id="test_provider", metadata={"param1": "value1"})
|
||||
|
||||
# Verify only one model exists
|
||||
models = await table.list_models()
|
||||
assert len(models.data) == 1
|
||||
assert models.data[0].identifier == "test_provider/test-model"
|
||||
|
||||
|
||||
async def test_double_registration_models_negative(cached_disk_dist_registry):
|
||||
"""Test that registering the same model with different data fails."""
|
||||
table = ModelsRoutingTable({"test_provider": InferenceImpl()}, cached_disk_dist_registry, {})
|
||||
await table.initialize()
|
||||
|
||||
# Register a model with specific metadata
|
||||
await table.register_model(model_id="test-model", provider_id="test_provider", metadata={"param1": "value1"})
|
||||
|
||||
# Try to register the same model with different metadata - should fail
|
||||
with pytest.raises(
|
||||
ValueError, match="Object of type 'model' and identifier 'test_provider/test-model' already exists"
|
||||
):
|
||||
await table.register_model(
|
||||
model_id="test-model", provider_id="test_provider", metadata={"param1": "different_value"}
|
||||
)
|
||||
|
||||
|
||||
async def test_double_registration_scoring_functions_positive(cached_disk_dist_registry):
|
||||
"""Test that registering the same scoring function twice with identical data succeeds."""
|
||||
table = ScoringFunctionsRoutingTable({"test_provider": ScoringFunctionsImpl()}, cached_disk_dist_registry, {})
|
||||
await table.initialize()
|
||||
|
||||
# Register a scoring function
|
||||
await table.register_scoring_function(
|
||||
scoring_fn_id="test-scoring-fn",
|
||||
provider_id="test_provider",
|
||||
description="Test scoring function",
|
||||
return_type=NumberType(),
|
||||
)
|
||||
|
||||
# Register the exact same scoring function again - should succeed (idempotent)
|
||||
await table.register_scoring_function(
|
||||
scoring_fn_id="test-scoring-fn",
|
||||
provider_id="test_provider",
|
||||
description="Test scoring function",
|
||||
return_type=NumberType(),
|
||||
)
|
||||
|
||||
# Verify only one scoring function exists
|
||||
scoring_functions = await table.list_scoring_functions()
|
||||
assert len(scoring_functions.data) == 1
|
||||
assert scoring_functions.data[0].identifier == "test-scoring-fn"
|
||||
|
||||
|
||||
async def test_double_registration_scoring_functions_negative(cached_disk_dist_registry):
|
||||
"""Test that registering the same scoring function with different data fails."""
|
||||
table = ScoringFunctionsRoutingTable({"test_provider": ScoringFunctionsImpl()}, cached_disk_dist_registry, {})
|
||||
await table.initialize()
|
||||
|
||||
# Register a scoring function
|
||||
await table.register_scoring_function(
|
||||
scoring_fn_id="test-scoring-fn",
|
||||
provider_id="test_provider",
|
||||
description="Test scoring function",
|
||||
return_type=NumberType(),
|
||||
)
|
||||
|
||||
# Try to register the same scoring function with different description - should fail
|
||||
with pytest.raises(
|
||||
ValueError, match="Object of type 'scoring_function' and identifier 'test-scoring-fn' already exists"
|
||||
):
|
||||
await table.register_scoring_function(
|
||||
scoring_fn_id="test-scoring-fn",
|
||||
provider_id="test_provider",
|
||||
description="Different description",
|
||||
return_type=NumberType(),
|
||||
)
|
||||
|
||||
|
||||
async def test_double_registration_different_providers(cached_disk_dist_registry):
|
||||
"""Test that registering objects with same ID but different providers succeeds."""
|
||||
impl1 = InferenceImpl()
|
||||
impl2 = InferenceImpl()
|
||||
table = ModelsRoutingTable({"provider1": impl1, "provider2": impl2}, cached_disk_dist_registry, {})
|
||||
await table.initialize()
|
||||
|
||||
# Register same model ID with different providers - should succeed
|
||||
await table.register_model(model_id="shared-model", provider_id="provider1")
|
||||
await table.register_model(model_id="shared-model", provider_id="provider2")
|
||||
|
||||
# Verify both models exist with different identifiers
|
||||
models = await table.list_models()
|
||||
assert len(models.data) == 2
|
||||
model_ids = {m.identifier for m in models.data}
|
||||
assert "provider1/shared-model" in model_ids
|
||||
assert "provider2/shared-model" in model_ids
|
||||
|
||||
|
||||
async def test_benchmarks_routing_table(cached_disk_dist_registry):
|
||||
table = BenchmarksRoutingTable({"test_provider": BenchmarksImpl()}, cached_disk_dist_registry, {})
|
||||
await table.initialize()
|
||||
|
|
@ -406,6 +451,7 @@ async def test_models_alias_registration_and_lookup(cached_disk_dist_registry):
|
|||
await table.initialize()
|
||||
|
||||
# Register model with alias (model_id different from provider_model_id)
|
||||
# NOTE: Aliases are not supported anymore, so this is a no-op
|
||||
await table.register_model(
|
||||
model_id="my-alias", provider_model_id="actual-provider-model", provider_id="test_provider"
|
||||
)
|
||||
|
|
@ -414,12 +460,15 @@ async def test_models_alias_registration_and_lookup(cached_disk_dist_registry):
|
|||
models = await table.list_models()
|
||||
assert len(models.data) == 1
|
||||
model = models.data[0]
|
||||
assert model.identifier == "my-alias" # Uses alias as identifier
|
||||
assert model.identifier == "test_provider/actual-provider-model"
|
||||
assert model.provider_resource_id == "actual-provider-model"
|
||||
|
||||
# Test lookup by alias works
|
||||
retrieved_model = await table.get_model("my-alias")
|
||||
assert retrieved_model.identifier == "my-alias"
|
||||
# Test lookup by alias fails
|
||||
with pytest.raises(ModelNotFoundError, match="Model 'my-alias' not found"):
|
||||
await table.get_model("my-alias")
|
||||
|
||||
retrieved_model = await table.get_model("test_provider/actual-provider-model")
|
||||
assert retrieved_model.identifier == "test_provider/actual-provider-model"
|
||||
assert retrieved_model.provider_resource_id == "actual-provider-model"
|
||||
|
||||
|
||||
|
|
@ -450,12 +499,8 @@ async def test_models_multi_provider_disambiguation(cached_disk_dist_registry):
|
|||
assert model2.provider_resource_id == "common-model"
|
||||
|
||||
# Test lookup by unscoped provider_model_id fails with multiple providers error
|
||||
try:
|
||||
with pytest.raises(ModelNotFoundError, match="Model 'common-model' not found"):
|
||||
await table.get_model("common-model")
|
||||
raise AssertionError("Should have raised ValueError for multiple providers")
|
||||
except ValueError as e:
|
||||
assert "Multiple providers found" in str(e)
|
||||
assert "provider1" in str(e) and "provider2" in str(e)
|
||||
|
||||
|
||||
async def test_models_fallback_lookup_behavior(cached_disk_dist_registry):
|
||||
|
|
@ -478,16 +523,12 @@ async def test_models_fallback_lookup_behavior(cached_disk_dist_registry):
|
|||
assert retrieved_model.identifier == "test_provider/test-model"
|
||||
|
||||
# Test lookup by unscoped provider_model_id (fallback via iteration)
|
||||
retrieved_model = await table.get_model("test-model")
|
||||
assert retrieved_model.identifier == "test_provider/test-model"
|
||||
assert retrieved_model.provider_resource_id == "test-model"
|
||||
with pytest.raises(ModelNotFoundError, match="Model 'test-model' not found"):
|
||||
await table.get_model("test-model")
|
||||
|
||||
# Test lookup of non-existent model fails
|
||||
try:
|
||||
with pytest.raises(ModelNotFoundError, match="Model 'non-existent' not found"):
|
||||
await table.get_model("non-existent")
|
||||
raise AssertionError("Should have raised ValueError for non-existent model")
|
||||
except ValueError as e:
|
||||
assert "not found" in str(e)
|
||||
|
||||
|
||||
async def test_models_source_tracking_default(cached_disk_dist_registry):
|
||||
|
|
@ -559,7 +600,7 @@ async def test_models_source_interaction_preserves_default(cached_disk_dist_regi
|
|||
assert len(models.data) == 1
|
||||
user_model = models.data[0]
|
||||
assert user_model.source == RegistryEntrySource.via_register_api
|
||||
assert user_model.identifier == "my-custom-alias"
|
||||
assert user_model.identifier == "test_provider/provider-model-1"
|
||||
assert user_model.provider_resource_id == "provider-model-1"
|
||||
|
||||
# Now simulate provider refresh
|
||||
|
|
@ -586,7 +627,7 @@ async def test_models_source_interaction_preserves_default(cached_disk_dist_regi
|
|||
assert len(models.data) == 2
|
||||
|
||||
# Find the user model and provider model
|
||||
user_model = next((m for m in models.data if m.identifier == "my-custom-alias"), None)
|
||||
user_model = next((m for m in models.data if m.identifier == "test_provider/provider-model-1"), None)
|
||||
provider_model = next((m for m in models.data if m.identifier == "test_provider/different-model"), None)
|
||||
|
||||
assert user_model is not None
|
||||
|
|
|
|||
|
|
@ -1,381 +0,0 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
# Unit tests for the routing tables vector_dbs
|
||||
|
||||
import time
|
||||
import uuid
|
||||
from unittest.mock import AsyncMock
|
||||
|
||||
import pytest
|
||||
|
||||
from llama_stack.apis.datatypes import Api
|
||||
from llama_stack.apis.models import ModelType
|
||||
from llama_stack.apis.vector_dbs.vector_dbs import VectorDB
|
||||
from llama_stack.apis.vector_io.vector_io import (
|
||||
VectorStoreContent,
|
||||
VectorStoreDeleteResponse,
|
||||
VectorStoreFileContentsResponse,
|
||||
VectorStoreFileCounts,
|
||||
VectorStoreFileDeleteResponse,
|
||||
VectorStoreFileObject,
|
||||
VectorStoreObject,
|
||||
VectorStoreSearchResponsePage,
|
||||
)
|
||||
from llama_stack.core.access_control.datatypes import AccessRule, Scope
|
||||
from llama_stack.core.datatypes import User
|
||||
from llama_stack.core.request_headers import request_provider_data_context
|
||||
from llama_stack.core.routing_tables.vector_dbs import VectorDBsRoutingTable
|
||||
from tests.unit.distribution.routers.test_routing_tables import Impl, InferenceImpl, ModelsRoutingTable
|
||||
|
||||
|
||||
class VectorDBImpl(Impl):
|
||||
def __init__(self):
|
||||
super().__init__(Api.vector_io)
|
||||
self.vector_stores = {}
|
||||
|
||||
async def register_vector_db(self, vector_db: VectorDB):
|
||||
return vector_db
|
||||
|
||||
async def unregister_vector_db(self, vector_db_id: str):
|
||||
return vector_db_id
|
||||
|
||||
async def openai_retrieve_vector_store(self, vector_store_id):
|
||||
return VectorStoreObject(
|
||||
id=vector_store_id,
|
||||
name="Test Store",
|
||||
created_at=int(time.time()),
|
||||
file_counts=VectorStoreFileCounts(completed=0, cancelled=0, failed=0, in_progress=0, total=0),
|
||||
)
|
||||
|
||||
async def openai_update_vector_store(self, vector_store_id, **kwargs):
|
||||
return VectorStoreObject(
|
||||
id=vector_store_id,
|
||||
name="Updated Store",
|
||||
created_at=int(time.time()),
|
||||
file_counts=VectorStoreFileCounts(completed=0, cancelled=0, failed=0, in_progress=0, total=0),
|
||||
)
|
||||
|
||||
async def openai_delete_vector_store(self, vector_store_id):
|
||||
return VectorStoreDeleteResponse(id=vector_store_id, object="vector_store.deleted", deleted=True)
|
||||
|
||||
async def openai_search_vector_store(self, vector_store_id, query, **kwargs):
|
||||
return VectorStoreSearchResponsePage(
|
||||
object="vector_store.search_results.page", search_query="query", data=[], has_more=False, next_page=None
|
||||
)
|
||||
|
||||
async def openai_attach_file_to_vector_store(self, vector_store_id, file_id, **kwargs):
|
||||
return VectorStoreFileObject(
|
||||
id=file_id,
|
||||
status="completed",
|
||||
chunking_strategy={"type": "auto"},
|
||||
created_at=int(time.time()),
|
||||
vector_store_id=vector_store_id,
|
||||
)
|
||||
|
||||
async def openai_list_files_in_vector_store(self, vector_store_id, **kwargs):
|
||||
return [
|
||||
VectorStoreFileObject(
|
||||
id="1",
|
||||
status="completed",
|
||||
chunking_strategy={"type": "auto"},
|
||||
created_at=int(time.time()),
|
||||
vector_store_id=vector_store_id,
|
||||
)
|
||||
]
|
||||
|
||||
async def openai_retrieve_vector_store_file(self, vector_store_id, file_id):
|
||||
return VectorStoreFileObject(
|
||||
id=file_id,
|
||||
status="completed",
|
||||
chunking_strategy={"type": "auto"},
|
||||
created_at=int(time.time()),
|
||||
vector_store_id=vector_store_id,
|
||||
)
|
||||
|
||||
async def openai_retrieve_vector_store_file_contents(self, vector_store_id, file_id):
|
||||
return VectorStoreFileContentsResponse(
|
||||
file_id=file_id,
|
||||
filename="Sample File name",
|
||||
attributes={"key": "value"},
|
||||
content=[VectorStoreContent(type="text", text="Sample content")],
|
||||
)
|
||||
|
||||
async def openai_update_vector_store_file(self, vector_store_id, file_id, **kwargs):
|
||||
return VectorStoreFileObject(
|
||||
id=file_id,
|
||||
status="completed",
|
||||
chunking_strategy={"type": "auto"},
|
||||
created_at=int(time.time()),
|
||||
vector_store_id=vector_store_id,
|
||||
)
|
||||
|
||||
async def openai_delete_vector_store_file(self, vector_store_id, file_id):
|
||||
return VectorStoreFileDeleteResponse(id=file_id, deleted=True)
|
||||
|
||||
async def openai_create_vector_store(
|
||||
self,
|
||||
name=None,
|
||||
embedding_model=None,
|
||||
embedding_dimension=None,
|
||||
provider_id=None,
|
||||
provider_vector_db_id=None,
|
||||
**kwargs,
|
||||
):
|
||||
vector_store_id = provider_vector_db_id or f"vs_{uuid.uuid4()}"
|
||||
vector_store = VectorStoreObject(
|
||||
id=vector_store_id,
|
||||
name=name or vector_store_id,
|
||||
created_at=int(time.time()),
|
||||
file_counts=VectorStoreFileCounts(completed=0, cancelled=0, failed=0, in_progress=0, total=0),
|
||||
)
|
||||
self.vector_stores[vector_store_id] = vector_store
|
||||
return vector_store
|
||||
|
||||
async def openai_list_vector_stores(self, **kwargs):
|
||||
from llama_stack.apis.vector_io.vector_io import VectorStoreListResponse
|
||||
|
||||
return VectorStoreListResponse(
|
||||
data=list(self.vector_stores.values()), has_more=False, first_id=None, last_id=None
|
||||
)
|
||||
|
||||
|
||||
async def test_vectordbs_routing_table(cached_disk_dist_registry):
|
||||
n = 10
|
||||
table = VectorDBsRoutingTable({"test_provider": VectorDBImpl()}, cached_disk_dist_registry, {})
|
||||
await table.initialize()
|
||||
|
||||
m_table = ModelsRoutingTable({"test_provider": InferenceImpl()}, cached_disk_dist_registry, {})
|
||||
await m_table.initialize()
|
||||
await m_table.register_model(
|
||||
model_id="test-model",
|
||||
provider_id="test_provider",
|
||||
metadata={"embedding_dimension": 128},
|
||||
model_type=ModelType.embedding,
|
||||
)
|
||||
|
||||
# Register multiple vector databases and verify listing
|
||||
vdb_dict = {}
|
||||
for i in range(n):
|
||||
vdb_dict[i] = await table.register_vector_db(vector_db_id=f"test-vectordb-{i}", embedding_model="test-model")
|
||||
|
||||
vector_dbs = await table.list_vector_dbs()
|
||||
|
||||
assert len(vector_dbs.data) == len(vdb_dict)
|
||||
vector_db_ids = {v.identifier for v in vector_dbs.data}
|
||||
for k in vdb_dict:
|
||||
assert vdb_dict[k].identifier in vector_db_ids
|
||||
for k in vdb_dict:
|
||||
await table.unregister_vector_db(vector_db_id=vdb_dict[k].identifier)
|
||||
|
||||
vector_dbs = await table.list_vector_dbs()
|
||||
assert len(vector_dbs.data) == 0
|
||||
|
||||
|
||||
async def test_vector_db_and_vector_store_id_mapping(cached_disk_dist_registry):
|
||||
n = 10
|
||||
impl = VectorDBImpl()
|
||||
table = VectorDBsRoutingTable({"test_provider": impl}, cached_disk_dist_registry, {})
|
||||
await table.initialize()
|
||||
|
||||
m_table = ModelsRoutingTable({"test_provider": InferenceImpl()}, cached_disk_dist_registry, {})
|
||||
await m_table.initialize()
|
||||
await m_table.register_model(
|
||||
model_id="test-model",
|
||||
provider_id="test_provider",
|
||||
metadata={"embedding_dimension": 128},
|
||||
model_type=ModelType.embedding,
|
||||
)
|
||||
|
||||
vdb_dict = {}
|
||||
for i in range(n):
|
||||
vdb_dict[i] = await table.register_vector_db(vector_db_id=f"test-vectordb-{i}", embedding_model="test-model")
|
||||
|
||||
vector_dbs = await table.list_vector_dbs()
|
||||
vector_db_ids = {v.identifier for v in vector_dbs.data}
|
||||
|
||||
vector_stores = await impl.openai_list_vector_stores()
|
||||
vector_store_ids = {v.id for v in vector_stores.data}
|
||||
|
||||
assert vector_db_ids == vector_store_ids, (
|
||||
f"Vector DB IDs {vector_db_ids} don't match vector store IDs {vector_store_ids}"
|
||||
)
|
||||
|
||||
for vector_store in vector_stores.data:
|
||||
vector_db = await table.get_vector_db(vector_store.id)
|
||||
assert vector_store.name == vector_db.vector_db_name, (
|
||||
f"Vector store name {vector_store.name} doesn't match vector store ID {vector_store.id}"
|
||||
)
|
||||
|
||||
for vector_db_id in vector_db_ids:
|
||||
await table.unregister_vector_db(vector_db_id)
|
||||
|
||||
assert len((await table.list_vector_dbs()).data) == 0
|
||||
|
||||
|
||||
async def test_vector_db_id_becomes_vector_store_name(cached_disk_dist_registry):
|
||||
impl = VectorDBImpl()
|
||||
table = VectorDBsRoutingTable({"test_provider": impl}, cached_disk_dist_registry, {})
|
||||
await table.initialize()
|
||||
|
||||
m_table = ModelsRoutingTable({"test_provider": InferenceImpl()}, cached_disk_dist_registry, {})
|
||||
await m_table.initialize()
|
||||
await m_table.register_model(
|
||||
model_id="test-model",
|
||||
provider_id="test_provider",
|
||||
metadata={"embedding_dimension": 128},
|
||||
model_type=ModelType.embedding,
|
||||
)
|
||||
|
||||
user_provided_id = "my-custom-vector-db"
|
||||
await table.register_vector_db(vector_db_id=user_provided_id, embedding_model="test-model")
|
||||
|
||||
vector_stores = await impl.openai_list_vector_stores()
|
||||
assert len(vector_stores.data) == 1
|
||||
|
||||
vector_store = vector_stores.data[0]
|
||||
|
||||
assert vector_store.name == user_provided_id
|
||||
|
||||
assert vector_store.id.startswith("vs_")
|
||||
assert vector_store.id != user_provided_id
|
||||
|
||||
vector_dbs = await table.list_vector_dbs()
|
||||
assert len(vector_dbs.data) == 1
|
||||
assert vector_dbs.data[0].identifier == vector_store.id
|
||||
|
||||
await table.unregister_vector_db(vector_store.id)
|
||||
|
||||
|
||||
async def test_openai_vector_stores_routing_table_roles(cached_disk_dist_registry):
|
||||
impl = VectorDBImpl()
|
||||
impl.openai_retrieve_vector_store = AsyncMock(return_value="OK")
|
||||
table = VectorDBsRoutingTable({"test_provider": impl}, cached_disk_dist_registry, policy=[])
|
||||
m_table = ModelsRoutingTable({"test_provider": InferenceImpl()}, cached_disk_dist_registry, policy=[])
|
||||
authorized_table = "vs1"
|
||||
authorized_team = "team1"
|
||||
unauthorized_team = "team2"
|
||||
|
||||
await m_table.initialize()
|
||||
await m_table.register_model(
|
||||
model_id="test-model",
|
||||
provider_id="test_provider",
|
||||
metadata={"embedding_dimension": 128},
|
||||
model_type=ModelType.embedding,
|
||||
)
|
||||
|
||||
authorized_user = User(principal="alice", attributes={"roles": [authorized_team]})
|
||||
with request_provider_data_context({}, authorized_user):
|
||||
registered_vdb = await table.register_vector_db(vector_db_id="vs1", embedding_model="test-model")
|
||||
authorized_table = registered_vdb.identifier # Use the actual generated ID
|
||||
|
||||
# Authorized reader
|
||||
with request_provider_data_context({}, authorized_user):
|
||||
res = await table.openai_retrieve_vector_store(authorized_table)
|
||||
assert res == "OK"
|
||||
|
||||
# Authorized updater
|
||||
impl.openai_update_vector_store_file = AsyncMock(return_value="UPDATED")
|
||||
with request_provider_data_context({}, authorized_user):
|
||||
res = await table.openai_update_vector_store_file(authorized_table, file_id="file1", attributes={"foo": "bar"})
|
||||
assert res == "UPDATED"
|
||||
|
||||
# Unauthorized reader
|
||||
unauthorized_user = User(principal="eve", attributes={"roles": [unauthorized_team]})
|
||||
with request_provider_data_context({}, unauthorized_user):
|
||||
with pytest.raises(ValueError):
|
||||
await table.openai_retrieve_vector_store(authorized_table)
|
||||
|
||||
# Unauthorized updater
|
||||
with request_provider_data_context({}, unauthorized_user):
|
||||
with pytest.raises(ValueError):
|
||||
await table.openai_update_vector_store_file(authorized_table, file_id="file1", attributes={"foo": "bar"})
|
||||
|
||||
# Authorized deleter
|
||||
impl.openai_delete_vector_store_file = AsyncMock(return_value="DELETED")
|
||||
with request_provider_data_context({}, authorized_user):
|
||||
res = await table.openai_delete_vector_store_file(authorized_table, file_id="file1")
|
||||
assert res == "DELETED"
|
||||
|
||||
# Unauthorized deleter
|
||||
with request_provider_data_context({}, unauthorized_user):
|
||||
with pytest.raises(ValueError):
|
||||
await table.openai_delete_vector_store_file(authorized_table, file_id="file1")
|
||||
|
||||
|
||||
async def test_openai_vector_stores_routing_table_actions(cached_disk_dist_registry):
|
||||
impl = VectorDBImpl()
|
||||
|
||||
policy = [
|
||||
AccessRule(permit=Scope(actions=["create", "read", "update", "delete"]), when="user with admin in roles"),
|
||||
AccessRule(permit=Scope(actions=["read"]), when="user with reader in roles"),
|
||||
]
|
||||
|
||||
table = VectorDBsRoutingTable({"test_provider": impl}, cached_disk_dist_registry, policy=policy)
|
||||
m_table = ModelsRoutingTable({"test_provider": InferenceImpl()}, cached_disk_dist_registry, policy=[])
|
||||
|
||||
vector_db_id = "vs1"
|
||||
file_id = "file-1"
|
||||
|
||||
admin_user = User(principal="admin", attributes={"roles": ["admin"]})
|
||||
read_only_user = User(principal="reader", attributes={"roles": ["reader"]})
|
||||
no_access_user = User(principal="outsider", attributes={"roles": ["no_access"]})
|
||||
|
||||
await m_table.initialize()
|
||||
await m_table.register_model(
|
||||
model_id="test-model",
|
||||
provider_id="test_provider",
|
||||
metadata={"embedding_dimension": 128},
|
||||
model_type=ModelType.embedding,
|
||||
)
|
||||
|
||||
with request_provider_data_context({}, admin_user):
|
||||
registered_vdb = await table.register_vector_db(vector_db_id=vector_db_id, embedding_model="test-model")
|
||||
vector_db_id = registered_vdb.identifier # Use the actual generated ID
|
||||
|
||||
read_methods = [
|
||||
(table.openai_retrieve_vector_store, (vector_db_id,), {}),
|
||||
(table.openai_search_vector_store, (vector_db_id, "query"), {}),
|
||||
(table.openai_list_files_in_vector_store, (vector_db_id,), {}),
|
||||
(table.openai_retrieve_vector_store_file, (vector_db_id, file_id), {}),
|
||||
(table.openai_retrieve_vector_store_file_contents, (vector_db_id, file_id), {}),
|
||||
]
|
||||
update_methods = [
|
||||
(table.openai_update_vector_store, (vector_db_id,), {"name": "Updated DB"}),
|
||||
(table.openai_attach_file_to_vector_store, (vector_db_id, file_id), {}),
|
||||
(table.openai_update_vector_store_file, (vector_db_id, file_id), {"attributes": {"key": "value"}}),
|
||||
]
|
||||
delete_methods = [
|
||||
(table.openai_delete_vector_store_file, (vector_db_id, file_id), {}),
|
||||
(table.openai_delete_vector_store, (vector_db_id,), {}),
|
||||
]
|
||||
|
||||
for user in [admin_user, read_only_user]:
|
||||
with request_provider_data_context({}, user):
|
||||
for method, args, kwargs in read_methods:
|
||||
result = await method(*args, **kwargs)
|
||||
assert result is not None, f"Read operation failed with user {user.principal}"
|
||||
|
||||
with request_provider_data_context({}, no_access_user):
|
||||
for method, args, kwargs in read_methods:
|
||||
with pytest.raises(ValueError):
|
||||
await method(*args, **kwargs)
|
||||
|
||||
with request_provider_data_context({}, admin_user):
|
||||
for method, args, kwargs in update_methods:
|
||||
result = await method(*args, **kwargs)
|
||||
assert result is not None, "Update operation failed with admin user"
|
||||
|
||||
with request_provider_data_context({}, admin_user):
|
||||
for method, args, kwargs in delete_methods:
|
||||
result = await method(*args, **kwargs)
|
||||
assert result is not None, "Delete operation failed with admin user"
|
||||
|
||||
for user in [read_only_user, no_access_user]:
|
||||
with request_provider_data_context({}, user):
|
||||
for method, args, kwargs in delete_methods:
|
||||
with pytest.raises(ValueError):
|
||||
await method(*args, **kwargs)
|
||||
318
tests/unit/distribution/test_api_recordings.py
Normal file
318
tests/unit/distribution/test_api_recordings.py
Normal file
|
|
@ -0,0 +1,318 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import tempfile
|
||||
from pathlib import Path
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
from openai import AsyncOpenAI
|
||||
|
||||
# Import the real Pydantic response types instead of using Mocks
|
||||
from llama_stack.apis.inference import (
|
||||
OpenAIAssistantMessageParam,
|
||||
OpenAIChatCompletion,
|
||||
OpenAIChoice,
|
||||
OpenAIEmbeddingData,
|
||||
OpenAIEmbeddingsResponse,
|
||||
OpenAIEmbeddingUsage,
|
||||
)
|
||||
from llama_stack.testing.api_recorder import (
|
||||
APIRecordingMode,
|
||||
ResponseStorage,
|
||||
api_recording,
|
||||
normalize_inference_request,
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def temp_storage_dir():
|
||||
"""Create a temporary directory for test recordings."""
|
||||
with tempfile.TemporaryDirectory() as temp_dir:
|
||||
yield Path(temp_dir)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def real_openai_chat_response():
|
||||
"""Real OpenAI chat completion response using proper Pydantic objects."""
|
||||
return OpenAIChatCompletion(
|
||||
id="chatcmpl-test123",
|
||||
choices=[
|
||||
OpenAIChoice(
|
||||
index=0,
|
||||
message=OpenAIAssistantMessageParam(
|
||||
role="assistant", content="Hello! I'm doing well, thank you for asking."
|
||||
),
|
||||
finish_reason="stop",
|
||||
)
|
||||
],
|
||||
created=1234567890,
|
||||
model="llama3.2:3b",
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def real_embeddings_response():
|
||||
"""Real OpenAI embeddings response using proper Pydantic objects."""
|
||||
return OpenAIEmbeddingsResponse(
|
||||
object="list",
|
||||
data=[
|
||||
OpenAIEmbeddingData(object="embedding", embedding=[0.1, 0.2, 0.3], index=0),
|
||||
OpenAIEmbeddingData(object="embedding", embedding=[0.4, 0.5, 0.6], index=1),
|
||||
],
|
||||
model="nomic-embed-text",
|
||||
usage=OpenAIEmbeddingUsage(prompt_tokens=6, total_tokens=6),
|
||||
)
|
||||
|
||||
|
||||
class TestInferenceRecording:
|
||||
"""Test the inference recording system."""
|
||||
|
||||
def test_request_normalization(self):
|
||||
"""Test that request normalization produces consistent hashes."""
|
||||
# Test basic normalization
|
||||
hash1 = normalize_inference_request(
|
||||
"POST",
|
||||
"http://localhost:11434/v1/chat/completions",
|
||||
{},
|
||||
{"model": "llama3.2:3b", "messages": [{"role": "user", "content": "Hello world"}], "temperature": 0.7},
|
||||
)
|
||||
|
||||
# Same request should produce same hash
|
||||
hash2 = normalize_inference_request(
|
||||
"POST",
|
||||
"http://localhost:11434/v1/chat/completions",
|
||||
{},
|
||||
{"model": "llama3.2:3b", "messages": [{"role": "user", "content": "Hello world"}], "temperature": 0.7},
|
||||
)
|
||||
|
||||
assert hash1 == hash2
|
||||
|
||||
# Different content should produce different hash
|
||||
hash3 = normalize_inference_request(
|
||||
"POST",
|
||||
"http://localhost:11434/v1/chat/completions",
|
||||
{},
|
||||
{
|
||||
"model": "llama3.2:3b",
|
||||
"messages": [{"role": "user", "content": "Different message"}],
|
||||
"temperature": 0.7,
|
||||
},
|
||||
)
|
||||
|
||||
assert hash1 != hash3
|
||||
|
||||
def test_request_normalization_edge_cases(self):
|
||||
"""Test request normalization is precise about request content."""
|
||||
# Test that different whitespace produces different hashes (no normalization)
|
||||
hash1 = normalize_inference_request(
|
||||
"POST",
|
||||
"http://test/v1/chat/completions",
|
||||
{},
|
||||
{"messages": [{"role": "user", "content": "Hello world\n\n"}]},
|
||||
)
|
||||
hash2 = normalize_inference_request(
|
||||
"POST", "http://test/v1/chat/completions", {}, {"messages": [{"role": "user", "content": "Hello world"}]}
|
||||
)
|
||||
assert hash1 != hash2 # Different whitespace should produce different hashes
|
||||
|
||||
# Test that different float precision produces different hashes (no rounding)
|
||||
hash3 = normalize_inference_request("POST", "http://test/v1/chat/completions", {}, {"temperature": 0.7000001})
|
||||
hash4 = normalize_inference_request("POST", "http://test/v1/chat/completions", {}, {"temperature": 0.7})
|
||||
assert hash3 == hash4 # Small float precision differences should normalize to the same hash
|
||||
|
||||
# String-embedded decimals with excessive precision should also normalize.
|
||||
body_with_precise_scores = {
|
||||
"messages": [
|
||||
{
|
||||
"role": "tool",
|
||||
"content": "score: 0.7472640164649847",
|
||||
}
|
||||
]
|
||||
}
|
||||
body_with_precise_scores_variation = {
|
||||
"messages": [
|
||||
{
|
||||
"role": "tool",
|
||||
"content": "score: 0.74726414959878",
|
||||
}
|
||||
]
|
||||
}
|
||||
hash5 = normalize_inference_request("POST", "http://test/v1/chat/completions", {}, body_with_precise_scores)
|
||||
hash6 = normalize_inference_request(
|
||||
"POST", "http://test/v1/chat/completions", {}, body_with_precise_scores_variation
|
||||
)
|
||||
assert hash5 == hash6
|
||||
|
||||
body_with_close_scores = {
|
||||
"messages": [
|
||||
{
|
||||
"role": "tool",
|
||||
"content": "score: 0.662477492560699",
|
||||
}
|
||||
]
|
||||
}
|
||||
body_with_close_scores_variation = {
|
||||
"messages": [
|
||||
{
|
||||
"role": "tool",
|
||||
"content": "score: 0.6624775971970099",
|
||||
}
|
||||
]
|
||||
}
|
||||
hash7 = normalize_inference_request("POST", "http://test/v1/chat/completions", {}, body_with_close_scores)
|
||||
hash8 = normalize_inference_request(
|
||||
"POST", "http://test/v1/chat/completions", {}, body_with_close_scores_variation
|
||||
)
|
||||
assert hash7 == hash8
|
||||
|
||||
def test_response_storage(self, temp_storage_dir):
|
||||
"""Test the ResponseStorage class."""
|
||||
temp_storage_dir = temp_storage_dir / "test_response_storage"
|
||||
storage = ResponseStorage(temp_storage_dir)
|
||||
|
||||
# Test storing and retrieving a recording
|
||||
request_hash = "test_hash_123"
|
||||
request_data = {
|
||||
"method": "POST",
|
||||
"url": "http://localhost:11434/v1/chat/completions",
|
||||
"endpoint": "/v1/chat/completions",
|
||||
"model": "llama3.2:3b",
|
||||
}
|
||||
response_data = {"body": {"content": "test response"}, "is_streaming": False}
|
||||
|
||||
storage.store_recording(request_hash, request_data, response_data)
|
||||
|
||||
# Verify file storage and retrieval
|
||||
retrieved = storage.find_recording(request_hash)
|
||||
assert retrieved is not None
|
||||
assert retrieved["request"]["model"] == "llama3.2:3b"
|
||||
assert retrieved["response"]["body"]["content"] == "test response"
|
||||
|
||||
async def test_recording_mode(self, temp_storage_dir, real_openai_chat_response):
|
||||
"""Test that recording mode captures and stores responses."""
|
||||
|
||||
async def mock_create(*args, **kwargs):
|
||||
return real_openai_chat_response
|
||||
|
||||
temp_storage_dir = temp_storage_dir / "test_recording_mode"
|
||||
with patch("openai.resources.chat.completions.AsyncCompletions.create", side_effect=mock_create):
|
||||
with api_recording(mode=APIRecordingMode.RECORD, storage_dir=str(temp_storage_dir)):
|
||||
client = AsyncOpenAI(base_url="http://localhost:11434/v1", api_key="test")
|
||||
|
||||
response = await client.chat.completions.create(
|
||||
model="llama3.2:3b",
|
||||
messages=[{"role": "user", "content": "Hello, how are you?"}],
|
||||
temperature=0.7,
|
||||
max_tokens=50,
|
||||
)
|
||||
|
||||
# Verify the response was returned correctly
|
||||
assert response.choices[0].message.content == "Hello! I'm doing well, thank you for asking."
|
||||
|
||||
# Verify recording was stored
|
||||
storage = ResponseStorage(temp_storage_dir)
|
||||
assert storage._get_test_dir().exists()
|
||||
|
||||
async def test_replay_mode(self, temp_storage_dir, real_openai_chat_response):
|
||||
"""Test that replay mode returns stored responses without making real calls."""
|
||||
|
||||
async def mock_create(*args, **kwargs):
|
||||
return real_openai_chat_response
|
||||
|
||||
temp_storage_dir = temp_storage_dir / "test_replay_mode"
|
||||
# First, record a response
|
||||
with patch("openai.resources.chat.completions.AsyncCompletions.create", side_effect=mock_create):
|
||||
with api_recording(mode=APIRecordingMode.RECORD, storage_dir=str(temp_storage_dir)):
|
||||
client = AsyncOpenAI(base_url="http://localhost:11434/v1", api_key="test")
|
||||
|
||||
response = await client.chat.completions.create(
|
||||
model="llama3.2:3b",
|
||||
messages=[{"role": "user", "content": "Hello, how are you?"}],
|
||||
temperature=0.7,
|
||||
max_tokens=50,
|
||||
)
|
||||
|
||||
# Now test replay mode - should not call the original method
|
||||
with patch("openai.resources.chat.completions.AsyncCompletions.create") as mock_create_patch:
|
||||
with api_recording(mode=APIRecordingMode.REPLAY, storage_dir=str(temp_storage_dir)):
|
||||
client = AsyncOpenAI(base_url="http://localhost:11434/v1", api_key="test")
|
||||
|
||||
response = await client.chat.completions.create(
|
||||
model="llama3.2:3b",
|
||||
messages=[{"role": "user", "content": "Hello, how are you?"}],
|
||||
temperature=0.7,
|
||||
max_tokens=50,
|
||||
)
|
||||
|
||||
# Verify we got the recorded response
|
||||
assert response.choices[0].message.content == "Hello! I'm doing well, thank you for asking."
|
||||
|
||||
# Verify the original method was NOT called
|
||||
mock_create_patch.assert_not_called()
|
||||
|
||||
async def test_replay_missing_recording(self, temp_storage_dir):
|
||||
"""Test that replay mode fails when no recording is found."""
|
||||
temp_storage_dir = temp_storage_dir / "test_replay_missing_recording"
|
||||
with patch("openai.resources.chat.completions.AsyncCompletions.create"):
|
||||
with api_recording(mode=APIRecordingMode.REPLAY, storage_dir=str(temp_storage_dir)):
|
||||
client = AsyncOpenAI(base_url="http://localhost:11434/v1", api_key="test")
|
||||
|
||||
with pytest.raises(RuntimeError, match="Recording not found"):
|
||||
await client.chat.completions.create(
|
||||
model="llama3.2:3b", messages=[{"role": "user", "content": "This was never recorded"}]
|
||||
)
|
||||
|
||||
async def test_embeddings_recording(self, temp_storage_dir, real_embeddings_response):
|
||||
"""Test recording and replay of embeddings calls."""
|
||||
|
||||
async def mock_create(*args, **kwargs):
|
||||
return real_embeddings_response
|
||||
|
||||
temp_storage_dir = temp_storage_dir / "test_embeddings_recording"
|
||||
# Record
|
||||
with patch("openai.resources.embeddings.AsyncEmbeddings.create", side_effect=mock_create):
|
||||
with api_recording(mode=APIRecordingMode.RECORD, storage_dir=str(temp_storage_dir)):
|
||||
client = AsyncOpenAI(base_url="http://localhost:11434/v1", api_key="test")
|
||||
|
||||
response = await client.embeddings.create(
|
||||
model="nomic-embed-text", input=["Hello world", "Test embedding"]
|
||||
)
|
||||
|
||||
assert len(response.data) == 2
|
||||
|
||||
# Replay
|
||||
with patch("openai.resources.embeddings.AsyncEmbeddings.create") as mock_create_patch:
|
||||
with api_recording(mode=APIRecordingMode.REPLAY, storage_dir=str(temp_storage_dir)):
|
||||
client = AsyncOpenAI(base_url="http://localhost:11434/v1", api_key="test")
|
||||
|
||||
response = await client.embeddings.create(
|
||||
model="nomic-embed-text", input=["Hello world", "Test embedding"]
|
||||
)
|
||||
|
||||
# Verify we got the recorded response
|
||||
assert len(response.data) == 2
|
||||
assert response.data[0].embedding == [0.1, 0.2, 0.3]
|
||||
|
||||
# Verify original method was not called
|
||||
mock_create_patch.assert_not_called()
|
||||
|
||||
async def test_live_mode(self, real_openai_chat_response):
|
||||
"""Test that live mode passes through to original methods."""
|
||||
|
||||
async def mock_create(*args, **kwargs):
|
||||
return real_openai_chat_response
|
||||
|
||||
with patch("openai.resources.chat.completions.AsyncCompletions.create", side_effect=mock_create):
|
||||
with api_recording(mode=APIRecordingMode.LIVE, storage_dir="foo"):
|
||||
client = AsyncOpenAI(base_url="http://localhost:11434/v1", api_key="test")
|
||||
|
||||
response = await client.chat.completions.create(
|
||||
model="llama3.2:3b", messages=[{"role": "user", "content": "Hello"}]
|
||||
)
|
||||
|
||||
# Verify the response was returned
|
||||
assert response.choices[0].message.content == "Hello! I'm doing well, thank you for asking."
|
||||
|
|
@ -1,40 +0,0 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from pathlib import Path
|
||||
|
||||
from llama_stack.cli.stack._build import (
|
||||
_run_stack_build_command_from_build_config,
|
||||
)
|
||||
from llama_stack.core.datatypes import BuildConfig, DistributionSpec
|
||||
from llama_stack.core.utils.image_types import LlamaStackImageType
|
||||
|
||||
|
||||
def test_container_build_passes_path(monkeypatch, tmp_path):
|
||||
called_with = {}
|
||||
|
||||
def spy_build_image(build_config, image_name, distro_or_config, run_config=None):
|
||||
called_with["path"] = distro_or_config
|
||||
called_with["run_config"] = run_config
|
||||
return 0
|
||||
|
||||
monkeypatch.setattr(
|
||||
"llama_stack.cli.stack._build.build_image",
|
||||
spy_build_image,
|
||||
raising=True,
|
||||
)
|
||||
|
||||
cfg = BuildConfig(
|
||||
image_type=LlamaStackImageType.CONTAINER.value,
|
||||
distribution_spec=DistributionSpec(providers={}, description=""),
|
||||
)
|
||||
|
||||
_run_stack_build_command_from_build_config(cfg, image_name="dummy")
|
||||
|
||||
assert "path" in called_with
|
||||
assert isinstance(called_with["path"], str)
|
||||
assert Path(called_with["path"]).exists()
|
||||
assert called_with["run_config"] is None
|
||||
|
|
@ -13,6 +13,15 @@ from pydantic import BaseModel, Field, ValidationError
|
|||
|
||||
from llama_stack.core.datatypes import Api, Provider, StackRunConfig
|
||||
from llama_stack.core.distribution import INTERNAL_APIS, get_provider_registry, providable_apis
|
||||
from llama_stack.core.storage.datatypes import (
|
||||
InferenceStoreReference,
|
||||
KVStoreReference,
|
||||
ServerStoresConfig,
|
||||
SqliteKVStoreConfig,
|
||||
SqliteSqlStoreConfig,
|
||||
SqlStoreReference,
|
||||
StorageConfig,
|
||||
)
|
||||
from llama_stack.providers.datatypes import ProviderSpec
|
||||
|
||||
|
||||
|
|
@ -29,6 +38,33 @@ class SampleConfig(BaseModel):
|
|||
}
|
||||
|
||||
|
||||
def _default_storage() -> StorageConfig:
|
||||
return StorageConfig(
|
||||
backends={
|
||||
"kv_default": SqliteKVStoreConfig(db_path=":memory:"),
|
||||
"sql_default": SqliteSqlStoreConfig(db_path=":memory:"),
|
||||
},
|
||||
stores=ServerStoresConfig(
|
||||
metadata=KVStoreReference(backend="kv_default", namespace="registry"),
|
||||
inference=InferenceStoreReference(backend="sql_default", table_name="inference_store"),
|
||||
conversations=SqlStoreReference(backend="sql_default", table_name="conversations"),
|
||||
prompts=KVStoreReference(backend="kv_default", namespace="prompts"),
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
def make_stack_config(**overrides) -> StackRunConfig:
|
||||
storage = overrides.pop("storage", _default_storage())
|
||||
defaults = dict(
|
||||
image_name="test_image",
|
||||
apis=[],
|
||||
providers={},
|
||||
storage=storage,
|
||||
)
|
||||
defaults.update(overrides)
|
||||
return StackRunConfig(**defaults)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_providers():
|
||||
"""Mock the available_providers function to return test providers."""
|
||||
|
|
@ -47,8 +83,8 @@ def mock_providers():
|
|||
@pytest.fixture
|
||||
def base_config(tmp_path):
|
||||
"""Create a base StackRunConfig with common settings."""
|
||||
return StackRunConfig(
|
||||
image_name="test_image",
|
||||
return make_stack_config(
|
||||
apis=["inference"],
|
||||
providers={
|
||||
"inference": [
|
||||
Provider(
|
||||
|
|
@ -220,8 +256,8 @@ class TestProviderRegistry:
|
|||
|
||||
def test_missing_directory(self, mock_providers):
|
||||
"""Test handling of missing external providers directory."""
|
||||
config = StackRunConfig(
|
||||
image_name="test_image",
|
||||
config = make_stack_config(
|
||||
apis=["inference"],
|
||||
providers={
|
||||
"inference": [
|
||||
Provider(
|
||||
|
|
@ -276,7 +312,6 @@ pip_packages:
|
|||
"""Test loading an external provider from a module (success path)."""
|
||||
from types import SimpleNamespace
|
||||
|
||||
from llama_stack.core.datatypes import Provider, StackRunConfig
|
||||
from llama_stack.providers.datatypes import Api, ProviderSpec
|
||||
|
||||
# Simulate a provider module with get_provider_spec
|
||||
|
|
@ -291,7 +326,7 @@ pip_packages:
|
|||
import_module_side_effect = make_import_module_side_effect(external_module=fake_module)
|
||||
|
||||
with patch("importlib.import_module", side_effect=import_module_side_effect) as mock_import:
|
||||
config = StackRunConfig(
|
||||
config = make_stack_config(
|
||||
image_name="test_image",
|
||||
providers={
|
||||
"inference": [
|
||||
|
|
@ -315,12 +350,11 @@ pip_packages:
|
|||
|
||||
def test_external_provider_from_module_not_found(self, mock_providers):
|
||||
"""Test handling ModuleNotFoundError for missing provider module."""
|
||||
from llama_stack.core.datatypes import Provider, StackRunConfig
|
||||
|
||||
import_module_side_effect = make_import_module_side_effect(raise_for_external=True)
|
||||
|
||||
with patch("importlib.import_module", side_effect=import_module_side_effect):
|
||||
config = StackRunConfig(
|
||||
config = make_stack_config(
|
||||
image_name="test_image",
|
||||
providers={
|
||||
"inference": [
|
||||
|
|
@ -339,12 +373,11 @@ pip_packages:
|
|||
|
||||
def test_external_provider_from_module_missing_get_provider_spec(self, mock_providers):
|
||||
"""Test handling missing get_provider_spec in provider module (should raise ValueError)."""
|
||||
from llama_stack.core.datatypes import Provider, StackRunConfig
|
||||
|
||||
import_module_side_effect = make_import_module_side_effect(missing_get_provider_spec=True)
|
||||
|
||||
with patch("importlib.import_module", side_effect=import_module_side_effect):
|
||||
config = StackRunConfig(
|
||||
config = make_stack_config(
|
||||
image_name="test_image",
|
||||
providers={
|
||||
"inference": [
|
||||
|
|
@ -397,13 +430,12 @@ class TestGetExternalProvidersFromModule:
|
|||
|
||||
def test_stackrunconfig_provider_without_module(self, mock_providers):
|
||||
"""Test that providers without module attribute are skipped."""
|
||||
from llama_stack.core.datatypes import Provider, StackRunConfig
|
||||
from llama_stack.core.distribution import get_external_providers_from_module
|
||||
|
||||
import_module_side_effect = make_import_module_side_effect()
|
||||
|
||||
with patch("importlib.import_module", side_effect=import_module_side_effect):
|
||||
config = StackRunConfig(
|
||||
config = make_stack_config(
|
||||
image_name="test_image",
|
||||
providers={
|
||||
"inference": [
|
||||
|
|
@ -424,7 +456,6 @@ class TestGetExternalProvidersFromModule:
|
|||
"""Test provider with module containing version spec (e.g., package==1.0.0)."""
|
||||
from types import SimpleNamespace
|
||||
|
||||
from llama_stack.core.datatypes import Provider, StackRunConfig
|
||||
from llama_stack.core.distribution import get_external_providers_from_module
|
||||
from llama_stack.providers.datatypes import ProviderSpec
|
||||
|
||||
|
|
@ -442,7 +473,7 @@ class TestGetExternalProvidersFromModule:
|
|||
raise ModuleNotFoundError(name)
|
||||
|
||||
with patch("importlib.import_module", side_effect=import_side_effect):
|
||||
config = StackRunConfig(
|
||||
config = make_stack_config(
|
||||
image_name="test_image",
|
||||
providers={
|
||||
"inference": [
|
||||
|
|
@ -562,7 +593,6 @@ class TestGetExternalProvidersFromModule:
|
|||
"""Test when get_provider_spec returns a list of specs."""
|
||||
from types import SimpleNamespace
|
||||
|
||||
from llama_stack.core.datatypes import Provider, StackRunConfig
|
||||
from llama_stack.core.distribution import get_external_providers_from_module
|
||||
from llama_stack.providers.datatypes import ProviderSpec
|
||||
|
||||
|
|
@ -587,7 +617,7 @@ class TestGetExternalProvidersFromModule:
|
|||
raise ModuleNotFoundError(name)
|
||||
|
||||
with patch("importlib.import_module", side_effect=import_side_effect):
|
||||
config = StackRunConfig(
|
||||
config = make_stack_config(
|
||||
image_name="test_image",
|
||||
providers={
|
||||
"inference": [
|
||||
|
|
@ -611,7 +641,6 @@ class TestGetExternalProvidersFromModule:
|
|||
"""Test that list return filters specs by provider_type."""
|
||||
from types import SimpleNamespace
|
||||
|
||||
from llama_stack.core.datatypes import Provider, StackRunConfig
|
||||
from llama_stack.core.distribution import get_external_providers_from_module
|
||||
from llama_stack.providers.datatypes import ProviderSpec
|
||||
|
||||
|
|
@ -636,7 +665,7 @@ class TestGetExternalProvidersFromModule:
|
|||
raise ModuleNotFoundError(name)
|
||||
|
||||
with patch("importlib.import_module", side_effect=import_side_effect):
|
||||
config = StackRunConfig(
|
||||
config = make_stack_config(
|
||||
image_name="test_image",
|
||||
providers={
|
||||
"inference": [
|
||||
|
|
@ -660,7 +689,6 @@ class TestGetExternalProvidersFromModule:
|
|||
"""Test that list return adds multiple different provider_types when config requests them."""
|
||||
from types import SimpleNamespace
|
||||
|
||||
from llama_stack.core.datatypes import Provider, StackRunConfig
|
||||
from llama_stack.core.distribution import get_external_providers_from_module
|
||||
from llama_stack.providers.datatypes import ProviderSpec
|
||||
|
||||
|
|
@ -686,7 +714,7 @@ class TestGetExternalProvidersFromModule:
|
|||
raise ModuleNotFoundError(name)
|
||||
|
||||
with patch("importlib.import_module", side_effect=import_side_effect):
|
||||
config = StackRunConfig(
|
||||
config = make_stack_config(
|
||||
image_name="test_image",
|
||||
providers={
|
||||
"inference": [
|
||||
|
|
@ -716,7 +744,6 @@ class TestGetExternalProvidersFromModule:
|
|||
|
||||
def test_module_not_found_raises_value_error(self, mock_providers):
|
||||
"""Test that ModuleNotFoundError raises ValueError with helpful message."""
|
||||
from llama_stack.core.datatypes import Provider, StackRunConfig
|
||||
from llama_stack.core.distribution import get_external_providers_from_module
|
||||
|
||||
def import_side_effect(name):
|
||||
|
|
@ -725,7 +752,7 @@ class TestGetExternalProvidersFromModule:
|
|||
raise ModuleNotFoundError(name)
|
||||
|
||||
with patch("importlib.import_module", side_effect=import_side_effect):
|
||||
config = StackRunConfig(
|
||||
config = make_stack_config(
|
||||
image_name="test_image",
|
||||
providers={
|
||||
"inference": [
|
||||
|
|
@ -749,7 +776,6 @@ class TestGetExternalProvidersFromModule:
|
|||
"""Test that generic exceptions are properly raised."""
|
||||
from types import SimpleNamespace
|
||||
|
||||
from llama_stack.core.datatypes import Provider, StackRunConfig
|
||||
from llama_stack.core.distribution import get_external_providers_from_module
|
||||
|
||||
def bad_spec():
|
||||
|
|
@ -763,7 +789,7 @@ class TestGetExternalProvidersFromModule:
|
|||
raise ModuleNotFoundError(name)
|
||||
|
||||
with patch("importlib.import_module", side_effect=import_side_effect):
|
||||
config = StackRunConfig(
|
||||
config = make_stack_config(
|
||||
image_name="test_image",
|
||||
providers={
|
||||
"inference": [
|
||||
|
|
@ -785,10 +811,9 @@ class TestGetExternalProvidersFromModule:
|
|||
|
||||
def test_empty_provider_list(self, mock_providers):
|
||||
"""Test with empty provider list."""
|
||||
from llama_stack.core.datatypes import StackRunConfig
|
||||
from llama_stack.core.distribution import get_external_providers_from_module
|
||||
|
||||
config = StackRunConfig(
|
||||
config = make_stack_config(
|
||||
image_name="test_image",
|
||||
providers={},
|
||||
)
|
||||
|
|
@ -803,7 +828,6 @@ class TestGetExternalProvidersFromModule:
|
|||
"""Test multiple APIs with providers."""
|
||||
from types import SimpleNamespace
|
||||
|
||||
from llama_stack.core.datatypes import Provider, StackRunConfig
|
||||
from llama_stack.core.distribution import get_external_providers_from_module
|
||||
from llama_stack.providers.datatypes import ProviderSpec
|
||||
|
||||
|
|
@ -828,7 +852,7 @@ class TestGetExternalProvidersFromModule:
|
|||
raise ModuleNotFoundError(name)
|
||||
|
||||
with patch("importlib.import_module", side_effect=import_side_effect):
|
||||
config = StackRunConfig(
|
||||
config = make_stack_config(
|
||||
image_name="test_image",
|
||||
providers={
|
||||
"inference": [
|
||||
|
|
|
|||
|
|
@ -1,382 +0,0 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import tempfile
|
||||
from pathlib import Path
|
||||
from unittest.mock import AsyncMock, Mock, patch
|
||||
|
||||
import pytest
|
||||
from openai import NOT_GIVEN, AsyncOpenAI
|
||||
from openai.types.model import Model as OpenAIModel
|
||||
|
||||
# Import the real Pydantic response types instead of using Mocks
|
||||
from llama_stack.apis.inference import (
|
||||
OpenAIAssistantMessageParam,
|
||||
OpenAIChatCompletion,
|
||||
OpenAIChoice,
|
||||
OpenAICompletion,
|
||||
OpenAIEmbeddingData,
|
||||
OpenAIEmbeddingsResponse,
|
||||
OpenAIEmbeddingUsage,
|
||||
)
|
||||
from llama_stack.testing.inference_recorder import (
|
||||
InferenceMode,
|
||||
ResponseStorage,
|
||||
inference_recording,
|
||||
normalize_request,
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def temp_storage_dir():
|
||||
"""Create a temporary directory for test recordings."""
|
||||
with tempfile.TemporaryDirectory() as temp_dir:
|
||||
yield Path(temp_dir)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def real_openai_chat_response():
|
||||
"""Real OpenAI chat completion response using proper Pydantic objects."""
|
||||
return OpenAIChatCompletion(
|
||||
id="chatcmpl-test123",
|
||||
choices=[
|
||||
OpenAIChoice(
|
||||
index=0,
|
||||
message=OpenAIAssistantMessageParam(
|
||||
role="assistant", content="Hello! I'm doing well, thank you for asking."
|
||||
),
|
||||
finish_reason="stop",
|
||||
)
|
||||
],
|
||||
created=1234567890,
|
||||
model="llama3.2:3b",
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def real_embeddings_response():
|
||||
"""Real OpenAI embeddings response using proper Pydantic objects."""
|
||||
return OpenAIEmbeddingsResponse(
|
||||
object="list",
|
||||
data=[
|
||||
OpenAIEmbeddingData(object="embedding", embedding=[0.1, 0.2, 0.3], index=0),
|
||||
OpenAIEmbeddingData(object="embedding", embedding=[0.4, 0.5, 0.6], index=1),
|
||||
],
|
||||
model="nomic-embed-text",
|
||||
usage=OpenAIEmbeddingUsage(prompt_tokens=6, total_tokens=6),
|
||||
)
|
||||
|
||||
|
||||
class TestInferenceRecording:
|
||||
"""Test the inference recording system."""
|
||||
|
||||
def test_request_normalization(self):
|
||||
"""Test that request normalization produces consistent hashes."""
|
||||
# Test basic normalization
|
||||
hash1 = normalize_request(
|
||||
"POST",
|
||||
"http://localhost:11434/v1/chat/completions",
|
||||
{},
|
||||
{"model": "llama3.2:3b", "messages": [{"role": "user", "content": "Hello world"}], "temperature": 0.7},
|
||||
)
|
||||
|
||||
# Same request should produce same hash
|
||||
hash2 = normalize_request(
|
||||
"POST",
|
||||
"http://localhost:11434/v1/chat/completions",
|
||||
{},
|
||||
{"model": "llama3.2:3b", "messages": [{"role": "user", "content": "Hello world"}], "temperature": 0.7},
|
||||
)
|
||||
|
||||
assert hash1 == hash2
|
||||
|
||||
# Different content should produce different hash
|
||||
hash3 = normalize_request(
|
||||
"POST",
|
||||
"http://localhost:11434/v1/chat/completions",
|
||||
{},
|
||||
{
|
||||
"model": "llama3.2:3b",
|
||||
"messages": [{"role": "user", "content": "Different message"}],
|
||||
"temperature": 0.7,
|
||||
},
|
||||
)
|
||||
|
||||
assert hash1 != hash3
|
||||
|
||||
def test_request_normalization_edge_cases(self):
|
||||
"""Test request normalization is precise about request content."""
|
||||
# Test that different whitespace produces different hashes (no normalization)
|
||||
hash1 = normalize_request(
|
||||
"POST",
|
||||
"http://test/v1/chat/completions",
|
||||
{},
|
||||
{"messages": [{"role": "user", "content": "Hello world\n\n"}]},
|
||||
)
|
||||
hash2 = normalize_request(
|
||||
"POST", "http://test/v1/chat/completions", {}, {"messages": [{"role": "user", "content": "Hello world"}]}
|
||||
)
|
||||
assert hash1 != hash2 # Different whitespace should produce different hashes
|
||||
|
||||
# Test that different float precision produces different hashes (no rounding)
|
||||
hash3 = normalize_request("POST", "http://test/v1/chat/completions", {}, {"temperature": 0.7000001})
|
||||
hash4 = normalize_request("POST", "http://test/v1/chat/completions", {}, {"temperature": 0.7})
|
||||
assert hash3 != hash4 # Different precision should produce different hashes
|
||||
|
||||
def test_response_storage(self, temp_storage_dir):
|
||||
"""Test the ResponseStorage class."""
|
||||
temp_storage_dir = temp_storage_dir / "test_response_storage"
|
||||
storage = ResponseStorage(temp_storage_dir)
|
||||
|
||||
# Test storing and retrieving a recording
|
||||
request_hash = "test_hash_123"
|
||||
request_data = {
|
||||
"method": "POST",
|
||||
"url": "http://localhost:11434/v1/chat/completions",
|
||||
"endpoint": "/v1/chat/completions",
|
||||
"model": "llama3.2:3b",
|
||||
}
|
||||
response_data = {"body": {"content": "test response"}, "is_streaming": False}
|
||||
|
||||
storage.store_recording(request_hash, request_data, response_data)
|
||||
|
||||
# Verify file storage and retrieval
|
||||
retrieved = storage.find_recording(request_hash)
|
||||
assert retrieved is not None
|
||||
assert retrieved["request"]["model"] == "llama3.2:3b"
|
||||
assert retrieved["response"]["body"]["content"] == "test response"
|
||||
|
||||
async def test_recording_mode(self, temp_storage_dir, real_openai_chat_response):
|
||||
"""Test that recording mode captures and stores responses."""
|
||||
temp_storage_dir = temp_storage_dir / "test_recording_mode"
|
||||
with inference_recording(mode=InferenceMode.RECORD, storage_dir=str(temp_storage_dir)):
|
||||
client = AsyncOpenAI(base_url="http://localhost:11434/v1", api_key="test")
|
||||
client.chat.completions._post = AsyncMock(return_value=real_openai_chat_response)
|
||||
|
||||
response = await client.chat.completions.create(
|
||||
model="llama3.2:3b",
|
||||
messages=[{"role": "user", "content": "Hello, how are you?"}],
|
||||
temperature=0.7,
|
||||
max_tokens=50,
|
||||
user=NOT_GIVEN,
|
||||
)
|
||||
|
||||
# Verify the response was returned correctly
|
||||
assert response.choices[0].message.content == "Hello! I'm doing well, thank you for asking."
|
||||
client.chat.completions._post.assert_called_once()
|
||||
|
||||
# Verify recording was stored
|
||||
storage = ResponseStorage(temp_storage_dir)
|
||||
dir = storage._get_test_dir()
|
||||
assert dir.exists()
|
||||
|
||||
async def test_replay_mode(self, temp_storage_dir, real_openai_chat_response):
|
||||
"""Test that replay mode returns stored responses without making real calls."""
|
||||
temp_storage_dir = temp_storage_dir / "test_replay_mode"
|
||||
# First, record a response
|
||||
with inference_recording(mode=InferenceMode.RECORD, storage_dir=str(temp_storage_dir)):
|
||||
client = AsyncOpenAI(base_url="http://localhost:11434/v1", api_key="test")
|
||||
client.chat.completions._post = AsyncMock(return_value=real_openai_chat_response)
|
||||
|
||||
response = await client.chat.completions.create(
|
||||
model="llama3.2:3b",
|
||||
messages=[{"role": "user", "content": "Hello, how are you?"}],
|
||||
temperature=0.7,
|
||||
max_tokens=50,
|
||||
user=NOT_GIVEN,
|
||||
)
|
||||
client.chat.completions._post.assert_called_once()
|
||||
|
||||
# Now test replay mode - should not call the original method
|
||||
with inference_recording(mode=InferenceMode.REPLAY, storage_dir=str(temp_storage_dir)):
|
||||
client = AsyncOpenAI(base_url="http://localhost:11434/v1", api_key="test")
|
||||
client.chat.completions._post = AsyncMock(return_value=real_openai_chat_response)
|
||||
|
||||
response = await client.chat.completions.create(
|
||||
model="llama3.2:3b",
|
||||
messages=[{"role": "user", "content": "Hello, how are you?"}],
|
||||
temperature=0.7,
|
||||
max_tokens=50,
|
||||
)
|
||||
|
||||
# Verify we got the recorded response
|
||||
assert response.choices[0].message.content == "Hello! I'm doing well, thank you for asking."
|
||||
|
||||
# Verify the original method was NOT called
|
||||
client.chat.completions._post.assert_not_called()
|
||||
|
||||
async def test_replay_mode_models(self, temp_storage_dir):
|
||||
"""Test that replay mode returns stored responses without making real model listing calls."""
|
||||
|
||||
async def _async_iterator(models):
|
||||
for model in models:
|
||||
yield model
|
||||
|
||||
models = [
|
||||
OpenAIModel(id="foo", created=1, object="model", owned_by="test"),
|
||||
OpenAIModel(id="bar", created=2, object="model", owned_by="test"),
|
||||
]
|
||||
|
||||
expected_ids = {m.id for m in models}
|
||||
|
||||
temp_storage_dir = temp_storage_dir / "test_replay_mode_models"
|
||||
|
||||
# baseline - mock works without recording
|
||||
client = AsyncOpenAI(base_url="http://localhost:11434/v1", api_key="test")
|
||||
client.models._get_api_list = Mock(return_value=_async_iterator(models))
|
||||
assert {m.id async for m in client.models.list()} == expected_ids
|
||||
client.models._get_api_list.assert_called_once()
|
||||
|
||||
# record the call
|
||||
with inference_recording(mode=InferenceMode.RECORD, storage_dir=temp_storage_dir):
|
||||
client = AsyncOpenAI(base_url="http://localhost:11434/v1", api_key="test")
|
||||
client.models._get_api_list = Mock(return_value=_async_iterator(models))
|
||||
assert {m.id async for m in client.models.list()} == expected_ids
|
||||
client.models._get_api_list.assert_called_once()
|
||||
|
||||
# replay the call
|
||||
with inference_recording(mode=InferenceMode.REPLAY, storage_dir=temp_storage_dir):
|
||||
client = AsyncOpenAI(base_url="http://localhost:11434/v1", api_key="test")
|
||||
client.models._get_api_list = Mock(return_value=_async_iterator(models))
|
||||
assert {m.id async for m in client.models.list()} == expected_ids
|
||||
client.models._get_api_list.assert_not_called()
|
||||
|
||||
async def test_replay_missing_recording(self, temp_storage_dir):
|
||||
"""Test that replay mode fails when no recording is found."""
|
||||
temp_storage_dir = temp_storage_dir / "test_replay_missing_recording"
|
||||
with patch("openai.resources.chat.completions.AsyncCompletions.create"):
|
||||
with inference_recording(mode=InferenceMode.REPLAY, storage_dir=str(temp_storage_dir)):
|
||||
client = AsyncOpenAI(base_url="http://localhost:11434/v1", api_key="test")
|
||||
|
||||
with pytest.raises(RuntimeError, match="No recorded response found"):
|
||||
await client.chat.completions.create(
|
||||
model="llama3.2:3b", messages=[{"role": "user", "content": "This was never recorded"}]
|
||||
)
|
||||
|
||||
async def test_embeddings_recording(self, temp_storage_dir, real_embeddings_response):
|
||||
"""Test recording and replay of embeddings calls."""
|
||||
|
||||
# baseline - mock works without recording
|
||||
client = AsyncOpenAI(base_url="http://localhost:11434/v1", api_key="test")
|
||||
client.embeddings._post = AsyncMock(return_value=real_embeddings_response)
|
||||
response = await client.embeddings.create(
|
||||
model=real_embeddings_response.model,
|
||||
input=["Hello world", "Test embedding"],
|
||||
encoding_format=NOT_GIVEN,
|
||||
)
|
||||
assert len(response.data) == 2
|
||||
assert response.data[0].embedding == [0.1, 0.2, 0.3]
|
||||
client.embeddings._post.assert_called_once()
|
||||
|
||||
temp_storage_dir = temp_storage_dir / "test_embeddings_recording"
|
||||
# Record
|
||||
with inference_recording(mode=InferenceMode.RECORD, storage_dir=str(temp_storage_dir)):
|
||||
client = AsyncOpenAI(base_url="http://localhost:11434/v1", api_key="test")
|
||||
client.embeddings._post = AsyncMock(return_value=real_embeddings_response)
|
||||
|
||||
response = await client.embeddings.create(
|
||||
model=real_embeddings_response.model,
|
||||
input=["Hello world", "Test embedding"],
|
||||
encoding_format=NOT_GIVEN,
|
||||
dimensions=NOT_GIVEN,
|
||||
user=NOT_GIVEN,
|
||||
)
|
||||
|
||||
assert len(response.data) == 2
|
||||
|
||||
# Replay
|
||||
with inference_recording(mode=InferenceMode.REPLAY, storage_dir=str(temp_storage_dir)):
|
||||
client = AsyncOpenAI(base_url="http://localhost:11434/v1", api_key="test")
|
||||
client.embeddings._post = AsyncMock(return_value=real_embeddings_response)
|
||||
|
||||
response = await client.embeddings.create(
|
||||
model=real_embeddings_response.model,
|
||||
input=["Hello world", "Test embedding"],
|
||||
)
|
||||
|
||||
# Verify we got the recorded response
|
||||
assert len(response.data) == 2
|
||||
assert response.data[0].embedding == [0.1, 0.2, 0.3]
|
||||
|
||||
# Verify original method was not called
|
||||
client.embeddings._post.assert_not_called()
|
||||
|
||||
async def test_completions_recording(self, temp_storage_dir):
|
||||
real_completions_response = OpenAICompletion(
|
||||
id="test_completion",
|
||||
object="text_completion",
|
||||
created=1234567890,
|
||||
model="llama3.2:3b",
|
||||
choices=[
|
||||
{
|
||||
"text": "Hello! I'm doing well, thank you for asking.",
|
||||
"index": 0,
|
||||
"logprobs": None,
|
||||
"finish_reason": "stop",
|
||||
}
|
||||
],
|
||||
)
|
||||
|
||||
temp_storage_dir = temp_storage_dir / "test_completions_recording"
|
||||
|
||||
# baseline - mock works without recording
|
||||
client = AsyncOpenAI(base_url="http://localhost:11434/v1", api_key="test")
|
||||
client.completions._post = AsyncMock(return_value=real_completions_response)
|
||||
response = await client.completions.create(
|
||||
model=real_completions_response.model,
|
||||
prompt="Hello, how are you?",
|
||||
temperature=0.7,
|
||||
max_tokens=50,
|
||||
user=NOT_GIVEN,
|
||||
)
|
||||
assert response.choices[0].text == real_completions_response.choices[0].text
|
||||
client.completions._post.assert_called_once()
|
||||
|
||||
# Record
|
||||
with inference_recording(mode=InferenceMode.RECORD, storage_dir=str(temp_storage_dir)):
|
||||
client = AsyncOpenAI(base_url="http://localhost:11434/v1", api_key="test")
|
||||
client.completions._post = AsyncMock(return_value=real_completions_response)
|
||||
|
||||
response = await client.completions.create(
|
||||
model=real_completions_response.model,
|
||||
prompt="Hello, how are you?",
|
||||
temperature=0.7,
|
||||
max_tokens=50,
|
||||
user=NOT_GIVEN,
|
||||
)
|
||||
|
||||
assert response.choices[0].text == real_completions_response.choices[0].text
|
||||
client.completions._post.assert_called_once()
|
||||
|
||||
# Replay
|
||||
with inference_recording(mode=InferenceMode.REPLAY, storage_dir=str(temp_storage_dir)):
|
||||
client = AsyncOpenAI(base_url="http://localhost:11434/v1", api_key="test")
|
||||
client.completions._post = AsyncMock(return_value=real_completions_response)
|
||||
response = await client.completions.create(
|
||||
model=real_completions_response.model,
|
||||
prompt="Hello, how are you?",
|
||||
temperature=0.7,
|
||||
max_tokens=50,
|
||||
)
|
||||
assert response.choices[0].text == real_completions_response.choices[0].text
|
||||
client.completions._post.assert_not_called()
|
||||
|
||||
async def test_live_mode(self, real_openai_chat_response):
|
||||
"""Test that live mode passes through to original methods."""
|
||||
|
||||
async def mock_create(*args, **kwargs):
|
||||
return real_openai_chat_response
|
||||
|
||||
with patch("openai.resources.chat.completions.AsyncCompletions.create", side_effect=mock_create):
|
||||
with inference_recording(mode=InferenceMode.LIVE, storage_dir="foo"):
|
||||
client = AsyncOpenAI(base_url="http://localhost:11434/v1", api_key="test")
|
||||
|
||||
response = await client.chat.completions.create(
|
||||
model="llama3.2:3b", messages=[{"role": "user", "content": "Hello"}]
|
||||
)
|
||||
|
||||
# Verify the response was returned
|
||||
assert response.choices[0].message.content == "Hello! I'm doing well, thank you for asking."
|
||||
50
tests/unit/distribution/test_stack_list_deps.py
Normal file
50
tests/unit/distribution/test_stack_list_deps.py
Normal file
|
|
@ -0,0 +1,50 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import argparse
|
||||
from io import StringIO
|
||||
from unittest.mock import patch
|
||||
|
||||
from llama_stack.cli.stack._list_deps import (
|
||||
run_stack_list_deps_command,
|
||||
)
|
||||
|
||||
|
||||
def test_stack_list_deps_basic():
|
||||
args = argparse.Namespace(
|
||||
config=None,
|
||||
env_name="test-env",
|
||||
providers="inference=remote::ollama",
|
||||
format="deps-only",
|
||||
)
|
||||
|
||||
with patch("sys.stdout", new_callable=StringIO) as mock_stdout:
|
||||
run_stack_list_deps_command(args)
|
||||
output = mock_stdout.getvalue()
|
||||
|
||||
# deps-only format should NOT include "uv pip install" or "Dependencies for"
|
||||
assert "uv pip install" not in output
|
||||
assert "Dependencies for" not in output
|
||||
|
||||
# Check that expected dependencies are present
|
||||
assert "ollama" in output
|
||||
assert "aiohttp" in output
|
||||
assert "fastapi" in output
|
||||
|
||||
|
||||
def test_stack_list_deps_with_distro_uv():
|
||||
args = argparse.Namespace(
|
||||
config="starter",
|
||||
env_name=None,
|
||||
providers=None,
|
||||
format="uv",
|
||||
)
|
||||
|
||||
with patch("sys.stdout", new_callable=StringIO) as mock_stdout:
|
||||
run_stack_list_deps_command(args)
|
||||
output = mock_stdout.getvalue()
|
||||
|
||||
assert "uv pip install" in output
|
||||
|
|
@ -11,11 +11,12 @@ from llama_stack.apis.common.errors import ResourceNotFoundError
|
|||
from llama_stack.apis.common.responses import Order
|
||||
from llama_stack.apis.files import OpenAIFilePurpose
|
||||
from llama_stack.core.access_control.access_control import default_policy
|
||||
from llama_stack.core.storage.datatypes import SqliteSqlStoreConfig, SqlStoreReference
|
||||
from llama_stack.providers.inline.files.localfs import (
|
||||
LocalfsFilesImpl,
|
||||
LocalfsFilesImplConfig,
|
||||
)
|
||||
from llama_stack.providers.utils.sqlstore.sqlstore import SqliteSqlStoreConfig
|
||||
from llama_stack.providers.utils.sqlstore.sqlstore import register_sqlstore_backends
|
||||
|
||||
|
||||
class MockUploadFile:
|
||||
|
|
@ -36,8 +37,11 @@ async def files_provider(tmp_path):
|
|||
storage_dir = tmp_path / "files"
|
||||
db_path = tmp_path / "files_metadata.db"
|
||||
|
||||
backend_name = "sql_localfs_test"
|
||||
register_sqlstore_backends({backend_name: SqliteSqlStoreConfig(db_path=db_path.as_posix())})
|
||||
config = LocalfsFilesImplConfig(
|
||||
storage_dir=storage_dir.as_posix(), metadata_store=SqliteSqlStoreConfig(db_path=db_path.as_posix())
|
||||
storage_dir=storage_dir.as_posix(),
|
||||
metadata_store=SqlStoreReference(backend=backend_name, table_name="files_metadata"),
|
||||
)
|
||||
|
||||
provider = LocalfsFilesImpl(config, default_policy())
|
||||
|
|
|
|||
|
|
@ -9,7 +9,16 @@ import random
|
|||
import pytest
|
||||
|
||||
from llama_stack.core.prompts.prompts import PromptServiceConfig, PromptServiceImpl
|
||||
from llama_stack.providers.utils.kvstore.config import SqliteKVStoreConfig
|
||||
from llama_stack.core.storage.datatypes import (
|
||||
InferenceStoreReference,
|
||||
KVStoreReference,
|
||||
ServerStoresConfig,
|
||||
SqliteKVStoreConfig,
|
||||
SqliteSqlStoreConfig,
|
||||
SqlStoreReference,
|
||||
StorageConfig,
|
||||
)
|
||||
from llama_stack.providers.utils.kvstore import register_kvstore_backends
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
|
|
@ -19,12 +28,29 @@ async def temp_prompt_store(tmp_path_factory):
|
|||
db_path = str(temp_dir / f"{unique_id}.db")
|
||||
|
||||
from llama_stack.core.datatypes import StackRunConfig
|
||||
from llama_stack.providers.utils.kvstore import kvstore_impl
|
||||
|
||||
mock_run_config = StackRunConfig(image_name="test-distribution", apis=[], providers={})
|
||||
storage = StorageConfig(
|
||||
backends={
|
||||
"kv_test": SqliteKVStoreConfig(db_path=db_path),
|
||||
"sql_test": SqliteSqlStoreConfig(db_path=str(temp_dir / f"{unique_id}_sql.db")),
|
||||
},
|
||||
stores=ServerStoresConfig(
|
||||
metadata=KVStoreReference(backend="kv_test", namespace="registry"),
|
||||
inference=InferenceStoreReference(backend="sql_test", table_name="inference"),
|
||||
conversations=SqlStoreReference(backend="sql_test", table_name="conversations"),
|
||||
prompts=KVStoreReference(backend="kv_test", namespace="prompts"),
|
||||
),
|
||||
)
|
||||
mock_run_config = StackRunConfig(
|
||||
image_name="test-distribution",
|
||||
apis=[],
|
||||
providers={},
|
||||
storage=storage,
|
||||
)
|
||||
config = PromptServiceConfig(run_config=mock_run_config)
|
||||
store = PromptServiceImpl(config, deps={})
|
||||
|
||||
store.kvstore = await kvstore_impl(SqliteKVStoreConfig(db_path=db_path))
|
||||
register_kvstore_backends({"kv_test": storage.backends["kv_test"]})
|
||||
await store.initialize()
|
||||
|
||||
yield store
|
||||
|
|
|
|||
|
|
@ -15,6 +15,7 @@ from llama_stack.apis.agents import (
|
|||
AgentCreateResponse,
|
||||
)
|
||||
from llama_stack.apis.common.responses import PaginatedResponse
|
||||
from llama_stack.apis.conversations import Conversations
|
||||
from llama_stack.apis.inference import Inference
|
||||
from llama_stack.apis.safety import Safety
|
||||
from llama_stack.apis.tools import ListToolDefsResponse, ToolDef, ToolGroups, ToolRuntime
|
||||
|
|
@ -25,6 +26,20 @@ from llama_stack.providers.inline.agents.meta_reference.config import MetaRefere
|
|||
from llama_stack.providers.inline.agents.meta_reference.persistence import AgentInfo
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def setup_backends(tmp_path):
|
||||
"""Register KV and SQL store backends for testing."""
|
||||
from llama_stack.core.storage.datatypes import SqliteKVStoreConfig, SqliteSqlStoreConfig
|
||||
from llama_stack.providers.utils.kvstore.kvstore import register_kvstore_backends
|
||||
from llama_stack.providers.utils.sqlstore.sqlstore import register_sqlstore_backends
|
||||
|
||||
kv_path = str(tmp_path / "test_kv.db")
|
||||
sql_path = str(tmp_path / "test_sql.db")
|
||||
|
||||
register_kvstore_backends({"kv_default": SqliteKVStoreConfig(db_path=kv_path)})
|
||||
register_sqlstore_backends({"sql_default": SqliteSqlStoreConfig(db_path=sql_path)})
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_apis():
|
||||
return {
|
||||
|
|
@ -33,20 +48,26 @@ def mock_apis():
|
|||
"safety_api": AsyncMock(spec=Safety),
|
||||
"tool_runtime_api": AsyncMock(spec=ToolRuntime),
|
||||
"tool_groups_api": AsyncMock(spec=ToolGroups),
|
||||
"conversations_api": AsyncMock(spec=Conversations),
|
||||
}
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def config(tmp_path):
|
||||
from llama_stack.core.storage.datatypes import KVStoreReference, ResponsesStoreReference
|
||||
from llama_stack.providers.inline.agents.meta_reference.config import AgentPersistenceConfig
|
||||
|
||||
return MetaReferenceAgentsImplConfig(
|
||||
persistence_store={
|
||||
"type": "sqlite",
|
||||
"db_path": str(tmp_path / "test.db"),
|
||||
},
|
||||
responses_store={
|
||||
"type": "sqlite",
|
||||
"db_path": str(tmp_path / "test.db"),
|
||||
},
|
||||
persistence=AgentPersistenceConfig(
|
||||
agent_state=KVStoreReference(
|
||||
backend="kv_default",
|
||||
namespace="agents",
|
||||
),
|
||||
responses=ResponsesStoreReference(
|
||||
backend="sql_default",
|
||||
table_name="responses",
|
||||
),
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
|
|
@ -59,7 +80,8 @@ async def agents_impl(config, mock_apis):
|
|||
mock_apis["safety_api"],
|
||||
mock_apis["tool_runtime_api"],
|
||||
mock_apis["tool_groups_api"],
|
||||
{},
|
||||
mock_apis["conversations_api"],
|
||||
[],
|
||||
)
|
||||
await impl.initialize()
|
||||
yield impl
|
||||
|
|
|
|||
|
|
@ -4,7 +4,7 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from unittest.mock import AsyncMock
|
||||
from unittest.mock import AsyncMock, patch
|
||||
|
||||
import pytest
|
||||
from openai.types.chat.chat_completion_chunk import (
|
||||
|
|
@ -20,9 +20,11 @@ from llama_stack.apis.agents.openai_responses import (
|
|||
ListOpenAIResponseInputItem,
|
||||
OpenAIResponseInputMessageContentText,
|
||||
OpenAIResponseInputToolFunction,
|
||||
OpenAIResponseInputToolMCP,
|
||||
OpenAIResponseInputToolWebSearch,
|
||||
OpenAIResponseMessage,
|
||||
OpenAIResponseOutputMessageContentOutputText,
|
||||
OpenAIResponseOutputMessageFunctionToolCall,
|
||||
OpenAIResponseOutputMessageMCPCall,
|
||||
OpenAIResponseOutputMessageWebSearchToolCall,
|
||||
OpenAIResponseText,
|
||||
|
|
@ -32,15 +34,16 @@ from llama_stack.apis.agents.openai_responses import (
|
|||
from llama_stack.apis.inference import (
|
||||
OpenAIAssistantMessageParam,
|
||||
OpenAIChatCompletionContentPartTextParam,
|
||||
OpenAIChatCompletionRequestWithExtraBody,
|
||||
OpenAIDeveloperMessageParam,
|
||||
OpenAIJSONSchema,
|
||||
OpenAIResponseFormatJSONObject,
|
||||
OpenAIResponseFormatJSONSchema,
|
||||
OpenAIUserMessageParam,
|
||||
)
|
||||
from llama_stack.apis.tools.tools import ToolDef, ToolGroups, ToolInvocationResult, ToolRuntime
|
||||
from llama_stack.apis.tools.tools import ListToolDefsResponse, ToolDef, ToolGroups, ToolInvocationResult, ToolRuntime
|
||||
from llama_stack.core.access_control.access_control import default_policy
|
||||
from llama_stack.core.datatypes import ResponsesStoreConfig
|
||||
from llama_stack.core.storage.datatypes import ResponsesStoreReference, SqliteSqlStoreConfig
|
||||
from llama_stack.providers.inline.agents.meta_reference.responses.openai_responses import (
|
||||
OpenAIResponsesImpl,
|
||||
)
|
||||
|
|
@ -48,7 +51,7 @@ from llama_stack.providers.utils.responses.responses_store import (
|
|||
ResponsesStore,
|
||||
_OpenAIResponseObjectWithInputAndMessages,
|
||||
)
|
||||
from llama_stack.providers.utils.sqlstore.sqlstore import SqliteSqlStoreConfig
|
||||
from llama_stack.providers.utils.sqlstore.sqlstore import register_sqlstore_backends
|
||||
from tests.unit.providers.agents.meta_reference.fixtures import load_chat_completion_fixture
|
||||
|
||||
|
||||
|
|
@ -82,9 +85,28 @@ def mock_vector_io_api():
|
|||
return vector_io_api
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_conversations_api():
|
||||
"""Mock conversations API for testing."""
|
||||
mock_api = AsyncMock()
|
||||
return mock_api
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_safety_api():
|
||||
safety_api = AsyncMock()
|
||||
return safety_api
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def openai_responses_impl(
|
||||
mock_inference_api, mock_tool_groups_api, mock_tool_runtime_api, mock_responses_store, mock_vector_io_api
|
||||
mock_inference_api,
|
||||
mock_tool_groups_api,
|
||||
mock_tool_runtime_api,
|
||||
mock_responses_store,
|
||||
mock_vector_io_api,
|
||||
mock_safety_api,
|
||||
mock_conversations_api,
|
||||
):
|
||||
return OpenAIResponsesImpl(
|
||||
inference_api=mock_inference_api,
|
||||
|
|
@ -92,6 +114,8 @@ def openai_responses_impl(
|
|||
tool_runtime_api=mock_tool_runtime_api,
|
||||
responses_store=mock_responses_store,
|
||||
vector_io_api=mock_vector_io_api,
|
||||
safety_api=mock_safety_api,
|
||||
conversations_api=mock_conversations_api,
|
||||
)
|
||||
|
||||
|
||||
|
|
@ -147,18 +171,24 @@ async def test_create_openai_response_with_string_input(openai_responses_impl, m
|
|||
chunks = [chunk async for chunk in result]
|
||||
|
||||
mock_inference_api.openai_chat_completion.assert_called_once_with(
|
||||
model=model,
|
||||
messages=[OpenAIUserMessageParam(role="user", content="What is the capital of Ireland?", name=None)],
|
||||
response_format=None,
|
||||
tools=None,
|
||||
stream=True,
|
||||
temperature=0.1,
|
||||
OpenAIChatCompletionRequestWithExtraBody(
|
||||
model=model,
|
||||
messages=[OpenAIUserMessageParam(role="user", content="What is the capital of Ireland?", name=None)],
|
||||
response_format=None,
|
||||
tools=None,
|
||||
stream=True,
|
||||
temperature=0.1,
|
||||
stream_options={
|
||||
"include_usage": True,
|
||||
},
|
||||
)
|
||||
)
|
||||
|
||||
# Should have content part events for text streaming
|
||||
# Expected: response.created, content_part.added, output_text.delta, content_part.done, response.completed
|
||||
assert len(chunks) >= 4
|
||||
# Expected: response.created, response.in_progress, content_part.added, output_text.delta, content_part.done, response.completed
|
||||
assert len(chunks) >= 5
|
||||
assert chunks[0].type == "response.created"
|
||||
assert any(chunk.type == "response.in_progress" for chunk in chunks)
|
||||
|
||||
# Check for content part events
|
||||
content_part_added_events = [c for c in chunks if c.type == "response.content_part.added"]
|
||||
|
|
@ -169,6 +199,14 @@ async def test_create_openai_response_with_string_input(openai_responses_impl, m
|
|||
assert len(content_part_done_events) >= 1, "Should have content_part.done event for text"
|
||||
assert len(text_delta_events) >= 1, "Should have text delta events"
|
||||
|
||||
added_event = content_part_added_events[0]
|
||||
done_event = content_part_done_events[0]
|
||||
assert added_event.content_index == 0
|
||||
assert done_event.content_index == 0
|
||||
assert added_event.output_index == done_event.output_index == 0
|
||||
assert added_event.item_id == done_event.item_id
|
||||
assert added_event.response_id == done_event.response_id
|
||||
|
||||
# Verify final event is completion
|
||||
assert chunks[-1].type == "response.completed"
|
||||
|
||||
|
|
@ -177,6 +215,8 @@ async def test_create_openai_response_with_string_input(openai_responses_impl, m
|
|||
assert final_response.model == model
|
||||
assert len(final_response.output) == 1
|
||||
assert isinstance(final_response.output[0], OpenAIResponseMessage)
|
||||
assert final_response.output[0].id == added_event.item_id
|
||||
assert final_response.id == added_event.response_id
|
||||
|
||||
openai_responses_impl.responses_store.store_response_object.assert_called_once()
|
||||
assert final_response.output[0].content[0].text == "Dublin"
|
||||
|
|
@ -228,13 +268,15 @@ async def test_create_openai_response_with_string_input_with_tools(openai_respon
|
|||
|
||||
# Verify
|
||||
first_call = mock_inference_api.openai_chat_completion.call_args_list[0]
|
||||
assert first_call.kwargs["messages"][0].content == "What is the capital of Ireland?"
|
||||
assert first_call.kwargs["tools"] is not None
|
||||
assert first_call.kwargs["temperature"] == 0.1
|
||||
first_params = first_call.args[0]
|
||||
assert first_params.messages[0].content == "What is the capital of Ireland?"
|
||||
assert first_params.tools is not None
|
||||
assert first_params.temperature == 0.1
|
||||
|
||||
second_call = mock_inference_api.openai_chat_completion.call_args_list[1]
|
||||
assert second_call.kwargs["messages"][-1].content == "Dublin"
|
||||
assert second_call.kwargs["temperature"] == 0.1
|
||||
second_params = second_call.args[0]
|
||||
assert second_params.messages[-1].content == "Dublin"
|
||||
assert second_params.temperature == 0.1
|
||||
|
||||
openai_responses_impl.tool_groups_api.get_tool.assert_called_once_with("web_search")
|
||||
openai_responses_impl.tool_runtime_api.invoke_tool.assert_called_once_with(
|
||||
|
|
@ -303,36 +345,42 @@ async def test_create_openai_response_with_tool_call_type_none(openai_responses_
|
|||
chunks = [chunk async for chunk in result]
|
||||
|
||||
# Verify event types
|
||||
# Should have: response.created, output_item.added, function_call_arguments.delta,
|
||||
# function_call_arguments.done, output_item.done, response.completed
|
||||
assert len(chunks) == 6
|
||||
# Should have: response.created, response.in_progress, output_item.added,
|
||||
# function_call_arguments.delta, function_call_arguments.done, output_item.done, response.completed
|
||||
assert len(chunks) == 7
|
||||
|
||||
event_types = [chunk.type for chunk in chunks]
|
||||
assert event_types == [
|
||||
"response.created",
|
||||
"response.in_progress",
|
||||
"response.output_item.added",
|
||||
"response.function_call_arguments.delta",
|
||||
"response.function_call_arguments.done",
|
||||
"response.output_item.done",
|
||||
"response.completed",
|
||||
]
|
||||
|
||||
# Verify inference API was called correctly (after iterating over result)
|
||||
first_call = mock_inference_api.openai_chat_completion.call_args_list[0]
|
||||
assert first_call.kwargs["messages"][0].content == input_text
|
||||
assert first_call.kwargs["tools"] is not None
|
||||
assert first_call.kwargs["temperature"] == 0.1
|
||||
first_params = first_call.args[0]
|
||||
assert first_params.messages[0].content == input_text
|
||||
assert first_params.tools is not None
|
||||
assert first_params.temperature == 0.1
|
||||
|
||||
# Check response.created event (should have empty output)
|
||||
assert chunks[0].type == "response.created"
|
||||
assert len(chunks[0].response.output) == 0
|
||||
|
||||
# Check streaming events
|
||||
assert chunks[1].type == "response.output_item.added"
|
||||
assert chunks[2].type == "response.function_call_arguments.delta"
|
||||
assert chunks[3].type == "response.function_call_arguments.done"
|
||||
assert chunks[4].type == "response.output_item.done"
|
||||
|
||||
# Check response.completed event (should have the tool call)
|
||||
assert chunks[5].type == "response.completed"
|
||||
assert len(chunks[5].response.output) == 1
|
||||
assert chunks[5].response.output[0].type == "function_call"
|
||||
assert chunks[5].response.output[0].name == "get_weather"
|
||||
completed_chunk = chunks[-1]
|
||||
assert completed_chunk.type == "response.completed"
|
||||
assert len(completed_chunk.response.output) == 1
|
||||
assert completed_chunk.response.output[0].type == "function_call"
|
||||
assert completed_chunk.response.output[0].name == "get_weather"
|
||||
|
||||
|
||||
async def test_create_openai_response_with_tool_call_function_arguments_none(openai_responses_impl, mock_inference_api):
|
||||
"""Test creating an OpenAI response with a tool call response that has a function that does not accept arguments, or arguments set to None when they are not mandatory."""
|
||||
# Setup
|
||||
"""Test creating an OpenAI response with tool calls that omit arguments."""
|
||||
|
||||
input_text = "What is the time right now?"
|
||||
model = "meta-llama/Llama-3.1-8B-Instruct"
|
||||
|
||||
|
|
@ -359,9 +407,22 @@ async def test_create_openai_response_with_tool_call_function_arguments_none(ope
|
|||
object="chat.completion.chunk",
|
||||
)
|
||||
|
||||
mock_inference_api.openai_chat_completion.return_value = fake_stream_toolcall()
|
||||
def assert_common_expectations(chunks) -> None:
|
||||
first_call = mock_inference_api.openai_chat_completion.call_args_list[0]
|
||||
first_params = first_call.args[0]
|
||||
assert first_params.messages[0].content == input_text
|
||||
assert first_params.tools is not None
|
||||
assert first_params.temperature == 0.1
|
||||
assert len(chunks[0].response.output) == 0
|
||||
completed_chunk = chunks[-1]
|
||||
assert completed_chunk.type == "response.completed"
|
||||
assert len(completed_chunk.response.output) == 1
|
||||
assert completed_chunk.response.output[0].type == "function_call"
|
||||
assert completed_chunk.response.output[0].name == "get_current_time"
|
||||
assert completed_chunk.response.output[0].arguments == "{}"
|
||||
|
||||
# Function does not accept arguments
|
||||
mock_inference_api.openai_chat_completion.return_value = fake_stream_toolcall()
|
||||
result = await openai_responses_impl.create_openai_response(
|
||||
input=input_text,
|
||||
model=model,
|
||||
|
|
@ -369,46 +430,23 @@ async def test_create_openai_response_with_tool_call_function_arguments_none(ope
|
|||
temperature=0.1,
|
||||
tools=[
|
||||
OpenAIResponseInputToolFunction(
|
||||
name="get_current_time",
|
||||
description="Get current time for system's timezone",
|
||||
parameters={},
|
||||
name="get_current_time", description="Get current time for system's timezone", parameters={}
|
||||
)
|
||||
],
|
||||
)
|
||||
|
||||
# Check that we got the content from our mocked tool execution result
|
||||
chunks = [chunk async for chunk in result]
|
||||
|
||||
# Verify event types
|
||||
# Should have: response.created, output_item.added, function_call_arguments.delta,
|
||||
# function_call_arguments.done, output_item.done, response.completed
|
||||
assert len(chunks) == 5
|
||||
|
||||
# Verify inference API was called correctly (after iterating over result)
|
||||
first_call = mock_inference_api.openai_chat_completion.call_args_list[0]
|
||||
assert first_call.kwargs["messages"][0].content == input_text
|
||||
assert first_call.kwargs["tools"] is not None
|
||||
assert first_call.kwargs["temperature"] == 0.1
|
||||
|
||||
# Check response.created event (should have empty output)
|
||||
assert chunks[0].type == "response.created"
|
||||
assert len(chunks[0].response.output) == 0
|
||||
|
||||
# Check streaming events
|
||||
assert chunks[1].type == "response.output_item.added"
|
||||
assert chunks[2].type == "response.function_call_arguments.done"
|
||||
assert chunks[3].type == "response.output_item.done"
|
||||
|
||||
# Check response.completed event (should have the tool call with arguments set to "{}")
|
||||
assert chunks[4].type == "response.completed"
|
||||
assert len(chunks[4].response.output) == 1
|
||||
assert chunks[4].response.output[0].type == "function_call"
|
||||
assert chunks[4].response.output[0].name == "get_current_time"
|
||||
assert chunks[4].response.output[0].arguments == "{}"
|
||||
|
||||
mock_inference_api.openai_chat_completion.return_value = fake_stream_toolcall()
|
||||
assert [chunk.type for chunk in chunks] == [
|
||||
"response.created",
|
||||
"response.in_progress",
|
||||
"response.output_item.added",
|
||||
"response.function_call_arguments.done",
|
||||
"response.output_item.done",
|
||||
"response.completed",
|
||||
]
|
||||
assert_common_expectations(chunks)
|
||||
|
||||
# Function accepts optional arguments
|
||||
mock_inference_api.openai_chat_completion.return_value = fake_stream_toolcall()
|
||||
result = await openai_responses_impl.create_openai_response(
|
||||
input=input_text,
|
||||
model=model,
|
||||
|
|
@ -418,42 +456,47 @@ async def test_create_openai_response_with_tool_call_function_arguments_none(ope
|
|||
OpenAIResponseInputToolFunction(
|
||||
name="get_current_time",
|
||||
description="Get current time for system's timezone",
|
||||
parameters={
|
||||
"timezone": "string",
|
||||
},
|
||||
parameters={"timezone": "string"},
|
||||
)
|
||||
],
|
||||
)
|
||||
|
||||
# Check that we got the content from our mocked tool execution result
|
||||
chunks = [chunk async for chunk in result]
|
||||
assert [chunk.type for chunk in chunks] == [
|
||||
"response.created",
|
||||
"response.in_progress",
|
||||
"response.output_item.added",
|
||||
"response.function_call_arguments.done",
|
||||
"response.output_item.done",
|
||||
"response.completed",
|
||||
]
|
||||
assert_common_expectations(chunks)
|
||||
|
||||
# Verify event types
|
||||
# Should have: response.created, output_item.added, function_call_arguments.delta,
|
||||
# function_call_arguments.done, output_item.done, response.completed
|
||||
assert len(chunks) == 5
|
||||
|
||||
# Verify inference API was called correctly (after iterating over result)
|
||||
first_call = mock_inference_api.openai_chat_completion.call_args_list[0]
|
||||
assert first_call.kwargs["messages"][0].content == input_text
|
||||
assert first_call.kwargs["tools"] is not None
|
||||
assert first_call.kwargs["temperature"] == 0.1
|
||||
|
||||
# Check response.created event (should have empty output)
|
||||
assert chunks[0].type == "response.created"
|
||||
assert len(chunks[0].response.output) == 0
|
||||
|
||||
# Check streaming events
|
||||
assert chunks[1].type == "response.output_item.added"
|
||||
assert chunks[2].type == "response.function_call_arguments.done"
|
||||
assert chunks[3].type == "response.output_item.done"
|
||||
|
||||
# Check response.completed event (should have the tool call with arguments set to "{}")
|
||||
assert chunks[4].type == "response.completed"
|
||||
assert len(chunks[4].response.output) == 1
|
||||
assert chunks[4].response.output[0].type == "function_call"
|
||||
assert chunks[4].response.output[0].name == "get_current_time"
|
||||
assert chunks[4].response.output[0].arguments == "{}"
|
||||
# Function accepts optional arguments with additional optional fields
|
||||
mock_inference_api.openai_chat_completion.return_value = fake_stream_toolcall()
|
||||
result = await openai_responses_impl.create_openai_response(
|
||||
input=input_text,
|
||||
model=model,
|
||||
stream=True,
|
||||
temperature=0.1,
|
||||
tools=[
|
||||
OpenAIResponseInputToolFunction(
|
||||
name="get_current_time",
|
||||
description="Get current time for system's timezone",
|
||||
parameters={"timezone": "string", "location": "string"},
|
||||
)
|
||||
],
|
||||
)
|
||||
chunks = [chunk async for chunk in result]
|
||||
assert [chunk.type for chunk in chunks] == [
|
||||
"response.created",
|
||||
"response.in_progress",
|
||||
"response.output_item.added",
|
||||
"response.function_call_arguments.done",
|
||||
"response.output_item.done",
|
||||
"response.completed",
|
||||
]
|
||||
assert_common_expectations(chunks)
|
||||
mock_inference_api.openai_chat_completion.return_value = fake_stream_toolcall()
|
||||
|
||||
|
||||
async def test_create_openai_response_with_multiple_messages(openai_responses_impl, mock_inference_api):
|
||||
|
|
@ -485,7 +528,9 @@ async def test_create_openai_response_with_multiple_messages(openai_responses_im
|
|||
|
||||
# Verify the the correct messages were sent to the inference API i.e.
|
||||
# All of the responses message were convered to the chat completion message objects
|
||||
inference_messages = mock_inference_api.openai_chat_completion.call_args_list[0].kwargs["messages"]
|
||||
call_args = mock_inference_api.openai_chat_completion.call_args_list[0]
|
||||
params = call_args.args[0]
|
||||
inference_messages = params.messages
|
||||
for i, m in enumerate(input_messages):
|
||||
if isinstance(m.content, str):
|
||||
assert inference_messages[i].content == m.content
|
||||
|
|
@ -653,7 +698,8 @@ async def test_create_openai_response_with_instructions(openai_responses_impl, m
|
|||
# Verify
|
||||
mock_inference_api.openai_chat_completion.assert_called_once()
|
||||
call_args = mock_inference_api.openai_chat_completion.call_args
|
||||
sent_messages = call_args.kwargs["messages"]
|
||||
params = call_args.args[0]
|
||||
sent_messages = params.messages
|
||||
|
||||
# Check that instructions were prepended as a system message
|
||||
assert len(sent_messages) == 2
|
||||
|
|
@ -691,7 +737,8 @@ async def test_create_openai_response_with_instructions_and_multiple_messages(
|
|||
# Verify
|
||||
mock_inference_api.openai_chat_completion.assert_called_once()
|
||||
call_args = mock_inference_api.openai_chat_completion.call_args
|
||||
sent_messages = call_args.kwargs["messages"]
|
||||
params = call_args.args[0]
|
||||
sent_messages = params.messages
|
||||
|
||||
# Check that instructions were prepended as a system message
|
||||
assert len(sent_messages) == 4 # 1 system + 3 input messages
|
||||
|
|
@ -751,7 +798,8 @@ async def test_create_openai_response_with_instructions_and_previous_response(
|
|||
# Verify
|
||||
mock_inference_api.openai_chat_completion.assert_called_once()
|
||||
call_args = mock_inference_api.openai_chat_completion.call_args
|
||||
sent_messages = call_args.kwargs["messages"]
|
||||
params = call_args.args[0]
|
||||
sent_messages = params.messages
|
||||
|
||||
# Check that instructions were prepended as a system message
|
||||
assert len(sent_messages) == 4, sent_messages
|
||||
|
|
@ -767,6 +815,69 @@ async def test_create_openai_response_with_instructions_and_previous_response(
|
|||
assert sent_messages[3].content == "Which is the largest?"
|
||||
|
||||
|
||||
async def test_create_openai_response_with_previous_response_instructions(
|
||||
openai_responses_impl, mock_responses_store, mock_inference_api
|
||||
):
|
||||
"""Test prepending instructions and previous response with instructions."""
|
||||
|
||||
input_item_message = OpenAIResponseMessage(
|
||||
id="123",
|
||||
content="Name some towns in Ireland",
|
||||
role="user",
|
||||
)
|
||||
response_output_message = OpenAIResponseMessage(
|
||||
id="123",
|
||||
content="Galway, Longford, Sligo",
|
||||
status="completed",
|
||||
role="assistant",
|
||||
)
|
||||
response = _OpenAIResponseObjectWithInputAndMessages(
|
||||
created_at=1,
|
||||
id="resp_123",
|
||||
model="fake_model",
|
||||
output=[response_output_message],
|
||||
status="completed",
|
||||
text=OpenAIResponseText(format=OpenAIResponseTextFormat(type="text")),
|
||||
input=[input_item_message],
|
||||
messages=[
|
||||
OpenAIUserMessageParam(content="Name some towns in Ireland"),
|
||||
OpenAIAssistantMessageParam(content="Galway, Longford, Sligo"),
|
||||
],
|
||||
instructions="You are a helpful assistant.",
|
||||
)
|
||||
mock_responses_store.get_response_object.return_value = response
|
||||
|
||||
model = "meta-llama/Llama-3.1-8B-Instruct"
|
||||
instructions = "You are a geography expert. Provide concise answers."
|
||||
|
||||
mock_inference_api.openai_chat_completion.return_value = fake_stream()
|
||||
|
||||
# Execute
|
||||
await openai_responses_impl.create_openai_response(
|
||||
input="Which is the largest?", model=model, instructions=instructions, previous_response_id="123"
|
||||
)
|
||||
|
||||
# Verify
|
||||
mock_inference_api.openai_chat_completion.assert_called_once()
|
||||
call_args = mock_inference_api.openai_chat_completion.call_args
|
||||
params = call_args.args[0]
|
||||
sent_messages = params.messages
|
||||
|
||||
# Check that instructions were prepended as a system message
|
||||
# and that the previous response instructions were not carried over
|
||||
assert len(sent_messages) == 4, sent_messages
|
||||
assert sent_messages[0].role == "system"
|
||||
assert sent_messages[0].content == instructions
|
||||
|
||||
# Check the rest of the messages were converted correctly
|
||||
assert sent_messages[1].role == "user"
|
||||
assert sent_messages[1].content == "Name some towns in Ireland"
|
||||
assert sent_messages[2].role == "assistant"
|
||||
assert sent_messages[2].content == "Galway, Longford, Sligo"
|
||||
assert sent_messages[3].role == "user"
|
||||
assert sent_messages[3].content == "Which is the largest?"
|
||||
|
||||
|
||||
async def test_list_openai_response_input_items_delegation(openai_responses_impl, mock_responses_store):
|
||||
"""Test that list_openai_response_input_items properly delegates to responses_store with correct parameters."""
|
||||
# Setup
|
||||
|
|
@ -807,8 +918,10 @@ async def test_responses_store_list_input_items_logic():
|
|||
|
||||
# Create mock store and response store
|
||||
mock_sql_store = AsyncMock()
|
||||
backend_name = "sql_responses_test"
|
||||
register_sqlstore_backends({backend_name: SqliteSqlStoreConfig(db_path="mock_db_path")})
|
||||
responses_store = ResponsesStore(
|
||||
ResponsesStoreConfig(sql_store_config=SqliteSqlStoreConfig(db_path="mock_db_path")), policy=default_policy()
|
||||
ResponsesStoreReference(backend=backend_name, table_name="responses"), policy=default_policy()
|
||||
)
|
||||
responses_store.sql_store = mock_sql_store
|
||||
|
||||
|
|
@ -953,6 +1066,58 @@ async def test_store_response_uses_rehydrated_input_with_previous_response(
|
|||
assert result.status == "completed"
|
||||
|
||||
|
||||
@patch("llama_stack.providers.utils.tools.mcp.list_mcp_tools")
|
||||
async def test_reuse_mcp_tool_list(
|
||||
mock_list_mcp_tools, openai_responses_impl, mock_responses_store, mock_inference_api
|
||||
):
|
||||
"""Test that mcp_list_tools can be reused where appropriate."""
|
||||
|
||||
mock_inference_api.openai_chat_completion.return_value = fake_stream()
|
||||
mock_list_mcp_tools.return_value = ListToolDefsResponse(
|
||||
data=[ToolDef(name="test_tool", description="a test tool", input_schema={}, output_schema={})]
|
||||
)
|
||||
|
||||
res1 = await openai_responses_impl.create_openai_response(
|
||||
input="What is 2+2?",
|
||||
model="meta-llama/Llama-3.1-8B-Instruct",
|
||||
store=True,
|
||||
tools=[
|
||||
OpenAIResponseInputToolFunction(name="fake", parameters=None),
|
||||
OpenAIResponseInputToolMCP(server_label="alabel", server_url="aurl"),
|
||||
],
|
||||
)
|
||||
args = mock_responses_store.store_response_object.call_args
|
||||
data = args.kwargs["response_object"].model_dump()
|
||||
data["input"] = [input_item.model_dump() for input_item in args.kwargs["input"]]
|
||||
data["messages"] = [msg.model_dump() for msg in args.kwargs["messages"]]
|
||||
stored = _OpenAIResponseObjectWithInputAndMessages(**data)
|
||||
mock_responses_store.get_response_object.return_value = stored
|
||||
|
||||
res2 = await openai_responses_impl.create_openai_response(
|
||||
previous_response_id=res1.id,
|
||||
input="Now what is 3+3?",
|
||||
model="meta-llama/Llama-3.1-8B-Instruct",
|
||||
store=True,
|
||||
tools=[
|
||||
OpenAIResponseInputToolMCP(server_label="alabel", server_url="aurl"),
|
||||
],
|
||||
)
|
||||
assert len(mock_inference_api.openai_chat_completion.call_args_list) == 2
|
||||
second_call = mock_inference_api.openai_chat_completion.call_args_list[1]
|
||||
second_params = second_call.args[0]
|
||||
tools_seen = second_params.tools
|
||||
assert len(tools_seen) == 1
|
||||
assert tools_seen[0]["function"]["name"] == "test_tool"
|
||||
assert tools_seen[0]["function"]["description"] == "a test tool"
|
||||
|
||||
assert mock_list_mcp_tools.call_count == 1
|
||||
listings = [obj for obj in res2.output if obj.type == "mcp_list_tools"]
|
||||
assert len(listings) == 1
|
||||
assert listings[0].server_label == "alabel"
|
||||
assert len(listings[0].tools) == 1
|
||||
assert listings[0].tools[0].name == "test_tool"
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"text_format, response_format",
|
||||
[
|
||||
|
|
@ -987,8 +1152,9 @@ async def test_create_openai_response_with_text_format(
|
|||
|
||||
# Verify
|
||||
first_call = mock_inference_api.openai_chat_completion.call_args_list[0]
|
||||
assert first_call.kwargs["messages"][0].content == input_text
|
||||
assert first_call.kwargs["response_format"] == response_format
|
||||
first_params = first_call.args[0]
|
||||
assert first_params.messages[0].content == input_text
|
||||
assert first_params.response_format == response_format
|
||||
|
||||
|
||||
async def test_create_openai_response_with_invalid_text_format(openai_responses_impl, mock_inference_api):
|
||||
|
|
@ -1004,3 +1170,75 @@ async def test_create_openai_response_with_invalid_text_format(openai_responses_
|
|||
model=model,
|
||||
text=OpenAIResponseText(format={"type": "invalid"}),
|
||||
)
|
||||
|
||||
|
||||
async def test_create_openai_response_with_output_types_as_input(
|
||||
openai_responses_impl, mock_inference_api, mock_responses_store
|
||||
):
|
||||
"""Test that response outputs can be used as inputs in multi-turn conversations.
|
||||
|
||||
Before adding OpenAIResponseOutput types to OpenAIResponseInput,
|
||||
creating a _OpenAIResponseObjectWithInputAndMessages with some output types
|
||||
in the input field would fail with a Pydantic ValidationError.
|
||||
|
||||
This test simulates storing a response where the input contains output message
|
||||
types (MCP calls, function calls), which happens in multi-turn conversations.
|
||||
"""
|
||||
model = "meta-llama/Llama-3.1-8B-Instruct"
|
||||
|
||||
# Mock the inference response
|
||||
mock_inference_api.openai_chat_completion.return_value = fake_stream()
|
||||
|
||||
# Create a response with store=True to trigger the storage path
|
||||
result = await openai_responses_impl.create_openai_response(
|
||||
input="What's the weather?",
|
||||
model=model,
|
||||
stream=True,
|
||||
temperature=0.1,
|
||||
store=True,
|
||||
)
|
||||
|
||||
# Consume the stream
|
||||
_ = [chunk async for chunk in result]
|
||||
|
||||
# Verify store was called
|
||||
assert mock_responses_store.store_response_object.called
|
||||
|
||||
# Get the stored data
|
||||
store_call_args = mock_responses_store.store_response_object.call_args
|
||||
stored_response = store_call_args.kwargs["response_object"]
|
||||
|
||||
# Now simulate a multi-turn conversation where outputs become inputs
|
||||
input_with_output_types = [
|
||||
OpenAIResponseMessage(role="user", content="What's the weather?", name=None),
|
||||
# These output types need to be valid OpenAIResponseInput
|
||||
OpenAIResponseOutputMessageFunctionToolCall(
|
||||
call_id="call_123",
|
||||
name="get_weather",
|
||||
arguments='{"city": "Tokyo"}',
|
||||
type="function_call",
|
||||
),
|
||||
OpenAIResponseOutputMessageMCPCall(
|
||||
id="mcp_456",
|
||||
type="mcp_call",
|
||||
server_label="weather_server",
|
||||
name="get_temperature",
|
||||
arguments='{"location": "Tokyo"}',
|
||||
output="25°C",
|
||||
),
|
||||
]
|
||||
|
||||
# This simulates storing a response in a multi-turn conversation
|
||||
# where previous outputs are included in the input.
|
||||
stored_with_outputs = _OpenAIResponseObjectWithInputAndMessages(
|
||||
id=stored_response.id,
|
||||
created_at=stored_response.created_at,
|
||||
model=stored_response.model,
|
||||
status=stored_response.status,
|
||||
output=stored_response.output,
|
||||
input=input_with_output_types, # This will trigger Pydantic validation
|
||||
messages=None,
|
||||
)
|
||||
|
||||
assert stored_with_outputs.input == input_with_output_types
|
||||
assert len(stored_with_outputs.input) == 3
|
||||
|
|
|
|||
|
|
@ -0,0 +1,249 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
|
||||
import pytest
|
||||
|
||||
from llama_stack.apis.agents.openai_responses import (
|
||||
OpenAIResponseMessage,
|
||||
OpenAIResponseObject,
|
||||
OpenAIResponseObjectStreamResponseCompleted,
|
||||
OpenAIResponseObjectStreamResponseOutputItemDone,
|
||||
OpenAIResponseOutputMessageContentOutputText,
|
||||
)
|
||||
from llama_stack.apis.common.errors import (
|
||||
ConversationNotFoundError,
|
||||
InvalidConversationIdError,
|
||||
)
|
||||
from llama_stack.apis.conversations.conversations import (
|
||||
ConversationItemList,
|
||||
)
|
||||
|
||||
# Import existing fixtures from the main responses test file
|
||||
pytest_plugins = ["tests.unit.providers.agents.meta_reference.test_openai_responses"]
|
||||
|
||||
from llama_stack.providers.inline.agents.meta_reference.responses.openai_responses import (
|
||||
OpenAIResponsesImpl,
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def responses_impl_with_conversations(
|
||||
mock_inference_api,
|
||||
mock_tool_groups_api,
|
||||
mock_tool_runtime_api,
|
||||
mock_responses_store,
|
||||
mock_vector_io_api,
|
||||
mock_conversations_api,
|
||||
mock_safety_api,
|
||||
):
|
||||
"""Create OpenAIResponsesImpl instance with conversations API."""
|
||||
return OpenAIResponsesImpl(
|
||||
inference_api=mock_inference_api,
|
||||
tool_groups_api=mock_tool_groups_api,
|
||||
tool_runtime_api=mock_tool_runtime_api,
|
||||
responses_store=mock_responses_store,
|
||||
vector_io_api=mock_vector_io_api,
|
||||
conversations_api=mock_conversations_api,
|
||||
safety_api=mock_safety_api,
|
||||
)
|
||||
|
||||
|
||||
class TestConversationValidation:
|
||||
"""Test conversation ID validation logic."""
|
||||
|
||||
async def test_nonexistent_conversation_raises_error(
|
||||
self, responses_impl_with_conversations, mock_conversations_api
|
||||
):
|
||||
"""Test that ConversationNotFoundError is raised for non-existent conversation."""
|
||||
conv_id = "conv_nonexistent"
|
||||
|
||||
# Mock conversation not found
|
||||
mock_conversations_api.list_items.side_effect = ConversationNotFoundError("conv_nonexistent")
|
||||
|
||||
with pytest.raises(ConversationNotFoundError):
|
||||
await responses_impl_with_conversations.create_openai_response(
|
||||
input="Hello", model="test-model", conversation=conv_id, stream=False
|
||||
)
|
||||
|
||||
|
||||
class TestMessageSyncing:
|
||||
"""Test message syncing to conversations."""
|
||||
|
||||
async def test_sync_response_to_conversation_simple(
|
||||
self, responses_impl_with_conversations, mock_conversations_api
|
||||
):
|
||||
"""Test syncing simple response to conversation."""
|
||||
conv_id = "conv_test123"
|
||||
input_text = "What are the 5 Ds of dodgeball?"
|
||||
|
||||
# Output items (what the model generated)
|
||||
output_items = [
|
||||
OpenAIResponseMessage(
|
||||
id="msg_response",
|
||||
content=[
|
||||
OpenAIResponseOutputMessageContentOutputText(
|
||||
text="The 5 Ds are: Dodge, Duck, Dip, Dive, and Dodge.", type="output_text", annotations=[]
|
||||
)
|
||||
],
|
||||
role="assistant",
|
||||
status="completed",
|
||||
type="message",
|
||||
)
|
||||
]
|
||||
|
||||
await responses_impl_with_conversations._sync_response_to_conversation(conv_id, input_text, output_items)
|
||||
|
||||
# should call add_items with user input and assistant response
|
||||
mock_conversations_api.add_items.assert_called_once()
|
||||
call_args = mock_conversations_api.add_items.call_args
|
||||
|
||||
assert call_args[0][0] == conv_id # conversation_id
|
||||
items = call_args[0][1] # conversation_items
|
||||
|
||||
assert len(items) == 2
|
||||
# User message
|
||||
assert items[0].type == "message"
|
||||
assert items[0].role == "user"
|
||||
assert items[0].content[0].type == "input_text"
|
||||
assert items[0].content[0].text == input_text
|
||||
|
||||
# Assistant message
|
||||
assert items[1].type == "message"
|
||||
assert items[1].role == "assistant"
|
||||
|
||||
async def test_sync_response_to_conversation_api_error(
|
||||
self, responses_impl_with_conversations, mock_conversations_api
|
||||
):
|
||||
mock_conversations_api.add_items.side_effect = Exception("API Error")
|
||||
output_items = []
|
||||
|
||||
# matching the behavior of OpenAI here
|
||||
with pytest.raises(Exception, match="API Error"):
|
||||
await responses_impl_with_conversations._sync_response_to_conversation(
|
||||
"conv_test123", "Hello", output_items
|
||||
)
|
||||
|
||||
async def test_sync_with_list_input(self, responses_impl_with_conversations, mock_conversations_api):
|
||||
"""Test syncing with list of input messages."""
|
||||
conv_id = "conv_test123"
|
||||
input_messages = [
|
||||
OpenAIResponseMessage(role="user", content=[{"type": "input_text", "text": "First message"}]),
|
||||
]
|
||||
output_items = [
|
||||
OpenAIResponseMessage(
|
||||
id="msg_response",
|
||||
content=[OpenAIResponseOutputMessageContentOutputText(text="Response", type="output_text")],
|
||||
role="assistant",
|
||||
status="completed",
|
||||
type="message",
|
||||
)
|
||||
]
|
||||
|
||||
await responses_impl_with_conversations._sync_response_to_conversation(conv_id, input_messages, output_items)
|
||||
|
||||
mock_conversations_api.add_items.assert_called_once()
|
||||
call_args = mock_conversations_api.add_items.call_args
|
||||
|
||||
items = call_args[0][1]
|
||||
# Should have input message + output message
|
||||
assert len(items) == 2
|
||||
|
||||
|
||||
class TestIntegrationWorkflow:
|
||||
"""Integration tests for the full conversation workflow."""
|
||||
|
||||
async def test_create_response_with_valid_conversation(
|
||||
self, responses_impl_with_conversations, mock_conversations_api
|
||||
):
|
||||
"""Test creating a response with a valid conversation parameter."""
|
||||
mock_conversations_api.list_items.return_value = ConversationItemList(
|
||||
data=[], first_id=None, has_more=False, last_id=None, object="list"
|
||||
)
|
||||
|
||||
async def mock_streaming_response(*args, **kwargs):
|
||||
message_item = OpenAIResponseMessage(
|
||||
id="msg_response",
|
||||
content=[
|
||||
OpenAIResponseOutputMessageContentOutputText(
|
||||
text="Test response", type="output_text", annotations=[]
|
||||
)
|
||||
],
|
||||
role="assistant",
|
||||
status="completed",
|
||||
type="message",
|
||||
)
|
||||
|
||||
# Emit output_item.done event first (needed for conversation sync)
|
||||
yield OpenAIResponseObjectStreamResponseOutputItemDone(
|
||||
response_id="resp_test123",
|
||||
item=message_item,
|
||||
output_index=0,
|
||||
sequence_number=1,
|
||||
type="response.output_item.done",
|
||||
)
|
||||
|
||||
# Then emit response.completed
|
||||
mock_response = OpenAIResponseObject(
|
||||
id="resp_test123",
|
||||
created_at=1234567890,
|
||||
model="test-model",
|
||||
object="response",
|
||||
output=[message_item],
|
||||
status="completed",
|
||||
)
|
||||
|
||||
yield OpenAIResponseObjectStreamResponseCompleted(response=mock_response, type="response.completed")
|
||||
|
||||
responses_impl_with_conversations._create_streaming_response = mock_streaming_response
|
||||
|
||||
input_text = "Hello, how are you?"
|
||||
conversation_id = "conv_test123"
|
||||
|
||||
response = await responses_impl_with_conversations.create_openai_response(
|
||||
input=input_text, model="test-model", conversation=conversation_id, stream=False
|
||||
)
|
||||
|
||||
assert response is not None
|
||||
assert response.id == "resp_test123"
|
||||
|
||||
# Note: conversation sync happens inside _create_streaming_response,
|
||||
# which we're mocking here, so we can't test it in this unit test.
|
||||
# The sync logic is tested separately in TestMessageSyncing.
|
||||
|
||||
async def test_create_response_with_invalid_conversation_id(self, responses_impl_with_conversations):
|
||||
"""Test creating a response with an invalid conversation ID."""
|
||||
with pytest.raises(InvalidConversationIdError) as exc_info:
|
||||
await responses_impl_with_conversations.create_openai_response(
|
||||
input="Hello", model="test-model", conversation="invalid_id", stream=False
|
||||
)
|
||||
|
||||
assert "Expected an ID that begins with 'conv_'" in str(exc_info.value)
|
||||
|
||||
async def test_create_response_with_nonexistent_conversation(
|
||||
self, responses_impl_with_conversations, mock_conversations_api
|
||||
):
|
||||
"""Test creating a response with a non-existent conversation."""
|
||||
mock_conversations_api.list_items.side_effect = ConversationNotFoundError("conv_nonexistent")
|
||||
|
||||
with pytest.raises(ConversationNotFoundError) as exc_info:
|
||||
await responses_impl_with_conversations.create_openai_response(
|
||||
input="Hello", model="test-model", conversation="conv_nonexistent", stream=False
|
||||
)
|
||||
|
||||
assert "not found" in str(exc_info.value)
|
||||
|
||||
async def test_conversation_and_previous_response_id(
|
||||
self, responses_impl_with_conversations, mock_conversations_api, mock_responses_store
|
||||
):
|
||||
with pytest.raises(ValueError) as exc_info:
|
||||
await responses_impl_with_conversations.create_openai_response(
|
||||
input="test", model="test", conversation="conv_123", previous_response_id="resp_123"
|
||||
)
|
||||
|
||||
assert "Mutually exclusive parameters" in str(exc_info.value)
|
||||
assert "previous_response_id" in str(exc_info.value)
|
||||
assert "conversation" in str(exc_info.value)
|
||||
|
|
@ -0,0 +1,183 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
|
||||
from llama_stack.apis.agents.openai_responses import (
|
||||
MCPListToolsTool,
|
||||
OpenAIResponseInputToolFileSearch,
|
||||
OpenAIResponseInputToolFunction,
|
||||
OpenAIResponseInputToolMCP,
|
||||
OpenAIResponseInputToolWebSearch,
|
||||
OpenAIResponseObject,
|
||||
OpenAIResponseOutputMessageMCPListTools,
|
||||
OpenAIResponseToolMCP,
|
||||
)
|
||||
from llama_stack.providers.inline.agents.meta_reference.responses.types import ToolContext
|
||||
|
||||
|
||||
class TestToolContext:
|
||||
def test_no_tools(self):
|
||||
tools = []
|
||||
context = ToolContext(tools)
|
||||
previous_response = OpenAIResponseObject(created_at=1234, id="test", model="mymodel", output=[], status="")
|
||||
context.recover_tools_from_previous_response(previous_response)
|
||||
|
||||
assert len(context.tools_to_process) == 0
|
||||
assert len(context.previous_tools) == 0
|
||||
assert len(context.previous_tool_listings) == 0
|
||||
|
||||
def test_no_previous_tools(self):
|
||||
tools = [
|
||||
OpenAIResponseInputToolFileSearch(vector_store_ids=["fake"]),
|
||||
OpenAIResponseInputToolMCP(server_label="label", server_url="url"),
|
||||
]
|
||||
context = ToolContext(tools)
|
||||
previous_response = OpenAIResponseObject(created_at=1234, id="test", model="mymodel", output=[], status="")
|
||||
context.recover_tools_from_previous_response(previous_response)
|
||||
|
||||
assert len(context.tools_to_process) == 2
|
||||
assert len(context.previous_tools) == 0
|
||||
assert len(context.previous_tool_listings) == 0
|
||||
|
||||
def test_reusable_server(self):
|
||||
tools = [
|
||||
OpenAIResponseInputToolFileSearch(vector_store_ids=["fake"]),
|
||||
OpenAIResponseInputToolMCP(server_label="alabel", server_url="aurl"),
|
||||
]
|
||||
context = ToolContext(tools)
|
||||
output = [
|
||||
OpenAIResponseOutputMessageMCPListTools(
|
||||
id="test", server_label="alabel", tools=[MCPListToolsTool(name="test_tool", input_schema={})]
|
||||
)
|
||||
]
|
||||
previous_response = OpenAIResponseObject(created_at=1234, id="test", model="fake", output=output, status="")
|
||||
previous_response.tools = [
|
||||
OpenAIResponseInputToolFileSearch(vector_store_ids=["fake"]),
|
||||
OpenAIResponseToolMCP(server_label="alabel"),
|
||||
]
|
||||
context.recover_tools_from_previous_response(previous_response)
|
||||
|
||||
assert len(context.tools_to_process) == 1
|
||||
assert context.tools_to_process[0].type == "file_search"
|
||||
assert len(context.previous_tools) == 1
|
||||
assert context.previous_tools["test_tool"].server_label == "alabel"
|
||||
assert context.previous_tools["test_tool"].server_url == "aurl"
|
||||
assert len(context.previous_tool_listings) == 1
|
||||
assert len(context.previous_tool_listings[0].tools) == 1
|
||||
assert context.previous_tool_listings[0].server_label == "alabel"
|
||||
|
||||
def test_multiple_reusable_servers(self):
|
||||
tools = [
|
||||
OpenAIResponseInputToolFunction(name="fake", parameters=None),
|
||||
OpenAIResponseInputToolMCP(server_label="anotherlabel", server_url="anotherurl"),
|
||||
OpenAIResponseInputToolWebSearch(),
|
||||
OpenAIResponseInputToolMCP(server_label="alabel", server_url="aurl"),
|
||||
]
|
||||
context = ToolContext(tools)
|
||||
output = [
|
||||
OpenAIResponseOutputMessageMCPListTools(
|
||||
id="test1", server_label="alabel", tools=[MCPListToolsTool(name="test_tool", input_schema={})]
|
||||
),
|
||||
OpenAIResponseOutputMessageMCPListTools(
|
||||
id="test2",
|
||||
server_label="anotherlabel",
|
||||
tools=[MCPListToolsTool(name="some_other_tool", input_schema={})],
|
||||
),
|
||||
]
|
||||
previous_response = OpenAIResponseObject(created_at=1234, id="test", model="fake", output=output, status="")
|
||||
previous_response.tools = [
|
||||
OpenAIResponseInputToolFunction(name="fake", parameters=None),
|
||||
OpenAIResponseToolMCP(server_label="anotherlabel", server_url="anotherurl"),
|
||||
OpenAIResponseInputToolWebSearch(type="web_search"),
|
||||
OpenAIResponseToolMCP(server_label="alabel", server_url="aurl"),
|
||||
]
|
||||
context.recover_tools_from_previous_response(previous_response)
|
||||
|
||||
assert len(context.tools_to_process) == 2
|
||||
assert context.tools_to_process[0].type == "function"
|
||||
assert context.tools_to_process[1].type == "web_search"
|
||||
assert len(context.previous_tools) == 2
|
||||
assert context.previous_tools["test_tool"].server_label == "alabel"
|
||||
assert context.previous_tools["test_tool"].server_url == "aurl"
|
||||
assert context.previous_tools["some_other_tool"].server_label == "anotherlabel"
|
||||
assert context.previous_tools["some_other_tool"].server_url == "anotherurl"
|
||||
assert len(context.previous_tool_listings) == 2
|
||||
assert len(context.previous_tool_listings[0].tools) == 1
|
||||
assert context.previous_tool_listings[0].server_label == "alabel"
|
||||
assert len(context.previous_tool_listings[1].tools) == 1
|
||||
assert context.previous_tool_listings[1].server_label == "anotherlabel"
|
||||
|
||||
def test_multiple_servers_only_one_reusable(self):
|
||||
tools = [
|
||||
OpenAIResponseInputToolFunction(name="fake", parameters=None),
|
||||
OpenAIResponseInputToolMCP(server_label="anotherlabel", server_url="anotherurl"),
|
||||
OpenAIResponseInputToolWebSearch(type="web_search"),
|
||||
OpenAIResponseInputToolMCP(server_label="alabel", server_url="aurl"),
|
||||
]
|
||||
context = ToolContext(tools)
|
||||
output = [
|
||||
OpenAIResponseOutputMessageMCPListTools(
|
||||
id="test2",
|
||||
server_label="anotherlabel",
|
||||
tools=[MCPListToolsTool(name="some_other_tool", input_schema={})],
|
||||
)
|
||||
]
|
||||
previous_response = OpenAIResponseObject(created_at=1234, id="test", model="fake", output=output, status="")
|
||||
previous_response.tools = [
|
||||
OpenAIResponseInputToolFunction(name="fake", parameters=None),
|
||||
OpenAIResponseToolMCP(server_label="anotherlabel", server_url="anotherurl"),
|
||||
OpenAIResponseInputToolWebSearch(type="web_search"),
|
||||
]
|
||||
context.recover_tools_from_previous_response(previous_response)
|
||||
|
||||
assert len(context.tools_to_process) == 3
|
||||
assert context.tools_to_process[0].type == "function"
|
||||
assert context.tools_to_process[1].type == "web_search"
|
||||
assert context.tools_to_process[2].type == "mcp"
|
||||
assert len(context.previous_tools) == 1
|
||||
assert context.previous_tools["some_other_tool"].server_label == "anotherlabel"
|
||||
assert context.previous_tools["some_other_tool"].server_url == "anotherurl"
|
||||
assert len(context.previous_tool_listings) == 1
|
||||
assert len(context.previous_tool_listings[0].tools) == 1
|
||||
assert context.previous_tool_listings[0].server_label == "anotherlabel"
|
||||
|
||||
def test_mismatched_allowed_tools(self):
|
||||
tools = [
|
||||
OpenAIResponseInputToolFunction(name="fake", parameters=None),
|
||||
OpenAIResponseInputToolMCP(server_label="anotherlabel", server_url="anotherurl"),
|
||||
OpenAIResponseInputToolWebSearch(type="web_search"),
|
||||
OpenAIResponseInputToolMCP(server_label="alabel", server_url="aurl", allowed_tools=["test_tool_2"]),
|
||||
]
|
||||
context = ToolContext(tools)
|
||||
output = [
|
||||
OpenAIResponseOutputMessageMCPListTools(
|
||||
id="test1", server_label="alabel", tools=[MCPListToolsTool(name="test_tool_1", input_schema={})]
|
||||
),
|
||||
OpenAIResponseOutputMessageMCPListTools(
|
||||
id="test2",
|
||||
server_label="anotherlabel",
|
||||
tools=[MCPListToolsTool(name="some_other_tool", input_schema={})],
|
||||
),
|
||||
]
|
||||
previous_response = OpenAIResponseObject(created_at=1234, id="test", model="fake", output=output, status="")
|
||||
previous_response.tools = [
|
||||
OpenAIResponseInputToolFunction(name="fake", parameters=None),
|
||||
OpenAIResponseToolMCP(server_label="anotherlabel", server_url="anotherurl"),
|
||||
OpenAIResponseInputToolWebSearch(type="web_search"),
|
||||
OpenAIResponseToolMCP(server_label="alabel", server_url="aurl"),
|
||||
]
|
||||
context.recover_tools_from_previous_response(previous_response)
|
||||
|
||||
assert len(context.tools_to_process) == 3
|
||||
assert context.tools_to_process[0].type == "function"
|
||||
assert context.tools_to_process[1].type == "web_search"
|
||||
assert context.tools_to_process[2].type == "mcp"
|
||||
assert len(context.previous_tools) == 1
|
||||
assert context.previous_tools["some_other_tool"].server_label == "anotherlabel"
|
||||
assert context.previous_tools["some_other_tool"].server_url == "anotherurl"
|
||||
assert len(context.previous_tool_listings) == 1
|
||||
assert len(context.previous_tool_listings[0].tools) == 1
|
||||
assert context.previous_tool_listings[0].server_label == "anotherlabel"
|
||||
|
|
@ -0,0 +1,155 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from unittest.mock import AsyncMock
|
||||
|
||||
import pytest
|
||||
|
||||
from llama_stack.apis.agents.agents import ResponseGuardrailSpec
|
||||
from llama_stack.apis.safety import ModerationObject, ModerationObjectResults
|
||||
from llama_stack.providers.inline.agents.meta_reference.responses.openai_responses import (
|
||||
OpenAIResponsesImpl,
|
||||
)
|
||||
from llama_stack.providers.inline.agents.meta_reference.responses.utils import (
|
||||
extract_guardrail_ids,
|
||||
run_guardrails,
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_apis():
|
||||
"""Create mock APIs for testing."""
|
||||
return {
|
||||
"inference_api": AsyncMock(),
|
||||
"tool_groups_api": AsyncMock(),
|
||||
"tool_runtime_api": AsyncMock(),
|
||||
"responses_store": AsyncMock(),
|
||||
"vector_io_api": AsyncMock(),
|
||||
"conversations_api": AsyncMock(),
|
||||
"safety_api": AsyncMock(),
|
||||
}
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def responses_impl(mock_apis):
|
||||
"""Create OpenAIResponsesImpl instance with mocked dependencies."""
|
||||
return OpenAIResponsesImpl(**mock_apis)
|
||||
|
||||
|
||||
def test_extract_guardrail_ids_from_strings(responses_impl):
|
||||
"""Test extraction from simple string guardrail IDs."""
|
||||
guardrails = ["llama-guard", "content-filter", "nsfw-detector"]
|
||||
result = extract_guardrail_ids(guardrails)
|
||||
assert result == ["llama-guard", "content-filter", "nsfw-detector"]
|
||||
|
||||
|
||||
def test_extract_guardrail_ids_from_objects(responses_impl):
|
||||
"""Test extraction from ResponseGuardrailSpec objects."""
|
||||
guardrails = [
|
||||
ResponseGuardrailSpec(type="llama-guard"),
|
||||
ResponseGuardrailSpec(type="content-filter"),
|
||||
]
|
||||
result = extract_guardrail_ids(guardrails)
|
||||
assert result == ["llama-guard", "content-filter"]
|
||||
|
||||
|
||||
def test_extract_guardrail_ids_mixed_formats(responses_impl):
|
||||
"""Test extraction from mixed string and object formats."""
|
||||
guardrails = [
|
||||
"llama-guard",
|
||||
ResponseGuardrailSpec(type="content-filter"),
|
||||
"nsfw-detector",
|
||||
]
|
||||
result = extract_guardrail_ids(guardrails)
|
||||
assert result == ["llama-guard", "content-filter", "nsfw-detector"]
|
||||
|
||||
|
||||
def test_extract_guardrail_ids_none_input(responses_impl):
|
||||
"""Test extraction with None input."""
|
||||
result = extract_guardrail_ids(None)
|
||||
assert result == []
|
||||
|
||||
|
||||
def test_extract_guardrail_ids_empty_list(responses_impl):
|
||||
"""Test extraction with empty list."""
|
||||
result = extract_guardrail_ids([])
|
||||
assert result == []
|
||||
|
||||
|
||||
def test_extract_guardrail_ids_unknown_format(responses_impl):
|
||||
"""Test extraction with unknown guardrail format raises ValueError."""
|
||||
# Create an object that's neither string nor ResponseGuardrailSpec
|
||||
unknown_object = {"invalid": "format"} # Plain dict, not ResponseGuardrailSpec
|
||||
guardrails = ["valid-guardrail", unknown_object, "another-guardrail"]
|
||||
with pytest.raises(ValueError, match="Unknown guardrail format.*expected str or ResponseGuardrailSpec"):
|
||||
extract_guardrail_ids(guardrails)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_safety_api():
|
||||
"""Create mock safety API for guardrails testing."""
|
||||
safety_api = AsyncMock()
|
||||
# Mock the routing table and shields list for guardrails lookup
|
||||
safety_api.routing_table = AsyncMock()
|
||||
shield = AsyncMock()
|
||||
shield.identifier = "llama-guard"
|
||||
shield.provider_resource_id = "llama-guard-model"
|
||||
safety_api.routing_table.list_shields.return_value = AsyncMock(data=[shield])
|
||||
return safety_api
|
||||
|
||||
|
||||
async def test_run_guardrails_no_violation(mock_safety_api):
|
||||
"""Test guardrails validation with no violations."""
|
||||
text = "Hello world"
|
||||
guardrail_ids = ["llama-guard"]
|
||||
|
||||
# Mock moderation to return non-flagged content
|
||||
unflagged_result = ModerationObjectResults(flagged=False, categories={"violence": False})
|
||||
mock_moderation_object = ModerationObject(id="test-mod-id", model="llama-guard-model", results=[unflagged_result])
|
||||
mock_safety_api.run_moderation.return_value = mock_moderation_object
|
||||
|
||||
result = await run_guardrails(mock_safety_api, text, guardrail_ids)
|
||||
|
||||
assert result is None
|
||||
# Verify run_moderation was called with the correct model
|
||||
mock_safety_api.run_moderation.assert_called_once()
|
||||
call_args = mock_safety_api.run_moderation.call_args
|
||||
assert call_args[1]["model"] == "llama-guard-model"
|
||||
|
||||
|
||||
async def test_run_guardrails_with_violation(mock_safety_api):
|
||||
"""Test guardrails validation with safety violation."""
|
||||
text = "Harmful content"
|
||||
guardrail_ids = ["llama-guard"]
|
||||
|
||||
# Mock moderation to return flagged content
|
||||
flagged_result = ModerationObjectResults(
|
||||
flagged=True,
|
||||
categories={"violence": True},
|
||||
user_message="Content flagged by moderation",
|
||||
metadata={"violation_type": ["S1"]},
|
||||
)
|
||||
mock_moderation_object = ModerationObject(id="test-mod-id", model="llama-guard-model", results=[flagged_result])
|
||||
mock_safety_api.run_moderation.return_value = mock_moderation_object
|
||||
|
||||
result = await run_guardrails(mock_safety_api, text, guardrail_ids)
|
||||
|
||||
assert result == "Content flagged by moderation (flagged for: violence) (violation type: S1)"
|
||||
|
||||
|
||||
async def test_run_guardrails_empty_inputs(mock_safety_api):
|
||||
"""Test guardrails validation with empty inputs."""
|
||||
# Test empty guardrail_ids
|
||||
result = await run_guardrails(mock_safety_api, "test", [])
|
||||
assert result is None
|
||||
|
||||
# Test empty text
|
||||
result = await run_guardrails(mock_safety_api, "", ["llama-guard"])
|
||||
assert result is None
|
||||
|
||||
# Test both empty
|
||||
result = await run_guardrails(mock_safety_api, "", [])
|
||||
assert result is None
|
||||
|
|
@ -12,10 +12,10 @@ from unittest.mock import AsyncMock
|
|||
|
||||
import pytest
|
||||
|
||||
from llama_stack.core.storage.datatypes import KVStoreReference, SqliteKVStoreConfig
|
||||
from llama_stack.providers.inline.batches.reference.batches import ReferenceBatchesImpl
|
||||
from llama_stack.providers.inline.batches.reference.config import ReferenceBatchesImplConfig
|
||||
from llama_stack.providers.utils.kvstore import kvstore_impl
|
||||
from llama_stack.providers.utils.kvstore.config import SqliteKVStoreConfig
|
||||
from llama_stack.providers.utils.kvstore import kvstore_impl, register_kvstore_backends
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
|
|
@ -23,8 +23,10 @@ async def provider():
|
|||
"""Create a test provider instance with temporary database."""
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
db_path = Path(tmpdir) / "test_batches.db"
|
||||
backend_name = "kv_batches_test"
|
||||
kvstore_config = SqliteKVStoreConfig(db_path=str(db_path))
|
||||
config = ReferenceBatchesImplConfig(kvstore=kvstore_config)
|
||||
register_kvstore_backends({backend_name: kvstore_config})
|
||||
config = ReferenceBatchesImplConfig(kvstore=KVStoreReference(backend=backend_name, namespace="batches"))
|
||||
|
||||
# Create kvstore and mock APIs
|
||||
kvstore = await kvstore_impl(config.kvstore)
|
||||
|
|
|
|||
|
|
@ -213,7 +213,6 @@ class TestReferenceBatchesImpl:
|
|||
@pytest.mark.parametrize(
|
||||
"endpoint",
|
||||
[
|
||||
"/v1/embeddings",
|
||||
"/v1/invalid/endpoint",
|
||||
"",
|
||||
],
|
||||
|
|
@ -765,3 +764,12 @@ class TestReferenceBatchesImpl:
|
|||
await asyncio.sleep(0.042) # let tasks start
|
||||
|
||||
assert active_batches == 2, f"Expected 2 active batches, got {active_batches}"
|
||||
|
||||
async def test_create_batch_embeddings_endpoint(self, provider):
|
||||
"""Test that batch creation succeeds with embeddings endpoint."""
|
||||
batch = await provider.create_batch(
|
||||
input_file_id="file_123",
|
||||
endpoint="/v1/embeddings",
|
||||
completion_window="24h",
|
||||
)
|
||||
assert batch.endpoint == "/v1/embeddings"
|
||||
|
|
|
|||
|
|
@ -8,8 +8,9 @@ import boto3
|
|||
import pytest
|
||||
from moto import mock_aws
|
||||
|
||||
from llama_stack.core.storage.datatypes import SqliteSqlStoreConfig, SqlStoreReference
|
||||
from llama_stack.providers.remote.files.s3 import S3FilesImplConfig, get_adapter_impl
|
||||
from llama_stack.providers.utils.sqlstore.sqlstore import SqliteSqlStoreConfig
|
||||
from llama_stack.providers.utils.sqlstore.sqlstore import register_sqlstore_backends
|
||||
|
||||
|
||||
class MockUploadFile:
|
||||
|
|
@ -38,11 +39,13 @@ def sample_text_file2():
|
|||
def s3_config(tmp_path):
|
||||
db_path = tmp_path / "s3_files_metadata.db"
|
||||
|
||||
backend_name = f"sql_s3_{tmp_path.name}"
|
||||
register_sqlstore_backends({backend_name: SqliteSqlStoreConfig(db_path=db_path.as_posix())})
|
||||
return S3FilesImplConfig(
|
||||
bucket_name=f"test-bucket-{tmp_path.name}",
|
||||
region="not-a-region",
|
||||
auto_create_bucket=True,
|
||||
metadata_store=SqliteSqlStoreConfig(db_path=db_path.as_posix()),
|
||||
metadata_store=SqlStoreReference(backend=backend_name, table_name="s3_files_metadata"),
|
||||
)
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -15,16 +15,16 @@ from llama_stack.providers.utils.inference.litellm_openai_mixin import LiteLLMOp
|
|||
|
||||
|
||||
# Test fixtures and helper classes
|
||||
class TestConfig(BaseModel):
|
||||
class FakeConfig(BaseModel):
|
||||
api_key: str | None = Field(default=None)
|
||||
|
||||
|
||||
class TestProviderDataValidator(BaseModel):
|
||||
class FakeProviderDataValidator(BaseModel):
|
||||
test_api_key: str | None = Field(default=None)
|
||||
|
||||
|
||||
class TestLiteLLMAdapter(LiteLLMOpenAIMixin):
|
||||
def __init__(self, config: TestConfig):
|
||||
class FakeLiteLLMAdapter(LiteLLMOpenAIMixin):
|
||||
def __init__(self, config: FakeConfig):
|
||||
super().__init__(
|
||||
litellm_provider_name="test",
|
||||
api_key_from_config=config.api_key,
|
||||
|
|
@ -36,11 +36,11 @@ class TestLiteLLMAdapter(LiteLLMOpenAIMixin):
|
|||
@pytest.fixture
|
||||
def adapter_with_config_key():
|
||||
"""Fixture to create adapter with API key in config"""
|
||||
config = TestConfig(api_key="config-api-key")
|
||||
adapter = TestLiteLLMAdapter(config)
|
||||
config = FakeConfig(api_key="config-api-key")
|
||||
adapter = FakeLiteLLMAdapter(config)
|
||||
adapter.__provider_spec__ = MagicMock()
|
||||
adapter.__provider_spec__.provider_data_validator = (
|
||||
"tests.unit.providers.inference.test_litellm_openai_mixin.TestProviderDataValidator"
|
||||
"tests.unit.providers.inference.test_litellm_openai_mixin.FakeProviderDataValidator"
|
||||
)
|
||||
return adapter
|
||||
|
||||
|
|
@ -48,11 +48,11 @@ def adapter_with_config_key():
|
|||
@pytest.fixture
|
||||
def adapter_without_config_key():
|
||||
"""Fixture to create adapter without API key in config"""
|
||||
config = TestConfig(api_key=None)
|
||||
adapter = TestLiteLLMAdapter(config)
|
||||
config = FakeConfig(api_key=None)
|
||||
adapter = FakeLiteLLMAdapter(config)
|
||||
adapter.__provider_spec__ = MagicMock()
|
||||
adapter.__provider_spec__.provider_data_validator = (
|
||||
"tests.unit.providers.inference.test_litellm_openai_mixin.TestProviderDataValidator"
|
||||
"tests.unit.providers.inference.test_litellm_openai_mixin.FakeProviderDataValidator"
|
||||
)
|
||||
return adapter
|
||||
|
||||
|
|
|
|||
|
|
@ -13,10 +13,16 @@ import pytest
|
|||
from llama_stack.apis.inference import (
|
||||
OpenAIAssistantMessageParam,
|
||||
OpenAIChatCompletion,
|
||||
OpenAIChatCompletionRequestWithExtraBody,
|
||||
OpenAIChoice,
|
||||
OpenAICompletion,
|
||||
OpenAICompletionChoice,
|
||||
OpenAICompletionRequestWithExtraBody,
|
||||
ToolChoice,
|
||||
)
|
||||
from llama_stack.apis.models import Model
|
||||
from llama_stack.core.routers.inference import InferenceRouter
|
||||
from llama_stack.core.routing_tables.models import ModelsRoutingTable
|
||||
from llama_stack.providers.datatypes import HealthStatus
|
||||
from llama_stack.providers.remote.inference.vllm.config import VLLMInferenceAdapterConfig
|
||||
from llama_stack.providers.remote.inference.vllm.vllm import VLLMInferenceAdapter
|
||||
|
|
@ -56,13 +62,14 @@ async def test_old_vllm_tool_choice(vllm_inference_adapter):
|
|||
mock_client_property.return_value = mock_client
|
||||
|
||||
# No tools but auto tool choice
|
||||
await vllm_inference_adapter.openai_chat_completion(
|
||||
"mock-model",
|
||||
[],
|
||||
params = OpenAIChatCompletionRequestWithExtraBody(
|
||||
model="mock-model",
|
||||
messages=[{"role": "user", "content": "test"}],
|
||||
stream=False,
|
||||
tools=None,
|
||||
tool_choice=ToolChoice.auto.value,
|
||||
)
|
||||
await vllm_inference_adapter.openai_chat_completion(params)
|
||||
mock_client.chat.completions.create.assert_called()
|
||||
call_args = mock_client.chat.completions.create.call_args
|
||||
# Ensure tool_choice gets converted to none for older vLLM versions
|
||||
|
|
@ -171,9 +178,12 @@ async def test_openai_chat_completion_is_async(vllm_inference_adapter):
|
|||
)
|
||||
|
||||
async def do_inference():
|
||||
await vllm_inference_adapter.openai_chat_completion(
|
||||
"mock-model", messages=["one fish", "two fish"], stream=False
|
||||
params = OpenAIChatCompletionRequestWithExtraBody(
|
||||
model="mock-model",
|
||||
messages=[{"role": "user", "content": "one fish two fish"}],
|
||||
stream=False,
|
||||
)
|
||||
await vllm_inference_adapter.openai_chat_completion(params)
|
||||
|
||||
with patch.object(VLLMInferenceAdapter, "client", new_callable=PropertyMock) as mock_create_client:
|
||||
mock_client = MagicMock()
|
||||
|
|
@ -186,3 +196,148 @@ async def test_openai_chat_completion_is_async(vllm_inference_adapter):
|
|||
|
||||
assert mock_create_client.call_count == 4 # no cheating
|
||||
assert total_time < (sleep_time * 2), f"Total time taken: {total_time}s exceeded expected max"
|
||||
|
||||
|
||||
async def test_vllm_completion_extra_body():
|
||||
"""
|
||||
Test that vLLM-specific guided_choice and prompt_logprobs parameters are correctly forwarded
|
||||
via extra_body to the underlying OpenAI client through the InferenceRouter.
|
||||
"""
|
||||
# Set up the vLLM adapter
|
||||
config = VLLMInferenceAdapterConfig(url="http://mocked.localhost:12345")
|
||||
vllm_adapter = VLLMInferenceAdapter(config=config)
|
||||
vllm_adapter.__provider_id__ = "vllm"
|
||||
await vllm_adapter.initialize()
|
||||
|
||||
# Create a mock model store
|
||||
mock_model_store = AsyncMock()
|
||||
mock_model = Model(identifier="mock-model", provider_resource_id="mock-model", provider_id="vllm")
|
||||
mock_model_store.get_model.return_value = mock_model
|
||||
mock_model_store.has_model.return_value = True
|
||||
|
||||
# Create a mock dist_registry
|
||||
mock_dist_registry = MagicMock()
|
||||
mock_dist_registry.get = AsyncMock(return_value=mock_model)
|
||||
mock_dist_registry.set = AsyncMock()
|
||||
|
||||
# Set up the routing table
|
||||
routing_table = ModelsRoutingTable(
|
||||
impls_by_provider_id={"vllm": vllm_adapter},
|
||||
dist_registry=mock_dist_registry,
|
||||
policy=[],
|
||||
)
|
||||
# Inject the model store into the adapter
|
||||
vllm_adapter.model_store = routing_table
|
||||
|
||||
# Create the InferenceRouter
|
||||
router = InferenceRouter(routing_table=routing_table)
|
||||
|
||||
# Patch the OpenAI client
|
||||
with patch.object(VLLMInferenceAdapter, "client", new_callable=PropertyMock) as mock_client_property:
|
||||
mock_client = MagicMock()
|
||||
mock_client.completions.create = AsyncMock(
|
||||
return_value=OpenAICompletion(
|
||||
id="cmpl-abc123",
|
||||
created=1,
|
||||
model="mock-model",
|
||||
choices=[
|
||||
OpenAICompletionChoice(
|
||||
text="joy",
|
||||
finish_reason="stop",
|
||||
index=0,
|
||||
)
|
||||
],
|
||||
)
|
||||
)
|
||||
mock_client_property.return_value = mock_client
|
||||
|
||||
# Test with guided_choice and prompt_logprobs as extra fields
|
||||
params = OpenAICompletionRequestWithExtraBody(
|
||||
model="mock-model",
|
||||
prompt="I am feeling happy",
|
||||
stream=False,
|
||||
guided_choice=["joy", "sadness"],
|
||||
prompt_logprobs=5,
|
||||
)
|
||||
await router.openai_completion(params)
|
||||
|
||||
# Verify that the client was called with extra_body containing both parameters
|
||||
mock_client.completions.create.assert_called_once()
|
||||
call_kwargs = mock_client.completions.create.call_args.kwargs
|
||||
assert "extra_body" in call_kwargs
|
||||
assert "guided_choice" in call_kwargs["extra_body"]
|
||||
assert call_kwargs["extra_body"]["guided_choice"] == ["joy", "sadness"]
|
||||
assert "prompt_logprobs" in call_kwargs["extra_body"]
|
||||
assert call_kwargs["extra_body"]["prompt_logprobs"] == 5
|
||||
|
||||
|
||||
async def test_vllm_chat_completion_extra_body():
|
||||
"""
|
||||
Test that vLLM-specific parameters (e.g., chat_template_kwargs) are correctly forwarded
|
||||
via extra_body to the underlying OpenAI client through the InferenceRouter for chat completion.
|
||||
"""
|
||||
# Set up the vLLM adapter
|
||||
config = VLLMInferenceAdapterConfig(url="http://mocked.localhost:12345")
|
||||
vllm_adapter = VLLMInferenceAdapter(config=config)
|
||||
vllm_adapter.__provider_id__ = "vllm"
|
||||
await vllm_adapter.initialize()
|
||||
|
||||
# Create a mock model store
|
||||
mock_model_store = AsyncMock()
|
||||
mock_model = Model(identifier="mock-model", provider_resource_id="mock-model", provider_id="vllm")
|
||||
mock_model_store.get_model.return_value = mock_model
|
||||
mock_model_store.has_model.return_value = True
|
||||
|
||||
# Create a mock dist_registry
|
||||
mock_dist_registry = MagicMock()
|
||||
mock_dist_registry.get = AsyncMock(return_value=mock_model)
|
||||
mock_dist_registry.set = AsyncMock()
|
||||
|
||||
# Set up the routing table
|
||||
routing_table = ModelsRoutingTable(
|
||||
impls_by_provider_id={"vllm": vllm_adapter},
|
||||
dist_registry=mock_dist_registry,
|
||||
policy=[],
|
||||
)
|
||||
# Inject the model store into the adapter
|
||||
vllm_adapter.model_store = routing_table
|
||||
|
||||
# Create the InferenceRouter
|
||||
router = InferenceRouter(routing_table=routing_table)
|
||||
|
||||
# Patch the OpenAI client
|
||||
with patch.object(VLLMInferenceAdapter, "client", new_callable=PropertyMock) as mock_client_property:
|
||||
mock_client = MagicMock()
|
||||
mock_client.chat.completions.create = AsyncMock(
|
||||
return_value=OpenAIChatCompletion(
|
||||
id="chatcmpl-abc123",
|
||||
created=1,
|
||||
model="mock-model",
|
||||
choices=[
|
||||
OpenAIChoice(
|
||||
message=OpenAIAssistantMessageParam(
|
||||
content="test response",
|
||||
),
|
||||
finish_reason="stop",
|
||||
index=0,
|
||||
)
|
||||
],
|
||||
)
|
||||
)
|
||||
mock_client_property.return_value = mock_client
|
||||
|
||||
# Test with chat_template_kwargs as extra field
|
||||
params = OpenAIChatCompletionRequestWithExtraBody(
|
||||
model="mock-model",
|
||||
messages=[{"role": "user", "content": "test"}],
|
||||
stream=False,
|
||||
chat_template_kwargs={"thinking": True},
|
||||
)
|
||||
await router.openai_chat_completion(params)
|
||||
|
||||
# Verify that the client was called with extra_body containing chat_template_kwargs
|
||||
mock_client.chat.completions.create.assert_called_once()
|
||||
call_kwargs = mock_client.chat.completions.create.call_args.kwargs
|
||||
assert "extra_body" in call_kwargs
|
||||
assert "chat_template_kwargs" in call_kwargs["extra_body"]
|
||||
assert call_kwargs["extra_body"]["chat_template_kwargs"] == {"thinking": True}
|
||||
|
|
|
|||
|
|
@ -4,10 +4,45 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from unittest.mock import AsyncMock
|
||||
|
||||
import pytest
|
||||
|
||||
from llama_stack.apis.tools import ToolDef
|
||||
from llama_stack.providers.inline.agents.meta_reference.responses.streaming import (
|
||||
convert_tooldef_to_chat_tool,
|
||||
)
|
||||
from llama_stack.providers.inline.agents.meta_reference.responses.types import ChatCompletionContext
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_safety_api():
|
||||
safety_api = AsyncMock()
|
||||
# Mock the routing table and shields list for guardrails lookup
|
||||
safety_api.routing_table = AsyncMock()
|
||||
shield = AsyncMock()
|
||||
shield.identifier = "llama-guard"
|
||||
shield.provider_resource_id = "llama-guard-model"
|
||||
safety_api.routing_table.list_shields.return_value = AsyncMock(data=[shield])
|
||||
# Mock run_moderation to return non-flagged result by default
|
||||
safety_api.run_moderation.return_value = AsyncMock(flagged=False)
|
||||
return safety_api
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_inference_api():
|
||||
inference_api = AsyncMock()
|
||||
return inference_api
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_context():
|
||||
context = AsyncMock(spec=ChatCompletionContext)
|
||||
# Add required attributes that StreamingResponseOrchestrator expects
|
||||
context.tool_context = AsyncMock()
|
||||
context.tool_context.previous_tools = {}
|
||||
context.messages = []
|
||||
return context
|
||||
|
||||
|
||||
def test_convert_tooldef_to_chat_tool_preserves_items_field():
|
||||
|
|
|
|||
|
|
@ -19,7 +19,7 @@ from llama_stack.providers.remote.safety.nvidia.config import NVIDIASafetyConfig
|
|||
from llama_stack.providers.remote.safety.nvidia.nvidia import NVIDIASafetyAdapter
|
||||
|
||||
|
||||
class TestNVIDIASafetyAdapter(NVIDIASafetyAdapter):
|
||||
class FakeNVIDIASafetyAdapter(NVIDIASafetyAdapter):
|
||||
"""Test implementation that provides the required shield_store."""
|
||||
|
||||
def __init__(self, config: NVIDIASafetyConfig, shield_store):
|
||||
|
|
@ -41,7 +41,7 @@ def nvidia_adapter():
|
|||
shield_store = AsyncMock()
|
||||
shield_store.get_shield = AsyncMock()
|
||||
|
||||
adapter = TestNVIDIASafetyAdapter(config=config, shield_store=shield_store)
|
||||
adapter = FakeNVIDIASafetyAdapter(config=config, shield_store=shield_store)
|
||||
|
||||
return adapter
|
||||
|
||||
|
|
|
|||
|
|
@ -12,7 +12,7 @@ from unittest.mock import AsyncMock, MagicMock, Mock, PropertyMock, patch
|
|||
import pytest
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from llama_stack.apis.inference import Model, OpenAIUserMessageParam
|
||||
from llama_stack.apis.inference import Model, OpenAIChatCompletionRequestWithExtraBody, OpenAIUserMessageParam
|
||||
from llama_stack.apis.models import ModelType
|
||||
from llama_stack.core.request_headers import request_provider_data_context
|
||||
from llama_stack.providers.utils.inference.model_registry import RemoteInferenceProviderConfig
|
||||
|
|
@ -23,10 +23,10 @@ class OpenAIMixinImpl(OpenAIMixin):
|
|||
__provider_id__: str = "test-provider"
|
||||
|
||||
def get_api_key(self) -> str:
|
||||
raise NotImplementedError("This method should be mocked in tests")
|
||||
return "test-api-key"
|
||||
|
||||
def get_base_url(self) -> str:
|
||||
raise NotImplementedError("This method should be mocked in tests")
|
||||
return "http://test-base-url"
|
||||
|
||||
|
||||
class OpenAIMixinWithEmbeddingsImpl(OpenAIMixinImpl):
|
||||
|
|
@ -38,6 +38,28 @@ class OpenAIMixinWithEmbeddingsImpl(OpenAIMixinImpl):
|
|||
}
|
||||
|
||||
|
||||
class OpenAIMixinWithCustomModelConstruction(OpenAIMixinImpl):
|
||||
"""Test implementation that uses construct_model_from_identifier to add rerank models"""
|
||||
|
||||
embedding_model_metadata: dict[str, dict[str, int]] = {
|
||||
"text-embedding-3-small": {"embedding_dimension": 1536, "context_length": 8192},
|
||||
"text-embedding-ada-002": {"embedding_dimension": 1536, "context_length": 8192},
|
||||
}
|
||||
|
||||
# Adds rerank models via construct_model_from_identifier
|
||||
rerank_model_ids: set[str] = {"rerank-model-1", "rerank-model-2"}
|
||||
|
||||
def construct_model_from_identifier(self, identifier: str) -> Model:
|
||||
if identifier in self.rerank_model_ids:
|
||||
return Model(
|
||||
provider_id=self.__provider_id__, # type: ignore[attr-defined]
|
||||
provider_resource_id=identifier,
|
||||
identifier=identifier,
|
||||
model_type=ModelType.rerank,
|
||||
)
|
||||
return super().construct_model_from_identifier(identifier)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mixin():
|
||||
"""Create a test instance of OpenAIMixin with mocked model_store"""
|
||||
|
|
@ -62,6 +84,13 @@ def mixin_with_embeddings():
|
|||
return OpenAIMixinWithEmbeddingsImpl(config=config)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mixin_with_custom_model_construction():
|
||||
"""Create a test instance using custom construct_model_from_identifier"""
|
||||
config = RemoteInferenceProviderConfig()
|
||||
return OpenAIMixinWithCustomModelConstruction(config=config)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_models():
|
||||
"""Create multiple mock OpenAI model objects"""
|
||||
|
|
@ -113,6 +142,19 @@ def mock_client_context():
|
|||
return _mock_client_context
|
||||
|
||||
|
||||
def _assert_models_match_expected(actual_models, expected_models):
|
||||
"""Verify the models match expected attributes.
|
||||
|
||||
Args:
|
||||
actual_models: List of models to verify
|
||||
expected_models: Mapping of model identifier to expected attribute values
|
||||
"""
|
||||
for identifier, expected_attrs in expected_models.items():
|
||||
model = next(m for m in actual_models if m.identifier == identifier)
|
||||
for attr_name, expected_value in expected_attrs.items():
|
||||
assert getattr(model, attr_name) == expected_value
|
||||
|
||||
|
||||
class TestOpenAIMixinListModels:
|
||||
"""Test cases for the list_models method"""
|
||||
|
||||
|
|
@ -205,7 +247,7 @@ class TestOpenAIMixinCheckModelAvailability:
|
|||
assert await mixin.check_model_availability("pre-registered-model")
|
||||
# Should not call the provider's list_models since model was found in store
|
||||
mock_client_with_models.models.list.assert_not_called()
|
||||
mock_model_store.has_model.assert_called_once_with("pre-registered-model")
|
||||
mock_model_store.has_model.assert_called_once_with("test-provider/pre-registered-model")
|
||||
|
||||
async def test_check_model_availability_fallback_to_provider_when_not_in_store(
|
||||
self, mixin, mock_client_with_models, mock_client_context
|
||||
|
|
@ -222,7 +264,7 @@ class TestOpenAIMixinCheckModelAvailability:
|
|||
assert await mixin.check_model_availability("some-mock-model-id")
|
||||
# Should call the provider's list_models since model was not found in store
|
||||
mock_client_with_models.models.list.assert_called_once()
|
||||
mock_model_store.has_model.assert_called_once_with("some-mock-model-id")
|
||||
mock_model_store.has_model.assert_called_once_with("test-provider/some-mock-model-id")
|
||||
|
||||
|
||||
class TestOpenAIMixinCacheBehavior:
|
||||
|
|
@ -271,7 +313,8 @@ class TestOpenAIMixinImagePreprocessing:
|
|||
with patch("llama_stack.providers.utils.inference.openai_mixin.localize_image_content") as mock_localize:
|
||||
mock_localize.return_value = (b"fake_image_data", "jpeg")
|
||||
|
||||
await mixin.openai_chat_completion(model="test-model", messages=[message])
|
||||
params = OpenAIChatCompletionRequestWithExtraBody(model="test-model", messages=[message])
|
||||
await mixin.openai_chat_completion(params)
|
||||
|
||||
mock_localize.assert_called_once_with("http://example.com/image.jpg")
|
||||
|
||||
|
|
@ -303,7 +346,8 @@ class TestOpenAIMixinImagePreprocessing:
|
|||
|
||||
with patch.object(type(mixin), "client", new_callable=PropertyMock, return_value=mock_client):
|
||||
with patch("llama_stack.providers.utils.inference.openai_mixin.localize_image_content") as mock_localize:
|
||||
await mixin.openai_chat_completion(model="test-model", messages=[message])
|
||||
params = OpenAIChatCompletionRequestWithExtraBody(model="test-model", messages=[message])
|
||||
await mixin.openai_chat_completion(params)
|
||||
|
||||
mock_localize.assert_not_called()
|
||||
|
||||
|
|
@ -340,21 +384,71 @@ class TestOpenAIMixinEmbeddingModelMetadata:
|
|||
assert result is not None
|
||||
assert len(result) == 2
|
||||
|
||||
# Find the models in the result
|
||||
embedding_model = next(m for m in result if m.identifier == "text-embedding-3-small")
|
||||
llm_model = next(m for m in result if m.identifier == "gpt-4")
|
||||
expected_models = {
|
||||
"text-embedding-3-small": {
|
||||
"model_type": ModelType.embedding,
|
||||
"metadata": {"embedding_dimension": 1536, "context_length": 8192},
|
||||
"provider_id": "test-provider",
|
||||
"provider_resource_id": "text-embedding-3-small",
|
||||
},
|
||||
"gpt-4": {
|
||||
"model_type": ModelType.llm,
|
||||
"metadata": {},
|
||||
"provider_id": "test-provider",
|
||||
"provider_resource_id": "gpt-4",
|
||||
},
|
||||
}
|
||||
|
||||
# Check embedding model
|
||||
assert embedding_model.model_type == ModelType.embedding
|
||||
assert embedding_model.metadata == {"embedding_dimension": 1536, "context_length": 8192}
|
||||
assert embedding_model.provider_id == "test-provider"
|
||||
assert embedding_model.provider_resource_id == "text-embedding-3-small"
|
||||
_assert_models_match_expected(result, expected_models)
|
||||
|
||||
# Check LLM model
|
||||
assert llm_model.model_type == ModelType.llm
|
||||
assert llm_model.metadata == {} # No metadata for LLMs
|
||||
assert llm_model.provider_id == "test-provider"
|
||||
assert llm_model.provider_resource_id == "gpt-4"
|
||||
|
||||
class TestOpenAIMixinCustomModelConstruction:
|
||||
"""Test cases for mixed model types (LLM, embedding, rerank) through construct_model_from_identifier"""
|
||||
|
||||
async def test_mixed_model_types_identification(self, mixin_with_custom_model_construction, mock_client_context):
|
||||
"""Test that LLM, embedding, and rerank models are correctly identified with proper types and metadata"""
|
||||
# Create mock models: 1 embedding, 1 rerank, 1 LLM
|
||||
mock_embedding_model = MagicMock(id="text-embedding-3-small")
|
||||
mock_rerank_model = MagicMock(id="rerank-model-1")
|
||||
mock_llm_model = MagicMock(id="gpt-4")
|
||||
mock_models = [mock_embedding_model, mock_rerank_model, mock_llm_model]
|
||||
|
||||
mock_client = MagicMock()
|
||||
|
||||
async def mock_models_list():
|
||||
for model in mock_models:
|
||||
yield model
|
||||
|
||||
mock_client.models.list.return_value = mock_models_list()
|
||||
|
||||
with mock_client_context(mixin_with_custom_model_construction, mock_client):
|
||||
result = await mixin_with_custom_model_construction.list_models()
|
||||
|
||||
assert result is not None
|
||||
assert len(result) == 3
|
||||
|
||||
expected_models = {
|
||||
"text-embedding-3-small": {
|
||||
"model_type": ModelType.embedding,
|
||||
"metadata": {"embedding_dimension": 1536, "context_length": 8192},
|
||||
"provider_id": "test-provider",
|
||||
"provider_resource_id": "text-embedding-3-small",
|
||||
},
|
||||
"rerank-model-1": {
|
||||
"model_type": ModelType.rerank,
|
||||
"metadata": {},
|
||||
"provider_id": "test-provider",
|
||||
"provider_resource_id": "rerank-model-1",
|
||||
},
|
||||
"gpt-4": {
|
||||
"model_type": ModelType.llm,
|
||||
"metadata": {},
|
||||
"provider_id": "test-provider",
|
||||
"provider_resource_id": "gpt-4",
|
||||
},
|
||||
}
|
||||
|
||||
_assert_models_match_expected(result, expected_models)
|
||||
|
||||
|
||||
class TestOpenAIMixinAllowedModels:
|
||||
|
|
@ -720,7 +814,7 @@ class TestOpenAIMixinProviderDataApiKey:
|
|||
):
|
||||
"""Test that ValueError is raised when provider data exists but doesn't have required key"""
|
||||
with request_provider_data_context({"x-llamastack-provider-data": json.dumps({"wrong_key": "some-value"})}):
|
||||
with pytest.raises(ValueError, match="API key is not set"):
|
||||
with pytest.raises(ValueError, match="API key not provided"):
|
||||
_ = mixin_with_provider_data_field_and_none_api_key.client
|
||||
|
||||
def test_error_message_includes_correct_field_names(self, mixin_with_provider_data_field_and_none_api_key):
|
||||
|
|
|
|||
|
|
@ -0,0 +1,77 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import pytest
|
||||
|
||||
from llama_stack.core.stack import replace_env_vars
|
||||
from llama_stack.providers.remote.inference.anthropic import AnthropicConfig
|
||||
from llama_stack.providers.remote.inference.azure import AzureConfig
|
||||
from llama_stack.providers.remote.inference.bedrock import BedrockConfig
|
||||
from llama_stack.providers.remote.inference.cerebras import CerebrasImplConfig
|
||||
from llama_stack.providers.remote.inference.databricks import DatabricksImplConfig
|
||||
from llama_stack.providers.remote.inference.fireworks import FireworksImplConfig
|
||||
from llama_stack.providers.remote.inference.gemini import GeminiConfig
|
||||
from llama_stack.providers.remote.inference.groq import GroqConfig
|
||||
from llama_stack.providers.remote.inference.llama_openai_compat import LlamaCompatConfig
|
||||
from llama_stack.providers.remote.inference.nvidia import NVIDIAConfig
|
||||
from llama_stack.providers.remote.inference.ollama import OllamaImplConfig
|
||||
from llama_stack.providers.remote.inference.openai import OpenAIConfig
|
||||
from llama_stack.providers.remote.inference.runpod import RunpodImplConfig
|
||||
from llama_stack.providers.remote.inference.sambanova import SambaNovaImplConfig
|
||||
from llama_stack.providers.remote.inference.tgi import TGIImplConfig
|
||||
from llama_stack.providers.remote.inference.together import TogetherImplConfig
|
||||
from llama_stack.providers.remote.inference.vertexai import VertexAIConfig
|
||||
from llama_stack.providers.remote.inference.vllm import VLLMInferenceAdapterConfig
|
||||
from llama_stack.providers.remote.inference.watsonx import WatsonXConfig
|
||||
|
||||
|
||||
class TestRemoteInferenceProviderConfig:
|
||||
@pytest.mark.parametrize(
|
||||
"config_cls,alias_name,env_name,extra_config",
|
||||
[
|
||||
(AnthropicConfig, "api_key", "ANTHROPIC_API_KEY", {}),
|
||||
(AzureConfig, "api_key", "AZURE_API_KEY", {"api_base": "HTTP://FAKE"}),
|
||||
(BedrockConfig, None, None, {}),
|
||||
(CerebrasImplConfig, "api_key", "CEREBRAS_API_KEY", {}),
|
||||
(DatabricksImplConfig, "api_token", "DATABRICKS_TOKEN", {}),
|
||||
(FireworksImplConfig, "api_key", "FIREWORKS_API_KEY", {}),
|
||||
(GeminiConfig, "api_key", "GEMINI_API_KEY", {}),
|
||||
(GroqConfig, "api_key", "GROQ_API_KEY", {}),
|
||||
(LlamaCompatConfig, "api_key", "LLAMA_API_KEY", {}),
|
||||
(NVIDIAConfig, "api_key", "NVIDIA_API_KEY", {}),
|
||||
(OllamaImplConfig, None, None, {}),
|
||||
(OpenAIConfig, "api_key", "OPENAI_API_KEY", {}),
|
||||
(RunpodImplConfig, "api_token", "RUNPOD_API_TOKEN", {}),
|
||||
(SambaNovaImplConfig, "api_key", "SAMBANOVA_API_KEY", {}),
|
||||
(TGIImplConfig, None, None, {"url": "FAKE"}),
|
||||
(TogetherImplConfig, "api_key", "TOGETHER_API_KEY", {}),
|
||||
(VertexAIConfig, None, None, {"project": "FAKE", "location": "FAKE"}),
|
||||
(VLLMInferenceAdapterConfig, "api_token", "VLLM_API_TOKEN", {}),
|
||||
(WatsonXConfig, "api_key", "WATSONX_API_KEY", {}),
|
||||
],
|
||||
)
|
||||
def test_provider_config_auth_credentials(self, monkeypatch, config_cls, alias_name, env_name, extra_config):
|
||||
"""Test that the config class correctly maps the alias to auth_credential."""
|
||||
secret_value = config_cls.__name__
|
||||
|
||||
if alias_name is None:
|
||||
pytest.skip("No alias name provided for this config class.")
|
||||
|
||||
config = config_cls(**{alias_name: secret_value, **extra_config})
|
||||
assert config.auth_credential is not None
|
||||
assert config.auth_credential.get_secret_value() == secret_value
|
||||
|
||||
schema = config_cls.model_json_schema()
|
||||
assert alias_name in schema["properties"]
|
||||
assert "auth_credential" not in schema["properties"]
|
||||
|
||||
if env_name:
|
||||
monkeypatch.setenv(env_name, secret_value)
|
||||
sample_config = config_cls.sample_run_config()
|
||||
expanded_config = replace_env_vars(sample_config)
|
||||
config_from_sample = config_cls(**{**expanded_config, **extra_config})
|
||||
assert config_from_sample.auth_credential is not None
|
||||
assert config_from_sample.auth_credential.get_secret_value() == secret_value
|
||||
|
|
@ -9,38 +9,29 @@ from unittest.mock import AsyncMock, MagicMock, patch
|
|||
|
||||
import numpy as np
|
||||
import pytest
|
||||
from chromadb import PersistentClient
|
||||
from pymilvus import MilvusClient, connections
|
||||
|
||||
from llama_stack.apis.vector_dbs import VectorDB
|
||||
from llama_stack.apis.vector_io import Chunk, ChunkMetadata, QueryChunksResponse
|
||||
from llama_stack.providers.inline.vector_io.chroma.config import ChromaVectorIOConfig
|
||||
from llama_stack.apis.vector_stores import VectorStore
|
||||
from llama_stack.core.storage.datatypes import KVStoreReference, SqliteKVStoreConfig
|
||||
from llama_stack.providers.inline.vector_io.faiss.config import FaissVectorIOConfig
|
||||
from llama_stack.providers.inline.vector_io.faiss.faiss import FaissIndex, FaissVectorIOAdapter
|
||||
from llama_stack.providers.inline.vector_io.milvus.config import MilvusVectorIOConfig, SqliteKVStoreConfig
|
||||
from llama_stack.providers.inline.vector_io.qdrant import QdrantVectorIOConfig
|
||||
from llama_stack.providers.inline.vector_io.sqlite_vec import SQLiteVectorIOConfig
|
||||
from llama_stack.providers.inline.vector_io.sqlite_vec.sqlite_vec import SQLiteVecIndex, SQLiteVecVectorIOAdapter
|
||||
from llama_stack.providers.remote.vector_io.chroma.chroma import ChromaIndex, ChromaVectorIOAdapter, maybe_await
|
||||
from llama_stack.providers.remote.vector_io.milvus.milvus import MilvusIndex, MilvusVectorIOAdapter
|
||||
from llama_stack.providers.remote.vector_io.pgvector.config import PGVectorVectorIOConfig
|
||||
from llama_stack.providers.remote.vector_io.pgvector.pgvector import PGVectorIndex, PGVectorVectorIOAdapter
|
||||
from llama_stack.providers.remote.vector_io.qdrant.qdrant import QdrantVectorIOAdapter
|
||||
from llama_stack.providers.remote.vector_io.weaviate.config import WeaviateVectorIOConfig
|
||||
from llama_stack.providers.remote.vector_io.weaviate.weaviate import WeaviateIndex, WeaviateVectorIOAdapter
|
||||
from llama_stack.providers.utils.kvstore import register_kvstore_backends
|
||||
|
||||
EMBEDDING_DIMENSION = 384
|
||||
EMBEDDING_DIMENSION = 768
|
||||
COLLECTION_PREFIX = "test_collection"
|
||||
MILVUS_ALIAS = "test_milvus"
|
||||
|
||||
|
||||
@pytest.fixture(params=["milvus", "sqlite_vec", "faiss", "chroma", "pgvector", "weaviate"])
|
||||
@pytest.fixture(params=["sqlite_vec", "faiss", "pgvector"])
|
||||
def vector_provider(request):
|
||||
return request.param
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def vector_db_id() -> str:
|
||||
def vector_store_id() -> str:
|
||||
return f"test-vector-db-{random.randint(1, 100)}"
|
||||
|
||||
|
||||
|
|
@ -122,8 +113,9 @@ async def unique_kvstore_config(tmp_path_factory):
|
|||
unique_id = f"test_kv_{np.random.randint(1e6)}"
|
||||
temp_dir = tmp_path_factory.getbasetemp()
|
||||
db_path = str(temp_dir / f"{unique_id}.db")
|
||||
|
||||
return SqliteKVStoreConfig(db_path=db_path)
|
||||
backend_name = f"kv_vector_{unique_id}"
|
||||
register_kvstore_backends({backend_name: SqliteKVStoreConfig(db_path=db_path)})
|
||||
return KVStoreReference(backend=backend_name, namespace=f"vector_io::{unique_id}")
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
|
|
@ -148,7 +140,7 @@ async def sqlite_vec_vec_index(embedding_dimension, tmp_path_factory):
|
|||
async def sqlite_vec_adapter(sqlite_vec_db_path, unique_kvstore_config, mock_inference_api, embedding_dimension):
|
||||
config = SQLiteVectorIOConfig(
|
||||
db_path=sqlite_vec_db_path,
|
||||
kvstore=unique_kvstore_config,
|
||||
persistence=unique_kvstore_config,
|
||||
)
|
||||
adapter = SQLiteVecVectorIOAdapter(
|
||||
config=config,
|
||||
|
|
@ -157,8 +149,8 @@ async def sqlite_vec_adapter(sqlite_vec_db_path, unique_kvstore_config, mock_inf
|
|||
)
|
||||
collection_id = f"sqlite_test_collection_{np.random.randint(1e6)}"
|
||||
await adapter.initialize()
|
||||
await adapter.register_vector_db(
|
||||
VectorDB(
|
||||
await adapter.register_vector_store(
|
||||
VectorStore(
|
||||
identifier=collection_id,
|
||||
provider_id="test_provider",
|
||||
embedding_model="test_model",
|
||||
|
|
@ -170,46 +162,6 @@ async def sqlite_vec_adapter(sqlite_vec_db_path, unique_kvstore_config, mock_inf
|
|||
await adapter.shutdown()
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def milvus_vec_db_path(tmp_path_factory):
|
||||
db_path = str(tmp_path_factory.getbasetemp() / "test_milvus.db")
|
||||
return db_path
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def milvus_vec_index(milvus_vec_db_path, embedding_dimension):
|
||||
client = MilvusClient(milvus_vec_db_path)
|
||||
name = f"{COLLECTION_PREFIX}_{np.random.randint(1e6)}"
|
||||
connections.connect(alias=MILVUS_ALIAS, uri=milvus_vec_db_path)
|
||||
index = MilvusIndex(client, name, consistency_level="Strong")
|
||||
index.db_path = milvus_vec_db_path
|
||||
yield index
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def milvus_vec_adapter(milvus_vec_db_path, unique_kvstore_config, mock_inference_api):
|
||||
config = MilvusVectorIOConfig(
|
||||
db_path=milvus_vec_db_path,
|
||||
kvstore=unique_kvstore_config,
|
||||
)
|
||||
adapter = MilvusVectorIOAdapter(
|
||||
config=config,
|
||||
inference_api=mock_inference_api,
|
||||
files_api=None,
|
||||
)
|
||||
await adapter.initialize()
|
||||
await adapter.register_vector_db(
|
||||
VectorDB(
|
||||
identifier=adapter.metadata_collection_name,
|
||||
provider_id="test_provider",
|
||||
embedding_model="test_model",
|
||||
embedding_dimension=128,
|
||||
)
|
||||
)
|
||||
yield adapter
|
||||
await adapter.shutdown()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def faiss_vec_db_path(tmp_path_factory):
|
||||
db_path = str(tmp_path_factory.getbasetemp() / "test_faiss.db")
|
||||
|
|
@ -226,7 +178,7 @@ async def faiss_vec_index(embedding_dimension):
|
|||
@pytest.fixture
|
||||
async def faiss_vec_adapter(unique_kvstore_config, mock_inference_api, embedding_dimension):
|
||||
config = FaissVectorIOConfig(
|
||||
kvstore=unique_kvstore_config,
|
||||
persistence=unique_kvstore_config,
|
||||
)
|
||||
adapter = FaissVectorIOAdapter(
|
||||
config=config,
|
||||
|
|
@ -234,8 +186,8 @@ async def faiss_vec_adapter(unique_kvstore_config, mock_inference_api, embedding
|
|||
files_api=None,
|
||||
)
|
||||
await adapter.initialize()
|
||||
await adapter.register_vector_db(
|
||||
VectorDB(
|
||||
await adapter.register_vector_store(
|
||||
VectorStore(
|
||||
identifier=f"faiss_test_collection_{np.random.randint(1e6)}",
|
||||
provider_id="test_provider",
|
||||
embedding_model="test_model",
|
||||
|
|
@ -246,98 +198,6 @@ async def faiss_vec_adapter(unique_kvstore_config, mock_inference_api, embedding
|
|||
await adapter.shutdown()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def chroma_vec_db_path(tmp_path_factory):
|
||||
persist_dir = tmp_path_factory.mktemp(f"chroma_{np.random.randint(1e6)}")
|
||||
return str(persist_dir)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def chroma_vec_index(chroma_vec_db_path, embedding_dimension):
|
||||
client = PersistentClient(path=chroma_vec_db_path)
|
||||
name = f"{COLLECTION_PREFIX}_{np.random.randint(1e6)}"
|
||||
collection = await maybe_await(client.get_or_create_collection(name))
|
||||
index = ChromaIndex(client=client, collection=collection)
|
||||
await index.initialize()
|
||||
yield index
|
||||
await index.delete()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def chroma_vec_adapter(chroma_vec_db_path, unique_kvstore_config, mock_inference_api, embedding_dimension):
|
||||
config = ChromaVectorIOConfig(
|
||||
db_path=chroma_vec_db_path,
|
||||
kvstore=unique_kvstore_config,
|
||||
)
|
||||
adapter = ChromaVectorIOAdapter(
|
||||
config=config,
|
||||
inference_api=mock_inference_api,
|
||||
files_api=None,
|
||||
)
|
||||
await adapter.initialize()
|
||||
await adapter.register_vector_db(
|
||||
VectorDB(
|
||||
identifier=f"chroma_test_collection_{random.randint(1, 1_000_000)}",
|
||||
provider_id="test_provider",
|
||||
embedding_model="test_model",
|
||||
embedding_dimension=embedding_dimension,
|
||||
)
|
||||
)
|
||||
yield adapter
|
||||
await adapter.shutdown()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def qdrant_vec_db_path(tmp_path_factory):
|
||||
import uuid
|
||||
|
||||
db_path = str(tmp_path_factory.getbasetemp() / f"test_qdrant_{uuid.uuid4()}.db")
|
||||
return db_path
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def qdrant_vec_adapter(qdrant_vec_db_path, unique_kvstore_config, mock_inference_api, embedding_dimension):
|
||||
import uuid
|
||||
|
||||
config = QdrantVectorIOConfig(
|
||||
db_path=qdrant_vec_db_path,
|
||||
kvstore=unique_kvstore_config,
|
||||
)
|
||||
adapter = QdrantVectorIOAdapter(
|
||||
config=config,
|
||||
inference_api=mock_inference_api,
|
||||
files_api=None,
|
||||
)
|
||||
collection_id = f"qdrant_test_collection_{uuid.uuid4()}"
|
||||
await adapter.initialize()
|
||||
await adapter.register_vector_db(
|
||||
VectorDB(
|
||||
identifier=collection_id,
|
||||
provider_id="test_provider",
|
||||
embedding_model="test_model",
|
||||
embedding_dimension=embedding_dimension,
|
||||
)
|
||||
)
|
||||
adapter.test_collection_id = collection_id
|
||||
yield adapter
|
||||
await adapter.shutdown()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def qdrant_vec_index(qdrant_vec_db_path, embedding_dimension):
|
||||
import uuid
|
||||
|
||||
from qdrant_client import AsyncQdrantClient
|
||||
|
||||
from llama_stack.providers.remote.vector_io.qdrant.qdrant import QdrantIndex
|
||||
|
||||
client = AsyncQdrantClient(path=qdrant_vec_db_path)
|
||||
collection_name = f"qdrant_test_collection_{uuid.uuid4()}"
|
||||
index = QdrantIndex(client, collection_name)
|
||||
yield index
|
||||
await index.delete()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_psycopg2_connection():
|
||||
connection = MagicMock()
|
||||
|
|
@ -355,7 +215,7 @@ def mock_psycopg2_connection():
|
|||
async def pgvector_vec_index(embedding_dimension, mock_psycopg2_connection):
|
||||
connection, cursor = mock_psycopg2_connection
|
||||
|
||||
vector_db = VectorDB(
|
||||
vector_store = VectorStore(
|
||||
identifier="test-vector-db",
|
||||
embedding_model="test-model",
|
||||
embedding_dimension=embedding_dimension,
|
||||
|
|
@ -365,7 +225,7 @@ async def pgvector_vec_index(embedding_dimension, mock_psycopg2_connection):
|
|||
|
||||
with patch("llama_stack.providers.remote.vector_io.pgvector.pgvector.psycopg2"):
|
||||
with patch("llama_stack.providers.remote.vector_io.pgvector.pgvector.execute_values"):
|
||||
index = PGVectorIndex(vector_db, embedding_dimension, connection, distance_metric="COSINE")
|
||||
index = PGVectorIndex(vector_store, embedding_dimension, connection, distance_metric="COSINE")
|
||||
index._test_chunks = []
|
||||
original_add_chunks = index.add_chunks
|
||||
|
||||
|
|
@ -393,7 +253,7 @@ async def pgvector_vec_adapter(unique_kvstore_config, mock_inference_api, embedd
|
|||
db="test_db",
|
||||
user="test_user",
|
||||
password="test_password",
|
||||
kvstore=unique_kvstore_config,
|
||||
persistence=unique_kvstore_config,
|
||||
)
|
||||
|
||||
adapter = PGVectorVectorIOAdapter(config, mock_inference_api, None)
|
||||
|
|
@ -421,110 +281,41 @@ async def pgvector_vec_adapter(unique_kvstore_config, mock_inference_api, embedd
|
|||
await adapter.initialize()
|
||||
adapter.conn = mock_conn
|
||||
|
||||
async def mock_insert_chunks(vector_db_id, chunks, ttl_seconds=None):
|
||||
index = await adapter._get_and_cache_vector_db_index(vector_db_id)
|
||||
async def mock_insert_chunks(vector_store_id, chunks, ttl_seconds=None):
|
||||
index = await adapter._get_and_cache_vector_store_index(vector_store_id)
|
||||
if not index:
|
||||
raise ValueError(f"Vector DB {vector_db_id} not found")
|
||||
raise ValueError(f"Vector DB {vector_store_id} not found")
|
||||
await index.insert_chunks(chunks)
|
||||
|
||||
adapter.insert_chunks = mock_insert_chunks
|
||||
|
||||
async def mock_query_chunks(vector_db_id, query, params=None):
|
||||
index = await adapter._get_and_cache_vector_db_index(vector_db_id)
|
||||
async def mock_query_chunks(vector_store_id, query, params=None):
|
||||
index = await adapter._get_and_cache_vector_store_index(vector_store_id)
|
||||
if not index:
|
||||
raise ValueError(f"Vector DB {vector_db_id} not found")
|
||||
raise ValueError(f"Vector DB {vector_store_id} not found")
|
||||
return await index.query_chunks(query, params)
|
||||
|
||||
adapter.query_chunks = mock_query_chunks
|
||||
|
||||
test_vector_db = VectorDB(
|
||||
test_vector_store = VectorStore(
|
||||
identifier=f"pgvector_test_collection_{random.randint(1, 1_000_000)}",
|
||||
provider_id="test_provider",
|
||||
embedding_model="test_model",
|
||||
embedding_dimension=embedding_dimension,
|
||||
)
|
||||
await adapter.register_vector_db(test_vector_db)
|
||||
adapter.test_collection_id = test_vector_db.identifier
|
||||
await adapter.register_vector_store(test_vector_store)
|
||||
adapter.test_collection_id = test_vector_store.identifier
|
||||
|
||||
yield adapter
|
||||
await adapter.shutdown()
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def weaviate_vec_db_path(tmp_path_factory):
|
||||
db_path = str(tmp_path_factory.getbasetemp() / "test_weaviate.db")
|
||||
return db_path
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def weaviate_vec_index(weaviate_vec_db_path):
|
||||
import pytest_socket
|
||||
import weaviate
|
||||
|
||||
pytest_socket.enable_socket()
|
||||
client = weaviate.connect_to_embedded(
|
||||
hostname="localhost",
|
||||
port=8080,
|
||||
grpc_port=50051,
|
||||
persistence_data_path=weaviate_vec_db_path,
|
||||
)
|
||||
index = WeaviateIndex(client=client, collection_name="Testcollection")
|
||||
await index.initialize()
|
||||
yield index
|
||||
await index.delete()
|
||||
client.close()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def weaviate_vec_adapter(weaviate_vec_db_path, unique_kvstore_config, mock_inference_api, embedding_dimension):
|
||||
import pytest_socket
|
||||
import weaviate
|
||||
|
||||
pytest_socket.enable_socket()
|
||||
|
||||
client = weaviate.connect_to_embedded(
|
||||
hostname="localhost",
|
||||
port=8080,
|
||||
grpc_port=50051,
|
||||
persistence_data_path=weaviate_vec_db_path,
|
||||
)
|
||||
|
||||
config = WeaviateVectorIOConfig(
|
||||
weaviate_cluster_url="localhost:8080",
|
||||
weaviate_api_key=None,
|
||||
kvstore=unique_kvstore_config,
|
||||
)
|
||||
adapter = WeaviateVectorIOAdapter(
|
||||
config=config,
|
||||
inference_api=mock_inference_api,
|
||||
files_api=None,
|
||||
)
|
||||
collection_id = f"weaviate_test_collection_{random.randint(1, 1_000_000)}"
|
||||
await adapter.initialize()
|
||||
await adapter.register_vector_db(
|
||||
VectorDB(
|
||||
identifier=collection_id,
|
||||
provider_id="test_provider",
|
||||
embedding_model="test_model",
|
||||
embedding_dimension=embedding_dimension,
|
||||
)
|
||||
)
|
||||
adapter.test_collection_id = collection_id
|
||||
yield adapter
|
||||
await adapter.shutdown()
|
||||
client.close()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def vector_io_adapter(vector_provider, request):
|
||||
vector_provider_dict = {
|
||||
"milvus": "milvus_vec_adapter",
|
||||
"faiss": "faiss_vec_adapter",
|
||||
"sqlite_vec": "sqlite_vec_adapter",
|
||||
"chroma": "chroma_vec_adapter",
|
||||
"qdrant": "qdrant_vec_adapter",
|
||||
"pgvector": "pgvector_vec_adapter",
|
||||
"weaviate": "weaviate_vec_adapter",
|
||||
}
|
||||
return request.getfixturevalue(vector_provider_dict[vector_provider])
|
||||
|
||||
|
|
|
|||
|
|
@ -1,326 +0,0 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
from llama_stack.apis.vector_io import QueryChunksResponse
|
||||
|
||||
# Mock the entire pymilvus module
|
||||
pymilvus_mock = MagicMock()
|
||||
pymilvus_mock.DataType = MagicMock()
|
||||
pymilvus_mock.MilvusClient = MagicMock
|
||||
pymilvus_mock.RRFRanker = MagicMock
|
||||
pymilvus_mock.WeightedRanker = MagicMock
|
||||
pymilvus_mock.AnnSearchRequest = MagicMock
|
||||
|
||||
# Apply the mock before importing MilvusIndex
|
||||
with patch.dict("sys.modules", {"pymilvus": pymilvus_mock}):
|
||||
from llama_stack.providers.remote.vector_io.milvus.milvus import MilvusIndex
|
||||
|
||||
# This test is a unit test for the MilvusVectorIOAdapter class. This should only contain
|
||||
# tests which are specific to this class. More general (API-level) tests should be placed in
|
||||
# tests/integration/vector_io/
|
||||
#
|
||||
# How to run this test:
|
||||
#
|
||||
# pytest tests/unit/providers/vector_io/test_milvus.py \
|
||||
# -v -s --tb=short --disable-warnings --asyncio-mode=auto
|
||||
|
||||
MILVUS_PROVIDER = "milvus"
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def mock_milvus_client() -> MagicMock:
|
||||
"""Create a mock Milvus client with common method behaviors."""
|
||||
client = MagicMock()
|
||||
|
||||
# Mock collection operations
|
||||
client.has_collection.return_value = False # Initially no collection
|
||||
client.create_collection.return_value = None
|
||||
client.drop_collection.return_value = None
|
||||
|
||||
# Mock insert operation
|
||||
client.insert.return_value = {"insert_count": 10}
|
||||
|
||||
# Mock search operation - return mock results (data should be dict, not JSON string)
|
||||
client.search.return_value = [
|
||||
[
|
||||
{
|
||||
"id": 0,
|
||||
"distance": 0.1,
|
||||
"entity": {"chunk_content": {"content": "mock chunk 1", "metadata": {"document_id": "doc1"}}},
|
||||
},
|
||||
{
|
||||
"id": 1,
|
||||
"distance": 0.2,
|
||||
"entity": {"chunk_content": {"content": "mock chunk 2", "metadata": {"document_id": "doc2"}}},
|
||||
},
|
||||
]
|
||||
]
|
||||
|
||||
# Mock query operation for keyword search (data should be dict, not JSON string)
|
||||
client.query.return_value = [
|
||||
{
|
||||
"chunk_id": "chunk1",
|
||||
"chunk_content": {"content": "mock chunk 1", "metadata": {"document_id": "doc1"}},
|
||||
"score": 0.9,
|
||||
},
|
||||
{
|
||||
"chunk_id": "chunk2",
|
||||
"chunk_content": {"content": "mock chunk 2", "metadata": {"document_id": "doc2"}},
|
||||
"score": 0.8,
|
||||
},
|
||||
{
|
||||
"chunk_id": "chunk3",
|
||||
"chunk_content": {"content": "mock chunk 3", "metadata": {"document_id": "doc3"}},
|
||||
"score": 0.7,
|
||||
},
|
||||
]
|
||||
|
||||
return client
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def milvus_index(mock_milvus_client):
|
||||
"""Create a MilvusIndex with mocked client."""
|
||||
index = MilvusIndex(client=mock_milvus_client, collection_name="test_collection")
|
||||
yield index
|
||||
# No real cleanup needed since we're using mocks
|
||||
|
||||
|
||||
async def test_add_chunks(milvus_index, sample_chunks, sample_embeddings, mock_milvus_client):
|
||||
# Setup: collection doesn't exist initially, then exists after creation
|
||||
mock_milvus_client.has_collection.side_effect = [False, True]
|
||||
|
||||
await milvus_index.add_chunks(sample_chunks, sample_embeddings)
|
||||
|
||||
# Verify collection was created and data was inserted
|
||||
mock_milvus_client.create_collection.assert_called_once()
|
||||
mock_milvus_client.insert.assert_called_once()
|
||||
|
||||
# Verify the insert call had the right number of chunks
|
||||
insert_call = mock_milvus_client.insert.call_args
|
||||
assert len(insert_call[1]["data"]) == len(sample_chunks)
|
||||
|
||||
|
||||
async def test_query_chunks_vector(
|
||||
milvus_index, sample_chunks, sample_embeddings, embedding_dimension, mock_milvus_client
|
||||
):
|
||||
# Setup: Add chunks first
|
||||
mock_milvus_client.has_collection.return_value = True
|
||||
await milvus_index.add_chunks(sample_chunks, sample_embeddings)
|
||||
|
||||
# Test vector search
|
||||
query_embedding = np.random.rand(embedding_dimension).astype(np.float32)
|
||||
response = await milvus_index.query_vector(query_embedding, k=2, score_threshold=0.0)
|
||||
|
||||
assert isinstance(response, QueryChunksResponse)
|
||||
assert len(response.chunks) == 2
|
||||
mock_milvus_client.search.assert_called_once()
|
||||
|
||||
|
||||
async def test_query_chunks_keyword_search(milvus_index, sample_chunks, sample_embeddings, mock_milvus_client):
|
||||
mock_milvus_client.has_collection.return_value = True
|
||||
await milvus_index.add_chunks(sample_chunks, sample_embeddings)
|
||||
|
||||
# Test keyword search
|
||||
query_string = "Sentence 5"
|
||||
response = await milvus_index.query_keyword(query_string=query_string, k=2, score_threshold=0.0)
|
||||
|
||||
assert isinstance(response, QueryChunksResponse)
|
||||
assert len(response.chunks) == 2
|
||||
|
||||
|
||||
async def test_bm25_fallback_to_simple_search(milvus_index, sample_chunks, sample_embeddings, mock_milvus_client):
|
||||
"""Test that when BM25 search fails, the system falls back to simple text search."""
|
||||
mock_milvus_client.has_collection.return_value = True
|
||||
await milvus_index.add_chunks(sample_chunks, sample_embeddings)
|
||||
|
||||
# Force BM25 search to fail
|
||||
mock_milvus_client.search.side_effect = Exception("BM25 search not available")
|
||||
|
||||
# Mock simple text search results
|
||||
mock_milvus_client.query.return_value = [
|
||||
{
|
||||
"chunk_id": "chunk1",
|
||||
"chunk_content": {"content": "Python programming language", "metadata": {"document_id": "doc1"}},
|
||||
},
|
||||
{
|
||||
"chunk_id": "chunk2",
|
||||
"chunk_content": {"content": "Machine learning algorithms", "metadata": {"document_id": "doc2"}},
|
||||
},
|
||||
]
|
||||
|
||||
# Test keyword search that should fall back to simple text search
|
||||
query_string = "Python"
|
||||
response = await milvus_index.query_keyword(query_string=query_string, k=3, score_threshold=0.0)
|
||||
|
||||
# Verify response structure
|
||||
assert isinstance(response, QueryChunksResponse)
|
||||
assert len(response.chunks) > 0, "Fallback search should return results"
|
||||
|
||||
# Verify that simple text search was used (query method called instead of search)
|
||||
mock_milvus_client.query.assert_called_once()
|
||||
mock_milvus_client.search.assert_called_once() # Called once but failed
|
||||
|
||||
# Verify the query uses parameterized filter with filter_params
|
||||
query_call_args = mock_milvus_client.query.call_args
|
||||
assert "filter" in query_call_args[1], "Query should include filter for text search"
|
||||
assert "filter_params" in query_call_args[1], "Query should use parameterized filter"
|
||||
assert query_call_args[1]["filter_params"]["content"] == "Python", "Filter params should contain the search term"
|
||||
|
||||
# Verify all returned chunks have score 1.0 (simple binary scoring)
|
||||
assert all(score == 1.0 for score in response.scores), "Simple text search should use binary scoring"
|
||||
|
||||
|
||||
async def test_delete_collection(milvus_index, mock_milvus_client):
|
||||
# Test collection deletion
|
||||
mock_milvus_client.has_collection.return_value = True
|
||||
|
||||
await milvus_index.delete()
|
||||
|
||||
mock_milvus_client.drop_collection.assert_called_once_with(collection_name=milvus_index.collection_name)
|
||||
|
||||
|
||||
async def test_query_hybrid_search_rrf(
|
||||
milvus_index, sample_chunks, sample_embeddings, embedding_dimension, mock_milvus_client
|
||||
):
|
||||
"""Test hybrid search with RRF reranker."""
|
||||
mock_milvus_client.has_collection.return_value = True
|
||||
await milvus_index.add_chunks(sample_chunks, sample_embeddings)
|
||||
|
||||
# Mock hybrid search results
|
||||
mock_milvus_client.hybrid_search.return_value = [
|
||||
[
|
||||
{
|
||||
"id": 0,
|
||||
"distance": 0.1,
|
||||
"entity": {"chunk_content": {"content": "mock chunk 1", "metadata": {"document_id": "doc1"}}},
|
||||
},
|
||||
{
|
||||
"id": 1,
|
||||
"distance": 0.2,
|
||||
"entity": {"chunk_content": {"content": "mock chunk 2", "metadata": {"document_id": "doc2"}}},
|
||||
},
|
||||
]
|
||||
]
|
||||
|
||||
# Test hybrid search with RRF reranker
|
||||
query_embedding = np.random.rand(embedding_dimension).astype(np.float32)
|
||||
query_string = "test query"
|
||||
response = await milvus_index.query_hybrid(
|
||||
embedding=query_embedding,
|
||||
query_string=query_string,
|
||||
k=2,
|
||||
score_threshold=0.0,
|
||||
reranker_type="rrf",
|
||||
reranker_params={"impact_factor": 60.0},
|
||||
)
|
||||
|
||||
assert isinstance(response, QueryChunksResponse)
|
||||
assert len(response.chunks) == 2
|
||||
assert len(response.scores) == 2
|
||||
|
||||
# Verify hybrid search was called with correct parameters
|
||||
mock_milvus_client.hybrid_search.assert_called_once()
|
||||
call_args = mock_milvus_client.hybrid_search.call_args
|
||||
|
||||
# Check that the request contains both vector and BM25 search requests
|
||||
reqs = call_args[1]["reqs"]
|
||||
assert len(reqs) == 2
|
||||
assert reqs[0].anns_field == "vector"
|
||||
assert reqs[1].anns_field == "sparse"
|
||||
ranker = call_args[1]["ranker"]
|
||||
assert ranker is not None
|
||||
|
||||
|
||||
async def test_query_hybrid_search_weighted(
|
||||
milvus_index, sample_chunks, sample_embeddings, embedding_dimension, mock_milvus_client
|
||||
):
|
||||
"""Test hybrid search with weighted reranker."""
|
||||
mock_milvus_client.has_collection.return_value = True
|
||||
await milvus_index.add_chunks(sample_chunks, sample_embeddings)
|
||||
|
||||
# Mock hybrid search results
|
||||
mock_milvus_client.hybrid_search.return_value = [
|
||||
[
|
||||
{
|
||||
"id": 0,
|
||||
"distance": 0.1,
|
||||
"entity": {"chunk_content": {"content": "mock chunk 1", "metadata": {"document_id": "doc1"}}},
|
||||
},
|
||||
{
|
||||
"id": 1,
|
||||
"distance": 0.2,
|
||||
"entity": {"chunk_content": {"content": "mock chunk 2", "metadata": {"document_id": "doc2"}}},
|
||||
},
|
||||
]
|
||||
]
|
||||
|
||||
# Test hybrid search with weighted reranker
|
||||
query_embedding = np.random.rand(embedding_dimension).astype(np.float32)
|
||||
query_string = "test query"
|
||||
response = await milvus_index.query_hybrid(
|
||||
embedding=query_embedding,
|
||||
query_string=query_string,
|
||||
k=2,
|
||||
score_threshold=0.0,
|
||||
reranker_type="weighted",
|
||||
reranker_params={"alpha": 0.7},
|
||||
)
|
||||
|
||||
assert isinstance(response, QueryChunksResponse)
|
||||
assert len(response.chunks) == 2
|
||||
assert len(response.scores) == 2
|
||||
|
||||
# Verify hybrid search was called with correct parameters
|
||||
mock_milvus_client.hybrid_search.assert_called_once()
|
||||
call_args = mock_milvus_client.hybrid_search.call_args
|
||||
ranker = call_args[1]["ranker"]
|
||||
assert ranker is not None
|
||||
|
||||
|
||||
async def test_query_hybrid_search_default_rrf(
|
||||
milvus_index, sample_chunks, sample_embeddings, embedding_dimension, mock_milvus_client
|
||||
):
|
||||
"""Test hybrid search with default RRF reranker (no reranker_type specified)."""
|
||||
mock_milvus_client.has_collection.return_value = True
|
||||
await milvus_index.add_chunks(sample_chunks, sample_embeddings)
|
||||
|
||||
# Mock hybrid search results
|
||||
mock_milvus_client.hybrid_search.return_value = [
|
||||
[
|
||||
{
|
||||
"id": 0,
|
||||
"distance": 0.1,
|
||||
"entity": {"chunk_content": {"content": "mock chunk 1", "metadata": {"document_id": "doc1"}}},
|
||||
},
|
||||
]
|
||||
]
|
||||
|
||||
# Test hybrid search with default reranker (should be RRF)
|
||||
query_embedding = np.random.rand(embedding_dimension).astype(np.float32)
|
||||
query_string = "test query"
|
||||
response = await milvus_index.query_hybrid(
|
||||
embedding=query_embedding,
|
||||
query_string=query_string,
|
||||
k=1,
|
||||
score_threshold=0.0,
|
||||
reranker_type="unknown_type", # Should default to RRF
|
||||
reranker_params=None, # Should use default impact_factor
|
||||
)
|
||||
|
||||
assert isinstance(response, QueryChunksResponse)
|
||||
assert len(response.chunks) == 1
|
||||
|
||||
# Verify hybrid search was called with RRF reranker
|
||||
mock_milvus_client.hybrid_search.assert_called_once()
|
||||
call_args = mock_milvus_client.hybrid_search.call_args
|
||||
ranker = call_args[1]["ranker"]
|
||||
assert ranker is not None
|
||||
|
|
@ -1,138 +0,0 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import asyncio
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
|
||||
from llama_stack.apis.vector_dbs import VectorDB
|
||||
from llama_stack.providers.remote.vector_io.pgvector.pgvector import PGVectorIndex
|
||||
|
||||
PGVECTOR_PROVIDER = "pgvector"
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def loop():
|
||||
return asyncio.new_event_loop()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def embedding_dimension():
|
||||
"""Default embedding dimension for tests."""
|
||||
return 384
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def pgvector_index(embedding_dimension, mock_psycopg2_connection):
|
||||
"""Create a PGVectorIndex instance with mocked database connection."""
|
||||
connection, cursor = mock_psycopg2_connection
|
||||
|
||||
vector_db = VectorDB(
|
||||
identifier="test-vector-db",
|
||||
embedding_model="test-model",
|
||||
embedding_dimension=embedding_dimension,
|
||||
provider_id=PGVECTOR_PROVIDER,
|
||||
provider_resource_id=f"{PGVECTOR_PROVIDER}:test-vector-db",
|
||||
)
|
||||
|
||||
with patch("llama_stack.providers.remote.vector_io.pgvector.pgvector.psycopg2"):
|
||||
# Use explicit COSINE distance metric for consistent testing
|
||||
index = PGVectorIndex(vector_db, embedding_dimension, connection, distance_metric="COSINE")
|
||||
|
||||
return index, cursor
|
||||
|
||||
|
||||
class TestPGVectorIndex:
|
||||
def test_distance_metric_validation(self, embedding_dimension, mock_psycopg2_connection):
|
||||
connection, cursor = mock_psycopg2_connection
|
||||
|
||||
vector_db = VectorDB(
|
||||
identifier="test-vector-db",
|
||||
embedding_model="test-model",
|
||||
embedding_dimension=embedding_dimension,
|
||||
provider_id=PGVECTOR_PROVIDER,
|
||||
provider_resource_id=f"{PGVECTOR_PROVIDER}:test-vector-db",
|
||||
)
|
||||
|
||||
with patch("llama_stack.providers.remote.vector_io.pgvector.pgvector.psycopg2"):
|
||||
index = PGVectorIndex(vector_db, embedding_dimension, connection, distance_metric="L2")
|
||||
assert index.distance_metric == "L2"
|
||||
with pytest.raises(ValueError, match="Distance metric 'INVALID' is not supported"):
|
||||
PGVectorIndex(vector_db, embedding_dimension, connection, distance_metric="INVALID")
|
||||
|
||||
def test_get_pgvector_search_function(self, pgvector_index):
|
||||
index, cursor = pgvector_index
|
||||
supported_metrics = index.PGVECTOR_DISTANCE_METRIC_TO_SEARCH_FUNCTION
|
||||
|
||||
for metric, function in supported_metrics.items():
|
||||
index.distance_metric = metric
|
||||
assert index.get_pgvector_search_function() == function
|
||||
|
||||
def test_check_distance_metric_availability(self, pgvector_index):
|
||||
index, cursor = pgvector_index
|
||||
supported_metrics = index.PGVECTOR_DISTANCE_METRIC_TO_SEARCH_FUNCTION
|
||||
|
||||
for metric in supported_metrics:
|
||||
index.check_distance_metric_availability(metric)
|
||||
|
||||
with pytest.raises(ValueError, match="Distance metric 'INVALID' is not supported"):
|
||||
index.check_distance_metric_availability("INVALID")
|
||||
|
||||
def test_constructor_invalid_distance_metric(self, embedding_dimension, mock_psycopg2_connection):
|
||||
connection, cursor = mock_psycopg2_connection
|
||||
|
||||
vector_db = VectorDB(
|
||||
identifier="test-vector-db",
|
||||
embedding_model="test-model",
|
||||
embedding_dimension=embedding_dimension,
|
||||
provider_id=PGVECTOR_PROVIDER,
|
||||
provider_resource_id=f"{PGVECTOR_PROVIDER}:test-vector-db",
|
||||
)
|
||||
|
||||
with patch("llama_stack.providers.remote.vector_io.pgvector.pgvector.psycopg2"):
|
||||
with pytest.raises(ValueError, match="Distance metric 'INVALID_METRIC' is not supported by PGVector"):
|
||||
PGVectorIndex(vector_db, embedding_dimension, connection, distance_metric="INVALID_METRIC")
|
||||
|
||||
with pytest.raises(ValueError, match="Supported metrics are:"):
|
||||
PGVectorIndex(vector_db, embedding_dimension, connection, distance_metric="UNKNOWN")
|
||||
|
||||
try:
|
||||
index = PGVectorIndex(vector_db, embedding_dimension, connection, distance_metric="COSINE")
|
||||
assert index.distance_metric == "COSINE"
|
||||
except ValueError:
|
||||
pytest.fail("Valid distance metric 'COSINE' should not raise ValueError")
|
||||
|
||||
def test_constructor_all_supported_distance_metrics(self, embedding_dimension, mock_psycopg2_connection):
|
||||
connection, cursor = mock_psycopg2_connection
|
||||
|
||||
vector_db = VectorDB(
|
||||
identifier="test-vector-db",
|
||||
embedding_model="test-model",
|
||||
embedding_dimension=embedding_dimension,
|
||||
provider_id=PGVECTOR_PROVIDER,
|
||||
provider_resource_id=f"{PGVECTOR_PROVIDER}:test-vector-db",
|
||||
)
|
||||
|
||||
supported_metrics = ["L2", "L1", "COSINE", "INNER_PRODUCT", "HAMMING", "JACCARD"]
|
||||
|
||||
with patch("llama_stack.providers.remote.vector_io.pgvector.pgvector.psycopg2"):
|
||||
for metric in supported_metrics:
|
||||
try:
|
||||
index = PGVectorIndex(vector_db, embedding_dimension, connection, distance_metric=metric)
|
||||
assert index.distance_metric == metric
|
||||
|
||||
expected_operators = {
|
||||
"L2": "<->",
|
||||
"L1": "<+>",
|
||||
"COSINE": "<=>",
|
||||
"INNER_PRODUCT": "<#>",
|
||||
"HAMMING": "<~>",
|
||||
"JACCARD": "<%>",
|
||||
}
|
||||
assert index.get_pgvector_search_function() == expected_operators[metric]
|
||||
except Exception as e:
|
||||
pytest.fail(f"Valid distance metric '{metric}' should not raise exception: {e}")
|
||||
|
|
@ -11,8 +11,8 @@ import numpy as np
|
|||
import pytest
|
||||
|
||||
from llama_stack.apis.files import Files
|
||||
from llama_stack.apis.vector_dbs import VectorDB
|
||||
from llama_stack.apis.vector_io import Chunk, QueryChunksResponse
|
||||
from llama_stack.apis.vector_stores import VectorStore
|
||||
from llama_stack.providers.datatypes import HealthStatus
|
||||
from llama_stack.providers.inline.vector_io.faiss.config import FaissVectorIOConfig
|
||||
from llama_stack.providers.inline.vector_io.faiss.faiss import (
|
||||
|
|
@ -39,12 +39,12 @@ def loop():
|
|||
|
||||
@pytest.fixture
|
||||
def embedding_dimension():
|
||||
return 384
|
||||
return 768
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def vector_db_id():
|
||||
return "test_vector_db"
|
||||
def vector_store_id():
|
||||
return "test_vector_store"
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
|
|
@ -61,12 +61,12 @@ def sample_embeddings(embedding_dimension):
|
|||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_vector_db(vector_db_id, embedding_dimension) -> MagicMock:
|
||||
mock_vector_db = MagicMock(spec=VectorDB)
|
||||
mock_vector_db.embedding_model = "mock_embedding_model"
|
||||
mock_vector_db.identifier = vector_db_id
|
||||
mock_vector_db.embedding_dimension = embedding_dimension
|
||||
return mock_vector_db
|
||||
def mock_vector_store(vector_store_id, embedding_dimension) -> MagicMock:
|
||||
mock_vector_store = MagicMock(spec=VectorStore)
|
||||
mock_vector_store.embedding_model = "mock_embedding_model"
|
||||
mock_vector_store.identifier = vector_store_id
|
||||
mock_vector_store.embedding_dimension = embedding_dimension
|
||||
return mock_vector_store
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
|
|
|
|||
|
|
@ -1,147 +0,0 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import asyncio
|
||||
import os
|
||||
from typing import Any
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from llama_stack.apis.inference import Inference
|
||||
from llama_stack.apis.inference.inference import OpenAIEmbeddingData, OpenAIEmbeddingsResponse, OpenAIEmbeddingUsage
|
||||
from llama_stack.apis.vector_io import (
|
||||
QueryChunksResponse,
|
||||
VectorDB,
|
||||
VectorDBStore,
|
||||
)
|
||||
from llama_stack.providers.inline.vector_io.qdrant.config import (
|
||||
QdrantVectorIOConfig as InlineQdrantVectorIOConfig,
|
||||
)
|
||||
from llama_stack.providers.remote.vector_io.qdrant.qdrant import (
|
||||
QdrantVectorIOAdapter,
|
||||
)
|
||||
from llama_stack.providers.utils.kvstore.config import SqliteKVStoreConfig
|
||||
|
||||
# This test is a unit test for the QdrantVectorIOAdapter class. This should only contain
|
||||
# tests which are specific to this class. More general (API-level) tests should be placed in
|
||||
# tests/integration/vector_io/
|
||||
#
|
||||
# How to run this test:
|
||||
#
|
||||
# pytest tests/unit/providers/vector_io/test_qdrant.py \
|
||||
# -v -s --tb=short --disable-warnings --asyncio-mode=auto
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def qdrant_config(tmp_path) -> InlineQdrantVectorIOConfig:
|
||||
kvstore_config = SqliteKVStoreConfig(db_name=os.path.join(tmp_path, "test_kvstore.db"))
|
||||
return InlineQdrantVectorIOConfig(path=os.path.join(tmp_path, "qdrant.db"), kvstore=kvstore_config)
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def loop():
|
||||
return asyncio.new_event_loop()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_vector_db(vector_db_id) -> MagicMock:
|
||||
mock_vector_db = MagicMock(spec=VectorDB)
|
||||
mock_vector_db.embedding_model = "embedding_model"
|
||||
mock_vector_db.identifier = vector_db_id
|
||||
mock_vector_db.embedding_dimension = 384
|
||||
mock_vector_db.model_dump_json.return_value = (
|
||||
'{"identifier": "'
|
||||
+ vector_db_id
|
||||
+ '", "provider_id": "qdrant", "embedding_model": "embedding_model", "embedding_dimension": 384}'
|
||||
)
|
||||
return mock_vector_db
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_vector_db_store(mock_vector_db) -> MagicMock:
|
||||
mock_store = MagicMock(spec=VectorDBStore)
|
||||
mock_store.get_vector_db = AsyncMock(return_value=mock_vector_db)
|
||||
return mock_store
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_api_service(sample_embeddings):
|
||||
mock_api_service = MagicMock(spec=Inference)
|
||||
mock_api_service.openai_embeddings = AsyncMock(
|
||||
return_value=OpenAIEmbeddingsResponse(
|
||||
model="mock-embedding-model",
|
||||
data=[OpenAIEmbeddingData(embedding=sample, index=i) for i, sample in enumerate(sample_embeddings)],
|
||||
usage=OpenAIEmbeddingUsage(prompt_tokens=10, total_tokens=10),
|
||||
)
|
||||
)
|
||||
return mock_api_service
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def qdrant_adapter(qdrant_config, mock_vector_db_store, mock_api_service, loop) -> QdrantVectorIOAdapter:
|
||||
adapter = QdrantVectorIOAdapter(config=qdrant_config, inference_api=mock_api_service, files_api=None)
|
||||
adapter.vector_db_store = mock_vector_db_store
|
||||
await adapter.initialize()
|
||||
yield adapter
|
||||
await adapter.shutdown()
|
||||
|
||||
|
||||
__QUERY = "Sample query"
|
||||
|
||||
|
||||
@pytest.mark.parametrize("max_query_chunks, expected_chunks", [(2, 2), (100, 60)])
|
||||
async def test_qdrant_adapter_returns_expected_chunks(
|
||||
qdrant_adapter: QdrantVectorIOAdapter,
|
||||
vector_db_id,
|
||||
sample_chunks,
|
||||
sample_embeddings,
|
||||
max_query_chunks,
|
||||
expected_chunks,
|
||||
) -> None:
|
||||
assert qdrant_adapter is not None
|
||||
await qdrant_adapter.insert_chunks(vector_db_id, sample_chunks)
|
||||
|
||||
index = await qdrant_adapter._get_and_cache_vector_db_index(vector_db_id=vector_db_id)
|
||||
assert index is not None
|
||||
|
||||
response = await qdrant_adapter.query_chunks(
|
||||
query=__QUERY,
|
||||
vector_db_id=vector_db_id,
|
||||
params={"max_chunks": max_query_chunks, "mode": "vector"},
|
||||
)
|
||||
assert isinstance(response, QueryChunksResponse)
|
||||
assert len(response.chunks) == expected_chunks
|
||||
|
||||
|
||||
# To by-pass attempt to convert a Mock to JSON
|
||||
def _prepare_for_json(value: Any) -> str:
|
||||
return str(value)
|
||||
|
||||
|
||||
@patch("llama_stack.providers.utils.telemetry.trace_protocol._prepare_for_json", new=_prepare_for_json)
|
||||
async def test_qdrant_register_and_unregister_vector_db(
|
||||
qdrant_adapter: QdrantVectorIOAdapter,
|
||||
mock_vector_db,
|
||||
sample_chunks,
|
||||
) -> None:
|
||||
# Initially, no collections
|
||||
vector_db_id = mock_vector_db.identifier
|
||||
assert len((await qdrant_adapter.client.get_collections()).collections) == 0
|
||||
|
||||
# Register does not create a collection
|
||||
assert not (await qdrant_adapter.client.collection_exists(vector_db_id))
|
||||
await qdrant_adapter.register_vector_db(mock_vector_db)
|
||||
assert not (await qdrant_adapter.client.collection_exists(vector_db_id))
|
||||
|
||||
# First insert creates the collection
|
||||
await qdrant_adapter.insert_chunks(vector_db_id, sample_chunks)
|
||||
assert await qdrant_adapter.client.collection_exists(vector_db_id)
|
||||
|
||||
# Unregister deletes the collection
|
||||
await qdrant_adapter.unregister_vector_db(vector_db_id)
|
||||
assert not (await qdrant_adapter.client.collection_exists(vector_db_id))
|
||||
assert len((await qdrant_adapter.client.get_collections()).collections) == 0
|
||||
|
|
@ -12,14 +12,16 @@ import numpy as np
|
|||
import pytest
|
||||
|
||||
from llama_stack.apis.common.errors import VectorStoreNotFoundError
|
||||
from llama_stack.apis.vector_dbs import VectorDB
|
||||
from llama_stack.apis.vector_io import (
|
||||
Chunk,
|
||||
OpenAICreateVectorStoreFileBatchRequestWithExtraBody,
|
||||
OpenAICreateVectorStoreRequestWithExtraBody,
|
||||
QueryChunksResponse,
|
||||
VectorStoreChunkingStrategyAuto,
|
||||
VectorStoreFileObject,
|
||||
)
|
||||
from llama_stack.providers.remote.vector_io.milvus.milvus import VECTOR_DBS_PREFIX
|
||||
from llama_stack.apis.vector_stores import VectorStore
|
||||
from llama_stack.providers.inline.vector_io.sqlite_vec.sqlite_vec import VECTOR_DBS_PREFIX
|
||||
|
||||
# This test is a unit test for the inline VectorIO providers. This should only contain
|
||||
# tests which are specific to this class. More general (API-level) tests should be placed in
|
||||
|
|
@ -69,7 +71,7 @@ async def test_chunk_id_conflict(vector_index, sample_chunks, embedding_dimensio
|
|||
|
||||
async def test_initialize_adapter_with_existing_kvstore(vector_io_adapter):
|
||||
key = f"{VECTOR_DBS_PREFIX}db1"
|
||||
dummy = VectorDB(
|
||||
dummy = VectorStore(
|
||||
identifier="foo_db", provider_id="test_provider", embedding_model="test_model", embedding_dimension=128
|
||||
)
|
||||
await vector_io_adapter.kvstore.set(key=key, value=json.dumps(dummy.model_dump()))
|
||||
|
|
@ -79,10 +81,10 @@ async def test_initialize_adapter_with_existing_kvstore(vector_io_adapter):
|
|||
|
||||
async def test_persistence_across_adapter_restarts(vector_io_adapter):
|
||||
await vector_io_adapter.initialize()
|
||||
dummy = VectorDB(
|
||||
dummy = VectorStore(
|
||||
identifier="foo_db", provider_id="test_provider", embedding_model="test_model", embedding_dimension=128
|
||||
)
|
||||
await vector_io_adapter.register_vector_db(dummy)
|
||||
await vector_io_adapter.register_vector_store(dummy)
|
||||
await vector_io_adapter.shutdown()
|
||||
|
||||
await vector_io_adapter.initialize()
|
||||
|
|
@ -90,26 +92,22 @@ async def test_persistence_across_adapter_restarts(vector_io_adapter):
|
|||
await vector_io_adapter.shutdown()
|
||||
|
||||
|
||||
async def test_register_and_unregister_vector_db(vector_io_adapter):
|
||||
async def test_register_and_unregister_vector_store(vector_io_adapter):
|
||||
unique_id = f"foo_db_{np.random.randint(1e6)}"
|
||||
dummy = VectorDB(
|
||||
dummy = VectorStore(
|
||||
identifier=unique_id, provider_id="test_provider", embedding_model="test_model", embedding_dimension=128
|
||||
)
|
||||
|
||||
await vector_io_adapter.register_vector_db(dummy)
|
||||
await vector_io_adapter.register_vector_store(dummy)
|
||||
assert dummy.identifier in vector_io_adapter.cache
|
||||
await vector_io_adapter.unregister_vector_db(dummy.identifier)
|
||||
await vector_io_adapter.unregister_vector_store(dummy.identifier)
|
||||
assert dummy.identifier not in vector_io_adapter.cache
|
||||
|
||||
|
||||
async def test_query_unregistered_raises(vector_io_adapter, vector_provider):
|
||||
fake_emb = np.zeros(8, dtype=np.float32)
|
||||
if vector_provider == "chroma":
|
||||
with pytest.raises(AttributeError):
|
||||
await vector_io_adapter.query_chunks("no_such_db", fake_emb)
|
||||
else:
|
||||
with pytest.raises(ValueError):
|
||||
await vector_io_adapter.query_chunks("no_such_db", fake_emb)
|
||||
with pytest.raises(ValueError):
|
||||
await vector_io_adapter.query_chunks("no_such_db", fake_emb)
|
||||
|
||||
|
||||
async def test_insert_chunks_calls_underlying_index(vector_io_adapter):
|
||||
|
|
@ -123,12 +121,43 @@ async def test_insert_chunks_calls_underlying_index(vector_io_adapter):
|
|||
|
||||
|
||||
async def test_insert_chunks_missing_db_raises(vector_io_adapter):
|
||||
vector_io_adapter._get_and_cache_vector_db_index = AsyncMock(return_value=None)
|
||||
vector_io_adapter._get_and_cache_vector_store_index = AsyncMock(return_value=None)
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
await vector_io_adapter.insert_chunks("db_not_exist", [])
|
||||
|
||||
|
||||
async def test_insert_chunks_with_missing_document_id(vector_io_adapter):
|
||||
"""Ensure no KeyError when document_id is missing or in different places."""
|
||||
from llama_stack.apis.vector_io import Chunk, ChunkMetadata
|
||||
|
||||
fake_index = AsyncMock()
|
||||
vector_io_adapter.cache["db1"] = fake_index
|
||||
|
||||
# Various document_id scenarios that shouldn't crash
|
||||
chunks = [
|
||||
Chunk(content="has doc_id in metadata", metadata={"document_id": "doc-1"}),
|
||||
Chunk(content="no doc_id anywhere", metadata={"source": "test"}),
|
||||
Chunk(content="doc_id in chunk_metadata", chunk_metadata=ChunkMetadata(document_id="doc-3")),
|
||||
]
|
||||
|
||||
# Should work without KeyError
|
||||
await vector_io_adapter.insert_chunks("db1", chunks)
|
||||
fake_index.insert_chunks.assert_awaited_once()
|
||||
|
||||
|
||||
async def test_document_id_with_invalid_type_raises_error():
|
||||
"""Ensure TypeError is raised when document_id is not a string."""
|
||||
from llama_stack.apis.vector_io import Chunk
|
||||
|
||||
# Integer document_id should raise TypeError
|
||||
chunk = Chunk(content="test", metadata={"document_id": 12345})
|
||||
with pytest.raises(TypeError) as exc_info:
|
||||
_ = chunk.document_id
|
||||
assert "metadata['document_id'] must be a string" in str(exc_info.value)
|
||||
assert "got int" in str(exc_info.value)
|
||||
|
||||
|
||||
async def test_query_chunks_calls_underlying_index_and_returns(vector_io_adapter):
|
||||
expected = QueryChunksResponse(chunks=[Chunk(content="c1")], scores=[0.1])
|
||||
fake_index = AsyncMock(query_chunks=AsyncMock(return_value=expected))
|
||||
|
|
@ -141,7 +170,7 @@ async def test_query_chunks_calls_underlying_index_and_returns(vector_io_adapter
|
|||
|
||||
|
||||
async def test_query_chunks_missing_db_raises(vector_io_adapter):
|
||||
vector_io_adapter._get_and_cache_vector_db_index = AsyncMock(return_value=None)
|
||||
vector_io_adapter._get_and_cache_vector_store_index = AsyncMock(return_value=None)
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
await vector_io_adapter.query_chunks("db_missing", "q", None)
|
||||
|
|
@ -153,7 +182,7 @@ async def test_save_openai_vector_store(vector_io_adapter):
|
|||
"id": store_id,
|
||||
"name": "Test Store",
|
||||
"description": "A test OpenAI vector store",
|
||||
"vector_db_id": "test_db",
|
||||
"vector_store_id": "test_db",
|
||||
"embedding_model": "test_model",
|
||||
}
|
||||
|
||||
|
|
@ -169,7 +198,7 @@ async def test_update_openai_vector_store(vector_io_adapter):
|
|||
"id": store_id,
|
||||
"name": "Test Store",
|
||||
"description": "A test OpenAI vector store",
|
||||
"vector_db_id": "test_db",
|
||||
"vector_store_id": "test_db",
|
||||
"embedding_model": "test_model",
|
||||
}
|
||||
|
||||
|
|
@ -185,7 +214,7 @@ async def test_delete_openai_vector_store(vector_io_adapter):
|
|||
"id": store_id,
|
||||
"name": "Test Store",
|
||||
"description": "A test OpenAI vector store",
|
||||
"vector_db_id": "test_db",
|
||||
"vector_store_id": "test_db",
|
||||
"embedding_model": "test_model",
|
||||
}
|
||||
|
||||
|
|
@ -200,7 +229,7 @@ async def test_load_openai_vector_stores(vector_io_adapter):
|
|||
"id": store_id,
|
||||
"name": "Test Store",
|
||||
"description": "A test OpenAI vector store",
|
||||
"vector_db_id": "test_db",
|
||||
"vector_store_id": "test_db",
|
||||
"embedding_model": "test_model",
|
||||
}
|
||||
|
||||
|
|
@ -330,8 +359,7 @@ async def test_create_vector_store_file_batch(vector_io_adapter):
|
|||
vector_io_adapter._process_file_batch_async = AsyncMock()
|
||||
|
||||
batch = await vector_io_adapter.openai_create_vector_store_file_batch(
|
||||
vector_store_id=store_id,
|
||||
file_ids=file_ids,
|
||||
vector_store_id=store_id, params=OpenAICreateVectorStoreFileBatchRequestWithExtraBody(file_ids=file_ids)
|
||||
)
|
||||
|
||||
assert batch.vector_store_id == store_id
|
||||
|
|
@ -358,8 +386,7 @@ async def test_retrieve_vector_store_file_batch(vector_io_adapter):
|
|||
|
||||
# Create batch first
|
||||
created_batch = await vector_io_adapter.openai_create_vector_store_file_batch(
|
||||
vector_store_id=store_id,
|
||||
file_ids=file_ids,
|
||||
vector_store_id=store_id, params=OpenAICreateVectorStoreFileBatchRequestWithExtraBody(file_ids=file_ids)
|
||||
)
|
||||
|
||||
# Retrieve batch
|
||||
|
|
@ -392,8 +419,7 @@ async def test_cancel_vector_store_file_batch(vector_io_adapter):
|
|||
|
||||
# Create batch
|
||||
batch = await vector_io_adapter.openai_create_vector_store_file_batch(
|
||||
vector_store_id=store_id,
|
||||
file_ids=file_ids,
|
||||
vector_store_id=store_id, params=OpenAICreateVectorStoreFileBatchRequestWithExtraBody(file_ids=file_ids)
|
||||
)
|
||||
|
||||
# Cancel batch
|
||||
|
|
@ -438,8 +464,7 @@ async def test_list_files_in_vector_store_file_batch(vector_io_adapter):
|
|||
|
||||
# Create batch
|
||||
batch = await vector_io_adapter.openai_create_vector_store_file_batch(
|
||||
vector_store_id=store_id,
|
||||
file_ids=file_ids,
|
||||
vector_store_id=store_id, params=OpenAICreateVectorStoreFileBatchRequestWithExtraBody(file_ids=file_ids)
|
||||
)
|
||||
|
||||
# List files
|
||||
|
|
@ -459,7 +484,7 @@ async def test_file_batch_validation_errors(vector_io_adapter):
|
|||
with pytest.raises(VectorStoreNotFoundError):
|
||||
await vector_io_adapter.openai_create_vector_store_file_batch(
|
||||
vector_store_id="nonexistent",
|
||||
file_ids=["file_1"],
|
||||
params=OpenAICreateVectorStoreFileBatchRequestWithExtraBody(file_ids=["file_1"]),
|
||||
)
|
||||
|
||||
# Setup store for remaining tests
|
||||
|
|
@ -476,8 +501,7 @@ async def test_file_batch_validation_errors(vector_io_adapter):
|
|||
# Test wrong vector store for batch
|
||||
vector_io_adapter.openai_attach_file_to_vector_store = AsyncMock()
|
||||
batch = await vector_io_adapter.openai_create_vector_store_file_batch(
|
||||
vector_store_id=store_id,
|
||||
file_ids=["file_1"],
|
||||
vector_store_id=store_id, params=OpenAICreateVectorStoreFileBatchRequestWithExtraBody(file_ids=["file_1"])
|
||||
)
|
||||
|
||||
# Create wrong_store so it exists but the batch doesn't belong to it
|
||||
|
|
@ -524,8 +548,7 @@ async def test_file_batch_pagination(vector_io_adapter):
|
|||
|
||||
# Create batch
|
||||
batch = await vector_io_adapter.openai_create_vector_store_file_batch(
|
||||
vector_store_id=store_id,
|
||||
file_ids=file_ids,
|
||||
vector_store_id=store_id, params=OpenAICreateVectorStoreFileBatchRequestWithExtraBody(file_ids=file_ids)
|
||||
)
|
||||
|
||||
# Test pagination with limit
|
||||
|
|
@ -597,8 +620,7 @@ async def test_file_batch_status_filtering(vector_io_adapter):
|
|||
|
||||
# Create batch
|
||||
batch = await vector_io_adapter.openai_create_vector_store_file_batch(
|
||||
vector_store_id=store_id,
|
||||
file_ids=file_ids,
|
||||
vector_store_id=store_id, params=OpenAICreateVectorStoreFileBatchRequestWithExtraBody(file_ids=file_ids)
|
||||
)
|
||||
|
||||
# Test filtering by completed status
|
||||
|
|
@ -640,8 +662,7 @@ async def test_cancel_completed_batch_fails(vector_io_adapter):
|
|||
|
||||
# Create batch
|
||||
batch = await vector_io_adapter.openai_create_vector_store_file_batch(
|
||||
vector_store_id=store_id,
|
||||
file_ids=file_ids,
|
||||
vector_store_id=store_id, params=OpenAICreateVectorStoreFileBatchRequestWithExtraBody(file_ids=file_ids)
|
||||
)
|
||||
|
||||
# Manually update status to completed
|
||||
|
|
@ -675,8 +696,7 @@ async def test_file_batch_persistence_across_restarts(vector_io_adapter):
|
|||
|
||||
# Create batch
|
||||
batch = await vector_io_adapter.openai_create_vector_store_file_batch(
|
||||
vector_store_id=store_id,
|
||||
file_ids=file_ids,
|
||||
vector_store_id=store_id, params=OpenAICreateVectorStoreFileBatchRequestWithExtraBody(file_ids=file_ids)
|
||||
)
|
||||
batch_id = batch.id
|
||||
|
||||
|
|
@ -731,8 +751,7 @@ async def test_cancelled_batch_persists_in_storage(vector_io_adapter):
|
|||
|
||||
# Create batch
|
||||
batch = await vector_io_adapter.openai_create_vector_store_file_batch(
|
||||
vector_store_id=store_id,
|
||||
file_ids=file_ids,
|
||||
vector_store_id=store_id, params=OpenAICreateVectorStoreFileBatchRequestWithExtraBody(file_ids=file_ids)
|
||||
)
|
||||
batch_id = batch.id
|
||||
|
||||
|
|
@ -779,10 +798,10 @@ async def test_only_in_progress_batches_resumed(vector_io_adapter):
|
|||
|
||||
# Create multiple batches
|
||||
batch1 = await vector_io_adapter.openai_create_vector_store_file_batch(
|
||||
vector_store_id=store_id, file_ids=["file_1"]
|
||||
vector_store_id=store_id, params=OpenAICreateVectorStoreFileBatchRequestWithExtraBody(file_ids=["file_1"])
|
||||
)
|
||||
batch2 = await vector_io_adapter.openai_create_vector_store_file_batch(
|
||||
vector_store_id=store_id, file_ids=["file_2"]
|
||||
vector_store_id=store_id, params=OpenAICreateVectorStoreFileBatchRequestWithExtraBody(file_ids=["file_2"])
|
||||
)
|
||||
|
||||
# Complete one batch (should persist with completed status)
|
||||
|
|
@ -795,7 +814,7 @@ async def test_only_in_progress_batches_resumed(vector_io_adapter):
|
|||
|
||||
# Create a third batch that stays in progress
|
||||
batch3 = await vector_io_adapter.openai_create_vector_store_file_batch(
|
||||
vector_store_id=store_id, file_ids=["file_3"]
|
||||
vector_store_id=store_id, params=OpenAICreateVectorStoreFileBatchRequestWithExtraBody(file_ids=["file_3"])
|
||||
)
|
||||
|
||||
# Simulate restart - clear memory and reload from persistence
|
||||
|
|
@ -956,8 +975,7 @@ async def test_max_concurrent_files_per_batch(vector_io_adapter):
|
|||
file_ids = [f"file_{i}" for i in range(8)] # 8 files, but limit should be 5
|
||||
|
||||
batch = await vector_io_adapter.openai_create_vector_store_file_batch(
|
||||
vector_store_id=store_id,
|
||||
file_ids=file_ids,
|
||||
vector_store_id=store_id, params=OpenAICreateVectorStoreFileBatchRequestWithExtraBody(file_ids=file_ids)
|
||||
)
|
||||
|
||||
# Give time for the semaphore logic to start processing files
|
||||
|
|
@ -975,3 +993,130 @@ async def test_max_concurrent_files_per_batch(vector_io_adapter):
|
|||
assert batch.status == "in_progress"
|
||||
assert batch.file_counts.total == 8
|
||||
assert batch.file_counts.in_progress == 8
|
||||
|
||||
|
||||
async def test_embedding_config_from_metadata(vector_io_adapter):
|
||||
"""Test that embedding configuration is correctly extracted from metadata."""
|
||||
|
||||
# Mock register_vector_store to avoid actual registration
|
||||
vector_io_adapter.register_vector_store = AsyncMock()
|
||||
# Set provider_id attribute for the adapter
|
||||
vector_io_adapter.__provider_id__ = "test_provider"
|
||||
|
||||
# Test with embedding config in metadata
|
||||
params = OpenAICreateVectorStoreRequestWithExtraBody(
|
||||
name="test_store",
|
||||
metadata={
|
||||
"embedding_model": "test-embedding-model",
|
||||
"embedding_dimension": "512",
|
||||
},
|
||||
model_extra={},
|
||||
)
|
||||
|
||||
await vector_io_adapter.openai_create_vector_store(params)
|
||||
|
||||
# Verify VectorStore was registered with correct embedding config from metadata
|
||||
vector_io_adapter.register_vector_store.assert_called_once()
|
||||
call_args = vector_io_adapter.register_vector_store.call_args[0][0]
|
||||
assert call_args.embedding_model == "test-embedding-model"
|
||||
assert call_args.embedding_dimension == 512
|
||||
|
||||
|
||||
async def test_embedding_config_from_extra_body(vector_io_adapter):
|
||||
"""Test that embedding configuration is correctly extracted from extra_body when metadata is empty."""
|
||||
|
||||
# Mock register_vector_store to avoid actual registration
|
||||
vector_io_adapter.register_vector_store = AsyncMock()
|
||||
# Set provider_id attribute for the adapter
|
||||
vector_io_adapter.__provider_id__ = "test_provider"
|
||||
|
||||
# Test with embedding config in extra_body only (metadata has no embedding_model)
|
||||
params = OpenAICreateVectorStoreRequestWithExtraBody(
|
||||
name="test_store",
|
||||
metadata={}, # Empty metadata to ensure extra_body is used
|
||||
**{
|
||||
"embedding_model": "extra-body-model",
|
||||
"embedding_dimension": 1024,
|
||||
},
|
||||
)
|
||||
|
||||
await vector_io_adapter.openai_create_vector_store(params)
|
||||
|
||||
# Verify VectorStore was registered with correct embedding config from extra_body
|
||||
vector_io_adapter.register_vector_store.assert_called_once()
|
||||
call_args = vector_io_adapter.register_vector_store.call_args[0][0]
|
||||
assert call_args.embedding_model == "extra-body-model"
|
||||
assert call_args.embedding_dimension == 1024
|
||||
|
||||
|
||||
async def test_embedding_config_consistency_check_passes(vector_io_adapter):
|
||||
"""Test that consistent embedding config in both metadata and extra_body passes validation."""
|
||||
|
||||
# Mock register_vector_store to avoid actual registration
|
||||
vector_io_adapter.register_vector_store = AsyncMock()
|
||||
# Set provider_id attribute for the adapter
|
||||
vector_io_adapter.__provider_id__ = "test_provider"
|
||||
|
||||
# Test with consistent embedding config in both metadata and extra_body
|
||||
params = OpenAICreateVectorStoreRequestWithExtraBody(
|
||||
name="test_store",
|
||||
metadata={
|
||||
"embedding_model": "consistent-model",
|
||||
"embedding_dimension": "768",
|
||||
},
|
||||
**{
|
||||
"embedding_model": "consistent-model",
|
||||
"embedding_dimension": 768,
|
||||
},
|
||||
)
|
||||
|
||||
await vector_io_adapter.openai_create_vector_store(params)
|
||||
|
||||
# Should not raise any error and use metadata config
|
||||
vector_io_adapter.register_vector_store.assert_called_once()
|
||||
call_args = vector_io_adapter.register_vector_store.call_args[0][0]
|
||||
assert call_args.embedding_model == "consistent-model"
|
||||
assert call_args.embedding_dimension == 768
|
||||
|
||||
|
||||
async def test_embedding_config_defaults_when_missing(vector_io_adapter):
|
||||
"""Test that embedding dimension defaults to 768 when not provided."""
|
||||
|
||||
# Mock register_vector_store to avoid actual registration
|
||||
vector_io_adapter.register_vector_store = AsyncMock()
|
||||
# Set provider_id attribute for the adapter
|
||||
vector_io_adapter.__provider_id__ = "test_provider"
|
||||
|
||||
# Test with only embedding model, no dimension (metadata empty to use extra_body)
|
||||
params = OpenAICreateVectorStoreRequestWithExtraBody(
|
||||
name="test_store",
|
||||
metadata={}, # Empty metadata to ensure extra_body is used
|
||||
**{
|
||||
"embedding_model": "model-without-dimension",
|
||||
},
|
||||
)
|
||||
|
||||
await vector_io_adapter.openai_create_vector_store(params)
|
||||
|
||||
# Should default to 768 dimensions
|
||||
vector_io_adapter.register_vector_store.assert_called_once()
|
||||
call_args = vector_io_adapter.register_vector_store.call_args[0][0]
|
||||
assert call_args.embedding_model == "model-without-dimension"
|
||||
assert call_args.embedding_dimension == 768
|
||||
|
||||
|
||||
async def test_embedding_config_required_model_missing(vector_io_adapter):
|
||||
"""Test that missing embedding model raises error."""
|
||||
|
||||
# Mock register_vector_store to avoid actual registration
|
||||
vector_io_adapter.register_vector_store = AsyncMock()
|
||||
# Set provider_id attribute for the adapter
|
||||
vector_io_adapter.__provider_id__ = "test_provider"
|
||||
# Mock the default model lookup to return None (no default model available)
|
||||
vector_io_adapter._get_default_embedding_model_and_dimension = AsyncMock(return_value=None)
|
||||
|
||||
# Test with no embedding model provided
|
||||
params = OpenAICreateVectorStoreRequestWithExtraBody(name="test_store", metadata={})
|
||||
|
||||
with pytest.raises(ValueError, match="embedding_model is required"):
|
||||
await vector_io_adapter.openai_create_vector_store(params)
|
||||
|
|
|
|||
|
|
@ -18,7 +18,7 @@ from llama_stack.providers.inline.tool_runtime.rag.memory import MemoryToolRunti
|
|||
|
||||
|
||||
class TestRagQuery:
|
||||
async def test_query_raises_on_empty_vector_db_ids(self):
|
||||
async def test_query_raises_on_empty_vector_store_ids(self):
|
||||
rag_tool = MemoryToolRuntimeImpl(
|
||||
config=MagicMock(), vector_io_api=MagicMock(), inference_api=MagicMock(), files_api=MagicMock()
|
||||
)
|
||||
|
|
@ -82,7 +82,7 @@ class TestRagQuery:
|
|||
with pytest.raises(ValueError):
|
||||
RAGQueryConfig(mode="wrong_mode")
|
||||
|
||||
async def test_query_adds_vector_db_id_to_chunk_metadata(self):
|
||||
async def test_query_adds_vector_store_id_to_chunk_metadata(self):
|
||||
rag_tool = MemoryToolRuntimeImpl(
|
||||
config=MagicMock(),
|
||||
vector_io_api=MagicMock(),
|
||||
|
|
|
|||
|
|
@ -13,12 +13,15 @@ from unittest.mock import AsyncMock, MagicMock
|
|||
import numpy as np
|
||||
import pytest
|
||||
|
||||
from llama_stack.apis.inference.inference import OpenAIEmbeddingData
|
||||
from llama_stack.apis.inference.inference import (
|
||||
OpenAIEmbeddingData,
|
||||
OpenAIEmbeddingsRequestWithExtraBody,
|
||||
)
|
||||
from llama_stack.apis.tools import RAGDocument
|
||||
from llama_stack.apis.vector_io import Chunk
|
||||
from llama_stack.providers.utils.memory.vector_store import (
|
||||
URL,
|
||||
VectorDBWithIndex,
|
||||
VectorStoreWithIndex,
|
||||
_validate_embedding,
|
||||
content_from_doc,
|
||||
make_overlapped_chunks,
|
||||
|
|
@ -203,15 +206,15 @@ class TestVectorStore:
|
|||
assert str(excinfo.value.__cause__) == "Cannot convert to string"
|
||||
|
||||
|
||||
class TestVectorDBWithIndex:
|
||||
class TestVectorStoreWithIndex:
|
||||
async def test_insert_chunks_without_embeddings(self):
|
||||
mock_vector_db = MagicMock()
|
||||
mock_vector_db.embedding_model = "test-model without embeddings"
|
||||
mock_vector_store = MagicMock()
|
||||
mock_vector_store.embedding_model = "test-model without embeddings"
|
||||
mock_index = AsyncMock()
|
||||
mock_inference_api = AsyncMock()
|
||||
|
||||
vector_db_with_index = VectorDBWithIndex(
|
||||
vector_db=mock_vector_db, index=mock_index, inference_api=mock_inference_api
|
||||
vector_store_with_index = VectorStoreWithIndex(
|
||||
vector_store=mock_vector_store, index=mock_index, inference_api=mock_inference_api
|
||||
)
|
||||
|
||||
chunks = [
|
||||
|
|
@ -224,25 +227,30 @@ class TestVectorDBWithIndex:
|
|||
OpenAIEmbeddingData(embedding=[0.4, 0.5, 0.6], index=1),
|
||||
]
|
||||
|
||||
await vector_db_with_index.insert_chunks(chunks)
|
||||
await vector_store_with_index.insert_chunks(chunks)
|
||||
|
||||
mock_inference_api.openai_embeddings.assert_called_once_with(
|
||||
"test-model without embeddings", ["Test 1", "Test 2"]
|
||||
)
|
||||
# Verify openai_embeddings was called with correct params
|
||||
mock_inference_api.openai_embeddings.assert_called_once()
|
||||
call_args = mock_inference_api.openai_embeddings.call_args[0]
|
||||
assert len(call_args) == 1
|
||||
params = call_args[0]
|
||||
assert isinstance(params, OpenAIEmbeddingsRequestWithExtraBody)
|
||||
assert params.model == "test-model without embeddings"
|
||||
assert params.input == ["Test 1", "Test 2"]
|
||||
mock_index.add_chunks.assert_called_once()
|
||||
args = mock_index.add_chunks.call_args[0]
|
||||
assert args[0] == chunks
|
||||
assert np.array_equal(args[1], np.array([[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]], dtype=np.float32))
|
||||
|
||||
async def test_insert_chunks_with_valid_embeddings(self):
|
||||
mock_vector_db = MagicMock()
|
||||
mock_vector_db.embedding_model = "test-model with embeddings"
|
||||
mock_vector_db.embedding_dimension = 3
|
||||
mock_vector_store = MagicMock()
|
||||
mock_vector_store.embedding_model = "test-model with embeddings"
|
||||
mock_vector_store.embedding_dimension = 3
|
||||
mock_index = AsyncMock()
|
||||
mock_inference_api = AsyncMock()
|
||||
|
||||
vector_db_with_index = VectorDBWithIndex(
|
||||
vector_db=mock_vector_db, index=mock_index, inference_api=mock_inference_api
|
||||
vector_store_with_index = VectorStoreWithIndex(
|
||||
vector_store=mock_vector_store, index=mock_index, inference_api=mock_inference_api
|
||||
)
|
||||
|
||||
chunks = [
|
||||
|
|
@ -250,7 +258,7 @@ class TestVectorDBWithIndex:
|
|||
Chunk(content="Test 2", embedding=[0.4, 0.5, 0.6], metadata={}),
|
||||
]
|
||||
|
||||
await vector_db_with_index.insert_chunks(chunks)
|
||||
await vector_store_with_index.insert_chunks(chunks)
|
||||
|
||||
mock_inference_api.openai_embeddings.assert_not_called()
|
||||
mock_index.add_chunks.assert_called_once()
|
||||
|
|
@ -259,14 +267,14 @@ class TestVectorDBWithIndex:
|
|||
assert np.array_equal(args[1], np.array([[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]], dtype=np.float32))
|
||||
|
||||
async def test_insert_chunks_with_invalid_embeddings(self):
|
||||
mock_vector_db = MagicMock()
|
||||
mock_vector_db.embedding_dimension = 3
|
||||
mock_vector_db.embedding_model = "test-model with invalid embeddings"
|
||||
mock_vector_store = MagicMock()
|
||||
mock_vector_store.embedding_dimension = 3
|
||||
mock_vector_store.embedding_model = "test-model with invalid embeddings"
|
||||
mock_index = AsyncMock()
|
||||
mock_inference_api = AsyncMock()
|
||||
|
||||
vector_db_with_index = VectorDBWithIndex(
|
||||
vector_db=mock_vector_db, index=mock_index, inference_api=mock_inference_api
|
||||
vector_store_with_index = VectorStoreWithIndex(
|
||||
vector_store=mock_vector_store, index=mock_index, inference_api=mock_inference_api
|
||||
)
|
||||
|
||||
# Verify Chunk raises ValueError for invalid embedding type
|
||||
|
|
@ -275,7 +283,7 @@ class TestVectorDBWithIndex:
|
|||
|
||||
# Verify Chunk raises ValueError for invalid embedding type in insert_chunks (i.e., Chunk errors before insert_chunks is called)
|
||||
with pytest.raises(ValueError, match="Input should be a valid list"):
|
||||
await vector_db_with_index.insert_chunks(
|
||||
await vector_store_with_index.insert_chunks(
|
||||
[
|
||||
Chunk(content="Test 1", embedding=None, metadata={}),
|
||||
Chunk(content="Test 2", embedding="invalid_type", metadata={}),
|
||||
|
|
@ -284,7 +292,7 @@ class TestVectorDBWithIndex:
|
|||
|
||||
# Verify Chunk raises ValueError for invalid embedding element type in insert_chunks (i.e., Chunk errors before insert_chunks is called)
|
||||
with pytest.raises(ValueError, match=" Input should be a valid number, unable to parse string as a number "):
|
||||
await vector_db_with_index.insert_chunks(
|
||||
await vector_store_with_index.insert_chunks(
|
||||
Chunk(content="Test 1", embedding=[0.1, "string", 0.3], metadata={})
|
||||
)
|
||||
|
||||
|
|
@ -292,20 +300,20 @@ class TestVectorDBWithIndex:
|
|||
Chunk(content="Test 1", embedding=[0.1, 0.2, 0.3, 0.4], metadata={}),
|
||||
]
|
||||
with pytest.raises(ValueError, match="has dimension 4, expected 3"):
|
||||
await vector_db_with_index.insert_chunks(chunks_wrong_dim)
|
||||
await vector_store_with_index.insert_chunks(chunks_wrong_dim)
|
||||
|
||||
mock_inference_api.openai_embeddings.assert_not_called()
|
||||
mock_index.add_chunks.assert_not_called()
|
||||
|
||||
async def test_insert_chunks_with_partially_precomputed_embeddings(self):
|
||||
mock_vector_db = MagicMock()
|
||||
mock_vector_db.embedding_model = "test-model with partial embeddings"
|
||||
mock_vector_db.embedding_dimension = 3
|
||||
mock_vector_store = MagicMock()
|
||||
mock_vector_store.embedding_model = "test-model with partial embeddings"
|
||||
mock_vector_store.embedding_dimension = 3
|
||||
mock_index = AsyncMock()
|
||||
mock_inference_api = AsyncMock()
|
||||
|
||||
vector_db_with_index = VectorDBWithIndex(
|
||||
vector_db=mock_vector_db, index=mock_index, inference_api=mock_inference_api
|
||||
vector_store_with_index = VectorStoreWithIndex(
|
||||
vector_store=mock_vector_store, index=mock_index, inference_api=mock_inference_api
|
||||
)
|
||||
|
||||
chunks = [
|
||||
|
|
@ -319,11 +327,16 @@ class TestVectorDBWithIndex:
|
|||
OpenAIEmbeddingData(embedding=[0.3, 0.3, 0.3], index=1),
|
||||
]
|
||||
|
||||
await vector_db_with_index.insert_chunks(chunks)
|
||||
await vector_store_with_index.insert_chunks(chunks)
|
||||
|
||||
mock_inference_api.openai_embeddings.assert_called_once_with(
|
||||
"test-model with partial embeddings", ["Test 1", "Test 3"]
|
||||
)
|
||||
# Verify openai_embeddings was called with correct params
|
||||
mock_inference_api.openai_embeddings.assert_called_once()
|
||||
call_args = mock_inference_api.openai_embeddings.call_args[0]
|
||||
assert len(call_args) == 1
|
||||
params = call_args[0]
|
||||
assert isinstance(params, OpenAIEmbeddingsRequestWithExtraBody)
|
||||
assert params.model == "test-model with partial embeddings"
|
||||
assert params.input == ["Test 1", "Test 3"]
|
||||
mock_index.add_chunks.assert_called_once()
|
||||
args = mock_index.add_chunks.call_args[0]
|
||||
assert len(args[0]) == 3
|
||||
|
|
|
|||
|
|
@ -8,23 +8,24 @@
|
|||
import pytest
|
||||
|
||||
from llama_stack.apis.inference import Model
|
||||
from llama_stack.apis.vector_dbs import VectorDB
|
||||
from llama_stack.apis.vector_stores import VectorStore
|
||||
from llama_stack.core.datatypes import VectorStoreWithOwner
|
||||
from llama_stack.core.storage.datatypes import KVStoreReference, SqliteKVStoreConfig
|
||||
from llama_stack.core.store.registry import (
|
||||
KEY_FORMAT,
|
||||
CachedDiskDistributionRegistry,
|
||||
DiskDistributionRegistry,
|
||||
)
|
||||
from llama_stack.providers.utils.kvstore import kvstore_impl
|
||||
from llama_stack.providers.utils.kvstore.config import SqliteKVStoreConfig
|
||||
from llama_stack.providers.utils.kvstore import kvstore_impl, register_kvstore_backends
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_vector_db():
|
||||
return VectorDB(
|
||||
identifier="test_vector_db",
|
||||
embedding_model="all-MiniLM-L6-v2",
|
||||
embedding_dimension=384,
|
||||
provider_resource_id="test_vector_db",
|
||||
def sample_vector_store():
|
||||
return VectorStore(
|
||||
identifier="test_vector_store",
|
||||
embedding_model="nomic-embed-text-v1.5",
|
||||
embedding_dimension=768,
|
||||
provider_resource_id="test_vector_store",
|
||||
provider_id="test-provider",
|
||||
)
|
||||
|
||||
|
|
@ -44,17 +45,17 @@ async def test_registry_initialization(disk_dist_registry):
|
|||
assert result is None
|
||||
|
||||
|
||||
async def test_basic_registration(disk_dist_registry, sample_vector_db, sample_model):
|
||||
print(f"Registering {sample_vector_db}")
|
||||
await disk_dist_registry.register(sample_vector_db)
|
||||
async def test_basic_registration(disk_dist_registry, sample_vector_store, sample_model):
|
||||
print(f"Registering {sample_vector_store}")
|
||||
await disk_dist_registry.register(sample_vector_store)
|
||||
print(f"Registering {sample_model}")
|
||||
await disk_dist_registry.register(sample_model)
|
||||
print("Getting vector_db")
|
||||
result_vector_db = await disk_dist_registry.get("vector_db", "test_vector_db")
|
||||
assert result_vector_db is not None
|
||||
assert result_vector_db.identifier == sample_vector_db.identifier
|
||||
assert result_vector_db.embedding_model == sample_vector_db.embedding_model
|
||||
assert result_vector_db.provider_id == sample_vector_db.provider_id
|
||||
print("Getting vector_store")
|
||||
result_vector_store = await disk_dist_registry.get("vector_store", "test_vector_store")
|
||||
assert result_vector_store is not None
|
||||
assert result_vector_store.identifier == sample_vector_store.identifier
|
||||
assert result_vector_store.embedding_model == sample_vector_store.embedding_model
|
||||
assert result_vector_store.provider_id == sample_vector_store.provider_id
|
||||
|
||||
result_model = await disk_dist_registry.get("model", "test_model")
|
||||
assert result_model is not None
|
||||
|
|
@ -62,133 +63,137 @@ async def test_basic_registration(disk_dist_registry, sample_vector_db, sample_m
|
|||
assert result_model.provider_id == sample_model.provider_id
|
||||
|
||||
|
||||
async def test_cached_registry_initialization(sqlite_kvstore, sample_vector_db, sample_model):
|
||||
async def test_cached_registry_initialization(sqlite_kvstore, sample_vector_store, sample_model):
|
||||
# First populate the disk registry
|
||||
disk_registry = DiskDistributionRegistry(sqlite_kvstore)
|
||||
await disk_registry.initialize()
|
||||
await disk_registry.register(sample_vector_db)
|
||||
await disk_registry.register(sample_vector_store)
|
||||
await disk_registry.register(sample_model)
|
||||
|
||||
# Test cached version loads from disk
|
||||
db_path = sqlite_kvstore.db_path
|
||||
cached_registry = CachedDiskDistributionRegistry(await kvstore_impl(SqliteKVStoreConfig(db_path=db_path)))
|
||||
backend_name = "kv_cached_test"
|
||||
register_kvstore_backends({backend_name: SqliteKVStoreConfig(db_path=db_path)})
|
||||
cached_registry = CachedDiskDistributionRegistry(
|
||||
await kvstore_impl(KVStoreReference(backend=backend_name, namespace="registry"))
|
||||
)
|
||||
await cached_registry.initialize()
|
||||
|
||||
result_vector_db = await cached_registry.get("vector_db", "test_vector_db")
|
||||
assert result_vector_db is not None
|
||||
assert result_vector_db.identifier == sample_vector_db.identifier
|
||||
assert result_vector_db.embedding_model == sample_vector_db.embedding_model
|
||||
assert result_vector_db.embedding_dimension == sample_vector_db.embedding_dimension
|
||||
assert result_vector_db.provider_id == sample_vector_db.provider_id
|
||||
result_vector_store = await cached_registry.get("vector_store", "test_vector_store")
|
||||
assert result_vector_store is not None
|
||||
assert result_vector_store.identifier == sample_vector_store.identifier
|
||||
assert result_vector_store.embedding_model == sample_vector_store.embedding_model
|
||||
assert result_vector_store.embedding_dimension == sample_vector_store.embedding_dimension
|
||||
assert result_vector_store.provider_id == sample_vector_store.provider_id
|
||||
|
||||
|
||||
async def test_cached_registry_updates(cached_disk_dist_registry):
|
||||
new_vector_db = VectorDB(
|
||||
identifier="test_vector_db_2",
|
||||
embedding_model="all-MiniLM-L6-v2",
|
||||
embedding_dimension=384,
|
||||
provider_resource_id="test_vector_db_2",
|
||||
new_vector_store = VectorStore(
|
||||
identifier="test_vector_store_2",
|
||||
embedding_model="nomic-embed-text-v1.5",
|
||||
embedding_dimension=768,
|
||||
provider_resource_id="test_vector_store_2",
|
||||
provider_id="baz",
|
||||
)
|
||||
await cached_disk_dist_registry.register(new_vector_db)
|
||||
await cached_disk_dist_registry.register(new_vector_store)
|
||||
|
||||
# Verify in cache
|
||||
result_vector_db = await cached_disk_dist_registry.get("vector_db", "test_vector_db_2")
|
||||
assert result_vector_db is not None
|
||||
assert result_vector_db.identifier == new_vector_db.identifier
|
||||
assert result_vector_db.provider_id == new_vector_db.provider_id
|
||||
result_vector_store = await cached_disk_dist_registry.get("vector_store", "test_vector_store_2")
|
||||
assert result_vector_store is not None
|
||||
assert result_vector_store.identifier == new_vector_store.identifier
|
||||
assert result_vector_store.provider_id == new_vector_store.provider_id
|
||||
|
||||
# Verify persisted to disk
|
||||
db_path = cached_disk_dist_registry.kvstore.db_path
|
||||
new_registry = DiskDistributionRegistry(await kvstore_impl(SqliteKVStoreConfig(db_path=db_path)))
|
||||
backend_name = "kv_cached_new"
|
||||
register_kvstore_backends({backend_name: SqliteKVStoreConfig(db_path=db_path)})
|
||||
new_registry = DiskDistributionRegistry(
|
||||
await kvstore_impl(KVStoreReference(backend=backend_name, namespace="registry"))
|
||||
)
|
||||
await new_registry.initialize()
|
||||
result_vector_db = await new_registry.get("vector_db", "test_vector_db_2")
|
||||
assert result_vector_db is not None
|
||||
assert result_vector_db.identifier == new_vector_db.identifier
|
||||
assert result_vector_db.provider_id == new_vector_db.provider_id
|
||||
result_vector_store = await new_registry.get("vector_store", "test_vector_store_2")
|
||||
assert result_vector_store is not None
|
||||
assert result_vector_store.identifier == new_vector_store.identifier
|
||||
assert result_vector_store.provider_id == new_vector_store.provider_id
|
||||
|
||||
|
||||
async def test_duplicate_provider_registration(cached_disk_dist_registry):
|
||||
original_vector_db = VectorDB(
|
||||
identifier="test_vector_db_2",
|
||||
embedding_model="all-MiniLM-L6-v2",
|
||||
embedding_dimension=384,
|
||||
provider_resource_id="test_vector_db_2",
|
||||
original_vector_store = VectorStore(
|
||||
identifier="test_vector_store_2",
|
||||
embedding_model="nomic-embed-text-v1.5",
|
||||
embedding_dimension=768,
|
||||
provider_resource_id="test_vector_store_2",
|
||||
provider_id="baz",
|
||||
)
|
||||
await cached_disk_dist_registry.register(original_vector_db)
|
||||
assert await cached_disk_dist_registry.register(original_vector_store)
|
||||
|
||||
duplicate_vector_db = VectorDB(
|
||||
identifier="test_vector_db_2",
|
||||
duplicate_vector_store = VectorStore(
|
||||
identifier="test_vector_store_2",
|
||||
embedding_model="different-model",
|
||||
embedding_dimension=384,
|
||||
provider_resource_id="test_vector_db_2",
|
||||
embedding_dimension=768,
|
||||
provider_resource_id="test_vector_store_2",
|
||||
provider_id="baz", # Same provider_id
|
||||
)
|
||||
|
||||
# Now we expect a ValueError to be raised for duplicate registration
|
||||
with pytest.raises(
|
||||
ValueError,
|
||||
match=r"Provider 'baz' is already registered.*Unregister the existing provider first before registering it again.",
|
||||
ValueError, match="Object of type 'vector_store' and identifier 'test_vector_store_2' already exists"
|
||||
):
|
||||
await cached_disk_dist_registry.register(duplicate_vector_db)
|
||||
await cached_disk_dist_registry.register(duplicate_vector_store)
|
||||
|
||||
# Verify the original registration is still intact
|
||||
result = await cached_disk_dist_registry.get("vector_db", "test_vector_db_2")
|
||||
result = await cached_disk_dist_registry.get("vector_store", "test_vector_store_2")
|
||||
assert result is not None
|
||||
assert result.embedding_model == original_vector_db.embedding_model # Original values preserved
|
||||
assert result.embedding_model == original_vector_store.embedding_model # Original values preserved
|
||||
|
||||
|
||||
async def test_get_all_objects(cached_disk_dist_registry):
|
||||
# Create multiple test banks
|
||||
# Create multiple test banks
|
||||
test_vector_dbs = [
|
||||
VectorDB(
|
||||
identifier=f"test_vector_db_{i}",
|
||||
embedding_model="all-MiniLM-L6-v2",
|
||||
embedding_dimension=384,
|
||||
provider_resource_id=f"test_vector_db_{i}",
|
||||
test_vector_stores = [
|
||||
VectorStore(
|
||||
identifier=f"test_vector_store_{i}",
|
||||
embedding_model="nomic-embed-text-v1.5",
|
||||
embedding_dimension=768,
|
||||
provider_resource_id=f"test_vector_store_{i}",
|
||||
provider_id=f"provider_{i}",
|
||||
)
|
||||
for i in range(3)
|
||||
]
|
||||
|
||||
# Register all vector_dbs
|
||||
for vector_db in test_vector_dbs:
|
||||
await cached_disk_dist_registry.register(vector_db)
|
||||
# Register all vector_stores
|
||||
for vector_store in test_vector_stores:
|
||||
await cached_disk_dist_registry.register(vector_store)
|
||||
|
||||
# Test get_all retrieval
|
||||
all_results = await cached_disk_dist_registry.get_all()
|
||||
assert len(all_results) == 3
|
||||
|
||||
# Verify each vector_db was stored correctly
|
||||
for original_vector_db in test_vector_dbs:
|
||||
matching_vector_dbs = [v for v in all_results if v.identifier == original_vector_db.identifier]
|
||||
assert len(matching_vector_dbs) == 1
|
||||
stored_vector_db = matching_vector_dbs[0]
|
||||
assert stored_vector_db.embedding_model == original_vector_db.embedding_model
|
||||
assert stored_vector_db.provider_id == original_vector_db.provider_id
|
||||
assert stored_vector_db.embedding_dimension == original_vector_db.embedding_dimension
|
||||
# Verify each vector_store was stored correctly
|
||||
for original_vector_store in test_vector_stores:
|
||||
matching_vector_stores = [v for v in all_results if v.identifier == original_vector_store.identifier]
|
||||
assert len(matching_vector_stores) == 1
|
||||
stored_vector_store = matching_vector_stores[0]
|
||||
assert stored_vector_store.embedding_model == original_vector_store.embedding_model
|
||||
assert stored_vector_store.provider_id == original_vector_store.provider_id
|
||||
assert stored_vector_store.embedding_dimension == original_vector_store.embedding_dimension
|
||||
|
||||
|
||||
async def test_parse_registry_values_error_handling(sqlite_kvstore):
|
||||
valid_db = VectorDB(
|
||||
identifier="valid_vector_db",
|
||||
embedding_model="all-MiniLM-L6-v2",
|
||||
embedding_dimension=384,
|
||||
provider_resource_id="valid_vector_db",
|
||||
valid_db = VectorStore(
|
||||
identifier="valid_vector_store",
|
||||
embedding_model="nomic-embed-text-v1.5",
|
||||
embedding_dimension=768,
|
||||
provider_resource_id="valid_vector_store",
|
||||
provider_id="test-provider",
|
||||
)
|
||||
|
||||
await sqlite_kvstore.set(
|
||||
KEY_FORMAT.format(type="vector_db", identifier="valid_vector_db"), valid_db.model_dump_json()
|
||||
KEY_FORMAT.format(type="vector_store", identifier="valid_vector_store"), valid_db.model_dump_json()
|
||||
)
|
||||
|
||||
await sqlite_kvstore.set(KEY_FORMAT.format(type="vector_db", identifier="corrupted_json"), "{not valid json")
|
||||
await sqlite_kvstore.set(KEY_FORMAT.format(type="vector_store", identifier="corrupted_json"), "{not valid json")
|
||||
|
||||
await sqlite_kvstore.set(
|
||||
KEY_FORMAT.format(type="vector_db", identifier="missing_fields"),
|
||||
'{"type": "vector_db", "identifier": "missing_fields"}',
|
||||
KEY_FORMAT.format(type="vector_store", identifier="missing_fields"),
|
||||
'{"type": "vector_store", "identifier": "missing_fields"}',
|
||||
)
|
||||
|
||||
test_registry = DiskDistributionRegistry(sqlite_kvstore)
|
||||
|
|
@ -199,32 +204,32 @@ async def test_parse_registry_values_error_handling(sqlite_kvstore):
|
|||
|
||||
# Should have filtered out the invalid entries
|
||||
assert len(all_objects) == 1
|
||||
assert all_objects[0].identifier == "valid_vector_db"
|
||||
assert all_objects[0].identifier == "valid_vector_store"
|
||||
|
||||
# Check that the get method also handles errors correctly
|
||||
invalid_obj = await test_registry.get("vector_db", "corrupted_json")
|
||||
invalid_obj = await test_registry.get("vector_store", "corrupted_json")
|
||||
assert invalid_obj is None
|
||||
|
||||
invalid_obj = await test_registry.get("vector_db", "missing_fields")
|
||||
invalid_obj = await test_registry.get("vector_store", "missing_fields")
|
||||
assert invalid_obj is None
|
||||
|
||||
|
||||
async def test_cached_registry_error_handling(sqlite_kvstore):
|
||||
valid_db = VectorDB(
|
||||
valid_db = VectorStore(
|
||||
identifier="valid_cached_db",
|
||||
embedding_model="all-MiniLM-L6-v2",
|
||||
embedding_dimension=384,
|
||||
embedding_model="nomic-embed-text-v1.5",
|
||||
embedding_dimension=768,
|
||||
provider_resource_id="valid_cached_db",
|
||||
provider_id="test-provider",
|
||||
)
|
||||
|
||||
await sqlite_kvstore.set(
|
||||
KEY_FORMAT.format(type="vector_db", identifier="valid_cached_db"), valid_db.model_dump_json()
|
||||
KEY_FORMAT.format(type="vector_store", identifier="valid_cached_db"), valid_db.model_dump_json()
|
||||
)
|
||||
|
||||
await sqlite_kvstore.set(
|
||||
KEY_FORMAT.format(type="vector_db", identifier="invalid_cached_db"),
|
||||
'{"type": "vector_db", "identifier": "invalid_cached_db", "embedding_model": 12345}', # Should be string
|
||||
KEY_FORMAT.format(type="vector_store", identifier="invalid_cached_db"),
|
||||
'{"type": "vector_store", "identifier": "invalid_cached_db", "embedding_model": 12345}', # Should be string
|
||||
)
|
||||
|
||||
cached_registry = CachedDiskDistributionRegistry(sqlite_kvstore)
|
||||
|
|
@ -234,5 +239,102 @@ async def test_cached_registry_error_handling(sqlite_kvstore):
|
|||
assert len(all_objects) == 1
|
||||
assert all_objects[0].identifier == "valid_cached_db"
|
||||
|
||||
invalid_obj = await cached_registry.get("vector_db", "invalid_cached_db")
|
||||
invalid_obj = await cached_registry.get("vector_store", "invalid_cached_db")
|
||||
assert invalid_obj is None
|
||||
|
||||
|
||||
async def test_double_registration_identical_objects(disk_dist_registry):
|
||||
"""Test that registering identical objects succeeds (idempotent)."""
|
||||
vector_store = VectorStoreWithOwner(
|
||||
identifier="test_vector_store",
|
||||
embedding_model="all-MiniLM-L6-v2",
|
||||
embedding_dimension=384,
|
||||
provider_resource_id="test_vector_store",
|
||||
provider_id="test-provider",
|
||||
)
|
||||
|
||||
# First registration should succeed
|
||||
result1 = await disk_dist_registry.register(vector_store)
|
||||
assert result1 is True
|
||||
|
||||
# Second registration of identical object should also succeed (idempotent)
|
||||
result2 = await disk_dist_registry.register(vector_store)
|
||||
assert result2 is True
|
||||
|
||||
# Verify object exists and is unchanged
|
||||
retrieved = await disk_dist_registry.get("vector_store", "test_vector_store")
|
||||
assert retrieved is not None
|
||||
assert retrieved.identifier == vector_store.identifier
|
||||
assert retrieved.embedding_model == vector_store.embedding_model
|
||||
|
||||
|
||||
async def test_double_registration_different_objects(disk_dist_registry):
|
||||
"""Test that registering different objects with same identifier fails."""
|
||||
vector_store1 = VectorStoreWithOwner(
|
||||
identifier="test_vector_store",
|
||||
embedding_model="all-MiniLM-L6-v2",
|
||||
embedding_dimension=384,
|
||||
provider_resource_id="test_vector_store",
|
||||
provider_id="test-provider",
|
||||
)
|
||||
|
||||
vector_store2 = VectorStoreWithOwner(
|
||||
identifier="test_vector_store", # Same identifier
|
||||
embedding_model="different-model", # Different embedding model
|
||||
embedding_dimension=384,
|
||||
provider_resource_id="test_vector_store",
|
||||
provider_id="test-provider",
|
||||
)
|
||||
|
||||
# First registration should succeed
|
||||
result1 = await disk_dist_registry.register(vector_store1)
|
||||
assert result1 is True
|
||||
|
||||
# Second registration with different data should fail
|
||||
with pytest.raises(
|
||||
ValueError, match="Object of type 'vector_store' and identifier 'test_vector_store' already exists"
|
||||
):
|
||||
await disk_dist_registry.register(vector_store2)
|
||||
|
||||
# Verify original object is unchanged
|
||||
retrieved = await disk_dist_registry.get("vector_store", "test_vector_store")
|
||||
assert retrieved is not None
|
||||
assert retrieved.embedding_model == "all-MiniLM-L6-v2" # Original value
|
||||
|
||||
|
||||
async def test_double_registration_with_cache(cached_disk_dist_registry):
|
||||
"""Test double registration behavior with caching enabled."""
|
||||
from llama_stack.apis.models import ModelType
|
||||
from llama_stack.core.datatypes import ModelWithOwner
|
||||
|
||||
model1 = ModelWithOwner(
|
||||
identifier="test_model",
|
||||
provider_resource_id="test_model",
|
||||
provider_id="test-provider",
|
||||
model_type=ModelType.llm,
|
||||
)
|
||||
|
||||
model2 = ModelWithOwner(
|
||||
identifier="test_model", # Same identifier
|
||||
provider_resource_id="test_model",
|
||||
provider_id="test-provider",
|
||||
model_type=ModelType.embedding, # Different type
|
||||
)
|
||||
|
||||
# First registration should succeed and populate cache
|
||||
result1 = await cached_disk_dist_registry.register(model1)
|
||||
assert result1 is True
|
||||
|
||||
# Verify in cache
|
||||
cached_model = cached_disk_dist_registry.get_cached("model", "test_model")
|
||||
assert cached_model is not None
|
||||
assert cached_model.model_type == ModelType.llm
|
||||
|
||||
# Second registration with different data should fail
|
||||
with pytest.raises(ValueError, match="Object of type 'model' and identifier 'test_model' already exists"):
|
||||
await cached_disk_dist_registry.register(model2)
|
||||
|
||||
# Cache should still contain original model
|
||||
cached_model_after = cached_disk_dist_registry.get_cached("model", "test_model")
|
||||
assert cached_model_after is not None
|
||||
assert cached_model_after.model_type == ModelType.llm
|
||||
|
|
|
|||
|
|
@ -256,12 +256,12 @@ async def test_setup_with_access_policy(cached_disk_dist_registry):
|
|||
- permit:
|
||||
principal: user-2
|
||||
actions: [read]
|
||||
resource: model::model-1
|
||||
resource: model::test_provider/model-1
|
||||
description: user-2 has read access to model-1 only
|
||||
- permit:
|
||||
principal: user-3
|
||||
actions: [read]
|
||||
resource: model::model-2
|
||||
resource: model::test_provider/model-2
|
||||
description: user-3 has read access to model-2 only
|
||||
- forbid:
|
||||
actions: [create, read, delete]
|
||||
|
|
@ -285,21 +285,15 @@ async def test_access_policy(mock_get_authenticated_user, test_setup_with_access
|
|||
"projects": ["foo", "bar"],
|
||||
},
|
||||
)
|
||||
await routing_table.register_model(
|
||||
"model-1", provider_model_id="test_provider/model-1", provider_id="test_provider"
|
||||
)
|
||||
await routing_table.register_model(
|
||||
"model-2", provider_model_id="test_provider/model-2", provider_id="test_provider"
|
||||
)
|
||||
await routing_table.register_model(
|
||||
"model-3", provider_model_id="test_provider/model-3", provider_id="test_provider"
|
||||
)
|
||||
model = await routing_table.get_model("model-1")
|
||||
assert model.identifier == "model-1"
|
||||
model = await routing_table.get_model("model-2")
|
||||
assert model.identifier == "model-2"
|
||||
model = await routing_table.get_model("model-3")
|
||||
assert model.identifier == "model-3"
|
||||
await routing_table.register_model("model-1", provider_model_id="model-1", provider_id="test_provider")
|
||||
await routing_table.register_model("model-2", provider_model_id="model-2", provider_id="test_provider")
|
||||
await routing_table.register_model("model-3", provider_model_id="model-3", provider_id="test_provider")
|
||||
model = await routing_table.get_model("test_provider/model-1")
|
||||
assert model.identifier == "test_provider/model-1"
|
||||
model = await routing_table.get_model("test_provider/model-2")
|
||||
assert model.identifier == "test_provider/model-2"
|
||||
model = await routing_table.get_model("test_provider/model-3")
|
||||
assert model.identifier == "test_provider/model-3"
|
||||
|
||||
mock_get_authenticated_user.return_value = User(
|
||||
"user-2",
|
||||
|
|
@ -308,16 +302,16 @@ async def test_access_policy(mock_get_authenticated_user, test_setup_with_access
|
|||
"projects": ["foo"],
|
||||
},
|
||||
)
|
||||
model = await routing_table.get_model("model-1")
|
||||
assert model.identifier == "model-1"
|
||||
model = await routing_table.get_model("test_provider/model-1")
|
||||
assert model.identifier == "test_provider/model-1"
|
||||
with pytest.raises(ValueError):
|
||||
await routing_table.get_model("model-2")
|
||||
await routing_table.get_model("test_provider/model-2")
|
||||
with pytest.raises(ValueError):
|
||||
await routing_table.get_model("model-3")
|
||||
await routing_table.get_model("test_provider/model-3")
|
||||
with pytest.raises(AccessDeniedError):
|
||||
await routing_table.register_model("model-4", provider_id="test_provider")
|
||||
with pytest.raises(AccessDeniedError):
|
||||
await routing_table.unregister_model("model-1")
|
||||
await routing_table.unregister_model("test_provider/model-1")
|
||||
|
||||
mock_get_authenticated_user.return_value = User(
|
||||
"user-3",
|
||||
|
|
@ -326,16 +320,16 @@ async def test_access_policy(mock_get_authenticated_user, test_setup_with_access
|
|||
"projects": ["bar"],
|
||||
},
|
||||
)
|
||||
model = await routing_table.get_model("model-2")
|
||||
assert model.identifier == "model-2"
|
||||
model = await routing_table.get_model("test_provider/model-2")
|
||||
assert model.identifier == "test_provider/model-2"
|
||||
with pytest.raises(ValueError):
|
||||
await routing_table.get_model("model-1")
|
||||
await routing_table.get_model("test_provider/model-1")
|
||||
with pytest.raises(ValueError):
|
||||
await routing_table.get_model("model-3")
|
||||
await routing_table.get_model("test_provider/model-3")
|
||||
with pytest.raises(AccessDeniedError):
|
||||
await routing_table.register_model("model-5", provider_id="test_provider")
|
||||
with pytest.raises(AccessDeniedError):
|
||||
await routing_table.unregister_model("model-2")
|
||||
await routing_table.unregister_model("test_provider/model-2")
|
||||
|
||||
mock_get_authenticated_user.return_value = User(
|
||||
"user-1",
|
||||
|
|
@ -344,9 +338,9 @@ async def test_access_policy(mock_get_authenticated_user, test_setup_with_access
|
|||
"projects": ["foo", "bar"],
|
||||
},
|
||||
)
|
||||
await routing_table.unregister_model("model-3")
|
||||
await routing_table.unregister_model("test_provider/model-3")
|
||||
with pytest.raises(ValueError):
|
||||
await routing_table.get_model("model-3")
|
||||
await routing_table.get_model("test_provider/model-3")
|
||||
|
||||
|
||||
def test_permit_when():
|
||||
|
|
|
|||
|
|
@ -5,7 +5,9 @@
|
|||
# the root directory of this source tree.
|
||||
|
||||
import base64
|
||||
from unittest.mock import AsyncMock, patch
|
||||
import json
|
||||
import logging # allow-direct-logging
|
||||
from unittest.mock import AsyncMock, Mock, patch
|
||||
|
||||
import pytest
|
||||
from fastapi import FastAPI
|
||||
|
|
@ -26,6 +28,13 @@ from llama_stack.core.server.auth_providers import (
|
|||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def suppress_auth_errors(caplog):
|
||||
"""Suppress expected ERROR/WARNING logs for tests that deliberately trigger authentication errors"""
|
||||
caplog.set_level(logging.CRITICAL, logger="llama_stack.core.server.auth")
|
||||
caplog.set_level(logging.CRITICAL, logger="llama_stack.core.server.auth_providers")
|
||||
|
||||
|
||||
class MockResponse:
|
||||
def __init__(self, status_code, json_data):
|
||||
self.status_code = status_code
|
||||
|
|
@ -122,7 +131,7 @@ def mock_impls():
|
|||
|
||||
|
||||
@pytest.fixture
|
||||
def scope_middleware_with_mocks(mock_auth_endpoint):
|
||||
def middleware_with_mocks(mock_auth_endpoint):
|
||||
"""Create AuthenticationMiddleware with mocked route implementations"""
|
||||
mock_app = AsyncMock()
|
||||
auth_config = AuthenticationConfig(
|
||||
|
|
@ -137,18 +146,20 @@ def scope_middleware_with_mocks(mock_auth_endpoint):
|
|||
# Mock the route_impls to simulate finding routes with required scopes
|
||||
from llama_stack.schema_utils import WebMethod
|
||||
|
||||
scoped_webmethod = WebMethod(route="/test/scoped", method="POST", required_scope="test.read")
|
||||
|
||||
public_webmethod = WebMethod(route="/test/public", method="GET")
|
||||
routes = {
|
||||
("POST", "/test/scoped"): WebMethod(route="/test/scoped", method="POST", required_scope="test.read"),
|
||||
("GET", "/test/public"): WebMethod(route="/test/public", method="GET"),
|
||||
("GET", "/health"): WebMethod(route="/health", method="GET", require_authentication=False),
|
||||
("GET", "/version"): WebMethod(route="/version", method="GET", require_authentication=False),
|
||||
("GET", "/models/list"): WebMethod(route="/models/list", method="GET", require_authentication=True),
|
||||
}
|
||||
|
||||
# Mock the route finding logic
|
||||
def mock_find_matching_route(method, path, route_impls):
|
||||
if method == "POST" and path == "/test/scoped":
|
||||
return None, {}, "/test/scoped", scoped_webmethod
|
||||
elif method == "GET" and path == "/test/public":
|
||||
return None, {}, "/test/public", public_webmethod
|
||||
else:
|
||||
raise ValueError("No matching route")
|
||||
webmethod = routes.get((method, path))
|
||||
if webmethod:
|
||||
return None, {}, path, webmethod
|
||||
raise ValueError("No matching route")
|
||||
|
||||
import llama_stack.core.server.auth
|
||||
|
||||
|
|
@ -234,20 +245,20 @@ def test_valid_http_authentication(http_client, valid_api_key):
|
|||
|
||||
|
||||
@patch("httpx.AsyncClient.post", new=mock_post_failure)
|
||||
def test_invalid_http_authentication(http_client, invalid_api_key):
|
||||
def test_invalid_http_authentication(http_client, invalid_api_key, suppress_auth_errors):
|
||||
response = http_client.get("/test", headers={"Authorization": f"Bearer {invalid_api_key}"})
|
||||
assert response.status_code == 401
|
||||
assert "Authentication failed" in response.json()["error"]["message"]
|
||||
|
||||
|
||||
@patch("httpx.AsyncClient.post", new=mock_post_exception)
|
||||
def test_http_auth_service_error(http_client, valid_api_key):
|
||||
def test_http_auth_service_error(http_client, valid_api_key, suppress_auth_errors):
|
||||
response = http_client.get("/test", headers={"Authorization": f"Bearer {valid_api_key}"})
|
||||
assert response.status_code == 401
|
||||
assert "Authentication service error" in response.json()["error"]["message"]
|
||||
|
||||
|
||||
def test_http_auth_request_payload(http_client, valid_api_key, mock_auth_endpoint):
|
||||
def test_http_auth_request_payload(http_client, valid_api_key, mock_auth_endpoint, suppress_auth_errors):
|
||||
with patch("httpx.AsyncClient.post") as mock_post:
|
||||
mock_response = MockResponse(200, {"message": "Authentication successful"})
|
||||
mock_post.return_value = mock_response
|
||||
|
|
@ -372,7 +383,7 @@ async def mock_jwks_response(*args, **kwargs):
|
|||
|
||||
@pytest.fixture
|
||||
def jwt_token_valid():
|
||||
from jose import jwt
|
||||
import jwt
|
||||
|
||||
return jwt.encode(
|
||||
{
|
||||
|
|
@ -387,15 +398,37 @@ def jwt_token_valid():
|
|||
)
|
||||
|
||||
|
||||
@patch("httpx.AsyncClient.get", new=mock_jwks_response)
|
||||
def test_valid_oauth2_authentication(oauth2_client, jwt_token_valid):
|
||||
@pytest.fixture
|
||||
def mock_jwks_urlopen():
|
||||
"""Mock urllib.request.urlopen for PyJWKClient JWKS requests."""
|
||||
with patch("urllib.request.urlopen") as mock_urlopen:
|
||||
# Mock the JWKS response for PyJWKClient
|
||||
mock_response = Mock()
|
||||
mock_response.read.return_value = json.dumps(
|
||||
{
|
||||
"keys": [
|
||||
{
|
||||
"kid": "1234567890",
|
||||
"kty": "oct",
|
||||
"alg": "HS256",
|
||||
"use": "sig",
|
||||
"k": base64.b64encode(b"foobarbaz").decode(),
|
||||
}
|
||||
]
|
||||
}
|
||||
).encode()
|
||||
mock_urlopen.return_value.__enter__.return_value = mock_response
|
||||
yield mock_urlopen
|
||||
|
||||
|
||||
def test_valid_oauth2_authentication(oauth2_client, jwt_token_valid, mock_jwks_urlopen):
|
||||
response = oauth2_client.get("/test", headers={"Authorization": f"Bearer {jwt_token_valid}"})
|
||||
assert response.status_code == 200
|
||||
assert response.json() == {"message": "Authentication successful"}
|
||||
|
||||
|
||||
@patch("httpx.AsyncClient.get", new=mock_jwks_response)
|
||||
def test_invalid_oauth2_authentication(oauth2_client, invalid_token):
|
||||
def test_invalid_oauth2_authentication(oauth2_client, invalid_token, suppress_auth_errors):
|
||||
response = oauth2_client.get("/test", headers={"Authorization": f"Bearer {invalid_token}"})
|
||||
assert response.status_code == 401
|
||||
assert "Invalid JWT token" in response.json()["error"]["message"]
|
||||
|
|
@ -440,13 +473,12 @@ def oauth2_client_with_jwks_token(oauth2_app_with_jwks_token):
|
|||
|
||||
|
||||
@patch("httpx.AsyncClient.get", new=mock_auth_jwks_response)
|
||||
def test_oauth2_with_jwks_token_expected(oauth2_client, jwt_token_valid):
|
||||
def test_oauth2_with_jwks_token_expected(oauth2_client, jwt_token_valid, suppress_auth_errors):
|
||||
response = oauth2_client.get("/test", headers={"Authorization": f"Bearer {jwt_token_valid}"})
|
||||
assert response.status_code == 401
|
||||
|
||||
|
||||
@patch("httpx.AsyncClient.get", new=mock_auth_jwks_response)
|
||||
def test_oauth2_with_jwks_token_configured(oauth2_client_with_jwks_token, jwt_token_valid):
|
||||
def test_oauth2_with_jwks_token_configured(oauth2_client_with_jwks_token, jwt_token_valid, mock_jwks_urlopen):
|
||||
response = oauth2_client_with_jwks_token.get("/test", headers={"Authorization": f"Bearer {jwt_token_valid}"})
|
||||
assert response.status_code == 200
|
||||
assert response.json() == {"message": "Authentication successful"}
|
||||
|
|
@ -492,6 +524,82 @@ def test_get_attributes_from_claims():
|
|||
assert set(attributes["teams"]) == {"my-team", "group1", "group2"}
|
||||
assert attributes["namespaces"] == ["my-tenant"]
|
||||
|
||||
# Test nested claims with dot notation (e.g., Keycloak resource_access structure)
|
||||
claims = {
|
||||
"sub": "user123",
|
||||
"resource_access": {"llamastack": {"roles": ["inference_max", "admin"]}, "other-client": {"roles": ["viewer"]}},
|
||||
"realm_access": {"roles": ["offline_access", "uma_authorization"]},
|
||||
}
|
||||
attributes = get_attributes_from_claims(
|
||||
claims, {"resource_access.llamastack.roles": "roles", "realm_access.roles": "realm_roles"}
|
||||
)
|
||||
assert set(attributes["roles"]) == {"inference_max", "admin"}
|
||||
assert set(attributes["realm_roles"]) == {"offline_access", "uma_authorization"}
|
||||
|
||||
# Test that dot notation takes precedence over literal keys with dots
|
||||
claims = {
|
||||
"my.dotted.key": "literal-value",
|
||||
"my": {"dotted": {"key": "nested-value"}},
|
||||
}
|
||||
attributes = get_attributes_from_claims(claims, {"my.dotted.key": "test"})
|
||||
assert attributes["test"] == ["nested-value"]
|
||||
|
||||
# Test that literal key works when nested traversal doesn't exist
|
||||
claims = {
|
||||
"my.dotted.key": "literal-value",
|
||||
}
|
||||
attributes = get_attributes_from_claims(claims, {"my.dotted.key": "test"})
|
||||
assert attributes["test"] == ["literal-value"]
|
||||
|
||||
# Test missing nested paths are handled gracefully
|
||||
claims = {
|
||||
"sub": "user123",
|
||||
"resource_access": {"other-client": {"roles": ["viewer"]}},
|
||||
}
|
||||
attributes = get_attributes_from_claims(
|
||||
claims,
|
||||
{
|
||||
"resource_access.llamastack.roles": "roles", # Missing nested path
|
||||
"resource_access.missing.key": "missing_attr", # Missing nested path
|
||||
"completely.missing.path": "another_missing", # Completely missing
|
||||
"sub": "username", # Existing path
|
||||
},
|
||||
)
|
||||
# Only the existing claim should be in attributes
|
||||
assert attributes["username"] == ["user123"]
|
||||
assert "roles" not in attributes
|
||||
assert "missing_attr" not in attributes
|
||||
assert "another_missing" not in attributes
|
||||
|
||||
# Test mixture of flat and nested claims paths
|
||||
claims = {
|
||||
"sub": "user456",
|
||||
"flat_key": "flat-value",
|
||||
"scope": "read write admin",
|
||||
"resource_access": {"app1": {"roles": ["role1", "role2"]}, "app2": {"roles": ["role3"]}},
|
||||
"groups": ["group1", "group2"],
|
||||
"metadata": {"tenant": "tenant1", "region": "us-west"},
|
||||
}
|
||||
attributes = get_attributes_from_claims(
|
||||
claims,
|
||||
{
|
||||
"sub": "user_id", # Flat string
|
||||
"scope": "permissions", # Flat string with spaces
|
||||
"groups": "teams", # Flat list
|
||||
"resource_access.app1.roles": "app1_roles", # Nested list
|
||||
"resource_access.app2.roles": "app2_roles", # Nested list
|
||||
"metadata.tenant": "tenant", # Nested string
|
||||
"metadata.region": "region", # Nested string
|
||||
},
|
||||
)
|
||||
assert attributes["user_id"] == ["user456"]
|
||||
assert set(attributes["permissions"]) == {"read", "write", "admin"}
|
||||
assert set(attributes["teams"]) == {"group1", "group2"}
|
||||
assert set(attributes["app1_roles"]) == {"role1", "role2"}
|
||||
assert attributes["app2_roles"] == ["role3"]
|
||||
assert attributes["tenant"] == ["tenant1"]
|
||||
assert attributes["region"] == ["us-west"]
|
||||
|
||||
|
||||
# TODO: add more tests for oauth2 token provider
|
||||
|
||||
|
|
@ -626,21 +734,21 @@ def test_valid_introspection_authentication(introspection_client, valid_api_key)
|
|||
|
||||
|
||||
@patch("httpx.AsyncClient.post", new=mock_introspection_inactive)
|
||||
def test_inactive_introspection_authentication(introspection_client, invalid_api_key):
|
||||
def test_inactive_introspection_authentication(introspection_client, invalid_api_key, suppress_auth_errors):
|
||||
response = introspection_client.get("/test", headers={"Authorization": f"Bearer {invalid_api_key}"})
|
||||
assert response.status_code == 401
|
||||
assert "Token not active" in response.json()["error"]["message"]
|
||||
|
||||
|
||||
@patch("httpx.AsyncClient.post", new=mock_introspection_invalid)
|
||||
def test_invalid_introspection_authentication(introspection_client, invalid_api_key):
|
||||
def test_invalid_introspection_authentication(introspection_client, invalid_api_key, suppress_auth_errors):
|
||||
response = introspection_client.get("/test", headers={"Authorization": f"Bearer {invalid_api_key}"})
|
||||
assert response.status_code == 401
|
||||
assert "Not JSON" in response.json()["error"]["message"]
|
||||
|
||||
|
||||
@patch("httpx.AsyncClient.post", new=mock_introspection_failed)
|
||||
def test_failed_introspection_authentication(introspection_client, invalid_api_key):
|
||||
def test_failed_introspection_authentication(introspection_client, invalid_api_key, suppress_auth_errors):
|
||||
response = introspection_client.get("/test", headers={"Authorization": f"Bearer {invalid_api_key}"})
|
||||
assert response.status_code == 401
|
||||
assert "Token introspection failed: 500" in response.json()["error"]["message"]
|
||||
|
|
@ -659,9 +767,9 @@ def test_valid_introspection_with_custom_mapping_authentication(
|
|||
|
||||
# Scope-based authorization tests
|
||||
@patch("httpx.AsyncClient.post", new=mock_post_success_with_scope)
|
||||
async def test_scope_authorization_success(scope_middleware_with_mocks, valid_api_key):
|
||||
async def test_scope_authorization_success(middleware_with_mocks, valid_api_key):
|
||||
"""Test that user with required scope can access protected endpoint"""
|
||||
middleware, mock_app = scope_middleware_with_mocks
|
||||
middleware, mock_app = middleware_with_mocks
|
||||
mock_receive = AsyncMock()
|
||||
mock_send = AsyncMock()
|
||||
|
||||
|
|
@ -680,9 +788,9 @@ async def test_scope_authorization_success(scope_middleware_with_mocks, valid_ap
|
|||
|
||||
|
||||
@patch("httpx.AsyncClient.post", new=mock_post_success_no_scope)
|
||||
async def test_scope_authorization_denied(scope_middleware_with_mocks, valid_api_key):
|
||||
async def test_scope_authorization_denied(middleware_with_mocks, valid_api_key):
|
||||
"""Test that user without required scope gets 403 access denied"""
|
||||
middleware, mock_app = scope_middleware_with_mocks
|
||||
middleware, mock_app = middleware_with_mocks
|
||||
mock_receive = AsyncMock()
|
||||
mock_send = AsyncMock()
|
||||
|
||||
|
|
@ -710,9 +818,9 @@ async def test_scope_authorization_denied(scope_middleware_with_mocks, valid_api
|
|||
|
||||
|
||||
@patch("httpx.AsyncClient.post", new=mock_post_success_no_scope)
|
||||
async def test_public_endpoint_no_scope_required(scope_middleware_with_mocks, valid_api_key):
|
||||
async def test_public_endpoint_no_scope_required(middleware_with_mocks, valid_api_key):
|
||||
"""Test that public endpoints work without specific scopes"""
|
||||
middleware, mock_app = scope_middleware_with_mocks
|
||||
middleware, mock_app = middleware_with_mocks
|
||||
mock_receive = AsyncMock()
|
||||
mock_send = AsyncMock()
|
||||
|
||||
|
|
@ -730,9 +838,9 @@ async def test_public_endpoint_no_scope_required(scope_middleware_with_mocks, va
|
|||
mock_send.assert_not_called()
|
||||
|
||||
|
||||
async def test_scope_authorization_no_auth_disabled(scope_middleware_with_mocks):
|
||||
async def test_scope_authorization_no_auth_disabled(middleware_with_mocks):
|
||||
"""Test that when auth is disabled (no user), scope checks are bypassed"""
|
||||
middleware, mock_app = scope_middleware_with_mocks
|
||||
middleware, mock_app = middleware_with_mocks
|
||||
mock_receive = AsyncMock()
|
||||
mock_send = AsyncMock()
|
||||
|
||||
|
|
@ -857,20 +965,22 @@ def test_valid_kubernetes_auth_authentication(kubernetes_auth_client, valid_toke
|
|||
|
||||
|
||||
@patch("httpx.AsyncClient.post", new=mock_kubernetes_selfsubjectreview_failure)
|
||||
def test_invalid_kubernetes_auth_authentication(kubernetes_auth_client, invalid_token):
|
||||
def test_invalid_kubernetes_auth_authentication(kubernetes_auth_client, invalid_token, suppress_auth_errors):
|
||||
response = kubernetes_auth_client.get("/test", headers={"Authorization": f"Bearer {invalid_token}"})
|
||||
assert response.status_code == 401
|
||||
assert "Invalid token" in response.json()["error"]["message"]
|
||||
|
||||
|
||||
@patch("httpx.AsyncClient.post", new=mock_kubernetes_selfsubjectreview_http_error)
|
||||
def test_kubernetes_auth_http_error(kubernetes_auth_client, valid_token):
|
||||
def test_kubernetes_auth_http_error(kubernetes_auth_client, valid_token, suppress_auth_errors):
|
||||
response = kubernetes_auth_client.get("/test", headers={"Authorization": f"Bearer {valid_token}"})
|
||||
assert response.status_code == 401
|
||||
assert "Token validation failed" in response.json()["error"]["message"]
|
||||
|
||||
|
||||
def test_kubernetes_auth_request_payload(kubernetes_auth_client, valid_token, mock_kubernetes_api_server):
|
||||
def test_kubernetes_auth_request_payload(
|
||||
kubernetes_auth_client, valid_token, mock_kubernetes_api_server, suppress_auth_errors
|
||||
):
|
||||
with patch("httpx.AsyncClient.post") as mock_post:
|
||||
mock_response = MockResponse(
|
||||
200,
|
||||
|
|
@ -907,3 +1017,41 @@ def test_kubernetes_auth_request_payload(kubernetes_auth_client, valid_token, mo
|
|||
request_body = call_args[1]["json"]
|
||||
assert request_body["apiVersion"] == "authentication.k8s.io/v1"
|
||||
assert request_body["kind"] == "SelfSubjectReview"
|
||||
|
||||
|
||||
async def test_unauthenticated_endpoint_access_health(middleware_with_mocks):
|
||||
"""Test that /health endpoints can be accessed without authentication"""
|
||||
middleware, mock_app = middleware_with_mocks
|
||||
|
||||
# Test request to /health without auth header (level prefix v1 is added by router)
|
||||
scope = {"type": "http", "path": "/health", "headers": [], "method": "GET"}
|
||||
receive = AsyncMock()
|
||||
send = AsyncMock()
|
||||
|
||||
# Should allow the request to proceed without authentication
|
||||
await middleware(scope, receive, send)
|
||||
|
||||
# Verify that the request was passed to the app
|
||||
mock_app.assert_called_once_with(scope, receive, send)
|
||||
|
||||
# Verify that no error response was sent
|
||||
assert not any(call[0][0].get("status") == 401 for call in send.call_args_list)
|
||||
|
||||
|
||||
async def test_unauthenticated_endpoint_denied_for_other_paths(middleware_with_mocks):
|
||||
"""Test that endpoints other than /health and /version require authentication"""
|
||||
middleware, mock_app = middleware_with_mocks
|
||||
|
||||
# Test request to /models/list without auth header
|
||||
scope = {"type": "http", "path": "/models/list", "headers": [], "method": "GET"}
|
||||
receive = AsyncMock()
|
||||
send = AsyncMock()
|
||||
|
||||
# Should return 401 error
|
||||
await middleware(scope, receive, send)
|
||||
|
||||
# Verify that the app was NOT called
|
||||
mock_app.assert_not_called()
|
||||
|
||||
# Verify that a 401 error response was sent
|
||||
assert any(call[0][0].get("status") == 401 for call in send.call_args_list)
|
||||
|
|
|
|||
|
|
@ -4,6 +4,7 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import logging # allow-direct-logging
|
||||
from unittest.mock import AsyncMock, patch
|
||||
|
||||
import httpx
|
||||
|
|
@ -15,6 +16,13 @@ from llama_stack.core.datatypes import AuthenticationConfig, AuthProviderType, G
|
|||
from llama_stack.core.server.auth import AuthenticationMiddleware
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def suppress_auth_errors(caplog):
|
||||
"""Suppress expected ERROR logs for tests that deliberately trigger authentication errors"""
|
||||
caplog.set_level(logging.CRITICAL, logger="llama_stack.core.server.auth")
|
||||
caplog.set_level(logging.CRITICAL, logger="llama_stack.core.server.auth_providers")
|
||||
|
||||
|
||||
class MockResponse:
|
||||
def __init__(self, status_code, json_data):
|
||||
self.status_code = status_code
|
||||
|
|
@ -119,7 +127,7 @@ def test_authenticated_endpoint_with_valid_github_token(mock_client_class, githu
|
|||
|
||||
|
||||
@patch("llama_stack.core.server.auth_providers.httpx.AsyncClient")
|
||||
def test_authenticated_endpoint_with_invalid_github_token(mock_client_class, github_token_client):
|
||||
def test_authenticated_endpoint_with_invalid_github_token(mock_client_class, github_token_client, suppress_auth_errors):
|
||||
"""Test accessing protected endpoint with invalid GitHub token"""
|
||||
# Mock the GitHub API to return 401 Unauthorized
|
||||
mock_client = AsyncMock()
|
||||
|
|
|
|||
|
|
@ -4,6 +4,9 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import logging # allow-direct-logging
|
||||
from uuid import uuid4
|
||||
|
||||
import pytest
|
||||
from fastapi import FastAPI, Request
|
||||
from fastapi.testclient import TestClient
|
||||
|
|
@ -11,7 +14,14 @@ from starlette.middleware.base import BaseHTTPMiddleware
|
|||
|
||||
from llama_stack.core.datatypes import QuotaConfig, QuotaPeriod
|
||||
from llama_stack.core.server.quota import QuotaMiddleware
|
||||
from llama_stack.providers.utils.kvstore.config import SqliteKVStoreConfig
|
||||
from llama_stack.core.storage.datatypes import KVStoreReference, SqliteKVStoreConfig
|
||||
from llama_stack.providers.utils.kvstore import register_kvstore_backends
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def suppress_quota_warnings(caplog):
|
||||
"""Suppress expected WARNING logs for SQLite backend and quota exceeded"""
|
||||
caplog.set_level(logging.CRITICAL, logger="llama_stack.core.server.quota")
|
||||
|
||||
|
||||
class InjectClientIDMiddleware(BaseHTTPMiddleware):
|
||||
|
|
@ -29,8 +39,10 @@ class InjectClientIDMiddleware(BaseHTTPMiddleware):
|
|||
|
||||
|
||||
def build_quota_config(db_path) -> QuotaConfig:
|
||||
backend_name = f"kv_quota_{uuid4().hex}"
|
||||
register_kvstore_backends({backend_name: SqliteKVStoreConfig(db_path=str(db_path))})
|
||||
return QuotaConfig(
|
||||
kvstore=SqliteKVStoreConfig(db_path=str(db_path)),
|
||||
kvstore=KVStoreReference(backend=backend_name, namespace="quota"),
|
||||
anonymous_max_requests=1,
|
||||
authenticated_max_requests=2,
|
||||
period=QuotaPeriod.DAY,
|
||||
|
|
@ -65,13 +77,13 @@ def auth_app(tmp_path, request):
|
|||
return app
|
||||
|
||||
|
||||
def test_authenticated_quota_allows_up_to_limit(auth_app):
|
||||
def test_authenticated_quota_allows_up_to_limit(auth_app, suppress_quota_warnings):
|
||||
client = TestClient(auth_app)
|
||||
assert client.get("/test").status_code == 200
|
||||
assert client.get("/test").status_code == 200
|
||||
|
||||
|
||||
def test_authenticated_quota_blocks_after_limit(auth_app):
|
||||
def test_authenticated_quota_blocks_after_limit(auth_app, suppress_quota_warnings):
|
||||
client = TestClient(auth_app)
|
||||
client.get("/test")
|
||||
client.get("/test")
|
||||
|
|
@ -80,7 +92,7 @@ def test_authenticated_quota_blocks_after_limit(auth_app):
|
|||
assert resp.json()["error"]["message"] == "Quota exceeded"
|
||||
|
||||
|
||||
def test_anonymous_quota_allows_up_to_limit(tmp_path, request):
|
||||
def test_anonymous_quota_allows_up_to_limit(tmp_path, request, suppress_quota_warnings):
|
||||
inner_app = FastAPI()
|
||||
|
||||
@inner_app.get("/test")
|
||||
|
|
@ -102,7 +114,7 @@ def test_anonymous_quota_allows_up_to_limit(tmp_path, request):
|
|||
assert client.get("/test").status_code == 200
|
||||
|
||||
|
||||
def test_anonymous_quota_blocks_after_limit(tmp_path, request):
|
||||
def test_anonymous_quota_blocks_after_limit(tmp_path, request, suppress_quota_warnings):
|
||||
inner_app = FastAPI()
|
||||
|
||||
@inner_app.get("/test")
|
||||
|
|
|
|||
|
|
@ -12,15 +12,22 @@ from unittest.mock import AsyncMock, MagicMock
|
|||
from pydantic import BaseModel, Field
|
||||
|
||||
from llama_stack.apis.inference import Inference
|
||||
from llama_stack.core.datatypes import (
|
||||
Api,
|
||||
Provider,
|
||||
StackRunConfig,
|
||||
)
|
||||
from llama_stack.core.datatypes import Api, Provider, StackRunConfig
|
||||
from llama_stack.core.resolver import resolve_impls
|
||||
from llama_stack.core.routers.inference import InferenceRouter
|
||||
from llama_stack.core.routing_tables.models import ModelsRoutingTable
|
||||
from llama_stack.core.storage.datatypes import (
|
||||
InferenceStoreReference,
|
||||
KVStoreReference,
|
||||
ServerStoresConfig,
|
||||
SqliteKVStoreConfig,
|
||||
SqliteSqlStoreConfig,
|
||||
SqlStoreReference,
|
||||
StorageConfig,
|
||||
)
|
||||
from llama_stack.providers.datatypes import InlineProviderSpec, ProviderSpec
|
||||
from llama_stack.providers.utils.kvstore import register_kvstore_backends
|
||||
from llama_stack.providers.utils.sqlstore.sqlstore import register_sqlstore_backends
|
||||
|
||||
|
||||
def add_protocol_methods(cls: type, protocol: type[Protocol]) -> None:
|
||||
|
|
@ -65,6 +72,35 @@ class SampleImpl:
|
|||
pass
|
||||
|
||||
|
||||
def make_run_config(**overrides) -> StackRunConfig:
|
||||
storage = overrides.pop(
|
||||
"storage",
|
||||
StorageConfig(
|
||||
backends={
|
||||
"kv_default": SqliteKVStoreConfig(db_path=":memory:"),
|
||||
"sql_default": SqliteSqlStoreConfig(db_path=":memory:"),
|
||||
},
|
||||
stores=ServerStoresConfig(
|
||||
metadata=KVStoreReference(backend="kv_default", namespace="registry"),
|
||||
inference=InferenceStoreReference(backend="sql_default", table_name="inference_store"),
|
||||
conversations=SqlStoreReference(backend="sql_default", table_name="conversations"),
|
||||
),
|
||||
),
|
||||
)
|
||||
register_kvstore_backends({name: cfg for name, cfg in storage.backends.items() if cfg.type.value.startswith("kv_")})
|
||||
register_sqlstore_backends(
|
||||
{name: cfg for name, cfg in storage.backends.items() if cfg.type.value.startswith("sql_")}
|
||||
)
|
||||
defaults = dict(
|
||||
image_name="test_image",
|
||||
apis=[],
|
||||
providers={},
|
||||
storage=storage,
|
||||
)
|
||||
defaults.update(overrides)
|
||||
return StackRunConfig(**defaults)
|
||||
|
||||
|
||||
async def test_resolve_impls_basic():
|
||||
# Create a real provider spec
|
||||
provider_spec = InlineProviderSpec(
|
||||
|
|
@ -78,7 +114,7 @@ async def test_resolve_impls_basic():
|
|||
# Create provider registry with our provider
|
||||
provider_registry = {Api.inference: {provider_spec.provider_type: provider_spec}}
|
||||
|
||||
run_config = StackRunConfig(
|
||||
run_config = make_run_config(
|
||||
image_name="test_image",
|
||||
providers={
|
||||
"inference": [
|
||||
|
|
|
|||
|
|
@ -41,7 +41,7 @@ class TestTranslateException:
|
|||
self.identifier = identifier
|
||||
self.owner = owner
|
||||
|
||||
resource = MockResource("vector_db", "test-db")
|
||||
resource = MockResource("vector_store", "test-db")
|
||||
|
||||
exc = AccessDeniedError("create", resource, user)
|
||||
result = translate_exception(exc)
|
||||
|
|
@ -49,7 +49,7 @@ class TestTranslateException:
|
|||
assert isinstance(result, HTTPException)
|
||||
assert result.status_code == 403
|
||||
assert "test-user" in result.detail
|
||||
assert "vector_db::test-db" in result.detail
|
||||
assert "vector_store::test-db" in result.detail
|
||||
assert "create" in result.detail
|
||||
assert "roles=['user']" in result.detail
|
||||
assert "teams=['dev']" in result.detail
|
||||
|
|
|
|||
|
|
@ -5,12 +5,21 @@
|
|||
# the root directory of this source tree.
|
||||
|
||||
import asyncio
|
||||
import logging # allow-direct-logging
|
||||
from unittest.mock import AsyncMock, MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
from llama_stack.apis.common.responses import PaginatedResponse
|
||||
from llama_stack.core.server.server import create_dynamic_typed_route, create_sse_event, sse_generator
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def suppress_sse_errors(caplog):
|
||||
"""Suppress expected ERROR logs for tests that deliberately trigger SSE errors"""
|
||||
caplog.set_level(logging.CRITICAL, logger="llama_stack.core.server.server")
|
||||
|
||||
|
||||
async def test_sse_generator_basic():
|
||||
# An AsyncIterator wrapped in an Awaitable, just like our web methods
|
||||
async def async_event_gen():
|
||||
|
|
@ -70,7 +79,7 @@ async def test_sse_generator_client_disconnected_before_response_starts():
|
|||
assert len(seen_events) == 0
|
||||
|
||||
|
||||
async def test_sse_generator_error_before_response_starts():
|
||||
async def test_sse_generator_error_before_response_starts(suppress_sse_errors):
|
||||
# Raise an error before the response starts
|
||||
async def async_event_gen():
|
||||
raise Exception("Test error")
|
||||
|
|
|
|||
|
|
@ -5,7 +5,6 @@
|
|||
# the root directory of this source tree.
|
||||
|
||||
import time
|
||||
from tempfile import TemporaryDirectory
|
||||
|
||||
import pytest
|
||||
|
||||
|
|
@ -16,8 +15,16 @@ from llama_stack.apis.inference import (
|
|||
OpenAIUserMessageParam,
|
||||
Order,
|
||||
)
|
||||
from llama_stack.core.storage.datatypes import InferenceStoreReference, SqliteSqlStoreConfig
|
||||
from llama_stack.providers.utils.inference.inference_store import InferenceStore
|
||||
from llama_stack.providers.utils.sqlstore.sqlstore import SqliteSqlStoreConfig
|
||||
from llama_stack.providers.utils.sqlstore.sqlstore import register_sqlstore_backends
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def setup_backends(tmp_path):
|
||||
"""Register SQL store backends for testing."""
|
||||
db_path = str(tmp_path / "test.db")
|
||||
register_sqlstore_backends({"sql_default": SqliteSqlStoreConfig(db_path=db_path)})
|
||||
|
||||
|
||||
def create_test_chat_completion(
|
||||
|
|
@ -44,167 +51,162 @@ def create_test_chat_completion(
|
|||
|
||||
async def test_inference_store_pagination_basic():
|
||||
"""Test basic pagination functionality."""
|
||||
with TemporaryDirectory() as tmp_dir:
|
||||
db_path = tmp_dir + "/test.db"
|
||||
store = InferenceStore(SqliteSqlStoreConfig(db_path=db_path), policy=[])
|
||||
await store.initialize()
|
||||
reference = InferenceStoreReference(backend="sql_default", table_name="chat_completions")
|
||||
store = InferenceStore(reference, policy=[])
|
||||
await store.initialize()
|
||||
|
||||
# Create test data with different timestamps
|
||||
base_time = int(time.time())
|
||||
test_data = [
|
||||
("zebra-task", base_time + 1),
|
||||
("apple-job", base_time + 2),
|
||||
("moon-work", base_time + 3),
|
||||
("banana-run", base_time + 4),
|
||||
("car-exec", base_time + 5),
|
||||
]
|
||||
# Create test data with different timestamps
|
||||
base_time = int(time.time())
|
||||
test_data = [
|
||||
("zebra-task", base_time + 1),
|
||||
("apple-job", base_time + 2),
|
||||
("moon-work", base_time + 3),
|
||||
("banana-run", base_time + 4),
|
||||
("car-exec", base_time + 5),
|
||||
]
|
||||
|
||||
# Store test chat completions
|
||||
for completion_id, timestamp in test_data:
|
||||
completion = create_test_chat_completion(completion_id, timestamp)
|
||||
input_messages = [OpenAIUserMessageParam(role="user", content=f"Test message for {completion_id}")]
|
||||
await store.store_chat_completion(completion, input_messages)
|
||||
# Store test chat completions
|
||||
for completion_id, timestamp in test_data:
|
||||
completion = create_test_chat_completion(completion_id, timestamp)
|
||||
input_messages = [OpenAIUserMessageParam(role="user", content=f"Test message for {completion_id}")]
|
||||
await store.store_chat_completion(completion, input_messages)
|
||||
|
||||
# Wait for all queued writes to complete
|
||||
await store.flush()
|
||||
# Wait for all queued writes to complete
|
||||
await store.flush()
|
||||
|
||||
# Test 1: First page with limit=2, descending order (default)
|
||||
result = await store.list_chat_completions(limit=2, order=Order.desc)
|
||||
assert len(result.data) == 2
|
||||
assert result.data[0].id == "car-exec" # Most recent first
|
||||
assert result.data[1].id == "banana-run"
|
||||
assert result.has_more is True
|
||||
assert result.last_id == "banana-run"
|
||||
# Test 1: First page with limit=2, descending order (default)
|
||||
result = await store.list_chat_completions(limit=2, order=Order.desc)
|
||||
assert len(result.data) == 2
|
||||
assert result.data[0].id == "car-exec" # Most recent first
|
||||
assert result.data[1].id == "banana-run"
|
||||
assert result.has_more is True
|
||||
assert result.last_id == "banana-run"
|
||||
|
||||
# Test 2: Second page using 'after' parameter
|
||||
result2 = await store.list_chat_completions(after="banana-run", limit=2, order=Order.desc)
|
||||
assert len(result2.data) == 2
|
||||
assert result2.data[0].id == "moon-work"
|
||||
assert result2.data[1].id == "apple-job"
|
||||
assert result2.has_more is True
|
||||
# Test 2: Second page using 'after' parameter
|
||||
result2 = await store.list_chat_completions(after="banana-run", limit=2, order=Order.desc)
|
||||
assert len(result2.data) == 2
|
||||
assert result2.data[0].id == "moon-work"
|
||||
assert result2.data[1].id == "apple-job"
|
||||
assert result2.has_more is True
|
||||
|
||||
# Test 3: Final page
|
||||
result3 = await store.list_chat_completions(after="apple-job", limit=2, order=Order.desc)
|
||||
assert len(result3.data) == 1
|
||||
assert result3.data[0].id == "zebra-task"
|
||||
assert result3.has_more is False
|
||||
# Test 3: Final page
|
||||
result3 = await store.list_chat_completions(after="apple-job", limit=2, order=Order.desc)
|
||||
assert len(result3.data) == 1
|
||||
assert result3.data[0].id == "zebra-task"
|
||||
assert result3.has_more is False
|
||||
|
||||
|
||||
async def test_inference_store_pagination_ascending():
|
||||
"""Test pagination with ascending order."""
|
||||
with TemporaryDirectory() as tmp_dir:
|
||||
db_path = tmp_dir + "/test.db"
|
||||
store = InferenceStore(SqliteSqlStoreConfig(db_path=db_path), policy=[])
|
||||
await store.initialize()
|
||||
reference = InferenceStoreReference(backend="sql_default", table_name="chat_completions")
|
||||
store = InferenceStore(reference, policy=[])
|
||||
await store.initialize()
|
||||
|
||||
# Create test data
|
||||
base_time = int(time.time())
|
||||
test_data = [
|
||||
("delta-item", base_time + 1),
|
||||
("charlie-task", base_time + 2),
|
||||
("alpha-work", base_time + 3),
|
||||
]
|
||||
# Create test data
|
||||
base_time = int(time.time())
|
||||
test_data = [
|
||||
("delta-item", base_time + 1),
|
||||
("charlie-task", base_time + 2),
|
||||
("alpha-work", base_time + 3),
|
||||
]
|
||||
|
||||
# Store test chat completions
|
||||
for completion_id, timestamp in test_data:
|
||||
completion = create_test_chat_completion(completion_id, timestamp)
|
||||
input_messages = [OpenAIUserMessageParam(role="user", content=f"Test message for {completion_id}")]
|
||||
await store.store_chat_completion(completion, input_messages)
|
||||
# Store test chat completions
|
||||
for completion_id, timestamp in test_data:
|
||||
completion = create_test_chat_completion(completion_id, timestamp)
|
||||
input_messages = [OpenAIUserMessageParam(role="user", content=f"Test message for {completion_id}")]
|
||||
await store.store_chat_completion(completion, input_messages)
|
||||
|
||||
# Wait for all queued writes to complete
|
||||
await store.flush()
|
||||
# Wait for all queued writes to complete
|
||||
await store.flush()
|
||||
|
||||
# Test ascending order pagination
|
||||
result = await store.list_chat_completions(limit=1, order=Order.asc)
|
||||
assert len(result.data) == 1
|
||||
assert result.data[0].id == "delta-item" # Oldest first
|
||||
assert result.has_more is True
|
||||
# Test ascending order pagination
|
||||
result = await store.list_chat_completions(limit=1, order=Order.asc)
|
||||
assert len(result.data) == 1
|
||||
assert result.data[0].id == "delta-item" # Oldest first
|
||||
assert result.has_more is True
|
||||
|
||||
# Second page with ascending order
|
||||
result2 = await store.list_chat_completions(after="delta-item", limit=1, order=Order.asc)
|
||||
assert len(result2.data) == 1
|
||||
assert result2.data[0].id == "charlie-task"
|
||||
assert result2.has_more is True
|
||||
# Second page with ascending order
|
||||
result2 = await store.list_chat_completions(after="delta-item", limit=1, order=Order.asc)
|
||||
assert len(result2.data) == 1
|
||||
assert result2.data[0].id == "charlie-task"
|
||||
assert result2.has_more is True
|
||||
|
||||
|
||||
async def test_inference_store_pagination_with_model_filter():
|
||||
"""Test pagination combined with model filtering."""
|
||||
with TemporaryDirectory() as tmp_dir:
|
||||
db_path = tmp_dir + "/test.db"
|
||||
store = InferenceStore(SqliteSqlStoreConfig(db_path=db_path), policy=[])
|
||||
await store.initialize()
|
||||
reference = InferenceStoreReference(backend="sql_default", table_name="chat_completions")
|
||||
store = InferenceStore(reference, policy=[])
|
||||
await store.initialize()
|
||||
|
||||
# Create test data with different models
|
||||
base_time = int(time.time())
|
||||
test_data = [
|
||||
("xyz-task", base_time + 1, "model-a"),
|
||||
("def-work", base_time + 2, "model-b"),
|
||||
("pqr-job", base_time + 3, "model-a"),
|
||||
("abc-run", base_time + 4, "model-b"),
|
||||
]
|
||||
# Create test data with different models
|
||||
base_time = int(time.time())
|
||||
test_data = [
|
||||
("xyz-task", base_time + 1, "model-a"),
|
||||
("def-work", base_time + 2, "model-b"),
|
||||
("pqr-job", base_time + 3, "model-a"),
|
||||
("abc-run", base_time + 4, "model-b"),
|
||||
]
|
||||
|
||||
# Store test chat completions
|
||||
for completion_id, timestamp, model in test_data:
|
||||
completion = create_test_chat_completion(completion_id, timestamp, model)
|
||||
input_messages = [OpenAIUserMessageParam(role="user", content=f"Test message for {completion_id}")]
|
||||
await store.store_chat_completion(completion, input_messages)
|
||||
# Store test chat completions
|
||||
for completion_id, timestamp, model in test_data:
|
||||
completion = create_test_chat_completion(completion_id, timestamp, model)
|
||||
input_messages = [OpenAIUserMessageParam(role="user", content=f"Test message for {completion_id}")]
|
||||
await store.store_chat_completion(completion, input_messages)
|
||||
|
||||
# Wait for all queued writes to complete
|
||||
await store.flush()
|
||||
# Wait for all queued writes to complete
|
||||
await store.flush()
|
||||
|
||||
# Test pagination with model filter
|
||||
result = await store.list_chat_completions(limit=1, model="model-a", order=Order.desc)
|
||||
assert len(result.data) == 1
|
||||
assert result.data[0].id == "pqr-job" # Most recent model-a
|
||||
assert result.data[0].model == "model-a"
|
||||
assert result.has_more is True
|
||||
# Test pagination with model filter
|
||||
result = await store.list_chat_completions(limit=1, model="model-a", order=Order.desc)
|
||||
assert len(result.data) == 1
|
||||
assert result.data[0].id == "pqr-job" # Most recent model-a
|
||||
assert result.data[0].model == "model-a"
|
||||
assert result.has_more is True
|
||||
|
||||
# Second page with model filter
|
||||
result2 = await store.list_chat_completions(after="pqr-job", limit=1, model="model-a", order=Order.desc)
|
||||
assert len(result2.data) == 1
|
||||
assert result2.data[0].id == "xyz-task"
|
||||
assert result2.data[0].model == "model-a"
|
||||
assert result2.has_more is False
|
||||
# Second page with model filter
|
||||
result2 = await store.list_chat_completions(after="pqr-job", limit=1, model="model-a", order=Order.desc)
|
||||
assert len(result2.data) == 1
|
||||
assert result2.data[0].id == "xyz-task"
|
||||
assert result2.data[0].model == "model-a"
|
||||
assert result2.has_more is False
|
||||
|
||||
|
||||
async def test_inference_store_pagination_invalid_after():
|
||||
"""Test error handling for invalid 'after' parameter."""
|
||||
with TemporaryDirectory() as tmp_dir:
|
||||
db_path = tmp_dir + "/test.db"
|
||||
store = InferenceStore(SqliteSqlStoreConfig(db_path=db_path), policy=[])
|
||||
await store.initialize()
|
||||
reference = InferenceStoreReference(backend="sql_default", table_name="chat_completions")
|
||||
store = InferenceStore(reference, policy=[])
|
||||
await store.initialize()
|
||||
|
||||
# Try to paginate with non-existent ID
|
||||
with pytest.raises(ValueError, match="Record with id='non-existent' not found in table 'chat_completions'"):
|
||||
await store.list_chat_completions(after="non-existent", limit=2)
|
||||
# Try to paginate with non-existent ID
|
||||
with pytest.raises(ValueError, match="Record with id='non-existent' not found in table 'chat_completions'"):
|
||||
await store.list_chat_completions(after="non-existent", limit=2)
|
||||
|
||||
|
||||
async def test_inference_store_pagination_no_limit():
|
||||
"""Test pagination behavior when no limit is specified."""
|
||||
with TemporaryDirectory() as tmp_dir:
|
||||
db_path = tmp_dir + "/test.db"
|
||||
store = InferenceStore(SqliteSqlStoreConfig(db_path=db_path), policy=[])
|
||||
await store.initialize()
|
||||
reference = InferenceStoreReference(backend="sql_default", table_name="chat_completions")
|
||||
store = InferenceStore(reference, policy=[])
|
||||
await store.initialize()
|
||||
|
||||
# Create test data
|
||||
base_time = int(time.time())
|
||||
test_data = [
|
||||
("omega-first", base_time + 1),
|
||||
("beta-second", base_time + 2),
|
||||
]
|
||||
# Create test data
|
||||
base_time = int(time.time())
|
||||
test_data = [
|
||||
("omega-first", base_time + 1),
|
||||
("beta-second", base_time + 2),
|
||||
]
|
||||
|
||||
# Store test chat completions
|
||||
for completion_id, timestamp in test_data:
|
||||
completion = create_test_chat_completion(completion_id, timestamp)
|
||||
input_messages = [OpenAIUserMessageParam(role="user", content=f"Test message for {completion_id}")]
|
||||
await store.store_chat_completion(completion, input_messages)
|
||||
# Store test chat completions
|
||||
for completion_id, timestamp in test_data:
|
||||
completion = create_test_chat_completion(completion_id, timestamp)
|
||||
input_messages = [OpenAIUserMessageParam(role="user", content=f"Test message for {completion_id}")]
|
||||
await store.store_chat_completion(completion, input_messages)
|
||||
|
||||
# Wait for all queued writes to complete
|
||||
await store.flush()
|
||||
# Wait for all queued writes to complete
|
||||
await store.flush()
|
||||
|
||||
# Test without limit
|
||||
result = await store.list_chat_completions(order=Order.desc)
|
||||
assert len(result.data) == 2
|
||||
assert result.data[0].id == "beta-second" # Most recent first
|
||||
assert result.data[1].id == "omega-first"
|
||||
assert result.has_more is False
|
||||
# Test without limit
|
||||
result = await store.list_chat_completions(order=Order.desc)
|
||||
assert len(result.data) == 2
|
||||
assert result.data[0].id == "beta-second" # Most recent first
|
||||
assert result.data[1].id == "omega-first"
|
||||
assert result.has_more is False
|
||||
|
|
|
|||
30
tests/unit/utils/kvstore/test_sqlite_memory.py
Normal file
30
tests/unit/utils/kvstore/test_sqlite_memory.py
Normal file
|
|
@ -0,0 +1,30 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
|
||||
from llama_stack.providers.utils.kvstore.config import SqliteKVStoreConfig
|
||||
from llama_stack.providers.utils.kvstore.sqlite.sqlite import SqliteKVStoreImpl
|
||||
|
||||
|
||||
async def test_memory_kvstore_persistence_behavior():
|
||||
"""Test that :memory: database doesn't persist across instances."""
|
||||
config = SqliteKVStoreConfig(db_path=":memory:")
|
||||
|
||||
# First instance
|
||||
store1 = SqliteKVStoreImpl(config)
|
||||
await store1.initialize()
|
||||
await store1.set("persist_test", "should_not_persist")
|
||||
await store1.shutdown()
|
||||
|
||||
# Second instance with same config
|
||||
store2 = SqliteKVStoreImpl(config)
|
||||
await store2.initialize()
|
||||
|
||||
# Data should not be present
|
||||
result = await store2.get("persist_test")
|
||||
assert result is None
|
||||
|
||||
await store2.shutdown()
|
||||
|
|
@ -6,6 +6,7 @@
|
|||
|
||||
import time
|
||||
from tempfile import TemporaryDirectory
|
||||
from uuid import uuid4
|
||||
|
||||
import pytest
|
||||
|
||||
|
|
@ -15,8 +16,18 @@ from llama_stack.apis.agents.openai_responses import (
|
|||
OpenAIResponseObject,
|
||||
)
|
||||
from llama_stack.apis.inference import OpenAIMessageParam, OpenAIUserMessageParam
|
||||
from llama_stack.core.storage.datatypes import ResponsesStoreReference, SqliteSqlStoreConfig
|
||||
from llama_stack.providers.utils.responses.responses_store import ResponsesStore
|
||||
from llama_stack.providers.utils.sqlstore.sqlstore import SqliteSqlStoreConfig
|
||||
from llama_stack.providers.utils.sqlstore.sqlstore import register_sqlstore_backends
|
||||
|
||||
|
||||
def build_store(db_path: str, policy: list | None = None) -> ResponsesStore:
|
||||
backend_name = f"sql_responses_{uuid4().hex}"
|
||||
register_sqlstore_backends({backend_name: SqliteSqlStoreConfig(db_path=db_path)})
|
||||
return ResponsesStore(
|
||||
ResponsesStoreReference(backend=backend_name, table_name="responses"),
|
||||
policy=policy or [],
|
||||
)
|
||||
|
||||
|
||||
def create_test_response_object(
|
||||
|
|
@ -54,7 +65,7 @@ async def test_responses_store_pagination_basic():
|
|||
"""Test basic pagination functionality for responses store."""
|
||||
with TemporaryDirectory() as tmp_dir:
|
||||
db_path = tmp_dir + "/test.db"
|
||||
store = ResponsesStore(SqliteSqlStoreConfig(db_path=db_path), policy=[])
|
||||
store = build_store(db_path)
|
||||
await store.initialize()
|
||||
|
||||
# Create test data with different timestamps
|
||||
|
|
@ -103,7 +114,7 @@ async def test_responses_store_pagination_ascending():
|
|||
"""Test pagination with ascending order."""
|
||||
with TemporaryDirectory() as tmp_dir:
|
||||
db_path = tmp_dir + "/test.db"
|
||||
store = ResponsesStore(SqliteSqlStoreConfig(db_path=db_path), policy=[])
|
||||
store = build_store(db_path)
|
||||
await store.initialize()
|
||||
|
||||
# Create test data
|
||||
|
|
@ -141,7 +152,7 @@ async def test_responses_store_pagination_with_model_filter():
|
|||
"""Test pagination combined with model filtering."""
|
||||
with TemporaryDirectory() as tmp_dir:
|
||||
db_path = tmp_dir + "/test.db"
|
||||
store = ResponsesStore(SqliteSqlStoreConfig(db_path=db_path), policy=[])
|
||||
store = build_store(db_path)
|
||||
await store.initialize()
|
||||
|
||||
# Create test data with different models
|
||||
|
|
@ -182,7 +193,7 @@ async def test_responses_store_pagination_invalid_after():
|
|||
"""Test error handling for invalid 'after' parameter."""
|
||||
with TemporaryDirectory() as tmp_dir:
|
||||
db_path = tmp_dir + "/test.db"
|
||||
store = ResponsesStore(SqliteSqlStoreConfig(db_path=db_path), policy=[])
|
||||
store = build_store(db_path)
|
||||
await store.initialize()
|
||||
|
||||
# Try to paginate with non-existent ID
|
||||
|
|
@ -194,7 +205,7 @@ async def test_responses_store_pagination_no_limit():
|
|||
"""Test pagination behavior when no limit is specified."""
|
||||
with TemporaryDirectory() as tmp_dir:
|
||||
db_path = tmp_dir + "/test.db"
|
||||
store = ResponsesStore(SqliteSqlStoreConfig(db_path=db_path), policy=[])
|
||||
store = build_store(db_path)
|
||||
await store.initialize()
|
||||
|
||||
# Create test data
|
||||
|
|
@ -226,7 +237,7 @@ async def test_responses_store_get_response_object():
|
|||
"""Test retrieving a single response object."""
|
||||
with TemporaryDirectory() as tmp_dir:
|
||||
db_path = tmp_dir + "/test.db"
|
||||
store = ResponsesStore(SqliteSqlStoreConfig(db_path=db_path), policy=[])
|
||||
store = build_store(db_path)
|
||||
await store.initialize()
|
||||
|
||||
# Store a test response
|
||||
|
|
@ -254,7 +265,7 @@ async def test_responses_store_input_items_pagination():
|
|||
"""Test pagination functionality for input items."""
|
||||
with TemporaryDirectory() as tmp_dir:
|
||||
db_path = tmp_dir + "/test.db"
|
||||
store = ResponsesStore(SqliteSqlStoreConfig(db_path=db_path), policy=[])
|
||||
store = build_store(db_path)
|
||||
await store.initialize()
|
||||
|
||||
# Store a test response with many inputs with explicit IDs
|
||||
|
|
@ -335,7 +346,7 @@ async def test_responses_store_input_items_before_pagination():
|
|||
"""Test before pagination functionality for input items."""
|
||||
with TemporaryDirectory() as tmp_dir:
|
||||
db_path = tmp_dir + "/test.db"
|
||||
store = ResponsesStore(SqliteSqlStoreConfig(db_path=db_path), policy=[])
|
||||
store = build_store(db_path)
|
||||
await store.initialize()
|
||||
|
||||
# Store a test response with many inputs with explicit IDs
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue