Merge remote-tracking branch 'origin/main' into stores
Some checks failed
Installer CI / smoke-test-on-dev (push) Failing after 3s
Installer CI / lint (push) Failing after 3s

This commit is contained in:
Ashwin Bharambe 2025-10-13 11:07:11 -07:00
commit b72154ce5e
1161 changed files with 609896 additions and 42960 deletions

View file

@ -0,0 +1,57 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from unittest.mock import AsyncMock, Mock
import pytest
from llama_stack.apis.vector_io import OpenAICreateVectorStoreRequestWithExtraBody
from llama_stack.core.routers.vector_io import VectorIORouter
async def test_single_provider_auto_selection():
# provider_id automatically selected during vector store create() when only one provider available
mock_routing_table = Mock()
mock_routing_table.impls_by_provider_id = {"inline::faiss": "mock_provider"}
mock_routing_table.get_all_with_type = AsyncMock(
return_value=[
Mock(identifier="all-MiniLM-L6-v2", model_type="embedding", metadata={"embedding_dimension": 384})
]
)
mock_routing_table.register_vector_db = AsyncMock(
return_value=Mock(identifier="vs_123", provider_id="inline::faiss", provider_resource_id="vs_123")
)
mock_routing_table.get_provider_impl = AsyncMock(
return_value=Mock(openai_create_vector_store=AsyncMock(return_value=Mock(id="vs_123")))
)
router = VectorIORouter(mock_routing_table)
request = OpenAICreateVectorStoreRequestWithExtraBody.model_validate(
{"name": "test_store", "embedding_model": "all-MiniLM-L6-v2"}
)
result = await router.openai_create_vector_store(request)
assert result.id == "vs_123"
async def test_create_vector_stores_multiple_providers_missing_provider_id_error():
# if multiple providers are available, vector store create will error without provider_id
mock_routing_table = Mock()
mock_routing_table.impls_by_provider_id = {
"inline::faiss": "mock_provider_1",
"inline::sqlite-vec": "mock_provider_2",
}
mock_routing_table.get_all_with_type = AsyncMock(
return_value=[
Mock(identifier="all-MiniLM-L6-v2", model_type="embedding", metadata={"embedding_dimension": 384})
]
)
router = VectorIORouter(mock_routing_table)
request = OpenAICreateVectorStoreRequestWithExtraBody.model_validate(
{"name": "test_store", "embedding_model": "all-MiniLM-L6-v2"}
)
with pytest.raises(ValueError, match="Multiple vector_io providers available"):
await router.openai_create_vector_store(request)

View file

@ -17,7 +17,6 @@ from llama_stack.apis.datatypes import Api
from llama_stack.apis.models import Model, ModelType
from llama_stack.apis.shields.shields import Shield
from llama_stack.apis.tools import ListToolDefsResponse, ToolDef, ToolGroup
from llama_stack.apis.vector_dbs import VectorDB
from llama_stack.core.datatypes import RegistryEntrySource
from llama_stack.core.routing_tables.benchmarks import BenchmarksRoutingTable
from llama_stack.core.routing_tables.datasets import DatasetsRoutingTable
@ -25,7 +24,6 @@ from llama_stack.core.routing_tables.models import ModelsRoutingTable
from llama_stack.core.routing_tables.scoring_functions import ScoringFunctionsRoutingTable
from llama_stack.core.routing_tables.shields import ShieldsRoutingTable
from llama_stack.core.routing_tables.toolgroups import ToolGroupsRoutingTable
from llama_stack.core.routing_tables.vector_dbs import VectorDBsRoutingTable
class Impl:
@ -146,31 +144,6 @@ class ToolGroupsImpl(Impl):
)
class VectorDBImpl(Impl):
def __init__(self):
super().__init__(Api.vector_io)
async def register_vector_db(self, vector_db: VectorDB):
return vector_db
async def unregister_vector_db(self, vector_db_id: str):
return vector_db_id
async def openai_create_vector_store(self, **kwargs):
import time
import uuid
from llama_stack.apis.vector_io.vector_io import VectorStoreFileCounts, VectorStoreObject
vector_store_id = kwargs.get("provider_vector_db_id") or f"vs_{uuid.uuid4()}"
return VectorStoreObject(
id=vector_store_id,
name=kwargs.get("name", vector_store_id),
created_at=int(time.time()),
file_counts=VectorStoreFileCounts(completed=0, cancelled=0, failed=0, in_progress=0, total=0),
)
async def test_models_routing_table(cached_disk_dist_registry):
table = ModelsRoutingTable({"test_provider": InferenceImpl()}, cached_disk_dist_registry, {})
await table.initialize()
@ -201,6 +174,12 @@ async def test_models_routing_table(cached_disk_dist_registry):
non_existent = await table.get_object_by_identifier("model", "non-existent-model")
assert non_existent is None
# Test has_model
assert await table.has_model("test_provider/test-model")
assert await table.has_model("test_provider/test-model-2")
assert not await table.has_model("non-existent-model")
assert not await table.has_model("test_provider/non-existent-model")
await table.unregister_model(model_id="test_provider/test-model")
await table.unregister_model(model_id="test_provider/test-model-2")
@ -257,40 +236,6 @@ async def test_shields_routing_table(cached_disk_dist_registry):
await table.unregister_shield(identifier="non-existent")
async def test_vectordbs_routing_table(cached_disk_dist_registry):
table = VectorDBsRoutingTable({"test_provider": VectorDBImpl()}, cached_disk_dist_registry, {})
await table.initialize()
m_table = ModelsRoutingTable({"test_provider": InferenceImpl()}, cached_disk_dist_registry, {})
await m_table.initialize()
await m_table.register_model(
model_id="test-model",
provider_id="test_provider",
metadata={"embedding_dimension": 128},
model_type=ModelType.embedding,
)
# Register multiple vector databases and verify listing
vdb1 = await table.register_vector_db(vector_db_id="test-vectordb", embedding_model="test_provider/test-model")
vdb2 = await table.register_vector_db(vector_db_id="test-vectordb-2", embedding_model="test_provider/test-model")
vector_dbs = await table.list_vector_dbs()
assert len(vector_dbs.data) == 2
vector_db_ids = {v.identifier for v in vector_dbs.data}
assert vdb1.identifier in vector_db_ids
assert vdb2.identifier in vector_db_ids
# Verify they have UUID-based identifiers
assert vdb1.identifier.startswith("vs_")
assert vdb2.identifier.startswith("vs_")
await table.unregister_vector_db(vector_db_id=vdb1.identifier)
await table.unregister_vector_db(vector_db_id=vdb2.identifier)
vector_dbs = await table.list_vector_dbs()
assert len(vector_dbs.data) == 0
async def test_datasets_routing_table(cached_disk_dist_registry):
table = DatasetsRoutingTable({"localfs": DatasetsImpl()}, cached_disk_dist_registry, {})
await table.initialize()
@ -348,6 +293,111 @@ async def test_scoring_functions_routing_table(cached_disk_dist_registry):
assert len(scoring_functions_list_after_deletion.data) == 0
async def test_double_registration_models_positive(cached_disk_dist_registry):
"""Test that registering the same model twice with identical data succeeds."""
table = ModelsRoutingTable({"test_provider": InferenceImpl()}, cached_disk_dist_registry, {})
await table.initialize()
# Register a model
await table.register_model(model_id="test-model", provider_id="test_provider", metadata={"param1": "value1"})
# Register the exact same model again - should succeed (idempotent)
await table.register_model(model_id="test-model", provider_id="test_provider", metadata={"param1": "value1"})
# Verify only one model exists
models = await table.list_models()
assert len(models.data) == 1
assert models.data[0].identifier == "test_provider/test-model"
async def test_double_registration_models_negative(cached_disk_dist_registry):
"""Test that registering the same model with different data fails."""
table = ModelsRoutingTable({"test_provider": InferenceImpl()}, cached_disk_dist_registry, {})
await table.initialize()
# Register a model with specific metadata
await table.register_model(model_id="test-model", provider_id="test_provider", metadata={"param1": "value1"})
# Try to register the same model with different metadata - should fail
with pytest.raises(
ValueError, match="Object of type 'model' and identifier 'test_provider/test-model' already exists"
):
await table.register_model(
model_id="test-model", provider_id="test_provider", metadata={"param1": "different_value"}
)
async def test_double_registration_scoring_functions_positive(cached_disk_dist_registry):
"""Test that registering the same scoring function twice with identical data succeeds."""
table = ScoringFunctionsRoutingTable({"test_provider": ScoringFunctionsImpl()}, cached_disk_dist_registry, {})
await table.initialize()
# Register a scoring function
await table.register_scoring_function(
scoring_fn_id="test-scoring-fn",
provider_id="test_provider",
description="Test scoring function",
return_type=NumberType(),
)
# Register the exact same scoring function again - should succeed (idempotent)
await table.register_scoring_function(
scoring_fn_id="test-scoring-fn",
provider_id="test_provider",
description="Test scoring function",
return_type=NumberType(),
)
# Verify only one scoring function exists
scoring_functions = await table.list_scoring_functions()
assert len(scoring_functions.data) == 1
assert scoring_functions.data[0].identifier == "test-scoring-fn"
async def test_double_registration_scoring_functions_negative(cached_disk_dist_registry):
"""Test that registering the same scoring function with different data fails."""
table = ScoringFunctionsRoutingTable({"test_provider": ScoringFunctionsImpl()}, cached_disk_dist_registry, {})
await table.initialize()
# Register a scoring function
await table.register_scoring_function(
scoring_fn_id="test-scoring-fn",
provider_id="test_provider",
description="Test scoring function",
return_type=NumberType(),
)
# Try to register the same scoring function with different description - should fail
with pytest.raises(
ValueError, match="Object of type 'scoring_function' and identifier 'test-scoring-fn' already exists"
):
await table.register_scoring_function(
scoring_fn_id="test-scoring-fn",
provider_id="test_provider",
description="Different description",
return_type=NumberType(),
)
async def test_double_registration_different_providers(cached_disk_dist_registry):
"""Test that registering objects with same ID but different providers succeeds."""
impl1 = InferenceImpl()
impl2 = InferenceImpl()
table = ModelsRoutingTable({"provider1": impl1, "provider2": impl2}, cached_disk_dist_registry, {})
await table.initialize()
# Register same model ID with different providers - should succeed
await table.register_model(model_id="shared-model", provider_id="provider1")
await table.register_model(model_id="shared-model", provider_id="provider2")
# Verify both models exist with different identifiers
models = await table.list_models()
assert len(models.data) == 2
model_ids = {m.identifier for m in models.data}
assert "provider1/shared-model" in model_ids
assert "provider2/shared-model" in model_ids
async def test_benchmarks_routing_table(cached_disk_dist_registry):
table = BenchmarksRoutingTable({"test_provider": BenchmarksImpl()}, cached_disk_dist_registry, {})
await table.initialize()

View file

@ -1,381 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
# Unit tests for the routing tables vector_dbs
import time
import uuid
from unittest.mock import AsyncMock
import pytest
from llama_stack.apis.datatypes import Api
from llama_stack.apis.models import ModelType
from llama_stack.apis.vector_dbs.vector_dbs import VectorDB
from llama_stack.apis.vector_io.vector_io import (
VectorStoreContent,
VectorStoreDeleteResponse,
VectorStoreFileContentsResponse,
VectorStoreFileCounts,
VectorStoreFileDeleteResponse,
VectorStoreFileObject,
VectorStoreObject,
VectorStoreSearchResponsePage,
)
from llama_stack.core.access_control.datatypes import AccessRule, Scope
from llama_stack.core.datatypes import User
from llama_stack.core.request_headers import request_provider_data_context
from llama_stack.core.routing_tables.vector_dbs import VectorDBsRoutingTable
from tests.unit.distribution.routers.test_routing_tables import Impl, InferenceImpl, ModelsRoutingTable
class VectorDBImpl(Impl):
def __init__(self):
super().__init__(Api.vector_io)
self.vector_stores = {}
async def register_vector_db(self, vector_db: VectorDB):
return vector_db
async def unregister_vector_db(self, vector_db_id: str):
return vector_db_id
async def openai_retrieve_vector_store(self, vector_store_id):
return VectorStoreObject(
id=vector_store_id,
name="Test Store",
created_at=int(time.time()),
file_counts=VectorStoreFileCounts(completed=0, cancelled=0, failed=0, in_progress=0, total=0),
)
async def openai_update_vector_store(self, vector_store_id, **kwargs):
return VectorStoreObject(
id=vector_store_id,
name="Updated Store",
created_at=int(time.time()),
file_counts=VectorStoreFileCounts(completed=0, cancelled=0, failed=0, in_progress=0, total=0),
)
async def openai_delete_vector_store(self, vector_store_id):
return VectorStoreDeleteResponse(id=vector_store_id, object="vector_store.deleted", deleted=True)
async def openai_search_vector_store(self, vector_store_id, query, **kwargs):
return VectorStoreSearchResponsePage(
object="vector_store.search_results.page", search_query="query", data=[], has_more=False, next_page=None
)
async def openai_attach_file_to_vector_store(self, vector_store_id, file_id, **kwargs):
return VectorStoreFileObject(
id=file_id,
status="completed",
chunking_strategy={"type": "auto"},
created_at=int(time.time()),
vector_store_id=vector_store_id,
)
async def openai_list_files_in_vector_store(self, vector_store_id, **kwargs):
return [
VectorStoreFileObject(
id="1",
status="completed",
chunking_strategy={"type": "auto"},
created_at=int(time.time()),
vector_store_id=vector_store_id,
)
]
async def openai_retrieve_vector_store_file(self, vector_store_id, file_id):
return VectorStoreFileObject(
id=file_id,
status="completed",
chunking_strategy={"type": "auto"},
created_at=int(time.time()),
vector_store_id=vector_store_id,
)
async def openai_retrieve_vector_store_file_contents(self, vector_store_id, file_id):
return VectorStoreFileContentsResponse(
file_id=file_id,
filename="Sample File name",
attributes={"key": "value"},
content=[VectorStoreContent(type="text", text="Sample content")],
)
async def openai_update_vector_store_file(self, vector_store_id, file_id, **kwargs):
return VectorStoreFileObject(
id=file_id,
status="completed",
chunking_strategy={"type": "auto"},
created_at=int(time.time()),
vector_store_id=vector_store_id,
)
async def openai_delete_vector_store_file(self, vector_store_id, file_id):
return VectorStoreFileDeleteResponse(id=file_id, deleted=True)
async def openai_create_vector_store(
self,
name=None,
embedding_model=None,
embedding_dimension=None,
provider_id=None,
provider_vector_db_id=None,
**kwargs,
):
vector_store_id = provider_vector_db_id or f"vs_{uuid.uuid4()}"
vector_store = VectorStoreObject(
id=vector_store_id,
name=name or vector_store_id,
created_at=int(time.time()),
file_counts=VectorStoreFileCounts(completed=0, cancelled=0, failed=0, in_progress=0, total=0),
)
self.vector_stores[vector_store_id] = vector_store
return vector_store
async def openai_list_vector_stores(self, **kwargs):
from llama_stack.apis.vector_io.vector_io import VectorStoreListResponse
return VectorStoreListResponse(
data=list(self.vector_stores.values()), has_more=False, first_id=None, last_id=None
)
async def test_vectordbs_routing_table(cached_disk_dist_registry):
n = 10
table = VectorDBsRoutingTable({"test_provider": VectorDBImpl()}, cached_disk_dist_registry, {})
await table.initialize()
m_table = ModelsRoutingTable({"test_provider": InferenceImpl()}, cached_disk_dist_registry, {})
await m_table.initialize()
await m_table.register_model(
model_id="test-model",
provider_id="test_provider",
metadata={"embedding_dimension": 128},
model_type=ModelType.embedding,
)
# Register multiple vector databases and verify listing
vdb_dict = {}
for i in range(n):
vdb_dict[i] = await table.register_vector_db(vector_db_id=f"test-vectordb-{i}", embedding_model="test-model")
vector_dbs = await table.list_vector_dbs()
assert len(vector_dbs.data) == len(vdb_dict)
vector_db_ids = {v.identifier for v in vector_dbs.data}
for k in vdb_dict:
assert vdb_dict[k].identifier in vector_db_ids
for k in vdb_dict:
await table.unregister_vector_db(vector_db_id=vdb_dict[k].identifier)
vector_dbs = await table.list_vector_dbs()
assert len(vector_dbs.data) == 0
async def test_vector_db_and_vector_store_id_mapping(cached_disk_dist_registry):
n = 10
impl = VectorDBImpl()
table = VectorDBsRoutingTable({"test_provider": impl}, cached_disk_dist_registry, {})
await table.initialize()
m_table = ModelsRoutingTable({"test_provider": InferenceImpl()}, cached_disk_dist_registry, {})
await m_table.initialize()
await m_table.register_model(
model_id="test-model",
provider_id="test_provider",
metadata={"embedding_dimension": 128},
model_type=ModelType.embedding,
)
vdb_dict = {}
for i in range(n):
vdb_dict[i] = await table.register_vector_db(vector_db_id=f"test-vectordb-{i}", embedding_model="test-model")
vector_dbs = await table.list_vector_dbs()
vector_db_ids = {v.identifier for v in vector_dbs.data}
vector_stores = await impl.openai_list_vector_stores()
vector_store_ids = {v.id for v in vector_stores.data}
assert vector_db_ids == vector_store_ids, (
f"Vector DB IDs {vector_db_ids} don't match vector store IDs {vector_store_ids}"
)
for vector_store in vector_stores.data:
vector_db = await table.get_vector_db(vector_store.id)
assert vector_store.name == vector_db.vector_db_name, (
f"Vector store name {vector_store.name} doesn't match vector store ID {vector_store.id}"
)
for vector_db_id in vector_db_ids:
await table.unregister_vector_db(vector_db_id)
assert len((await table.list_vector_dbs()).data) == 0
async def test_vector_db_id_becomes_vector_store_name(cached_disk_dist_registry):
impl = VectorDBImpl()
table = VectorDBsRoutingTable({"test_provider": impl}, cached_disk_dist_registry, {})
await table.initialize()
m_table = ModelsRoutingTable({"test_provider": InferenceImpl()}, cached_disk_dist_registry, {})
await m_table.initialize()
await m_table.register_model(
model_id="test-model",
provider_id="test_provider",
metadata={"embedding_dimension": 128},
model_type=ModelType.embedding,
)
user_provided_id = "my-custom-vector-db"
await table.register_vector_db(vector_db_id=user_provided_id, embedding_model="test-model")
vector_stores = await impl.openai_list_vector_stores()
assert len(vector_stores.data) == 1
vector_store = vector_stores.data[0]
assert vector_store.name == user_provided_id
assert vector_store.id.startswith("vs_")
assert vector_store.id != user_provided_id
vector_dbs = await table.list_vector_dbs()
assert len(vector_dbs.data) == 1
assert vector_dbs.data[0].identifier == vector_store.id
await table.unregister_vector_db(vector_store.id)
async def test_openai_vector_stores_routing_table_roles(cached_disk_dist_registry):
impl = VectorDBImpl()
impl.openai_retrieve_vector_store = AsyncMock(return_value="OK")
table = VectorDBsRoutingTable({"test_provider": impl}, cached_disk_dist_registry, policy=[])
m_table = ModelsRoutingTable({"test_provider": InferenceImpl()}, cached_disk_dist_registry, policy=[])
authorized_table = "vs1"
authorized_team = "team1"
unauthorized_team = "team2"
await m_table.initialize()
await m_table.register_model(
model_id="test-model",
provider_id="test_provider",
metadata={"embedding_dimension": 128},
model_type=ModelType.embedding,
)
authorized_user = User(principal="alice", attributes={"roles": [authorized_team]})
with request_provider_data_context({}, authorized_user):
registered_vdb = await table.register_vector_db(vector_db_id="vs1", embedding_model="test-model")
authorized_table = registered_vdb.identifier # Use the actual generated ID
# Authorized reader
with request_provider_data_context({}, authorized_user):
res = await table.openai_retrieve_vector_store(authorized_table)
assert res == "OK"
# Authorized updater
impl.openai_update_vector_store_file = AsyncMock(return_value="UPDATED")
with request_provider_data_context({}, authorized_user):
res = await table.openai_update_vector_store_file(authorized_table, file_id="file1", attributes={"foo": "bar"})
assert res == "UPDATED"
# Unauthorized reader
unauthorized_user = User(principal="eve", attributes={"roles": [unauthorized_team]})
with request_provider_data_context({}, unauthorized_user):
with pytest.raises(ValueError):
await table.openai_retrieve_vector_store(authorized_table)
# Unauthorized updater
with request_provider_data_context({}, unauthorized_user):
with pytest.raises(ValueError):
await table.openai_update_vector_store_file(authorized_table, file_id="file1", attributes={"foo": "bar"})
# Authorized deleter
impl.openai_delete_vector_store_file = AsyncMock(return_value="DELETED")
with request_provider_data_context({}, authorized_user):
res = await table.openai_delete_vector_store_file(authorized_table, file_id="file1")
assert res == "DELETED"
# Unauthorized deleter
with request_provider_data_context({}, unauthorized_user):
with pytest.raises(ValueError):
await table.openai_delete_vector_store_file(authorized_table, file_id="file1")
async def test_openai_vector_stores_routing_table_actions(cached_disk_dist_registry):
impl = VectorDBImpl()
policy = [
AccessRule(permit=Scope(actions=["create", "read", "update", "delete"]), when="user with admin in roles"),
AccessRule(permit=Scope(actions=["read"]), when="user with reader in roles"),
]
table = VectorDBsRoutingTable({"test_provider": impl}, cached_disk_dist_registry, policy=policy)
m_table = ModelsRoutingTable({"test_provider": InferenceImpl()}, cached_disk_dist_registry, policy=[])
vector_db_id = "vs1"
file_id = "file-1"
admin_user = User(principal="admin", attributes={"roles": ["admin"]})
read_only_user = User(principal="reader", attributes={"roles": ["reader"]})
no_access_user = User(principal="outsider", attributes={"roles": ["no_access"]})
await m_table.initialize()
await m_table.register_model(
model_id="test-model",
provider_id="test_provider",
metadata={"embedding_dimension": 128},
model_type=ModelType.embedding,
)
with request_provider_data_context({}, admin_user):
registered_vdb = await table.register_vector_db(vector_db_id=vector_db_id, embedding_model="test-model")
vector_db_id = registered_vdb.identifier # Use the actual generated ID
read_methods = [
(table.openai_retrieve_vector_store, (vector_db_id,), {}),
(table.openai_search_vector_store, (vector_db_id, "query"), {}),
(table.openai_list_files_in_vector_store, (vector_db_id,), {}),
(table.openai_retrieve_vector_store_file, (vector_db_id, file_id), {}),
(table.openai_retrieve_vector_store_file_contents, (vector_db_id, file_id), {}),
]
update_methods = [
(table.openai_update_vector_store, (vector_db_id,), {"name": "Updated DB"}),
(table.openai_attach_file_to_vector_store, (vector_db_id, file_id), {}),
(table.openai_update_vector_store_file, (vector_db_id, file_id), {"attributes": {"key": "value"}}),
]
delete_methods = [
(table.openai_delete_vector_store_file, (vector_db_id, file_id), {}),
(table.openai_delete_vector_store, (vector_db_id,), {}),
]
for user in [admin_user, read_only_user]:
with request_provider_data_context({}, user):
for method, args, kwargs in read_methods:
result = await method(*args, **kwargs)
assert result is not None, f"Read operation failed with user {user.principal}"
with request_provider_data_context({}, no_access_user):
for method, args, kwargs in read_methods:
with pytest.raises(ValueError):
await method(*args, **kwargs)
with request_provider_data_context({}, admin_user):
for method, args, kwargs in update_methods:
result = await method(*args, **kwargs)
assert result is not None, "Update operation failed with admin user"
with request_provider_data_context({}, admin_user):
for method, args, kwargs in delete_methods:
result = await method(*args, **kwargs)
assert result is not None, "Delete operation failed with admin user"
for user in [read_only_user, no_access_user]:
with request_provider_data_context({}, user):
for method, args, kwargs in delete_methods:
with pytest.raises(ValueError):
await method(*args, **kwargs)

View file

@ -0,0 +1,318 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import tempfile
from pathlib import Path
from unittest.mock import patch
import pytest
from openai import AsyncOpenAI
# Import the real Pydantic response types instead of using Mocks
from llama_stack.apis.inference import (
OpenAIAssistantMessageParam,
OpenAIChatCompletion,
OpenAIChoice,
OpenAIEmbeddingData,
OpenAIEmbeddingsResponse,
OpenAIEmbeddingUsage,
)
from llama_stack.testing.api_recorder import (
APIRecordingMode,
ResponseStorage,
api_recording,
normalize_inference_request,
)
@pytest.fixture
def temp_storage_dir():
"""Create a temporary directory for test recordings."""
with tempfile.TemporaryDirectory() as temp_dir:
yield Path(temp_dir)
@pytest.fixture
def real_openai_chat_response():
"""Real OpenAI chat completion response using proper Pydantic objects."""
return OpenAIChatCompletion(
id="chatcmpl-test123",
choices=[
OpenAIChoice(
index=0,
message=OpenAIAssistantMessageParam(
role="assistant", content="Hello! I'm doing well, thank you for asking."
),
finish_reason="stop",
)
],
created=1234567890,
model="llama3.2:3b",
)
@pytest.fixture
def real_embeddings_response():
"""Real OpenAI embeddings response using proper Pydantic objects."""
return OpenAIEmbeddingsResponse(
object="list",
data=[
OpenAIEmbeddingData(object="embedding", embedding=[0.1, 0.2, 0.3], index=0),
OpenAIEmbeddingData(object="embedding", embedding=[0.4, 0.5, 0.6], index=1),
],
model="nomic-embed-text",
usage=OpenAIEmbeddingUsage(prompt_tokens=6, total_tokens=6),
)
class TestInferenceRecording:
"""Test the inference recording system."""
def test_request_normalization(self):
"""Test that request normalization produces consistent hashes."""
# Test basic normalization
hash1 = normalize_inference_request(
"POST",
"http://localhost:11434/v1/chat/completions",
{},
{"model": "llama3.2:3b", "messages": [{"role": "user", "content": "Hello world"}], "temperature": 0.7},
)
# Same request should produce same hash
hash2 = normalize_inference_request(
"POST",
"http://localhost:11434/v1/chat/completions",
{},
{"model": "llama3.2:3b", "messages": [{"role": "user", "content": "Hello world"}], "temperature": 0.7},
)
assert hash1 == hash2
# Different content should produce different hash
hash3 = normalize_inference_request(
"POST",
"http://localhost:11434/v1/chat/completions",
{},
{
"model": "llama3.2:3b",
"messages": [{"role": "user", "content": "Different message"}],
"temperature": 0.7,
},
)
assert hash1 != hash3
def test_request_normalization_edge_cases(self):
"""Test request normalization is precise about request content."""
# Test that different whitespace produces different hashes (no normalization)
hash1 = normalize_inference_request(
"POST",
"http://test/v1/chat/completions",
{},
{"messages": [{"role": "user", "content": "Hello world\n\n"}]},
)
hash2 = normalize_inference_request(
"POST", "http://test/v1/chat/completions", {}, {"messages": [{"role": "user", "content": "Hello world"}]}
)
assert hash1 != hash2 # Different whitespace should produce different hashes
# Test that different float precision produces different hashes (no rounding)
hash3 = normalize_inference_request("POST", "http://test/v1/chat/completions", {}, {"temperature": 0.7000001})
hash4 = normalize_inference_request("POST", "http://test/v1/chat/completions", {}, {"temperature": 0.7})
assert hash3 == hash4 # Small float precision differences should normalize to the same hash
# String-embedded decimals with excessive precision should also normalize.
body_with_precise_scores = {
"messages": [
{
"role": "tool",
"content": "score: 0.7472640164649847",
}
]
}
body_with_precise_scores_variation = {
"messages": [
{
"role": "tool",
"content": "score: 0.74726414959878",
}
]
}
hash5 = normalize_inference_request("POST", "http://test/v1/chat/completions", {}, body_with_precise_scores)
hash6 = normalize_inference_request(
"POST", "http://test/v1/chat/completions", {}, body_with_precise_scores_variation
)
assert hash5 == hash6
body_with_close_scores = {
"messages": [
{
"role": "tool",
"content": "score: 0.662477492560699",
}
]
}
body_with_close_scores_variation = {
"messages": [
{
"role": "tool",
"content": "score: 0.6624775971970099",
}
]
}
hash7 = normalize_inference_request("POST", "http://test/v1/chat/completions", {}, body_with_close_scores)
hash8 = normalize_inference_request(
"POST", "http://test/v1/chat/completions", {}, body_with_close_scores_variation
)
assert hash7 == hash8
def test_response_storage(self, temp_storage_dir):
"""Test the ResponseStorage class."""
temp_storage_dir = temp_storage_dir / "test_response_storage"
storage = ResponseStorage(temp_storage_dir)
# Test storing and retrieving a recording
request_hash = "test_hash_123"
request_data = {
"method": "POST",
"url": "http://localhost:11434/v1/chat/completions",
"endpoint": "/v1/chat/completions",
"model": "llama3.2:3b",
}
response_data = {"body": {"content": "test response"}, "is_streaming": False}
storage.store_recording(request_hash, request_data, response_data)
# Verify file storage and retrieval
retrieved = storage.find_recording(request_hash)
assert retrieved is not None
assert retrieved["request"]["model"] == "llama3.2:3b"
assert retrieved["response"]["body"]["content"] == "test response"
async def test_recording_mode(self, temp_storage_dir, real_openai_chat_response):
"""Test that recording mode captures and stores responses."""
async def mock_create(*args, **kwargs):
return real_openai_chat_response
temp_storage_dir = temp_storage_dir / "test_recording_mode"
with patch("openai.resources.chat.completions.AsyncCompletions.create", side_effect=mock_create):
with api_recording(mode=APIRecordingMode.RECORD, storage_dir=str(temp_storage_dir)):
client = AsyncOpenAI(base_url="http://localhost:11434/v1", api_key="test")
response = await client.chat.completions.create(
model="llama3.2:3b",
messages=[{"role": "user", "content": "Hello, how are you?"}],
temperature=0.7,
max_tokens=50,
)
# Verify the response was returned correctly
assert response.choices[0].message.content == "Hello! I'm doing well, thank you for asking."
# Verify recording was stored
storage = ResponseStorage(temp_storage_dir)
assert storage._get_test_dir().exists()
async def test_replay_mode(self, temp_storage_dir, real_openai_chat_response):
"""Test that replay mode returns stored responses without making real calls."""
async def mock_create(*args, **kwargs):
return real_openai_chat_response
temp_storage_dir = temp_storage_dir / "test_replay_mode"
# First, record a response
with patch("openai.resources.chat.completions.AsyncCompletions.create", side_effect=mock_create):
with api_recording(mode=APIRecordingMode.RECORD, storage_dir=str(temp_storage_dir)):
client = AsyncOpenAI(base_url="http://localhost:11434/v1", api_key="test")
response = await client.chat.completions.create(
model="llama3.2:3b",
messages=[{"role": "user", "content": "Hello, how are you?"}],
temperature=0.7,
max_tokens=50,
)
# Now test replay mode - should not call the original method
with patch("openai.resources.chat.completions.AsyncCompletions.create") as mock_create_patch:
with api_recording(mode=APIRecordingMode.REPLAY, storage_dir=str(temp_storage_dir)):
client = AsyncOpenAI(base_url="http://localhost:11434/v1", api_key="test")
response = await client.chat.completions.create(
model="llama3.2:3b",
messages=[{"role": "user", "content": "Hello, how are you?"}],
temperature=0.7,
max_tokens=50,
)
# Verify we got the recorded response
assert response.choices[0].message.content == "Hello! I'm doing well, thank you for asking."
# Verify the original method was NOT called
mock_create_patch.assert_not_called()
async def test_replay_missing_recording(self, temp_storage_dir):
"""Test that replay mode fails when no recording is found."""
temp_storage_dir = temp_storage_dir / "test_replay_missing_recording"
with patch("openai.resources.chat.completions.AsyncCompletions.create"):
with api_recording(mode=APIRecordingMode.REPLAY, storage_dir=str(temp_storage_dir)):
client = AsyncOpenAI(base_url="http://localhost:11434/v1", api_key="test")
with pytest.raises(RuntimeError, match="Recording not found"):
await client.chat.completions.create(
model="llama3.2:3b", messages=[{"role": "user", "content": "This was never recorded"}]
)
async def test_embeddings_recording(self, temp_storage_dir, real_embeddings_response):
"""Test recording and replay of embeddings calls."""
async def mock_create(*args, **kwargs):
return real_embeddings_response
temp_storage_dir = temp_storage_dir / "test_embeddings_recording"
# Record
with patch("openai.resources.embeddings.AsyncEmbeddings.create", side_effect=mock_create):
with api_recording(mode=APIRecordingMode.RECORD, storage_dir=str(temp_storage_dir)):
client = AsyncOpenAI(base_url="http://localhost:11434/v1", api_key="test")
response = await client.embeddings.create(
model="nomic-embed-text", input=["Hello world", "Test embedding"]
)
assert len(response.data) == 2
# Replay
with patch("openai.resources.embeddings.AsyncEmbeddings.create") as mock_create_patch:
with api_recording(mode=APIRecordingMode.REPLAY, storage_dir=str(temp_storage_dir)):
client = AsyncOpenAI(base_url="http://localhost:11434/v1", api_key="test")
response = await client.embeddings.create(
model="nomic-embed-text", input=["Hello world", "Test embedding"]
)
# Verify we got the recorded response
assert len(response.data) == 2
assert response.data[0].embedding == [0.1, 0.2, 0.3]
# Verify original method was not called
mock_create_patch.assert_not_called()
async def test_live_mode(self, real_openai_chat_response):
"""Test that live mode passes through to original methods."""
async def mock_create(*args, **kwargs):
return real_openai_chat_response
with patch("openai.resources.chat.completions.AsyncCompletions.create", side_effect=mock_create):
with api_recording(mode=APIRecordingMode.LIVE, storage_dir="foo"):
client = AsyncOpenAI(base_url="http://localhost:11434/v1", api_key="test")
response = await client.chat.completions.create(
model="llama3.2:3b", messages=[{"role": "user", "content": "Hello"}]
)
# Verify the response was returned
assert response.choices[0].message.content == "Hello! I'm doing well, thank you for asking."

View file

@ -390,3 +390,467 @@ pip_packages:
assert provider.is_external is True
# config_class is empty string in partial spec
assert provider.config_class == ""
class TestGetExternalProvidersFromModule:
"""Test suite for installing external providers from module."""
def test_stackrunconfig_provider_without_module(self, mock_providers):
"""Test that providers without module attribute are skipped."""
from llama_stack.core.datatypes import Provider, StackRunConfig
from llama_stack.core.distribution import get_external_providers_from_module
import_module_side_effect = make_import_module_side_effect()
with patch("importlib.import_module", side_effect=import_module_side_effect):
config = StackRunConfig(
image_name="test_image",
providers={
"inference": [
Provider(
provider_id="no_module",
provider_type="no_module",
config={},
)
]
},
)
registry = {Api.inference: {}}
result = get_external_providers_from_module(registry, config, building=False)
# Should not add anything to registry
assert len(result[Api.inference]) == 0
def test_stackrunconfig_with_version_spec(self, mock_providers):
"""Test provider with module containing version spec (e.g., package==1.0.0)."""
from types import SimpleNamespace
from llama_stack.core.datatypes import Provider, StackRunConfig
from llama_stack.core.distribution import get_external_providers_from_module
from llama_stack.providers.datatypes import ProviderSpec
fake_spec = ProviderSpec(
api=Api.inference,
provider_type="versioned_test",
config_class="versioned_test.config.VersionedTestConfig",
module="versioned_test==1.0.0",
)
fake_module = SimpleNamespace(get_provider_spec=lambda: fake_spec)
def import_side_effect(name):
if name == "versioned_test.provider":
return fake_module
raise ModuleNotFoundError(name)
with patch("importlib.import_module", side_effect=import_side_effect):
config = StackRunConfig(
image_name="test_image",
providers={
"inference": [
Provider(
provider_id="versioned",
provider_type="versioned_test",
config={},
module="versioned_test==1.0.0",
)
]
},
)
registry = {Api.inference: {}}
result = get_external_providers_from_module(registry, config, building=False)
assert "versioned_test" in result[Api.inference]
assert result[Api.inference]["versioned_test"].module == "versioned_test==1.0.0"
def test_buildconfig_does_not_import_module(self, mock_providers):
"""Test that BuildConfig does not import the module (building=True)."""
from llama_stack.core.datatypes import BuildConfig, BuildProvider, DistributionSpec
from llama_stack.core.distribution import get_external_providers_from_module
build_config = BuildConfig(
version=2,
image_type="container",
image_name="test_image",
distribution_spec=DistributionSpec(
description="test",
providers={
"inference": [
BuildProvider(
provider_type="build_test",
module="build_test==1.0.0",
)
]
},
),
)
# Should not call import_module at all when building
with patch("importlib.import_module") as mock_import:
registry = {Api.inference: {}}
result = get_external_providers_from_module(registry, build_config, building=True)
# Verify module was NOT imported
mock_import.assert_not_called()
# Verify partial spec was created
assert "build_test" in result[Api.inference]
provider = result[Api.inference]["build_test"]
assert provider.module == "build_test==1.0.0"
assert provider.is_external is True
assert provider.config_class == ""
assert provider.api == Api.inference
def test_buildconfig_multiple_providers(self, mock_providers):
"""Test BuildConfig with multiple providers for the same API."""
from llama_stack.core.datatypes import BuildConfig, BuildProvider, DistributionSpec
from llama_stack.core.distribution import get_external_providers_from_module
build_config = BuildConfig(
version=2,
image_type="container",
image_name="test_image",
distribution_spec=DistributionSpec(
description="test",
providers={
"inference": [
BuildProvider(provider_type="provider1", module="provider1"),
BuildProvider(provider_type="provider2", module="provider2"),
]
},
),
)
with patch("importlib.import_module") as mock_import:
registry = {Api.inference: {}}
result = get_external_providers_from_module(registry, build_config, building=True)
mock_import.assert_not_called()
assert "provider1" in result[Api.inference]
assert "provider2" in result[Api.inference]
def test_distributionspec_does_not_import_module(self, mock_providers):
"""Test that DistributionSpec does not import the module (building=True)."""
from llama_stack.core.datatypes import BuildProvider, DistributionSpec
from llama_stack.core.distribution import get_external_providers_from_module
dist_spec = DistributionSpec(
description="test distribution",
providers={
"inference": [
BuildProvider(
provider_type="dist_test",
module="dist_test==2.0.0",
)
]
},
)
# Should not call import_module at all when building
with patch("importlib.import_module") as mock_import:
registry = {Api.inference: {}}
result = get_external_providers_from_module(registry, dist_spec, building=True)
# Verify module was NOT imported
mock_import.assert_not_called()
# Verify partial spec was created
assert "dist_test" in result[Api.inference]
provider = result[Api.inference]["dist_test"]
assert provider.module == "dist_test==2.0.0"
assert provider.is_external is True
assert provider.config_class == ""
def test_list_return_from_get_provider_spec(self, mock_providers):
"""Test when get_provider_spec returns a list of specs."""
from types import SimpleNamespace
from llama_stack.core.datatypes import Provider, StackRunConfig
from llama_stack.core.distribution import get_external_providers_from_module
from llama_stack.providers.datatypes import ProviderSpec
spec1 = ProviderSpec(
api=Api.inference,
provider_type="list_test",
config_class="list_test.config.Config1",
module="list_test",
)
spec2 = ProviderSpec(
api=Api.inference,
provider_type="list_test_remote",
config_class="list_test.config.Config2",
module="list_test",
)
fake_module = SimpleNamespace(get_provider_spec=lambda: [spec1, spec2])
def import_side_effect(name):
if name == "list_test.provider":
return fake_module
raise ModuleNotFoundError(name)
with patch("importlib.import_module", side_effect=import_side_effect):
config = StackRunConfig(
image_name="test_image",
providers={
"inference": [
Provider(
provider_id="list_test",
provider_type="list_test",
config={},
module="list_test",
)
]
},
)
registry = {Api.inference: {}}
result = get_external_providers_from_module(registry, config, building=False)
# Only the matching provider_type should be added
assert "list_test" in result[Api.inference]
assert result[Api.inference]["list_test"].config_class == "list_test.config.Config1"
def test_list_return_filters_by_provider_type(self, mock_providers):
"""Test that list return filters specs by provider_type."""
from types import SimpleNamespace
from llama_stack.core.datatypes import Provider, StackRunConfig
from llama_stack.core.distribution import get_external_providers_from_module
from llama_stack.providers.datatypes import ProviderSpec
spec1 = ProviderSpec(
api=Api.inference,
provider_type="wanted",
config_class="test.Config1",
module="test",
)
spec2 = ProviderSpec(
api=Api.inference,
provider_type="unwanted",
config_class="test.Config2",
module="test",
)
fake_module = SimpleNamespace(get_provider_spec=lambda: [spec1, spec2])
def import_side_effect(name):
if name == "test.provider":
return fake_module
raise ModuleNotFoundError(name)
with patch("importlib.import_module", side_effect=import_side_effect):
config = StackRunConfig(
image_name="test_image",
providers={
"inference": [
Provider(
provider_id="wanted",
provider_type="wanted",
config={},
module="test",
)
]
},
)
registry = {Api.inference: {}}
result = get_external_providers_from_module(registry, config, building=False)
# Only the matching provider_type should be added
assert "wanted" in result[Api.inference]
assert "unwanted" not in result[Api.inference]
def test_list_return_adds_multiple_provider_types(self, mock_providers):
"""Test that list return adds multiple different provider_types when config requests them."""
from types import SimpleNamespace
from llama_stack.core.datatypes import Provider, StackRunConfig
from llama_stack.core.distribution import get_external_providers_from_module
from llama_stack.providers.datatypes import ProviderSpec
# Module returns both inline and remote variants
spec1 = ProviderSpec(
api=Api.inference,
provider_type="remote::ollama",
config_class="test.RemoteConfig",
module="test",
)
spec2 = ProviderSpec(
api=Api.inference,
provider_type="inline::ollama",
config_class="test.InlineConfig",
module="test",
)
fake_module = SimpleNamespace(get_provider_spec=lambda: [spec1, spec2])
def import_side_effect(name):
if name == "test.provider":
return fake_module
raise ModuleNotFoundError(name)
with patch("importlib.import_module", side_effect=import_side_effect):
config = StackRunConfig(
image_name="test_image",
providers={
"inference": [
Provider(
provider_id="remote_ollama",
provider_type="remote::ollama",
config={},
module="test",
),
Provider(
provider_id="inline_ollama",
provider_type="inline::ollama",
config={},
module="test",
),
]
},
)
registry = {Api.inference: {}}
result = get_external_providers_from_module(registry, config, building=False)
# Both provider types should be added to registry
assert "remote::ollama" in result[Api.inference]
assert "inline::ollama" in result[Api.inference]
assert result[Api.inference]["remote::ollama"].config_class == "test.RemoteConfig"
assert result[Api.inference]["inline::ollama"].config_class == "test.InlineConfig"
def test_module_not_found_raises_value_error(self, mock_providers):
"""Test that ModuleNotFoundError raises ValueError with helpful message."""
from llama_stack.core.datatypes import Provider, StackRunConfig
from llama_stack.core.distribution import get_external_providers_from_module
def import_side_effect(name):
if name == "missing_module.provider":
raise ModuleNotFoundError(name)
raise ModuleNotFoundError(name)
with patch("importlib.import_module", side_effect=import_side_effect):
config = StackRunConfig(
image_name="test_image",
providers={
"inference": [
Provider(
provider_id="missing",
provider_type="missing",
config={},
module="missing_module",
)
]
},
)
registry = {Api.inference: {}}
with pytest.raises(ValueError) as exc_info:
get_external_providers_from_module(registry, config, building=False)
assert "get_provider_spec not found" in str(exc_info.value)
def test_generic_exception_is_raised(self, mock_providers):
"""Test that generic exceptions are properly raised."""
from types import SimpleNamespace
from llama_stack.core.datatypes import Provider, StackRunConfig
from llama_stack.core.distribution import get_external_providers_from_module
def bad_spec():
raise RuntimeError("Something went wrong")
fake_module = SimpleNamespace(get_provider_spec=bad_spec)
def import_side_effect(name):
if name == "error_module.provider":
return fake_module
raise ModuleNotFoundError(name)
with patch("importlib.import_module", side_effect=import_side_effect):
config = StackRunConfig(
image_name="test_image",
providers={
"inference": [
Provider(
provider_id="error",
provider_type="error",
config={},
module="error_module",
)
]
},
)
registry = {Api.inference: {}}
with pytest.raises(RuntimeError) as exc_info:
get_external_providers_from_module(registry, config, building=False)
assert "Something went wrong" in str(exc_info.value)
def test_empty_provider_list(self, mock_providers):
"""Test with empty provider list."""
from llama_stack.core.datatypes import StackRunConfig
from llama_stack.core.distribution import get_external_providers_from_module
config = StackRunConfig(
image_name="test_image",
providers={},
)
registry = {Api.inference: {}}
result = get_external_providers_from_module(registry, config, building=False)
# Should return registry unchanged
assert result == registry
assert len(result[Api.inference]) == 0
def test_multiple_apis_with_providers(self, mock_providers):
"""Test multiple APIs with providers."""
from types import SimpleNamespace
from llama_stack.core.datatypes import Provider, StackRunConfig
from llama_stack.core.distribution import get_external_providers_from_module
from llama_stack.providers.datatypes import ProviderSpec
inference_spec = ProviderSpec(
api=Api.inference,
provider_type="inf_test",
config_class="inf.Config",
module="inf_test",
)
safety_spec = ProviderSpec(
api=Api.safety,
provider_type="safe_test",
config_class="safe.Config",
module="safe_test",
)
def import_side_effect(name):
if name == "inf_test.provider":
return SimpleNamespace(get_provider_spec=lambda: inference_spec)
elif name == "safe_test.provider":
return SimpleNamespace(get_provider_spec=lambda: safety_spec)
raise ModuleNotFoundError(name)
with patch("importlib.import_module", side_effect=import_side_effect):
config = StackRunConfig(
image_name="test_image",
providers={
"inference": [
Provider(
provider_id="inf",
provider_type="inf_test",
config={},
module="inf_test",
)
],
"safety": [
Provider(
provider_id="safe",
provider_type="safe_test",
config={},
module="safe_test",
)
],
},
)
registry = {Api.inference: {}, Api.safety: {}}
result = get_external_providers_from_module(registry, config, building=False)
assert "inf_test" in result[Api.inference]
assert "safe_test" in result[Api.safety]

View file

@ -1,382 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import tempfile
from pathlib import Path
from unittest.mock import AsyncMock, Mock, patch
import pytest
from openai import NOT_GIVEN, AsyncOpenAI
from openai.types.model import Model as OpenAIModel
# Import the real Pydantic response types instead of using Mocks
from llama_stack.apis.inference import (
OpenAIAssistantMessageParam,
OpenAIChatCompletion,
OpenAIChoice,
OpenAICompletion,
OpenAIEmbeddingData,
OpenAIEmbeddingsResponse,
OpenAIEmbeddingUsage,
)
from llama_stack.testing.inference_recorder import (
InferenceMode,
ResponseStorage,
inference_recording,
normalize_request,
)
@pytest.fixture
def temp_storage_dir():
"""Create a temporary directory for test recordings."""
with tempfile.TemporaryDirectory() as temp_dir:
yield Path(temp_dir)
@pytest.fixture
def real_openai_chat_response():
"""Real OpenAI chat completion response using proper Pydantic objects."""
return OpenAIChatCompletion(
id="chatcmpl-test123",
choices=[
OpenAIChoice(
index=0,
message=OpenAIAssistantMessageParam(
role="assistant", content="Hello! I'm doing well, thank you for asking."
),
finish_reason="stop",
)
],
created=1234567890,
model="llama3.2:3b",
)
@pytest.fixture
def real_embeddings_response():
"""Real OpenAI embeddings response using proper Pydantic objects."""
return OpenAIEmbeddingsResponse(
object="list",
data=[
OpenAIEmbeddingData(object="embedding", embedding=[0.1, 0.2, 0.3], index=0),
OpenAIEmbeddingData(object="embedding", embedding=[0.4, 0.5, 0.6], index=1),
],
model="nomic-embed-text",
usage=OpenAIEmbeddingUsage(prompt_tokens=6, total_tokens=6),
)
class TestInferenceRecording:
"""Test the inference recording system."""
def test_request_normalization(self):
"""Test that request normalization produces consistent hashes."""
# Test basic normalization
hash1 = normalize_request(
"POST",
"http://localhost:11434/v1/chat/completions",
{},
{"model": "llama3.2:3b", "messages": [{"role": "user", "content": "Hello world"}], "temperature": 0.7},
)
# Same request should produce same hash
hash2 = normalize_request(
"POST",
"http://localhost:11434/v1/chat/completions",
{},
{"model": "llama3.2:3b", "messages": [{"role": "user", "content": "Hello world"}], "temperature": 0.7},
)
assert hash1 == hash2
# Different content should produce different hash
hash3 = normalize_request(
"POST",
"http://localhost:11434/v1/chat/completions",
{},
{
"model": "llama3.2:3b",
"messages": [{"role": "user", "content": "Different message"}],
"temperature": 0.7,
},
)
assert hash1 != hash3
def test_request_normalization_edge_cases(self):
"""Test request normalization is precise about request content."""
# Test that different whitespace produces different hashes (no normalization)
hash1 = normalize_request(
"POST",
"http://test/v1/chat/completions",
{},
{"messages": [{"role": "user", "content": "Hello world\n\n"}]},
)
hash2 = normalize_request(
"POST", "http://test/v1/chat/completions", {}, {"messages": [{"role": "user", "content": "Hello world"}]}
)
assert hash1 != hash2 # Different whitespace should produce different hashes
# Test that different float precision produces different hashes (no rounding)
hash3 = normalize_request("POST", "http://test/v1/chat/completions", {}, {"temperature": 0.7000001})
hash4 = normalize_request("POST", "http://test/v1/chat/completions", {}, {"temperature": 0.7})
assert hash3 != hash4 # Different precision should produce different hashes
def test_response_storage(self, temp_storage_dir):
"""Test the ResponseStorage class."""
temp_storage_dir = temp_storage_dir / "test_response_storage"
storage = ResponseStorage(temp_storage_dir)
# Test storing and retrieving a recording
request_hash = "test_hash_123"
request_data = {
"method": "POST",
"url": "http://localhost:11434/v1/chat/completions",
"endpoint": "/v1/chat/completions",
"model": "llama3.2:3b",
}
response_data = {"body": {"content": "test response"}, "is_streaming": False}
storage.store_recording(request_hash, request_data, response_data)
# Verify file storage and retrieval
retrieved = storage.find_recording(request_hash)
assert retrieved is not None
assert retrieved["request"]["model"] == "llama3.2:3b"
assert retrieved["response"]["body"]["content"] == "test response"
async def test_recording_mode(self, temp_storage_dir, real_openai_chat_response):
"""Test that recording mode captures and stores responses."""
temp_storage_dir = temp_storage_dir / "test_recording_mode"
with inference_recording(mode=InferenceMode.RECORD, storage_dir=str(temp_storage_dir)):
client = AsyncOpenAI(base_url="http://localhost:11434/v1", api_key="test")
client.chat.completions._post = AsyncMock(return_value=real_openai_chat_response)
response = await client.chat.completions.create(
model="llama3.2:3b",
messages=[{"role": "user", "content": "Hello, how are you?"}],
temperature=0.7,
max_tokens=50,
user=NOT_GIVEN,
)
# Verify the response was returned correctly
assert response.choices[0].message.content == "Hello! I'm doing well, thank you for asking."
client.chat.completions._post.assert_called_once()
# Verify recording was stored
storage = ResponseStorage(temp_storage_dir)
dir = storage._get_test_dir()
assert dir.exists()
async def test_replay_mode(self, temp_storage_dir, real_openai_chat_response):
"""Test that replay mode returns stored responses without making real calls."""
temp_storage_dir = temp_storage_dir / "test_replay_mode"
# First, record a response
with inference_recording(mode=InferenceMode.RECORD, storage_dir=str(temp_storage_dir)):
client = AsyncOpenAI(base_url="http://localhost:11434/v1", api_key="test")
client.chat.completions._post = AsyncMock(return_value=real_openai_chat_response)
response = await client.chat.completions.create(
model="llama3.2:3b",
messages=[{"role": "user", "content": "Hello, how are you?"}],
temperature=0.7,
max_tokens=50,
user=NOT_GIVEN,
)
client.chat.completions._post.assert_called_once()
# Now test replay mode - should not call the original method
with inference_recording(mode=InferenceMode.REPLAY, storage_dir=str(temp_storage_dir)):
client = AsyncOpenAI(base_url="http://localhost:11434/v1", api_key="test")
client.chat.completions._post = AsyncMock(return_value=real_openai_chat_response)
response = await client.chat.completions.create(
model="llama3.2:3b",
messages=[{"role": "user", "content": "Hello, how are you?"}],
temperature=0.7,
max_tokens=50,
)
# Verify we got the recorded response
assert response.choices[0].message.content == "Hello! I'm doing well, thank you for asking."
# Verify the original method was NOT called
client.chat.completions._post.assert_not_called()
async def test_replay_mode_models(self, temp_storage_dir):
"""Test that replay mode returns stored responses without making real model listing calls."""
async def _async_iterator(models):
for model in models:
yield model
models = [
OpenAIModel(id="foo", created=1, object="model", owned_by="test"),
OpenAIModel(id="bar", created=2, object="model", owned_by="test"),
]
expected_ids = {m.id for m in models}
temp_storage_dir = temp_storage_dir / "test_replay_mode_models"
# baseline - mock works without recording
client = AsyncOpenAI(base_url="http://localhost:11434/v1", api_key="test")
client.models._get_api_list = Mock(return_value=_async_iterator(models))
assert {m.id async for m in client.models.list()} == expected_ids
client.models._get_api_list.assert_called_once()
# record the call
with inference_recording(mode=InferenceMode.RECORD, storage_dir=temp_storage_dir):
client = AsyncOpenAI(base_url="http://localhost:11434/v1", api_key="test")
client.models._get_api_list = Mock(return_value=_async_iterator(models))
assert {m.id async for m in client.models.list()} == expected_ids
client.models._get_api_list.assert_called_once()
# replay the call
with inference_recording(mode=InferenceMode.REPLAY, storage_dir=temp_storage_dir):
client = AsyncOpenAI(base_url="http://localhost:11434/v1", api_key="test")
client.models._get_api_list = Mock(return_value=_async_iterator(models))
assert {m.id async for m in client.models.list()} == expected_ids
client.models._get_api_list.assert_not_called()
async def test_replay_missing_recording(self, temp_storage_dir):
"""Test that replay mode fails when no recording is found."""
temp_storage_dir = temp_storage_dir / "test_replay_missing_recording"
with patch("openai.resources.chat.completions.AsyncCompletions.create"):
with inference_recording(mode=InferenceMode.REPLAY, storage_dir=str(temp_storage_dir)):
client = AsyncOpenAI(base_url="http://localhost:11434/v1", api_key="test")
with pytest.raises(RuntimeError, match="No recorded response found"):
await client.chat.completions.create(
model="llama3.2:3b", messages=[{"role": "user", "content": "This was never recorded"}]
)
async def test_embeddings_recording(self, temp_storage_dir, real_embeddings_response):
"""Test recording and replay of embeddings calls."""
# baseline - mock works without recording
client = AsyncOpenAI(base_url="http://localhost:11434/v1", api_key="test")
client.embeddings._post = AsyncMock(return_value=real_embeddings_response)
response = await client.embeddings.create(
model=real_embeddings_response.model,
input=["Hello world", "Test embedding"],
encoding_format=NOT_GIVEN,
)
assert len(response.data) == 2
assert response.data[0].embedding == [0.1, 0.2, 0.3]
client.embeddings._post.assert_called_once()
temp_storage_dir = temp_storage_dir / "test_embeddings_recording"
# Record
with inference_recording(mode=InferenceMode.RECORD, storage_dir=str(temp_storage_dir)):
client = AsyncOpenAI(base_url="http://localhost:11434/v1", api_key="test")
client.embeddings._post = AsyncMock(return_value=real_embeddings_response)
response = await client.embeddings.create(
model=real_embeddings_response.model,
input=["Hello world", "Test embedding"],
encoding_format=NOT_GIVEN,
dimensions=NOT_GIVEN,
user=NOT_GIVEN,
)
assert len(response.data) == 2
# Replay
with inference_recording(mode=InferenceMode.REPLAY, storage_dir=str(temp_storage_dir)):
client = AsyncOpenAI(base_url="http://localhost:11434/v1", api_key="test")
client.embeddings._post = AsyncMock(return_value=real_embeddings_response)
response = await client.embeddings.create(
model=real_embeddings_response.model,
input=["Hello world", "Test embedding"],
)
# Verify we got the recorded response
assert len(response.data) == 2
assert response.data[0].embedding == [0.1, 0.2, 0.3]
# Verify original method was not called
client.embeddings._post.assert_not_called()
async def test_completions_recording(self, temp_storage_dir):
real_completions_response = OpenAICompletion(
id="test_completion",
object="text_completion",
created=1234567890,
model="llama3.2:3b",
choices=[
{
"text": "Hello! I'm doing well, thank you for asking.",
"index": 0,
"logprobs": None,
"finish_reason": "stop",
}
],
)
temp_storage_dir = temp_storage_dir / "test_completions_recording"
# baseline - mock works without recording
client = AsyncOpenAI(base_url="http://localhost:11434/v1", api_key="test")
client.completions._post = AsyncMock(return_value=real_completions_response)
response = await client.completions.create(
model=real_completions_response.model,
prompt="Hello, how are you?",
temperature=0.7,
max_tokens=50,
user=NOT_GIVEN,
)
assert response.choices[0].text == real_completions_response.choices[0].text
client.completions._post.assert_called_once()
# Record
with inference_recording(mode=InferenceMode.RECORD, storage_dir=str(temp_storage_dir)):
client = AsyncOpenAI(base_url="http://localhost:11434/v1", api_key="test")
client.completions._post = AsyncMock(return_value=real_completions_response)
response = await client.completions.create(
model=real_completions_response.model,
prompt="Hello, how are you?",
temperature=0.7,
max_tokens=50,
user=NOT_GIVEN,
)
assert response.choices[0].text == real_completions_response.choices[0].text
client.completions._post.assert_called_once()
# Replay
with inference_recording(mode=InferenceMode.REPLAY, storage_dir=str(temp_storage_dir)):
client = AsyncOpenAI(base_url="http://localhost:11434/v1", api_key="test")
client.completions._post = AsyncMock(return_value=real_completions_response)
response = await client.completions.create(
model=real_completions_response.model,
prompt="Hello, how are you?",
temperature=0.7,
max_tokens=50,
)
assert response.choices[0].text == real_completions_response.choices[0].text
client.completions._post.assert_not_called()
async def test_live_mode(self, real_openai_chat_response):
"""Test that live mode passes through to original methods."""
async def mock_create(*args, **kwargs):
return real_openai_chat_response
with patch("openai.resources.chat.completions.AsyncCompletions.create", side_effect=mock_create):
with inference_recording(mode=InferenceMode.LIVE, storage_dir="foo"):
client = AsyncOpenAI(base_url="http://localhost:11434/v1", api_key="test")
response = await client.chat.completions.create(
model="llama3.2:3b", messages=[{"role": "user", "content": "Hello"}]
)
# Verify the response was returned
assert response.choices[0].message.content == "Hello! I'm doing well, thank you for asking."

View file

@ -15,6 +15,7 @@ from llama_stack.apis.agents import (
AgentCreateResponse,
)
from llama_stack.apis.common.responses import PaginatedResponse
from llama_stack.apis.conversations import Conversations
from llama_stack.apis.inference import Inference
from llama_stack.apis.safety import Safety
from llama_stack.apis.tools import ListToolDefsResponse, ToolDef, ToolGroups, ToolRuntime
@ -33,6 +34,7 @@ def mock_apis():
"safety_api": AsyncMock(spec=Safety),
"tool_runtime_api": AsyncMock(spec=ToolRuntime),
"tool_groups_api": AsyncMock(spec=ToolGroups),
"conversations_api": AsyncMock(spec=Conversations),
}
@ -59,7 +61,8 @@ async def agents_impl(config, mock_apis):
mock_apis["safety_api"],
mock_apis["tool_runtime_api"],
mock_apis["tool_groups_api"],
{},
mock_apis["conversations_api"],
[],
)
await impl.initialize()
yield impl

View file

@ -4,7 +4,7 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from unittest.mock import AsyncMock
from unittest.mock import AsyncMock, patch
import pytest
from openai.types.chat.chat_completion_chunk import (
@ -20,6 +20,7 @@ from llama_stack.apis.agents.openai_responses import (
ListOpenAIResponseInputItem,
OpenAIResponseInputMessageContentText,
OpenAIResponseInputToolFunction,
OpenAIResponseInputToolMCP,
OpenAIResponseInputToolWebSearch,
OpenAIResponseMessage,
OpenAIResponseOutputMessageContentOutputText,
@ -32,13 +33,14 @@ from llama_stack.apis.agents.openai_responses import (
from llama_stack.apis.inference import (
OpenAIAssistantMessageParam,
OpenAIChatCompletionContentPartTextParam,
OpenAIChatCompletionRequestWithExtraBody,
OpenAIDeveloperMessageParam,
OpenAIJSONSchema,
OpenAIResponseFormatJSONObject,
OpenAIResponseFormatJSONSchema,
OpenAIUserMessageParam,
)
from llama_stack.apis.tools.tools import ToolDef, ToolGroups, ToolInvocationResult, ToolRuntime
from llama_stack.apis.tools.tools import ListToolDefsResponse, ToolDef, ToolGroups, ToolInvocationResult, ToolRuntime
from llama_stack.core.access_control.access_control import default_policy
from llama_stack.core.datatypes import ResponsesStoreConfig
from llama_stack.providers.inline.agents.meta_reference.responses.openai_responses import (
@ -82,9 +84,21 @@ def mock_vector_io_api():
return vector_io_api
@pytest.fixture
def mock_conversations_api():
"""Mock conversations API for testing."""
mock_api = AsyncMock()
return mock_api
@pytest.fixture
def openai_responses_impl(
mock_inference_api, mock_tool_groups_api, mock_tool_runtime_api, mock_responses_store, mock_vector_io_api
mock_inference_api,
mock_tool_groups_api,
mock_tool_runtime_api,
mock_responses_store,
mock_vector_io_api,
mock_conversations_api,
):
return OpenAIResponsesImpl(
inference_api=mock_inference_api,
@ -92,6 +106,7 @@ def openai_responses_impl(
tool_runtime_api=mock_tool_runtime_api,
responses_store=mock_responses_store,
vector_io_api=mock_vector_io_api,
conversations_api=mock_conversations_api,
)
@ -147,18 +162,24 @@ async def test_create_openai_response_with_string_input(openai_responses_impl, m
chunks = [chunk async for chunk in result]
mock_inference_api.openai_chat_completion.assert_called_once_with(
model=model,
messages=[OpenAIUserMessageParam(role="user", content="What is the capital of Ireland?", name=None)],
response_format=None,
tools=None,
stream=True,
temperature=0.1,
OpenAIChatCompletionRequestWithExtraBody(
model=model,
messages=[OpenAIUserMessageParam(role="user", content="What is the capital of Ireland?", name=None)],
response_format=None,
tools=None,
stream=True,
temperature=0.1,
stream_options={
"include_usage": True,
},
)
)
# Should have content part events for text streaming
# Expected: response.created, content_part.added, output_text.delta, content_part.done, response.completed
assert len(chunks) >= 4
# Expected: response.created, response.in_progress, content_part.added, output_text.delta, content_part.done, response.completed
assert len(chunks) >= 5
assert chunks[0].type == "response.created"
assert any(chunk.type == "response.in_progress" for chunk in chunks)
# Check for content part events
content_part_added_events = [c for c in chunks if c.type == "response.content_part.added"]
@ -169,6 +190,14 @@ async def test_create_openai_response_with_string_input(openai_responses_impl, m
assert len(content_part_done_events) >= 1, "Should have content_part.done event for text"
assert len(text_delta_events) >= 1, "Should have text delta events"
added_event = content_part_added_events[0]
done_event = content_part_done_events[0]
assert added_event.content_index == 0
assert done_event.content_index == 0
assert added_event.output_index == done_event.output_index == 0
assert added_event.item_id == done_event.item_id
assert added_event.response_id == done_event.response_id
# Verify final event is completion
assert chunks[-1].type == "response.completed"
@ -177,6 +206,8 @@ async def test_create_openai_response_with_string_input(openai_responses_impl, m
assert final_response.model == model
assert len(final_response.output) == 1
assert isinstance(final_response.output[0], OpenAIResponseMessage)
assert final_response.output[0].id == added_event.item_id
assert final_response.id == added_event.response_id
openai_responses_impl.responses_store.store_response_object.assert_called_once()
assert final_response.output[0].content[0].text == "Dublin"
@ -228,13 +259,15 @@ async def test_create_openai_response_with_string_input_with_tools(openai_respon
# Verify
first_call = mock_inference_api.openai_chat_completion.call_args_list[0]
assert first_call.kwargs["messages"][0].content == "What is the capital of Ireland?"
assert first_call.kwargs["tools"] is not None
assert first_call.kwargs["temperature"] == 0.1
first_params = first_call.args[0]
assert first_params.messages[0].content == "What is the capital of Ireland?"
assert first_params.tools is not None
assert first_params.temperature == 0.1
second_call = mock_inference_api.openai_chat_completion.call_args_list[1]
assert second_call.kwargs["messages"][-1].content == "Dublin"
assert second_call.kwargs["temperature"] == 0.1
second_params = second_call.args[0]
assert second_params.messages[-1].content == "Dublin"
assert second_params.temperature == 0.1
openai_responses_impl.tool_groups_api.get_tool.assert_called_once_with("web_search")
openai_responses_impl.tool_runtime_api.invoke_tool.assert_called_once_with(
@ -303,36 +336,42 @@ async def test_create_openai_response_with_tool_call_type_none(openai_responses_
chunks = [chunk async for chunk in result]
# Verify event types
# Should have: response.created, output_item.added, function_call_arguments.delta,
# function_call_arguments.done, output_item.done, response.completed
assert len(chunks) == 6
# Should have: response.created, response.in_progress, output_item.added,
# function_call_arguments.delta, function_call_arguments.done, output_item.done, response.completed
assert len(chunks) == 7
event_types = [chunk.type for chunk in chunks]
assert event_types == [
"response.created",
"response.in_progress",
"response.output_item.added",
"response.function_call_arguments.delta",
"response.function_call_arguments.done",
"response.output_item.done",
"response.completed",
]
# Verify inference API was called correctly (after iterating over result)
first_call = mock_inference_api.openai_chat_completion.call_args_list[0]
assert first_call.kwargs["messages"][0].content == input_text
assert first_call.kwargs["tools"] is not None
assert first_call.kwargs["temperature"] == 0.1
first_params = first_call.args[0]
assert first_params.messages[0].content == input_text
assert first_params.tools is not None
assert first_params.temperature == 0.1
# Check response.created event (should have empty output)
assert chunks[0].type == "response.created"
assert len(chunks[0].response.output) == 0
# Check streaming events
assert chunks[1].type == "response.output_item.added"
assert chunks[2].type == "response.function_call_arguments.delta"
assert chunks[3].type == "response.function_call_arguments.done"
assert chunks[4].type == "response.output_item.done"
# Check response.completed event (should have the tool call)
assert chunks[5].type == "response.completed"
assert len(chunks[5].response.output) == 1
assert chunks[5].response.output[0].type == "function_call"
assert chunks[5].response.output[0].name == "get_weather"
completed_chunk = chunks[-1]
assert completed_chunk.type == "response.completed"
assert len(completed_chunk.response.output) == 1
assert completed_chunk.response.output[0].type == "function_call"
assert completed_chunk.response.output[0].name == "get_weather"
async def test_create_openai_response_with_tool_call_function_arguments_none(openai_responses_impl, mock_inference_api):
"""Test creating an OpenAI response with a tool call response that has a function that does not accept arguments, or arguments set to None when they are not mandatory."""
# Setup
"""Test creating an OpenAI response with tool calls that omit arguments."""
input_text = "What is the time right now?"
model = "meta-llama/Llama-3.1-8B-Instruct"
@ -359,9 +398,22 @@ async def test_create_openai_response_with_tool_call_function_arguments_none(ope
object="chat.completion.chunk",
)
mock_inference_api.openai_chat_completion.return_value = fake_stream_toolcall()
def assert_common_expectations(chunks) -> None:
first_call = mock_inference_api.openai_chat_completion.call_args_list[0]
first_params = first_call.args[0]
assert first_params.messages[0].content == input_text
assert first_params.tools is not None
assert first_params.temperature == 0.1
assert len(chunks[0].response.output) == 0
completed_chunk = chunks[-1]
assert completed_chunk.type == "response.completed"
assert len(completed_chunk.response.output) == 1
assert completed_chunk.response.output[0].type == "function_call"
assert completed_chunk.response.output[0].name == "get_current_time"
assert completed_chunk.response.output[0].arguments == "{}"
# Function does not accept arguments
mock_inference_api.openai_chat_completion.return_value = fake_stream_toolcall()
result = await openai_responses_impl.create_openai_response(
input=input_text,
model=model,
@ -369,46 +421,23 @@ async def test_create_openai_response_with_tool_call_function_arguments_none(ope
temperature=0.1,
tools=[
OpenAIResponseInputToolFunction(
name="get_current_time",
description="Get current time for system's timezone",
parameters={},
name="get_current_time", description="Get current time for system's timezone", parameters={}
)
],
)
# Check that we got the content from our mocked tool execution result
chunks = [chunk async for chunk in result]
# Verify event types
# Should have: response.created, output_item.added, function_call_arguments.delta,
# function_call_arguments.done, output_item.done, response.completed
assert len(chunks) == 5
# Verify inference API was called correctly (after iterating over result)
first_call = mock_inference_api.openai_chat_completion.call_args_list[0]
assert first_call.kwargs["messages"][0].content == input_text
assert first_call.kwargs["tools"] is not None
assert first_call.kwargs["temperature"] == 0.1
# Check response.created event (should have empty output)
assert chunks[0].type == "response.created"
assert len(chunks[0].response.output) == 0
# Check streaming events
assert chunks[1].type == "response.output_item.added"
assert chunks[2].type == "response.function_call_arguments.done"
assert chunks[3].type == "response.output_item.done"
# Check response.completed event (should have the tool call with arguments set to "{}")
assert chunks[4].type == "response.completed"
assert len(chunks[4].response.output) == 1
assert chunks[4].response.output[0].type == "function_call"
assert chunks[4].response.output[0].name == "get_current_time"
assert chunks[4].response.output[0].arguments == "{}"
mock_inference_api.openai_chat_completion.return_value = fake_stream_toolcall()
assert [chunk.type for chunk in chunks] == [
"response.created",
"response.in_progress",
"response.output_item.added",
"response.function_call_arguments.done",
"response.output_item.done",
"response.completed",
]
assert_common_expectations(chunks)
# Function accepts optional arguments
mock_inference_api.openai_chat_completion.return_value = fake_stream_toolcall()
result = await openai_responses_impl.create_openai_response(
input=input_text,
model=model,
@ -418,42 +447,47 @@ async def test_create_openai_response_with_tool_call_function_arguments_none(ope
OpenAIResponseInputToolFunction(
name="get_current_time",
description="Get current time for system's timezone",
parameters={
"timezone": "string",
},
parameters={"timezone": "string"},
)
],
)
# Check that we got the content from our mocked tool execution result
chunks = [chunk async for chunk in result]
assert [chunk.type for chunk in chunks] == [
"response.created",
"response.in_progress",
"response.output_item.added",
"response.function_call_arguments.done",
"response.output_item.done",
"response.completed",
]
assert_common_expectations(chunks)
# Verify event types
# Should have: response.created, output_item.added, function_call_arguments.delta,
# function_call_arguments.done, output_item.done, response.completed
assert len(chunks) == 5
# Verify inference API was called correctly (after iterating over result)
first_call = mock_inference_api.openai_chat_completion.call_args_list[0]
assert first_call.kwargs["messages"][0].content == input_text
assert first_call.kwargs["tools"] is not None
assert first_call.kwargs["temperature"] == 0.1
# Check response.created event (should have empty output)
assert chunks[0].type == "response.created"
assert len(chunks[0].response.output) == 0
# Check streaming events
assert chunks[1].type == "response.output_item.added"
assert chunks[2].type == "response.function_call_arguments.done"
assert chunks[3].type == "response.output_item.done"
# Check response.completed event (should have the tool call with arguments set to "{}")
assert chunks[4].type == "response.completed"
assert len(chunks[4].response.output) == 1
assert chunks[4].response.output[0].type == "function_call"
assert chunks[4].response.output[0].name == "get_current_time"
assert chunks[4].response.output[0].arguments == "{}"
# Function accepts optional arguments with additional optional fields
mock_inference_api.openai_chat_completion.return_value = fake_stream_toolcall()
result = await openai_responses_impl.create_openai_response(
input=input_text,
model=model,
stream=True,
temperature=0.1,
tools=[
OpenAIResponseInputToolFunction(
name="get_current_time",
description="Get current time for system's timezone",
parameters={"timezone": "string", "location": "string"},
)
],
)
chunks = [chunk async for chunk in result]
assert [chunk.type for chunk in chunks] == [
"response.created",
"response.in_progress",
"response.output_item.added",
"response.function_call_arguments.done",
"response.output_item.done",
"response.completed",
]
assert_common_expectations(chunks)
mock_inference_api.openai_chat_completion.return_value = fake_stream_toolcall()
async def test_create_openai_response_with_multiple_messages(openai_responses_impl, mock_inference_api):
@ -485,7 +519,9 @@ async def test_create_openai_response_with_multiple_messages(openai_responses_im
# Verify the the correct messages were sent to the inference API i.e.
# All of the responses message were convered to the chat completion message objects
inference_messages = mock_inference_api.openai_chat_completion.call_args_list[0].kwargs["messages"]
call_args = mock_inference_api.openai_chat_completion.call_args_list[0]
params = call_args.args[0]
inference_messages = params.messages
for i, m in enumerate(input_messages):
if isinstance(m.content, str):
assert inference_messages[i].content == m.content
@ -653,7 +689,8 @@ async def test_create_openai_response_with_instructions(openai_responses_impl, m
# Verify
mock_inference_api.openai_chat_completion.assert_called_once()
call_args = mock_inference_api.openai_chat_completion.call_args
sent_messages = call_args.kwargs["messages"]
params = call_args.args[0]
sent_messages = params.messages
# Check that instructions were prepended as a system message
assert len(sent_messages) == 2
@ -691,7 +728,8 @@ async def test_create_openai_response_with_instructions_and_multiple_messages(
# Verify
mock_inference_api.openai_chat_completion.assert_called_once()
call_args = mock_inference_api.openai_chat_completion.call_args
sent_messages = call_args.kwargs["messages"]
params = call_args.args[0]
sent_messages = params.messages
# Check that instructions were prepended as a system message
assert len(sent_messages) == 4 # 1 system + 3 input messages
@ -751,7 +789,8 @@ async def test_create_openai_response_with_instructions_and_previous_response(
# Verify
mock_inference_api.openai_chat_completion.assert_called_once()
call_args = mock_inference_api.openai_chat_completion.call_args
sent_messages = call_args.kwargs["messages"]
params = call_args.args[0]
sent_messages = params.messages
# Check that instructions were prepended as a system message
assert len(sent_messages) == 4, sent_messages
@ -953,6 +992,58 @@ async def test_store_response_uses_rehydrated_input_with_previous_response(
assert result.status == "completed"
@patch("llama_stack.providers.utils.tools.mcp.list_mcp_tools")
async def test_reuse_mcp_tool_list(
mock_list_mcp_tools, openai_responses_impl, mock_responses_store, mock_inference_api
):
"""Test that mcp_list_tools can be reused where appropriate."""
mock_inference_api.openai_chat_completion.return_value = fake_stream()
mock_list_mcp_tools.return_value = ListToolDefsResponse(
data=[ToolDef(name="test_tool", description="a test tool", input_schema={}, output_schema={})]
)
res1 = await openai_responses_impl.create_openai_response(
input="What is 2+2?",
model="meta-llama/Llama-3.1-8B-Instruct",
store=True,
tools=[
OpenAIResponseInputToolFunction(name="fake", parameters=None),
OpenAIResponseInputToolMCP(server_label="alabel", server_url="aurl"),
],
)
args = mock_responses_store.store_response_object.call_args
data = args.kwargs["response_object"].model_dump()
data["input"] = [input_item.model_dump() for input_item in args.kwargs["input"]]
data["messages"] = [msg.model_dump() for msg in args.kwargs["messages"]]
stored = _OpenAIResponseObjectWithInputAndMessages(**data)
mock_responses_store.get_response_object.return_value = stored
res2 = await openai_responses_impl.create_openai_response(
previous_response_id=res1.id,
input="Now what is 3+3?",
model="meta-llama/Llama-3.1-8B-Instruct",
store=True,
tools=[
OpenAIResponseInputToolMCP(server_label="alabel", server_url="aurl"),
],
)
assert len(mock_inference_api.openai_chat_completion.call_args_list) == 2
second_call = mock_inference_api.openai_chat_completion.call_args_list[1]
second_params = second_call.args[0]
tools_seen = second_params.tools
assert len(tools_seen) == 1
assert tools_seen[0]["function"]["name"] == "test_tool"
assert tools_seen[0]["function"]["description"] == "a test tool"
assert mock_list_mcp_tools.call_count == 1
listings = [obj for obj in res2.output if obj.type == "mcp_list_tools"]
assert len(listings) == 1
assert listings[0].server_label == "alabel"
assert len(listings[0].tools) == 1
assert listings[0].tools[0].name == "test_tool"
@pytest.mark.parametrize(
"text_format, response_format",
[
@ -987,8 +1078,9 @@ async def test_create_openai_response_with_text_format(
# Verify
first_call = mock_inference_api.openai_chat_completion.call_args_list[0]
assert first_call.kwargs["messages"][0].content == input_text
assert first_call.kwargs["response_format"] == response_format
first_params = first_call.args[0]
assert first_params.messages[0].content == input_text
assert first_params.response_format == response_format
async def test_create_openai_response_with_invalid_text_format(openai_responses_impl, mock_inference_api):

View file

@ -0,0 +1,331 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import pytest
from llama_stack.apis.agents.openai_responses import (
OpenAIResponseMessage,
OpenAIResponseObject,
OpenAIResponseObjectStreamResponseCompleted,
OpenAIResponseOutputMessageContentOutputText,
)
from llama_stack.apis.common.errors import (
ConversationNotFoundError,
InvalidConversationIdError,
)
from llama_stack.apis.conversations.conversations import (
ConversationItemList,
)
# Import existing fixtures from the main responses test file
pytest_plugins = ["tests.unit.providers.agents.meta_reference.test_openai_responses"]
from llama_stack.providers.inline.agents.meta_reference.responses.openai_responses import (
OpenAIResponsesImpl,
)
@pytest.fixture
def responses_impl_with_conversations(
mock_inference_api,
mock_tool_groups_api,
mock_tool_runtime_api,
mock_responses_store,
mock_vector_io_api,
mock_conversations_api,
):
"""Create OpenAIResponsesImpl instance with conversations API."""
return OpenAIResponsesImpl(
inference_api=mock_inference_api,
tool_groups_api=mock_tool_groups_api,
tool_runtime_api=mock_tool_runtime_api,
responses_store=mock_responses_store,
vector_io_api=mock_vector_io_api,
conversations_api=mock_conversations_api,
)
class TestConversationValidation:
"""Test conversation ID validation logic."""
async def test_nonexistent_conversation_raises_error(
self, responses_impl_with_conversations, mock_conversations_api
):
"""Test that ConversationNotFoundError is raised for non-existent conversation."""
conv_id = "conv_nonexistent"
# Mock conversation not found
mock_conversations_api.list.side_effect = ConversationNotFoundError("conv_nonexistent")
with pytest.raises(ConversationNotFoundError):
await responses_impl_with_conversations.create_openai_response(
input="Hello", model="test-model", conversation=conv_id, stream=False
)
class TestConversationContextLoading:
"""Test conversation context loading functionality."""
async def test_load_conversation_context_simple_input(
self, responses_impl_with_conversations, mock_conversations_api
):
"""Test loading conversation context with simple string input."""
conv_id = "conv_test123"
input_text = "Hello, how are you?"
# mock items in chronological order (a consequence of order="asc")
mock_conversation_items = ConversationItemList(
data=[
OpenAIResponseMessage(
id="msg_1",
content=[{"type": "input_text", "text": "Previous user message"}],
role="user",
status="completed",
type="message",
),
OpenAIResponseMessage(
id="msg_2",
content=[{"type": "output_text", "text": "Previous assistant response"}],
role="assistant",
status="completed",
type="message",
),
],
first_id="msg_1",
has_more=False,
last_id="msg_2",
object="list",
)
mock_conversations_api.list.return_value = mock_conversation_items
result = await responses_impl_with_conversations._load_conversation_context(conv_id, input_text)
# should have conversation history + new input
assert len(result) == 3
assert isinstance(result[0], OpenAIResponseMessage)
assert result[0].role == "user"
assert isinstance(result[1], OpenAIResponseMessage)
assert result[1].role == "assistant"
assert isinstance(result[2], OpenAIResponseMessage)
assert result[2].role == "user"
assert result[2].content == input_text
async def test_load_conversation_context_api_error(self, responses_impl_with_conversations, mock_conversations_api):
"""Test loading conversation context when API call fails."""
conv_id = "conv_test123"
input_text = "Hello"
mock_conversations_api.list.side_effect = Exception("API Error")
with pytest.raises(Exception, match="API Error"):
await responses_impl_with_conversations._load_conversation_context(conv_id, input_text)
async def test_load_conversation_context_with_list_input(
self, responses_impl_with_conversations, mock_conversations_api
):
"""Test loading conversation context with list input."""
conv_id = "conv_test123"
input_messages = [
OpenAIResponseMessage(role="user", content="First message"),
OpenAIResponseMessage(role="user", content="Second message"),
]
mock_conversations_api.list.return_value = ConversationItemList(
data=[], first_id=None, has_more=False, last_id=None, object="list"
)
result = await responses_impl_with_conversations._load_conversation_context(conv_id, input_messages)
assert len(result) == 2
assert result == input_messages
async def test_load_conversation_context_empty_conversation(
self, responses_impl_with_conversations, mock_conversations_api
):
"""Test loading context from empty conversation."""
conv_id = "conv_empty"
input_text = "Hello"
mock_conversations_api.list.return_value = ConversationItemList(
data=[], first_id=None, has_more=False, last_id=None, object="list"
)
result = await responses_impl_with_conversations._load_conversation_context(conv_id, input_text)
assert len(result) == 1
assert result[0].role == "user"
assert result[0].content == input_text
class TestMessageSyncing:
"""Test message syncing to conversations."""
async def test_sync_response_to_conversation_simple(
self, responses_impl_with_conversations, mock_conversations_api
):
"""Test syncing simple response to conversation."""
conv_id = "conv_test123"
input_text = "What are the 5 Ds of dodgeball?"
# mock response
mock_response = OpenAIResponseObject(
id="resp_123",
created_at=1234567890,
model="test-model",
object="response",
output=[
OpenAIResponseMessage(
id="msg_response",
content=[
OpenAIResponseOutputMessageContentOutputText(
text="The 5 Ds are: Dodge, Duck, Dip, Dive, and Dodge.", type="output_text", annotations=[]
)
],
role="assistant",
status="completed",
type="message",
)
],
status="completed",
)
await responses_impl_with_conversations._sync_response_to_conversation(conv_id, input_text, mock_response)
# should call add_items with user input and assistant response
mock_conversations_api.add_items.assert_called_once()
call_args = mock_conversations_api.add_items.call_args
assert call_args[0][0] == conv_id # conversation_id
items = call_args[0][1] # conversation_items
assert len(items) == 2
# User message
assert items[0].type == "message"
assert items[0].role == "user"
assert items[0].content[0].type == "input_text"
assert items[0].content[0].text == input_text
# Assistant message
assert items[1].type == "message"
assert items[1].role == "assistant"
async def test_sync_response_to_conversation_api_error(
self, responses_impl_with_conversations, mock_conversations_api
):
mock_conversations_api.add_items.side_effect = Exception("API Error")
mock_response = OpenAIResponseObject(
id="resp_123", created_at=1234567890, model="test-model", object="response", output=[], status="completed"
)
# matching the behavior of OpenAI here
with pytest.raises(Exception, match="API Error"):
await responses_impl_with_conversations._sync_response_to_conversation(
"conv_test123", "Hello", mock_response
)
async def test_sync_unsupported_types(self, responses_impl_with_conversations):
mock_response = OpenAIResponseObject(
id="resp_123", created_at=1234567890, model="test-model", object="response", output=[], status="completed"
)
with pytest.raises(NotImplementedError, match="Unsupported input item type"):
await responses_impl_with_conversations._sync_response_to_conversation(
"conv_123", [{"not": "message"}], mock_response
)
with pytest.raises(NotImplementedError, match="Unsupported message role: system"):
await responses_impl_with_conversations._sync_response_to_conversation(
"conv_123", [OpenAIResponseMessage(role="system", content="test")], mock_response
)
class TestIntegrationWorkflow:
"""Integration tests for the full conversation workflow."""
async def test_create_response_with_valid_conversation(
self, responses_impl_with_conversations, mock_conversations_api
):
"""Test creating a response with a valid conversation parameter."""
mock_conversations_api.list.return_value = ConversationItemList(
data=[], first_id=None, has_more=False, last_id=None, object="list"
)
async def mock_streaming_response(*args, **kwargs):
mock_response = OpenAIResponseObject(
id="resp_test123",
created_at=1234567890,
model="test-model",
object="response",
output=[
OpenAIResponseMessage(
id="msg_response",
content=[
OpenAIResponseOutputMessageContentOutputText(
text="Test response", type="output_text", annotations=[]
)
],
role="assistant",
status="completed",
type="message",
)
],
status="completed",
)
yield OpenAIResponseObjectStreamResponseCompleted(response=mock_response, type="response.completed")
responses_impl_with_conversations._create_streaming_response = mock_streaming_response
input_text = "Hello, how are you?"
conversation_id = "conv_test123"
response = await responses_impl_with_conversations.create_openai_response(
input=input_text, model="test-model", conversation=conversation_id, stream=False
)
assert response is not None
assert response.id == "resp_test123"
mock_conversations_api.list.assert_called_once_with(conversation_id, order="asc")
# Note: conversation sync happens in the streaming response flow,
# which is complex to mock fully in this unit test
async def test_create_response_with_invalid_conversation_id(self, responses_impl_with_conversations):
"""Test creating a response with an invalid conversation ID."""
with pytest.raises(InvalidConversationIdError) as exc_info:
await responses_impl_with_conversations.create_openai_response(
input="Hello", model="test-model", conversation="invalid_id", stream=False
)
assert "Expected an ID that begins with 'conv_'" in str(exc_info.value)
async def test_create_response_with_nonexistent_conversation(
self, responses_impl_with_conversations, mock_conversations_api
):
"""Test creating a response with a non-existent conversation."""
mock_conversations_api.list.side_effect = ConversationNotFoundError("conv_nonexistent")
with pytest.raises(ConversationNotFoundError) as exc_info:
await responses_impl_with_conversations.create_openai_response(
input="Hello", model="test-model", conversation="conv_nonexistent", stream=False
)
assert "not found" in str(exc_info.value)
async def test_conversation_and_previous_response_id(
self, responses_impl_with_conversations, mock_conversations_api, mock_responses_store
):
with pytest.raises(ValueError) as exc_info:
await responses_impl_with_conversations.create_openai_response(
input="test", model="test", conversation="conv_123", previous_response_id="resp_123"
)
assert "Mutually exclusive parameters" in str(exc_info.value)
assert "previous_response_id" in str(exc_info.value)
assert "conversation" in str(exc_info.value)

View file

@ -8,6 +8,7 @@
import pytest
from llama_stack.apis.agents.openai_responses import (
OpenAIResponseAnnotationFileCitation,
OpenAIResponseInputFunctionToolCallOutput,
OpenAIResponseInputMessageContentImage,
OpenAIResponseInputMessageContentText,
@ -35,6 +36,7 @@ from llama_stack.apis.inference import (
OpenAIUserMessageParam,
)
from llama_stack.providers.inline.agents.meta_reference.responses.utils import (
_extract_citations_from_text,
convert_chat_choice_to_response_message,
convert_response_content_to_chat_content,
convert_response_input_to_chat_messages,
@ -340,3 +342,26 @@ class TestIsFunctionToolCall:
result = is_function_tool_call(tool_call, tools)
assert result is False
class TestExtractCitationsFromText:
def test_extract_citations_and_annotations(self):
text = "Start [not-a-file]. New source <|file-abc123|>. "
text += "Other source <|file-def456|>? Repeat source <|file-abc123|>! No citation."
file_mapping = {"file-abc123": "doc1.pdf", "file-def456": "doc2.txt"}
annotations, cleaned_text = _extract_citations_from_text(text, file_mapping)
expected_annotations = [
OpenAIResponseAnnotationFileCitation(file_id="file-abc123", filename="doc1.pdf", index=30),
OpenAIResponseAnnotationFileCitation(file_id="file-def456", filename="doc2.txt", index=44),
OpenAIResponseAnnotationFileCitation(file_id="file-abc123", filename="doc1.pdf", index=59),
]
expected_clean_text = "Start [not-a-file]. New source. Other source? Repeat source! No citation."
assert cleaned_text == expected_clean_text
assert annotations == expected_annotations
# OpenAI cites at the end of the sentence
assert cleaned_text[expected_annotations[0].index] == "."
assert cleaned_text[expected_annotations[1].index] == "?"
assert cleaned_text[expected_annotations[2].index] == "!"

View file

@ -0,0 +1,183 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from llama_stack.apis.agents.openai_responses import (
MCPListToolsTool,
OpenAIResponseInputToolFileSearch,
OpenAIResponseInputToolFunction,
OpenAIResponseInputToolMCP,
OpenAIResponseInputToolWebSearch,
OpenAIResponseObject,
OpenAIResponseOutputMessageMCPListTools,
OpenAIResponseToolMCP,
)
from llama_stack.providers.inline.agents.meta_reference.responses.types import ToolContext
class TestToolContext:
def test_no_tools(self):
tools = []
context = ToolContext(tools)
previous_response = OpenAIResponseObject(created_at=1234, id="test", model="mymodel", output=[], status="")
context.recover_tools_from_previous_response(previous_response)
assert len(context.tools_to_process) == 0
assert len(context.previous_tools) == 0
assert len(context.previous_tool_listings) == 0
def test_no_previous_tools(self):
tools = [
OpenAIResponseInputToolFileSearch(vector_store_ids=["fake"]),
OpenAIResponseInputToolMCP(server_label="label", server_url="url"),
]
context = ToolContext(tools)
previous_response = OpenAIResponseObject(created_at=1234, id="test", model="mymodel", output=[], status="")
context.recover_tools_from_previous_response(previous_response)
assert len(context.tools_to_process) == 2
assert len(context.previous_tools) == 0
assert len(context.previous_tool_listings) == 0
def test_reusable_server(self):
tools = [
OpenAIResponseInputToolFileSearch(vector_store_ids=["fake"]),
OpenAIResponseInputToolMCP(server_label="alabel", server_url="aurl"),
]
context = ToolContext(tools)
output = [
OpenAIResponseOutputMessageMCPListTools(
id="test", server_label="alabel", tools=[MCPListToolsTool(name="test_tool", input_schema={})]
)
]
previous_response = OpenAIResponseObject(created_at=1234, id="test", model="fake", output=output, status="")
previous_response.tools = [
OpenAIResponseInputToolFileSearch(vector_store_ids=["fake"]),
OpenAIResponseToolMCP(server_label="alabel"),
]
context.recover_tools_from_previous_response(previous_response)
assert len(context.tools_to_process) == 1
assert context.tools_to_process[0].type == "file_search"
assert len(context.previous_tools) == 1
assert context.previous_tools["test_tool"].server_label == "alabel"
assert context.previous_tools["test_tool"].server_url == "aurl"
assert len(context.previous_tool_listings) == 1
assert len(context.previous_tool_listings[0].tools) == 1
assert context.previous_tool_listings[0].server_label == "alabel"
def test_multiple_reusable_servers(self):
tools = [
OpenAIResponseInputToolFunction(name="fake", parameters=None),
OpenAIResponseInputToolMCP(server_label="anotherlabel", server_url="anotherurl"),
OpenAIResponseInputToolWebSearch(),
OpenAIResponseInputToolMCP(server_label="alabel", server_url="aurl"),
]
context = ToolContext(tools)
output = [
OpenAIResponseOutputMessageMCPListTools(
id="test1", server_label="alabel", tools=[MCPListToolsTool(name="test_tool", input_schema={})]
),
OpenAIResponseOutputMessageMCPListTools(
id="test2",
server_label="anotherlabel",
tools=[MCPListToolsTool(name="some_other_tool", input_schema={})],
),
]
previous_response = OpenAIResponseObject(created_at=1234, id="test", model="fake", output=output, status="")
previous_response.tools = [
OpenAIResponseInputToolFunction(name="fake", parameters=None),
OpenAIResponseToolMCP(server_label="anotherlabel", server_url="anotherurl"),
OpenAIResponseInputToolWebSearch(type="web_search"),
OpenAIResponseToolMCP(server_label="alabel", server_url="aurl"),
]
context.recover_tools_from_previous_response(previous_response)
assert len(context.tools_to_process) == 2
assert context.tools_to_process[0].type == "function"
assert context.tools_to_process[1].type == "web_search"
assert len(context.previous_tools) == 2
assert context.previous_tools["test_tool"].server_label == "alabel"
assert context.previous_tools["test_tool"].server_url == "aurl"
assert context.previous_tools["some_other_tool"].server_label == "anotherlabel"
assert context.previous_tools["some_other_tool"].server_url == "anotherurl"
assert len(context.previous_tool_listings) == 2
assert len(context.previous_tool_listings[0].tools) == 1
assert context.previous_tool_listings[0].server_label == "alabel"
assert len(context.previous_tool_listings[1].tools) == 1
assert context.previous_tool_listings[1].server_label == "anotherlabel"
def test_multiple_servers_only_one_reusable(self):
tools = [
OpenAIResponseInputToolFunction(name="fake", parameters=None),
OpenAIResponseInputToolMCP(server_label="anotherlabel", server_url="anotherurl"),
OpenAIResponseInputToolWebSearch(type="web_search"),
OpenAIResponseInputToolMCP(server_label="alabel", server_url="aurl"),
]
context = ToolContext(tools)
output = [
OpenAIResponseOutputMessageMCPListTools(
id="test2",
server_label="anotherlabel",
tools=[MCPListToolsTool(name="some_other_tool", input_schema={})],
)
]
previous_response = OpenAIResponseObject(created_at=1234, id="test", model="fake", output=output, status="")
previous_response.tools = [
OpenAIResponseInputToolFunction(name="fake", parameters=None),
OpenAIResponseToolMCP(server_label="anotherlabel", server_url="anotherurl"),
OpenAIResponseInputToolWebSearch(type="web_search"),
]
context.recover_tools_from_previous_response(previous_response)
assert len(context.tools_to_process) == 3
assert context.tools_to_process[0].type == "function"
assert context.tools_to_process[1].type == "web_search"
assert context.tools_to_process[2].type == "mcp"
assert len(context.previous_tools) == 1
assert context.previous_tools["some_other_tool"].server_label == "anotherlabel"
assert context.previous_tools["some_other_tool"].server_url == "anotherurl"
assert len(context.previous_tool_listings) == 1
assert len(context.previous_tool_listings[0].tools) == 1
assert context.previous_tool_listings[0].server_label == "anotherlabel"
def test_mismatched_allowed_tools(self):
tools = [
OpenAIResponseInputToolFunction(name="fake", parameters=None),
OpenAIResponseInputToolMCP(server_label="anotherlabel", server_url="anotherurl"),
OpenAIResponseInputToolWebSearch(type="web_search"),
OpenAIResponseInputToolMCP(server_label="alabel", server_url="aurl", allowed_tools=["test_tool_2"]),
]
context = ToolContext(tools)
output = [
OpenAIResponseOutputMessageMCPListTools(
id="test1", server_label="alabel", tools=[MCPListToolsTool(name="test_tool_1", input_schema={})]
),
OpenAIResponseOutputMessageMCPListTools(
id="test2",
server_label="anotherlabel",
tools=[MCPListToolsTool(name="some_other_tool", input_schema={})],
),
]
previous_response = OpenAIResponseObject(created_at=1234, id="test", model="fake", output=output, status="")
previous_response.tools = [
OpenAIResponseInputToolFunction(name="fake", parameters=None),
OpenAIResponseToolMCP(server_label="anotherlabel", server_url="anotherurl"),
OpenAIResponseInputToolWebSearch(type="web_search"),
OpenAIResponseToolMCP(server_label="alabel", server_url="aurl"),
]
context.recover_tools_from_previous_response(previous_response)
assert len(context.tools_to_process) == 3
assert context.tools_to_process[0].type == "function"
assert context.tools_to_process[1].type == "web_search"
assert context.tools_to_process[2].type == "mcp"
assert len(context.previous_tools) == 1
assert context.previous_tools["some_other_tool"].server_label == "anotherlabel"
assert context.previous_tools["some_other_tool"].server_url == "anotherurl"
assert len(context.previous_tool_listings) == 1
assert len(context.previous_tool_listings[0].tools) == 1
assert context.previous_tool_listings[0].server_label == "anotherlabel"

View file

@ -213,7 +213,6 @@ class TestReferenceBatchesImpl:
@pytest.mark.parametrize(
"endpoint",
[
"/v1/embeddings",
"/v1/invalid/endpoint",
"",
],
@ -765,3 +764,12 @@ class TestReferenceBatchesImpl:
await asyncio.sleep(0.042) # let tasks start
assert active_batches == 2, f"Expected 2 active batches, got {active_batches}"
async def test_create_batch_embeddings_endpoint(self, provider):
"""Test that batch creation succeeds with embeddings endpoint."""
batch = await provider.create_batch(
input_file_id="file_123",
endpoint="/v1/embeddings",
completion_window="24h",
)
assert batch.endpoint == "/v1/embeddings"

View file

@ -7,6 +7,8 @@
import json
from unittest.mock import MagicMock
import pytest
from llama_stack.core.request_headers import request_provider_data_context
from llama_stack.providers.remote.inference.groq.config import GroqConfig
from llama_stack.providers.remote.inference.groq.groq import GroqInferenceAdapter
@ -16,74 +18,71 @@ from llama_stack.providers.remote.inference.openai.config import OpenAIConfig
from llama_stack.providers.remote.inference.openai.openai import OpenAIInferenceAdapter
from llama_stack.providers.remote.inference.together.config import TogetherImplConfig
from llama_stack.providers.remote.inference.together.together import TogetherInferenceAdapter
from llama_stack.providers.remote.inference.watsonx.config import WatsonXConfig
from llama_stack.providers.remote.inference.watsonx.watsonx import WatsonXInferenceAdapter
def test_groq_provider_openai_client_caching():
"""Ensure the Groq provider does not cache api keys across client requests"""
config = GroqConfig()
inference_adapter = GroqInferenceAdapter(config)
inference_adapter.__provider_spec__ = MagicMock()
inference_adapter.__provider_spec__.provider_data_validator = (
"llama_stack.providers.remote.inference.groq.config.GroqProviderDataValidator"
)
for api_key in ["test1", "test2"]:
with request_provider_data_context(
{"x-llamastack-provider-data": json.dumps({inference_adapter.provider_data_api_key_field: api_key})}
):
assert inference_adapter.client.api_key == api_key
def test_openai_provider_openai_client_caching():
@pytest.mark.parametrize(
"config_cls,adapter_cls,provider_data_validator",
[
(
GroqConfig,
GroqInferenceAdapter,
"llama_stack.providers.remote.inference.groq.config.GroqProviderDataValidator",
),
(
OpenAIConfig,
OpenAIInferenceAdapter,
"llama_stack.providers.remote.inference.openai.config.OpenAIProviderDataValidator",
),
(
TogetherImplConfig,
TogetherInferenceAdapter,
"llama_stack.providers.remote.inference.together.TogetherProviderDataValidator",
),
(
LlamaCompatConfig,
LlamaCompatInferenceAdapter,
"llama_stack.providers.remote.inference.llama_openai_compat.config.LlamaProviderDataValidator",
),
],
)
def test_openai_provider_data_used(config_cls, adapter_cls, provider_data_validator: str):
"""Ensure the OpenAI provider does not cache api keys across client requests"""
config = OpenAIConfig()
inference_adapter = OpenAIInferenceAdapter(config)
inference_adapter = adapter_cls(config=config_cls())
inference_adapter.__provider_spec__ = MagicMock()
inference_adapter.__provider_spec__.provider_data_validator = (
"llama_stack.providers.remote.inference.openai.config.OpenAIProviderDataValidator"
)
inference_adapter.__provider_spec__.provider_data_validator = provider_data_validator
for api_key in ["test1", "test2"]:
with request_provider_data_context(
{"x-llamastack-provider-data": json.dumps({inference_adapter.provider_data_api_key_field: api_key})}
):
openai_client = inference_adapter.client
assert openai_client.api_key == api_key
def test_together_provider_openai_client_caching():
"""Ensure the Together provider does not cache api keys across client requests"""
config = TogetherImplConfig()
inference_adapter = TogetherInferenceAdapter(config)
inference_adapter.__provider_spec__ = MagicMock()
inference_adapter.__provider_spec__.provider_data_validator = (
"llama_stack.providers.remote.inference.together.TogetherProviderDataValidator"
)
for api_key in ["test1", "test2"]:
with request_provider_data_context({"x-llamastack-provider-data": json.dumps({"together_api_key": api_key})}):
together_client = inference_adapter._get_client()
assert together_client.client.api_key == api_key
openai_client = inference_adapter._get_openai_client()
assert openai_client.api_key == api_key
def test_llama_compat_provider_openai_client_caching():
"""Ensure the LlamaCompat provider does not cache api keys across client requests"""
config = LlamaCompatConfig()
inference_adapter = LlamaCompatInferenceAdapter(config)
inference_adapter.__provider_spec__ = MagicMock()
inference_adapter.__provider_spec__.provider_data_validator = (
"llama_stack.providers.remote.inference.llama_openai_compat.config.LlamaProviderDataValidator"
)
for api_key in ["test1", "test2"]:
with request_provider_data_context({"x-llamastack-provider-data": json.dumps({"llama_api_key": api_key})}):
assert inference_adapter.client.api_key == api_key
@pytest.mark.parametrize(
"config_cls,adapter_cls,provider_data_validator",
[
(
WatsonXConfig,
WatsonXInferenceAdapter,
"llama_stack.providers.remote.inference.watsonx.config.WatsonXProviderDataValidator",
),
],
)
def test_litellm_provider_data_used(config_cls, adapter_cls, provider_data_validator: str):
"""Validate data for LiteLLM-based providers. Similar to test_openai_provider_data_used, but without the
assumption that there is an OpenAI-compatible client object."""
inference_adapter = adapter_cls(config=config_cls())
inference_adapter.__provider_spec__ = MagicMock()
inference_adapter.__provider_spec__.provider_data_validator = provider_data_validator
for api_key in ["test1", "test2"]:
with request_provider_data_context(
{"x-llamastack-provider-data": json.dumps({inference_adapter.provider_data_api_key_field: api_key})}
):
assert inference_adapter.get_api_key() == api_key

View file

@ -18,7 +18,7 @@ class TestOpenAIBaseURLConfig:
def test_default_base_url_without_env_var(self):
"""Test that the adapter uses the default OpenAI base URL when no environment variable is set."""
config = OpenAIConfig(api_key="test-key")
adapter = OpenAIInferenceAdapter(config)
adapter = OpenAIInferenceAdapter(config=config)
adapter.provider_data_api_key_field = None # Disable provider data for this test
assert adapter.get_base_url() == "https://api.openai.com/v1"
@ -27,7 +27,7 @@ class TestOpenAIBaseURLConfig:
"""Test that the adapter uses a custom base URL when provided in config."""
custom_url = "https://custom.openai.com/v1"
config = OpenAIConfig(api_key="test-key", base_url=custom_url)
adapter = OpenAIInferenceAdapter(config)
adapter = OpenAIInferenceAdapter(config=config)
adapter.provider_data_api_key_field = None # Disable provider data for this test
assert adapter.get_base_url() == custom_url
@ -39,7 +39,7 @@ class TestOpenAIBaseURLConfig:
config_data = OpenAIConfig.sample_run_config(api_key="test-key")
processed_config = replace_env_vars(config_data)
config = OpenAIConfig.model_validate(processed_config)
adapter = OpenAIInferenceAdapter(config)
adapter = OpenAIInferenceAdapter(config=config)
adapter.provider_data_api_key_field = None # Disable provider data for this test
assert adapter.get_base_url() == "https://env.openai.com/v1"
@ -49,7 +49,7 @@ class TestOpenAIBaseURLConfig:
"""Test that explicit config value overrides environment variable."""
custom_url = "https://config.openai.com/v1"
config = OpenAIConfig(api_key="test-key", base_url=custom_url)
adapter = OpenAIInferenceAdapter(config)
adapter = OpenAIInferenceAdapter(config=config)
adapter.provider_data_api_key_field = None # Disable provider data for this test
# Config should take precedence over environment variable
@ -60,7 +60,7 @@ class TestOpenAIBaseURLConfig:
"""Test that the OpenAI client is initialized with the configured base URL."""
custom_url = "https://test.openai.com/v1"
config = OpenAIConfig(api_key="test-key", base_url=custom_url)
adapter = OpenAIInferenceAdapter(config)
adapter = OpenAIInferenceAdapter(config=config)
adapter.provider_data_api_key_field = None # Disable provider data for this test
# Mock the get_api_key method since it's delegated to LiteLLMOpenAIMixin
@ -80,7 +80,7 @@ class TestOpenAIBaseURLConfig:
"""Test that check_model_availability uses the configured base URL."""
custom_url = "https://test.openai.com/v1"
config = OpenAIConfig(api_key="test-key", base_url=custom_url)
adapter = OpenAIInferenceAdapter(config)
adapter = OpenAIInferenceAdapter(config=config)
adapter.provider_data_api_key_field = None # Disable provider data for this test
# Mock the get_api_key method
@ -122,7 +122,7 @@ class TestOpenAIBaseURLConfig:
config_data = OpenAIConfig.sample_run_config(api_key="test-key")
processed_config = replace_env_vars(config_data)
config = OpenAIConfig.model_validate(processed_config)
adapter = OpenAIInferenceAdapter(config)
adapter = OpenAIInferenceAdapter(config=config)
adapter.provider_data_api_key_field = None # Disable provider data for this test
# Mock the get_api_key method

View file

@ -5,45 +5,27 @@
# the root directory of this source tree.
import asyncio
import json
import time
from unittest.mock import AsyncMock, MagicMock, PropertyMock, patch
import pytest
from openai.types.chat.chat_completion_chunk import (
ChatCompletionChunk as OpenAIChatCompletionChunk,
)
from openai.types.chat.chat_completion_chunk import (
Choice as OpenAIChoiceChunk,
)
from openai.types.chat.chat_completion_chunk import (
ChoiceDelta as OpenAIChoiceDelta,
)
from openai.types.chat.chat_completion_chunk import (
ChoiceDeltaToolCall as OpenAIChoiceDeltaToolCall,
)
from openai.types.chat.chat_completion_chunk import (
ChoiceDeltaToolCallFunction as OpenAIChoiceDeltaToolCallFunction,
)
from openai.types.model import Model as OpenAIModel
from llama_stack.apis.inference import (
ChatCompletionRequest,
ChatCompletionResponseEventType,
OpenAIAssistantMessageParam,
OpenAIChatCompletion,
OpenAIChatCompletionRequestWithExtraBody,
OpenAIChoice,
OpenAICompletion,
OpenAICompletionChoice,
OpenAICompletionRequestWithExtraBody,
ToolChoice,
UserMessage,
)
from llama_stack.apis.models import Model
from llama_stack.models.llama.datatypes import StopReason
from llama_stack.core.routers.inference import InferenceRouter
from llama_stack.core.routing_tables.models import ModelsRoutingTable
from llama_stack.providers.datatypes import HealthStatus
from llama_stack.providers.remote.inference.vllm.config import VLLMInferenceAdapterConfig
from llama_stack.providers.remote.inference.vllm.vllm import (
VLLMInferenceAdapter,
_process_vllm_chat_completion_stream_response,
)
from llama_stack.providers.remote.inference.vllm.vllm import VLLMInferenceAdapter
# These are unit test for the remote vllm provider
# implementation. This should only contain tests which are specific to
@ -56,37 +38,15 @@ from llama_stack.providers.remote.inference.vllm.vllm import (
# -v -s --tb=short --disable-warnings
@pytest.fixture(scope="module")
def mock_openai_models_list():
with patch("openai.resources.models.AsyncModels.list") as mock_list:
yield mock_list
@pytest.fixture(scope="function")
async def vllm_inference_adapter():
config = VLLMInferenceAdapterConfig(url="http://mocked.localhost:12345")
inference_adapter = VLLMInferenceAdapter(config)
inference_adapter = VLLMInferenceAdapter(config=config)
inference_adapter.model_store = AsyncMock()
# Mock the __provider_spec__ attribute that would normally be set by the resolver
inference_adapter.__provider_spec__ = MagicMock()
inference_adapter.__provider_spec__.provider_type = "vllm-inference"
inference_adapter.__provider_spec__.provider_data_validator = MagicMock()
await inference_adapter.initialize()
return inference_adapter
async def test_register_model_checks_vllm(mock_openai_models_list, vllm_inference_adapter):
async def mock_openai_models():
yield OpenAIModel(id="foo", created=1, object="model", owned_by="test")
mock_openai_models_list.return_value = mock_openai_models()
foo_model = Model(identifier="foo", provider_resource_id="foo", provider_id="vllm-inference")
await vllm_inference_adapter.register_model(foo_model)
mock_openai_models_list.assert_called()
async def test_old_vllm_tool_choice(vllm_inference_adapter):
"""
Test that we set tool_choice to none when no tools are in use
@ -102,416 +62,20 @@ async def test_old_vllm_tool_choice(vllm_inference_adapter):
mock_client_property.return_value = mock_client
# No tools but auto tool choice
await vllm_inference_adapter.openai_chat_completion(
"mock-model",
[],
params = OpenAIChatCompletionRequestWithExtraBody(
model="mock-model",
messages=[{"role": "user", "content": "test"}],
stream=False,
tools=None,
tool_choice=ToolChoice.auto.value,
)
await vllm_inference_adapter.openai_chat_completion(params)
mock_client.chat.completions.create.assert_called()
call_args = mock_client.chat.completions.create.call_args
# Ensure tool_choice gets converted to none for older vLLM versions
assert call_args.kwargs["tool_choice"] == ToolChoice.none.value
async def test_tool_call_delta_empty_tool_call_buf():
"""
Test that we don't generate extra chunks when processing a
tool call response that didn't call any tools. Previously we would
emit chunks with spurious ToolCallParseStatus.succeeded or
ToolCallParseStatus.failed when processing chunks that didn't
actually make any tool calls.
"""
async def mock_stream():
delta = OpenAIChoiceDelta(content="", tool_calls=None)
choices = [OpenAIChoiceChunk(delta=delta, finish_reason="stop", index=0)]
mock_chunk = OpenAIChatCompletionChunk(
id="chunk-1",
created=1,
model="foo",
object="chat.completion.chunk",
choices=choices,
)
for chunk in [mock_chunk]:
yield chunk
chunks = [chunk async for chunk in _process_vllm_chat_completion_stream_response(mock_stream())]
assert len(chunks) == 2
assert chunks[0].event.event_type.value == "start"
assert chunks[1].event.event_type.value == "complete"
assert chunks[1].event.stop_reason == StopReason.end_of_turn
async def test_tool_call_delta_streaming_arguments_dict():
async def mock_stream():
mock_chunk_1 = OpenAIChatCompletionChunk(
id="chunk-1",
created=1,
model="foo",
object="chat.completion.chunk",
choices=[
OpenAIChoiceChunk(
delta=OpenAIChoiceDelta(
content="",
tool_calls=[
OpenAIChoiceDeltaToolCall(
id="tc_1",
index=1,
function=OpenAIChoiceDeltaToolCallFunction(
name="power",
arguments="",
),
)
],
),
finish_reason=None,
index=0,
)
],
)
mock_chunk_2 = OpenAIChatCompletionChunk(
id="chunk-2",
created=1,
model="foo",
object="chat.completion.chunk",
choices=[
OpenAIChoiceChunk(
delta=OpenAIChoiceDelta(
content="",
tool_calls=[
OpenAIChoiceDeltaToolCall(
id="tc_1",
index=1,
function=OpenAIChoiceDeltaToolCallFunction(
name="power",
arguments='{"number": 28, "power": 3}',
),
)
],
),
finish_reason=None,
index=0,
)
],
)
mock_chunk_3 = OpenAIChatCompletionChunk(
id="chunk-3",
created=1,
model="foo",
object="chat.completion.chunk",
choices=[
OpenAIChoiceChunk(
delta=OpenAIChoiceDelta(content="", tool_calls=None), finish_reason="tool_calls", index=0
)
],
)
for chunk in [mock_chunk_1, mock_chunk_2, mock_chunk_3]:
yield chunk
chunks = [chunk async for chunk in _process_vllm_chat_completion_stream_response(mock_stream())]
assert len(chunks) == 3
assert chunks[0].event.event_type.value == "start"
assert chunks[1].event.event_type.value == "progress"
assert chunks[1].event.delta.type == "tool_call"
assert chunks[1].event.delta.parse_status.value == "succeeded"
assert chunks[1].event.delta.tool_call.arguments == '{"number": 28, "power": 3}'
assert chunks[2].event.event_type.value == "complete"
async def test_multiple_tool_calls():
async def mock_stream():
mock_chunk_1 = OpenAIChatCompletionChunk(
id="chunk-1",
created=1,
model="foo",
object="chat.completion.chunk",
choices=[
OpenAIChoiceChunk(
delta=OpenAIChoiceDelta(
content="",
tool_calls=[
OpenAIChoiceDeltaToolCall(
id="",
index=1,
function=OpenAIChoiceDeltaToolCallFunction(
name="power",
arguments='{"number": 28, "power": 3}',
),
),
],
),
finish_reason=None,
index=0,
)
],
)
mock_chunk_2 = OpenAIChatCompletionChunk(
id="chunk-2",
created=1,
model="foo",
object="chat.completion.chunk",
choices=[
OpenAIChoiceChunk(
delta=OpenAIChoiceDelta(
content="",
tool_calls=[
OpenAIChoiceDeltaToolCall(
id="",
index=2,
function=OpenAIChoiceDeltaToolCallFunction(
name="multiple",
arguments='{"first_number": 4, "second_number": 7}',
),
),
],
),
finish_reason=None,
index=0,
)
],
)
mock_chunk_3 = OpenAIChatCompletionChunk(
id="chunk-3",
created=1,
model="foo",
object="chat.completion.chunk",
choices=[
OpenAIChoiceChunk(
delta=OpenAIChoiceDelta(content="", tool_calls=None), finish_reason="tool_calls", index=0
)
],
)
for chunk in [mock_chunk_1, mock_chunk_2, mock_chunk_3]:
yield chunk
chunks = [chunk async for chunk in _process_vllm_chat_completion_stream_response(mock_stream())]
assert len(chunks) == 4
assert chunks[0].event.event_type.value == "start"
assert chunks[1].event.event_type.value == "progress"
assert chunks[1].event.delta.type == "tool_call"
assert chunks[1].event.delta.parse_status.value == "succeeded"
assert chunks[1].event.delta.tool_call.arguments == '{"number": 28, "power": 3}'
assert chunks[2].event.event_type.value == "progress"
assert chunks[2].event.delta.type == "tool_call"
assert chunks[2].event.delta.parse_status.value == "succeeded"
assert chunks[2].event.delta.tool_call.arguments == '{"first_number": 4, "second_number": 7}'
assert chunks[3].event.event_type.value == "complete"
async def test_process_vllm_chat_completion_stream_response_no_choices():
"""
Test that we don't error out when vLLM returns no choices for a
completion request. This can happen when there's an error thrown
in vLLM for example.
"""
async def mock_stream():
choices = []
mock_chunk = OpenAIChatCompletionChunk(
id="chunk-1",
created=1,
model="foo",
object="chat.completion.chunk",
choices=choices,
)
for chunk in [mock_chunk]:
yield chunk
chunks = [chunk async for chunk in _process_vllm_chat_completion_stream_response(mock_stream())]
assert len(chunks) == 1
assert chunks[0].event.event_type.value == "start"
async def test_get_params_empty_tools(vllm_inference_adapter):
request = ChatCompletionRequest(
tools=[],
model="test_model",
messages=[UserMessage(content="test")],
)
params = await vllm_inference_adapter._get_params(request)
assert "tools" not in params
async def test_process_vllm_chat_completion_stream_response_tool_call_args_last_chunk():
"""
Tests the edge case where the model returns the arguments for the tool call in the same chunk that
contains the finish reason (i.e., the last one).
We want to make sure the tool call is executed in this case, and the parameters are passed correctly.
"""
mock_tool_name = "mock_tool"
mock_tool_arguments = {"arg1": 0, "arg2": 100}
mock_tool_arguments_str = json.dumps(mock_tool_arguments)
async def mock_stream():
mock_chunks = [
OpenAIChatCompletionChunk(
id="chunk-1",
created=1,
model="foo",
object="chat.completion.chunk",
choices=[
{
"delta": {
"content": None,
"tool_calls": [
{
"index": 0,
"id": "mock_id",
"type": "function",
"function": {
"name": mock_tool_name,
"arguments": None,
},
}
],
},
"finish_reason": None,
"logprobs": None,
"index": 0,
}
],
),
OpenAIChatCompletionChunk(
id="chunk-1",
created=1,
model="foo",
object="chat.completion.chunk",
choices=[
{
"delta": {
"content": None,
"tool_calls": [
{
"index": 0,
"id": None,
"function": {
"name": None,
"arguments": mock_tool_arguments_str,
},
}
],
},
"finish_reason": "tool_calls",
"logprobs": None,
"index": 0,
}
],
),
]
for chunk in mock_chunks:
yield chunk
chunks = [chunk async for chunk in _process_vllm_chat_completion_stream_response(mock_stream())]
assert len(chunks) == 3
assert chunks[-1].event.event_type == ChatCompletionResponseEventType.complete
assert chunks[-2].event.delta.type == "tool_call"
assert chunks[-2].event.delta.tool_call.tool_name == mock_tool_name
assert chunks[-2].event.delta.tool_call.arguments == mock_tool_arguments_str
async def test_process_vllm_chat_completion_stream_response_no_finish_reason():
"""
Tests the edge case where the model requests a tool call and stays idle without explicitly providing the
finish reason.
We want to make sure that this case is recognized and handled correctly, i.e., as a valid end of message.
"""
mock_tool_name = "mock_tool"
mock_tool_arguments = {"arg1": 0, "arg2": 100}
mock_tool_arguments_str = json.dumps(mock_tool_arguments)
async def mock_stream():
mock_chunks = [
OpenAIChatCompletionChunk(
id="chunk-1",
created=1,
model="foo",
object="chat.completion.chunk",
choices=[
{
"delta": {
"content": None,
"tool_calls": [
{
"index": 0,
"id": "mock_id",
"type": "function",
"function": {
"name": mock_tool_name,
"arguments": mock_tool_arguments_str,
},
}
],
},
"finish_reason": None,
"logprobs": None,
"index": 0,
}
],
),
]
for chunk in mock_chunks:
yield chunk
chunks = [chunk async for chunk in _process_vllm_chat_completion_stream_response(mock_stream())]
assert len(chunks) == 3
assert chunks[-1].event.event_type == ChatCompletionResponseEventType.complete
assert chunks[-2].event.delta.type == "tool_call"
assert chunks[-2].event.delta.tool_call.tool_name == mock_tool_name
assert chunks[-2].event.delta.tool_call.arguments == mock_tool_arguments_str
async def test_process_vllm_chat_completion_stream_response_tool_without_args():
"""
Tests the edge case where no arguments are provided for the tool call.
Tool calls with no arguments should be treated as regular tool calls, which was not the case until now.
"""
mock_tool_name = "mock_tool"
async def mock_stream():
mock_chunks = [
OpenAIChatCompletionChunk(
id="chunk-1",
created=1,
model="foo",
object="chat.completion.chunk",
choices=[
{
"delta": {
"content": None,
"tool_calls": [
{
"index": 0,
"id": "mock_id",
"type": "function",
"function": {
"name": mock_tool_name,
"arguments": "",
},
}
],
},
"finish_reason": None,
"logprobs": None,
"index": 0,
}
],
),
]
for chunk in mock_chunks:
yield chunk
chunks = [chunk async for chunk in _process_vllm_chat_completion_stream_response(mock_stream())]
assert len(chunks) == 3
assert chunks[-1].event.event_type == ChatCompletionResponseEventType.complete
assert chunks[-2].event.delta.type == "tool_call"
assert chunks[-2].event.delta.tool_call.tool_name == mock_tool_name
assert chunks[-2].event.delta.tool_call.arguments == "{}"
async def test_health_status_success(vllm_inference_adapter):
"""
Test the health method of VLLM InferenceAdapter when the connection is successful.
@ -614,9 +178,12 @@ async def test_openai_chat_completion_is_async(vllm_inference_adapter):
)
async def do_inference():
await vllm_inference_adapter.openai_chat_completion(
"mock-model", messages=["one fish", "two fish"], stream=False
params = OpenAIChatCompletionRequestWithExtraBody(
model="mock-model",
messages=[{"role": "user", "content": "one fish two fish"}],
stream=False,
)
await vllm_inference_adapter.openai_chat_completion(params)
with patch.object(VLLMInferenceAdapter, "client", new_callable=PropertyMock) as mock_create_client:
mock_client = MagicMock()
@ -631,105 +198,146 @@ async def test_openai_chat_completion_is_async(vllm_inference_adapter):
assert total_time < (sleep_time * 2), f"Total time taken: {total_time}s exceeded expected max"
async def test_should_refresh_models():
async def test_vllm_completion_extra_body():
"""
Test the should_refresh_models method with different refresh_models configurations.
This test verifies that:
1. When refresh_models is True, should_refresh_models returns True regardless of api_token
2. When refresh_models is False, should_refresh_models returns False regardless of api_token
Test that vLLM-specific guided_choice and prompt_logprobs parameters are correctly forwarded
via extra_body to the underlying OpenAI client through the InferenceRouter.
"""
# Set up the vLLM adapter
config = VLLMInferenceAdapterConfig(url="http://mocked.localhost:12345")
vllm_adapter = VLLMInferenceAdapter(config=config)
vllm_adapter.__provider_id__ = "vllm"
await vllm_adapter.initialize()
# Test case 1: refresh_models is True, api_token is None
config1 = VLLMInferenceAdapterConfig(url="http://test.localhost", api_token=None, refresh_models=True)
adapter1 = VLLMInferenceAdapter(config1)
result1 = await adapter1.should_refresh_models()
assert result1 is True, "should_refresh_models should return True when refresh_models is True"
# Create a mock model store
mock_model_store = AsyncMock()
mock_model = Model(identifier="mock-model", provider_resource_id="mock-model", provider_id="vllm")
mock_model_store.get_model.return_value = mock_model
mock_model_store.has_model.return_value = True
# Test case 2: refresh_models is True, api_token is empty string
config2 = VLLMInferenceAdapterConfig(url="http://test.localhost", api_token="", refresh_models=True)
adapter2 = VLLMInferenceAdapter(config2)
result2 = await adapter2.should_refresh_models()
assert result2 is True, "should_refresh_models should return True when refresh_models is True"
# Create a mock dist_registry
mock_dist_registry = MagicMock()
mock_dist_registry.get = AsyncMock(return_value=mock_model)
mock_dist_registry.set = AsyncMock()
# Test case 3: refresh_models is True, api_token is "fake" (default)
config3 = VLLMInferenceAdapterConfig(url="http://test.localhost", api_token="fake", refresh_models=True)
adapter3 = VLLMInferenceAdapter(config3)
result3 = await adapter3.should_refresh_models()
assert result3 is True, "should_refresh_models should return True when refresh_models is True"
# Set up the routing table
routing_table = ModelsRoutingTable(
impls_by_provider_id={"vllm": vllm_adapter},
dist_registry=mock_dist_registry,
policy=[],
)
# Inject the model store into the adapter
vllm_adapter.model_store = routing_table
# Test case 4: refresh_models is True, api_token is real token
config4 = VLLMInferenceAdapterConfig(url="http://test.localhost", api_token="real-token-123", refresh_models=True)
adapter4 = VLLMInferenceAdapter(config4)
result4 = await adapter4.should_refresh_models()
assert result4 is True, "should_refresh_models should return True when refresh_models is True"
# Create the InferenceRouter
router = InferenceRouter(routing_table=routing_table)
# Test case 5: refresh_models is False, api_token is real token
config5 = VLLMInferenceAdapterConfig(url="http://test.localhost", api_token="real-token-456", refresh_models=False)
adapter5 = VLLMInferenceAdapter(config5)
result5 = await adapter5.should_refresh_models()
assert result5 is False, "should_refresh_models should return False when refresh_models is False"
async def test_provider_data_var_context_propagation(vllm_inference_adapter):
"""
Test that PROVIDER_DATA_VAR context is properly propagated through the vLLM inference adapter.
This ensures that dynamic provider data (like API tokens) can be passed through context.
Note: The base URL is always taken from config.url, not from provider data.
"""
# Mock the AsyncOpenAI class to capture provider data
with (
patch("llama_stack.providers.utils.inference.openai_mixin.AsyncOpenAI") as mock_openai_class,
patch.object(vllm_inference_adapter, "get_request_provider_data") as mock_get_provider_data,
):
mock_client = AsyncMock()
mock_client.chat.completions.create = AsyncMock()
mock_openai_class.return_value = mock_client
# Mock provider data to return test data
mock_provider_data = MagicMock()
mock_provider_data.vllm_api_token = "test-token-123"
mock_provider_data.vllm_url = "http://test-server:8000/v1"
mock_get_provider_data.return_value = mock_provider_data
# Mock the model
mock_model = Model(identifier="test-model", provider_resource_id="test-model", provider_id="vllm-inference")
vllm_inference_adapter.model_store.get_model.return_value = mock_model
try:
# Execute chat completion
await vllm_inference_adapter.openai_chat_completion(
model="test-model",
messages=[UserMessage(content="Hello")],
stream=False,
# Patch the OpenAI client
with patch.object(VLLMInferenceAdapter, "client", new_callable=PropertyMock) as mock_client_property:
mock_client = MagicMock()
mock_client.completions.create = AsyncMock(
return_value=OpenAICompletion(
id="cmpl-abc123",
created=1,
model="mock-model",
choices=[
OpenAICompletionChoice(
text="joy",
finish_reason="stop",
index=0,
)
],
)
)
mock_client_property.return_value = mock_client
# Verify that ALL client calls were made with the correct parameters
calls = mock_openai_class.call_args_list
incorrect_calls = []
# Test with guided_choice and prompt_logprobs as extra fields
params = OpenAICompletionRequestWithExtraBody(
model="mock-model",
prompt="I am feeling happy",
stream=False,
guided_choice=["joy", "sadness"],
prompt_logprobs=5,
)
await router.openai_completion(params)
for i, call in enumerate(calls):
api_key = call[1]["api_key"]
base_url = call[1]["base_url"]
# Verify that the client was called with extra_body containing both parameters
mock_client.completions.create.assert_called_once()
call_kwargs = mock_client.completions.create.call_args.kwargs
assert "extra_body" in call_kwargs
assert "guided_choice" in call_kwargs["extra_body"]
assert call_kwargs["extra_body"]["guided_choice"] == ["joy", "sadness"]
assert "prompt_logprobs" in call_kwargs["extra_body"]
assert call_kwargs["extra_body"]["prompt_logprobs"] == 5
if api_key != "test-token-123" or base_url != "http://mocked.localhost:12345":
incorrect_calls.append({"call_index": i, "api_key": api_key, "base_url": base_url})
if incorrect_calls:
error_msg = (
f"Found {len(incorrect_calls)} calls with incorrect parameters out of {len(calls)} total calls:\n"
)
for incorrect_call in incorrect_calls:
error_msg += f" Call {incorrect_call['call_index']}: api_key='{incorrect_call['api_key']}', base_url='{incorrect_call['base_url']}'\n"
error_msg += "Expected: api_key='test-token-123', base_url='http://mocked.localhost:12345'"
raise AssertionError(error_msg)
async def test_vllm_chat_completion_extra_body():
"""
Test that vLLM-specific parameters (e.g., chat_template_kwargs) are correctly forwarded
via extra_body to the underlying OpenAI client through the InferenceRouter for chat completion.
"""
# Set up the vLLM adapter
config = VLLMInferenceAdapterConfig(url="http://mocked.localhost:12345")
vllm_adapter = VLLMInferenceAdapter(config=config)
vllm_adapter.__provider_id__ = "vllm"
await vllm_adapter.initialize()
# Ensure at least one call was made
assert len(calls) >= 1, "No AsyncOpenAI client calls were made"
# Create a mock model store
mock_model_store = AsyncMock()
mock_model = Model(identifier="mock-model", provider_resource_id="mock-model", provider_id="vllm")
mock_model_store.get_model.return_value = mock_model
mock_model_store.has_model.return_value = True
# Verify that chat completion was called
mock_client.chat.completions.create.assert_called_once()
# Create a mock dist_registry
mock_dist_registry = MagicMock()
mock_dist_registry.get = AsyncMock(return_value=mock_model)
mock_dist_registry.set = AsyncMock()
finally:
# Clean up context
pass
# Set up the routing table
routing_table = ModelsRoutingTable(
impls_by_provider_id={"vllm": vllm_adapter},
dist_registry=mock_dist_registry,
policy=[],
)
# Inject the model store into the adapter
vllm_adapter.model_store = routing_table
# Create the InferenceRouter
router = InferenceRouter(routing_table=routing_table)
# Patch the OpenAI client
with patch.object(VLLMInferenceAdapter, "client", new_callable=PropertyMock) as mock_client_property:
mock_client = MagicMock()
mock_client.chat.completions.create = AsyncMock(
return_value=OpenAIChatCompletion(
id="chatcmpl-abc123",
created=1,
model="mock-model",
choices=[
OpenAIChoice(
message=OpenAIAssistantMessageParam(
content="test response",
),
finish_reason="stop",
index=0,
)
],
)
)
mock_client_property.return_value = mock_client
# Test with chat_template_kwargs as extra field
params = OpenAIChatCompletionRequestWithExtraBody(
model="mock-model",
messages=[{"role": "user", "content": "test"}],
stream=False,
chat_template_kwargs={"thinking": True},
)
await router.openai_chat_completion(params)
# Verify that the client was called with extra_body containing chat_template_kwargs
mock_client.chat.completions.create.assert_called_once()
call_kwargs = mock_client.chat.completions.create.call_args.kwargs
assert "extra_body" in call_kwargs
assert "chat_template_kwargs" in call_kwargs["extra_body"]
assert call_kwargs["extra_body"]["chat_template_kwargs"] == {"thinking": True}

View file

@ -5,14 +5,17 @@
# the root directory of this source tree.
import json
from collections.abc import Iterable
from typing import Any
from unittest.mock import AsyncMock, MagicMock, Mock, PropertyMock, patch
import pytest
from pydantic import BaseModel, Field
from llama_stack.apis.inference import Model, OpenAIUserMessageParam
from llama_stack.apis.inference import Model, OpenAIChatCompletionRequestWithExtraBody, OpenAIUserMessageParam
from llama_stack.apis.models import ModelType
from llama_stack.core.request_headers import request_provider_data_context
from llama_stack.providers.utils.inference.model_registry import RemoteInferenceProviderConfig
from llama_stack.providers.utils.inference.openai_mixin import OpenAIMixin
@ -29,7 +32,7 @@ class OpenAIMixinImpl(OpenAIMixin):
class OpenAIMixinWithEmbeddingsImpl(OpenAIMixinImpl):
"""Test implementation with embedding model metadata"""
embedding_model_metadata = {
embedding_model_metadata: dict[str, dict[str, int]] = {
"text-embedding-3-small": {"embedding_dimension": 1536, "context_length": 8192},
"text-embedding-ada-002": {"embedding_dimension": 1536, "context_length": 8192},
}
@ -38,13 +41,15 @@ class OpenAIMixinWithEmbeddingsImpl(OpenAIMixinImpl):
@pytest.fixture
def mixin():
"""Create a test instance of OpenAIMixin with mocked model_store"""
mixin_instance = OpenAIMixinImpl()
config = RemoteInferenceProviderConfig()
mixin_instance = OpenAIMixinImpl(config=config)
# just enough to satisfy _get_provider_model_id calls
mock_model_store = MagicMock()
# Mock model_store with async methods
mock_model_store = AsyncMock()
mock_model = MagicMock()
mock_model.provider_resource_id = "test-provider-resource-id"
mock_model_store.get_model = AsyncMock(return_value=mock_model)
mock_model_store.has_model = AsyncMock(return_value=False) # Default to False, tests can override
mixin_instance.model_store = mock_model_store
return mixin_instance
@ -53,7 +58,8 @@ def mixin():
@pytest.fixture
def mixin_with_embeddings():
"""Create a test instance of OpenAIMixin with embedding model metadata"""
return OpenAIMixinWithEmbeddingsImpl()
config = RemoteInferenceProviderConfig()
return OpenAIMixinWithEmbeddingsImpl(config=config)
@pytest.fixture
@ -184,6 +190,40 @@ class TestOpenAIMixinCheckModelAvailability:
assert len(mixin._model_cache) == 3
async def test_check_model_availability_with_pre_registered_model(
self, mixin, mock_client_with_models, mock_client_context
):
"""Test that check_model_availability returns True for pre-registered models in model_store"""
# Mock model_store.has_model to return True for a specific model
mock_model_store = AsyncMock()
mock_model_store.has_model = AsyncMock(return_value=True)
mixin.model_store = mock_model_store
# Test that pre-registered model is found without calling the provider's API
with mock_client_context(mixin, mock_client_with_models):
mock_client_with_models.models.list.assert_not_called()
assert await mixin.check_model_availability("pre-registered-model")
# Should not call the provider's list_models since model was found in store
mock_client_with_models.models.list.assert_not_called()
mock_model_store.has_model.assert_called_once_with("pre-registered-model")
async def test_check_model_availability_fallback_to_provider_when_not_in_store(
self, mixin, mock_client_with_models, mock_client_context
):
"""Test that check_model_availability falls back to provider when model not in store"""
# Mock model_store.has_model to return False
mock_model_store = AsyncMock()
mock_model_store.has_model = AsyncMock(return_value=False)
mixin.model_store = mock_model_store
# Test that it falls back to provider's model cache
with mock_client_context(mixin, mock_client_with_models):
mock_client_with_models.models.list.assert_not_called()
assert await mixin.check_model_availability("some-mock-model-id")
# Should call the provider's list_models since model was not found in store
mock_client_with_models.models.list.assert_called_once()
mock_model_store.has_model.assert_called_once_with("some-mock-model-id")
class TestOpenAIMixinCacheBehavior:
"""Test cases for cache behavior and edge cases"""
@ -231,7 +271,8 @@ class TestOpenAIMixinImagePreprocessing:
with patch("llama_stack.providers.utils.inference.openai_mixin.localize_image_content") as mock_localize:
mock_localize.return_value = (b"fake_image_data", "jpeg")
await mixin.openai_chat_completion(model="test-model", messages=[message])
params = OpenAIChatCompletionRequestWithExtraBody(model="test-model", messages=[message])
await mixin.openai_chat_completion(params)
mock_localize.assert_called_once_with("http://example.com/image.jpg")
@ -263,7 +304,8 @@ class TestOpenAIMixinImagePreprocessing:
with patch.object(type(mixin), "client", new_callable=PropertyMock, return_value=mock_client):
with patch("llama_stack.providers.utils.inference.openai_mixin.localize_image_content") as mock_localize:
await mixin.openai_chat_completion(model="test-model", messages=[message])
params = OpenAIChatCompletionRequestWithExtraBody(model="test-model", messages=[message])
await mixin.openai_chat_completion(params)
mock_localize.assert_not_called()
@ -461,10 +503,16 @@ class TestOpenAIMixinModelRegistration:
assert result is None
async def test_should_refresh_models(self, mixin):
"""Test should_refresh_models method (should always return False)"""
"""Test should_refresh_models method returns config value"""
# Default config has refresh_models=False
result = await mixin.should_refresh_models()
assert result is False
config_with_refresh = RemoteInferenceProviderConfig(refresh_models=True)
mixin_with_refresh = OpenAIMixinImpl(config=config_with_refresh)
result_with_refresh = await mixin_with_refresh.should_refresh_models()
assert result_with_refresh is True
async def test_register_model_error_propagation(self, mixin, mock_client_with_exception, mock_client_context):
"""Test that errors from provider API are properly propagated during registration"""
model = Model(
@ -498,13 +546,145 @@ class OpenAIMixinWithProviderData(OpenAIMixinImpl):
return "default-base-url"
class CustomListProviderModelIdsImplementation(OpenAIMixinImpl):
"""Test implementation with custom list_provider_model_ids override"""
custom_model_ids: Any
async def list_provider_model_ids(self) -> Iterable[str]:
"""Return custom model IDs list"""
return self.custom_model_ids
class TestOpenAIMixinCustomListProviderModelIds:
"""Test cases for custom list_provider_model_ids() implementation functionality"""
@pytest.fixture
def custom_model_ids_list(self):
"""Create a list of custom model ID strings"""
return ["custom-model-1", "custom-model-2", "custom-embedding"]
@pytest.fixture
def config(self):
"""Create RemoteInferenceProviderConfig instance"""
return RemoteInferenceProviderConfig()
@pytest.fixture
def adapter(self, custom_model_ids_list, config):
"""Create mixin instance with custom list_provider_model_ids implementation"""
mixin = CustomListProviderModelIdsImplementation(config=config, custom_model_ids=custom_model_ids_list)
mixin.embedding_model_metadata = {"custom-embedding": {"embedding_dimension": 768, "context_length": 512}}
return mixin
async def test_is_used(self, adapter, custom_model_ids_list):
"""Test that custom list_provider_model_ids() implementation is used instead of client.models.list()"""
result = await adapter.list_models()
assert result is not None
assert len(result) == 3
assert set(custom_model_ids_list) == {m.identifier for m in result}
async def test_populates_cache(self, adapter, custom_model_ids_list):
"""Test that custom list_provider_model_ids() results are cached"""
assert len(adapter._model_cache) == 0
await adapter.list_models()
assert set(custom_model_ids_list) == set(adapter._model_cache.keys())
async def test_respects_allowed_models(self, config):
"""Test that custom list_provider_model_ids() respects allowed_models filtering"""
mixin = CustomListProviderModelIdsImplementation(
config=config, custom_model_ids=["model-1", "model-2", "model-3"]
)
mixin.allowed_models = ["model-1"]
result = await mixin.list_models()
assert result is not None
assert len(result) == 1
assert result[0].identifier == "model-1"
async def test_with_empty_list(self, config):
"""Test that custom list_provider_model_ids() handles empty list correctly"""
mixin = CustomListProviderModelIdsImplementation(config=config, custom_model_ids=[])
result = await mixin.list_models()
assert result is not None
assert len(result) == 0
assert len(mixin._model_cache) == 0
async def test_wrong_type_raises_error(self, config):
"""Test that list_provider_model_ids() returning unhashable items results in an error"""
mixin = CustomListProviderModelIdsImplementation(
config=config, custom_model_ids=["valid-model", ["nested", "list"]]
)
with pytest.raises(Exception, match="is not a string"):
await mixin.list_models()
mixin = CustomListProviderModelIdsImplementation(
config=config, custom_model_ids=[{"key": "value"}, "valid-model"]
)
with pytest.raises(Exception, match="is not a string"):
await mixin.list_models()
mixin = CustomListProviderModelIdsImplementation(config=config, custom_model_ids=["valid-model", 42.0])
with pytest.raises(Exception, match="is not a string"):
await mixin.list_models()
mixin = CustomListProviderModelIdsImplementation(config=config, custom_model_ids=[None])
with pytest.raises(Exception, match="is not a string"):
await mixin.list_models()
async def test_non_iterable_raises_error(self, config):
"""Test that list_provider_model_ids() returning non-iterable type raises error"""
mixin = CustomListProviderModelIdsImplementation(config=config, custom_model_ids=42)
with pytest.raises(
TypeError,
match=r"Failed to list models: CustomListProviderModelIdsImplementation\.list_provider_model_ids\(\) must return an iterable.*but returned int",
):
await mixin.list_models()
async def test_accepts_various_iterables(self, config):
"""Test that list_provider_model_ids() accepts tuples, sets, generators, etc."""
tuples = CustomListProviderModelIdsImplementation(
config=config, custom_model_ids=("model-1", "model-2", "model-3")
)
result = await tuples.list_models()
assert result is not None
assert len(result) == 3
class GeneratorAdapter(OpenAIMixinImpl):
async def list_provider_model_ids(self) -> Iterable[str]:
def gen():
yield "gen-model-1"
yield "gen-model-2"
return gen()
mixin = GeneratorAdapter(config=config)
result = await mixin.list_models()
assert result is not None
assert len(result) == 2
sets = CustomListProviderModelIdsImplementation(config=config, custom_model_ids={"set-model-1", "set-model-2"})
result = await sets.list_models()
assert result is not None
assert len(result) == 2
class TestOpenAIMixinProviderDataApiKey:
"""Test cases for provider_data_api_key_field functionality"""
@pytest.fixture
def mixin_with_provider_data_field(self):
"""Mixin instance with provider_data_api_key_field set"""
mixin_instance = OpenAIMixinWithProviderData()
config = RemoteInferenceProviderConfig()
mixin_instance = OpenAIMixinWithProviderData(config=config)
# Mock provider_spec for provider data validation
mock_provider_spec = MagicMock()
@ -542,7 +722,7 @@ class TestOpenAIMixinProviderDataApiKey:
):
"""Test that ValueError is raised when provider data exists but doesn't have required key"""
with request_provider_data_context({"x-llamastack-provider-data": json.dumps({"wrong_key": "some-value"})}):
with pytest.raises(ValueError, match="API key is not set"):
with pytest.raises(ValueError, match="API key not provided"):
_ = mixin_with_provider_data_field_and_none_api_key.client
def test_error_message_includes_correct_field_names(self, mixin_with_provider_data_field_and_none_api_key):

View file

@ -0,0 +1,77 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import pytest
from llama_stack.core.stack import replace_env_vars
from llama_stack.providers.remote.inference.anthropic import AnthropicConfig
from llama_stack.providers.remote.inference.azure import AzureConfig
from llama_stack.providers.remote.inference.bedrock import BedrockConfig
from llama_stack.providers.remote.inference.cerebras import CerebrasImplConfig
from llama_stack.providers.remote.inference.databricks import DatabricksImplConfig
from llama_stack.providers.remote.inference.fireworks import FireworksImplConfig
from llama_stack.providers.remote.inference.gemini import GeminiConfig
from llama_stack.providers.remote.inference.groq import GroqConfig
from llama_stack.providers.remote.inference.llama_openai_compat import LlamaCompatConfig
from llama_stack.providers.remote.inference.nvidia import NVIDIAConfig
from llama_stack.providers.remote.inference.ollama import OllamaImplConfig
from llama_stack.providers.remote.inference.openai import OpenAIConfig
from llama_stack.providers.remote.inference.runpod import RunpodImplConfig
from llama_stack.providers.remote.inference.sambanova import SambaNovaImplConfig
from llama_stack.providers.remote.inference.tgi import TGIImplConfig
from llama_stack.providers.remote.inference.together import TogetherImplConfig
from llama_stack.providers.remote.inference.vertexai import VertexAIConfig
from llama_stack.providers.remote.inference.vllm import VLLMInferenceAdapterConfig
from llama_stack.providers.remote.inference.watsonx import WatsonXConfig
class TestRemoteInferenceProviderConfig:
@pytest.mark.parametrize(
"config_cls,alias_name,env_name,extra_config",
[
(AnthropicConfig, "api_key", "ANTHROPIC_API_KEY", {}),
(AzureConfig, "api_key", "AZURE_API_KEY", {"api_base": "HTTP://FAKE"}),
(BedrockConfig, None, None, {}),
(CerebrasImplConfig, "api_key", "CEREBRAS_API_KEY", {}),
(DatabricksImplConfig, "api_token", "DATABRICKS_TOKEN", {}),
(FireworksImplConfig, "api_key", "FIREWORKS_API_KEY", {}),
(GeminiConfig, "api_key", "GEMINI_API_KEY", {}),
(GroqConfig, "api_key", "GROQ_API_KEY", {}),
(LlamaCompatConfig, "api_key", "LLAMA_API_KEY", {}),
(NVIDIAConfig, "api_key", "NVIDIA_API_KEY", {}),
(OllamaImplConfig, None, None, {}),
(OpenAIConfig, "api_key", "OPENAI_API_KEY", {}),
(RunpodImplConfig, "api_token", "RUNPOD_API_TOKEN", {}),
(SambaNovaImplConfig, "api_key", "SAMBANOVA_API_KEY", {}),
(TGIImplConfig, None, None, {"url": "FAKE"}),
(TogetherImplConfig, "api_key", "TOGETHER_API_KEY", {}),
(VertexAIConfig, None, None, {"project": "FAKE", "location": "FAKE"}),
(VLLMInferenceAdapterConfig, "api_token", "VLLM_API_TOKEN", {}),
(WatsonXConfig, "api_key", "WATSONX_API_KEY", {}),
],
)
def test_provider_config_auth_credentials(self, monkeypatch, config_cls, alias_name, env_name, extra_config):
"""Test that the config class correctly maps the alias to auth_credential."""
secret_value = config_cls.__name__
if alias_name is None:
pytest.skip("No alias name provided for this config class.")
config = config_cls(**{alias_name: secret_value, **extra_config})
assert config.auth_credential is not None
assert config.auth_credential.get_secret_value() == secret_value
schema = config_cls.model_json_schema()
assert alias_name in schema["properties"]
assert "auth_credential" not in schema["properties"]
if env_name:
monkeypatch.setenv(env_name, secret_value)
sample_config = config_cls.sample_run_config()
expanded_config = replace_env_vars(sample_config)
config_from_sample = config_cls(**{**expanded_config, **extra_config})
assert config_from_sample.auth_credential is not None
assert config_from_sample.auth_credential.get_secret_value() == secret_value

View file

@ -9,32 +9,22 @@ from unittest.mock import AsyncMock, MagicMock, patch
import numpy as np
import pytest
from chromadb import PersistentClient
from pymilvus import MilvusClient, connections
from llama_stack.apis.vector_dbs import VectorDB
from llama_stack.apis.vector_io import Chunk, ChunkMetadata, QueryChunksResponse
from llama_stack.providers.inline.vector_io.chroma.config import ChromaVectorIOConfig
from llama_stack.providers.inline.vector_io.faiss.config import FaissVectorIOConfig
from llama_stack.providers.inline.vector_io.faiss.faiss import FaissIndex, FaissVectorIOAdapter
from llama_stack.providers.inline.vector_io.milvus.config import MilvusVectorIOConfig, SqliteKVStoreConfig
from llama_stack.providers.inline.vector_io.qdrant import QdrantVectorIOConfig
from llama_stack.providers.inline.vector_io.sqlite_vec import SQLiteVectorIOConfig
from llama_stack.providers.inline.vector_io.sqlite_vec.sqlite_vec import SQLiteVecIndex, SQLiteVecVectorIOAdapter
from llama_stack.providers.remote.vector_io.chroma.chroma import ChromaIndex, ChromaVectorIOAdapter, maybe_await
from llama_stack.providers.remote.vector_io.milvus.milvus import MilvusIndex, MilvusVectorIOAdapter
from llama_stack.providers.remote.vector_io.pgvector.config import PGVectorVectorIOConfig
from llama_stack.providers.remote.vector_io.pgvector.pgvector import PGVectorIndex, PGVectorVectorIOAdapter
from llama_stack.providers.remote.vector_io.qdrant.qdrant import QdrantVectorIOAdapter
from llama_stack.providers.remote.vector_io.weaviate.config import WeaviateVectorIOConfig
from llama_stack.providers.remote.vector_io.weaviate.weaviate import WeaviateIndex, WeaviateVectorIOAdapter
from llama_stack.providers.utils.kvstore.config import SqliteKVStoreConfig
EMBEDDING_DIMENSION = 384
COLLECTION_PREFIX = "test_collection"
MILVUS_ALIAS = "test_milvus"
@pytest.fixture(params=["milvus", "sqlite_vec", "faiss", "chroma", "pgvector", "weaviate"])
@pytest.fixture(params=["sqlite_vec", "faiss", "pgvector"])
def vector_provider(request):
return request.param
@ -145,10 +135,10 @@ async def sqlite_vec_vec_index(embedding_dimension, tmp_path_factory):
@pytest.fixture
async def sqlite_vec_adapter(sqlite_vec_db_path, mock_inference_api, embedding_dimension):
async def sqlite_vec_adapter(sqlite_vec_db_path, unique_kvstore_config, mock_inference_api, embedding_dimension):
config = SQLiteVectorIOConfig(
db_path=sqlite_vec_db_path,
kvstore=SqliteKVStoreConfig(),
kvstore=unique_kvstore_config,
)
adapter = SQLiteVecVectorIOAdapter(
config=config,
@ -170,46 +160,6 @@ async def sqlite_vec_adapter(sqlite_vec_db_path, mock_inference_api, embedding_d
await adapter.shutdown()
@pytest.fixture(scope="session")
def milvus_vec_db_path(tmp_path_factory):
db_path = str(tmp_path_factory.getbasetemp() / "test_milvus.db")
return db_path
@pytest.fixture
async def milvus_vec_index(milvus_vec_db_path, embedding_dimension):
client = MilvusClient(milvus_vec_db_path)
name = f"{COLLECTION_PREFIX}_{np.random.randint(1e6)}"
connections.connect(alias=MILVUS_ALIAS, uri=milvus_vec_db_path)
index = MilvusIndex(client, name, consistency_level="Strong")
index.db_path = milvus_vec_db_path
yield index
@pytest.fixture
async def milvus_vec_adapter(milvus_vec_db_path, mock_inference_api):
config = MilvusVectorIOConfig(
db_path=milvus_vec_db_path,
kvstore=SqliteKVStoreConfig(),
)
adapter = MilvusVectorIOAdapter(
config=config,
inference_api=mock_inference_api,
files_api=None,
)
await adapter.initialize()
await adapter.register_vector_db(
VectorDB(
identifier=adapter.metadata_collection_name,
provider_id="test_provider",
embedding_model="test_model",
embedding_dimension=128,
)
)
yield adapter
await adapter.shutdown()
@pytest.fixture
def faiss_vec_db_path(tmp_path_factory):
db_path = str(tmp_path_factory.getbasetemp() / "test_faiss.db")
@ -246,98 +196,6 @@ async def faiss_vec_adapter(unique_kvstore_config, mock_inference_api, embedding
await adapter.shutdown()
@pytest.fixture
def chroma_vec_db_path(tmp_path_factory):
persist_dir = tmp_path_factory.mktemp(f"chroma_{np.random.randint(1e6)}")
return str(persist_dir)
@pytest.fixture
async def chroma_vec_index(chroma_vec_db_path, embedding_dimension):
client = PersistentClient(path=chroma_vec_db_path)
name = f"{COLLECTION_PREFIX}_{np.random.randint(1e6)}"
collection = await maybe_await(client.get_or_create_collection(name))
index = ChromaIndex(client=client, collection=collection)
await index.initialize()
yield index
await index.delete()
@pytest.fixture
async def chroma_vec_adapter(chroma_vec_db_path, mock_inference_api, embedding_dimension):
config = ChromaVectorIOConfig(
db_path=chroma_vec_db_path,
kvstore=SqliteKVStoreConfig(),
)
adapter = ChromaVectorIOAdapter(
config=config,
inference_api=mock_inference_api,
files_api=None,
)
await adapter.initialize()
await adapter.register_vector_db(
VectorDB(
identifier=f"chroma_test_collection_{random.randint(1, 1_000_000)}",
provider_id="test_provider",
embedding_model="test_model",
embedding_dimension=embedding_dimension,
)
)
yield adapter
await adapter.shutdown()
@pytest.fixture
def qdrant_vec_db_path(tmp_path_factory):
import uuid
db_path = str(tmp_path_factory.getbasetemp() / f"test_qdrant_{uuid.uuid4()}.db")
return db_path
@pytest.fixture
async def qdrant_vec_adapter(qdrant_vec_db_path, mock_inference_api, embedding_dimension):
import uuid
config = QdrantVectorIOConfig(
db_path=qdrant_vec_db_path,
kvstore=SqliteKVStoreConfig(),
)
adapter = QdrantVectorIOAdapter(
config=config,
inference_api=mock_inference_api,
files_api=None,
)
collection_id = f"qdrant_test_collection_{uuid.uuid4()}"
await adapter.initialize()
await adapter.register_vector_db(
VectorDB(
identifier=collection_id,
provider_id="test_provider",
embedding_model="test_model",
embedding_dimension=embedding_dimension,
)
)
adapter.test_collection_id = collection_id
yield adapter
await adapter.shutdown()
@pytest.fixture
async def qdrant_vec_index(qdrant_vec_db_path, embedding_dimension):
import uuid
from qdrant_client import AsyncQdrantClient
from llama_stack.providers.remote.vector_io.qdrant.qdrant import QdrantIndex
client = AsyncQdrantClient(path=qdrant_vec_db_path)
collection_name = f"qdrant_test_collection_{uuid.uuid4()}"
index = QdrantIndex(client, collection_name)
yield index
await index.delete()
@pytest.fixture
def mock_psycopg2_connection():
connection = MagicMock()
@ -386,14 +244,14 @@ async def pgvector_vec_index(embedding_dimension, mock_psycopg2_connection):
@pytest.fixture
async def pgvector_vec_adapter(mock_inference_api, embedding_dimension):
async def pgvector_vec_adapter(unique_kvstore_config, mock_inference_api, embedding_dimension):
config = PGVectorVectorIOConfig(
host="localhost",
port=5432,
db="test_db",
user="test_user",
password="test_password",
kvstore=SqliteKVStoreConfig(),
kvstore=unique_kvstore_config,
)
adapter = PGVectorVectorIOAdapter(config, mock_inference_api, None)
@ -450,81 +308,12 @@ async def pgvector_vec_adapter(mock_inference_api, embedding_dimension):
await adapter.shutdown()
@pytest.fixture(scope="session")
def weaviate_vec_db_path(tmp_path_factory):
db_path = str(tmp_path_factory.getbasetemp() / "test_weaviate.db")
return db_path
@pytest.fixture
async def weaviate_vec_index(weaviate_vec_db_path):
import pytest_socket
import weaviate
pytest_socket.enable_socket()
client = weaviate.connect_to_embedded(
hostname="localhost",
port=8080,
grpc_port=50051,
persistence_data_path=weaviate_vec_db_path,
)
index = WeaviateIndex(client=client, collection_name="Testcollection")
await index.initialize()
yield index
await index.delete()
client.close()
@pytest.fixture
async def weaviate_vec_adapter(weaviate_vec_db_path, mock_inference_api, embedding_dimension):
import pytest_socket
import weaviate
pytest_socket.enable_socket()
client = weaviate.connect_to_embedded(
hostname="localhost",
port=8080,
grpc_port=50051,
persistence_data_path=weaviate_vec_db_path,
)
config = WeaviateVectorIOConfig(
weaviate_cluster_url="localhost:8080",
weaviate_api_key=None,
kvstore=SqliteKVStoreConfig(),
)
adapter = WeaviateVectorIOAdapter(
config=config,
inference_api=mock_inference_api,
files_api=None,
)
collection_id = f"weaviate_test_collection_{random.randint(1, 1_000_000)}"
await adapter.initialize()
await adapter.register_vector_db(
VectorDB(
identifier=collection_id,
provider_id="test_provider",
embedding_model="test_model",
embedding_dimension=embedding_dimension,
)
)
adapter.test_collection_id = collection_id
yield adapter
await adapter.shutdown()
client.close()
@pytest.fixture
def vector_io_adapter(vector_provider, request):
vector_provider_dict = {
"milvus": "milvus_vec_adapter",
"faiss": "faiss_vec_adapter",
"sqlite_vec": "sqlite_vec_adapter",
"chroma": "chroma_vec_adapter",
"qdrant": "qdrant_vec_adapter",
"pgvector": "pgvector_vec_adapter",
"weaviate": "weaviate_vec_adapter",
}
return request.getfixturevalue(vector_provider_dict[vector_provider])

View file

@ -1,326 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from unittest.mock import MagicMock, patch
import numpy as np
import pytest
from llama_stack.apis.vector_io import QueryChunksResponse
# Mock the entire pymilvus module
pymilvus_mock = MagicMock()
pymilvus_mock.DataType = MagicMock()
pymilvus_mock.MilvusClient = MagicMock
pymilvus_mock.RRFRanker = MagicMock
pymilvus_mock.WeightedRanker = MagicMock
pymilvus_mock.AnnSearchRequest = MagicMock
# Apply the mock before importing MilvusIndex
with patch.dict("sys.modules", {"pymilvus": pymilvus_mock}):
from llama_stack.providers.remote.vector_io.milvus.milvus import MilvusIndex
# This test is a unit test for the MilvusVectorIOAdapter class. This should only contain
# tests which are specific to this class. More general (API-level) tests should be placed in
# tests/integration/vector_io/
#
# How to run this test:
#
# pytest tests/unit/providers/vector_io/test_milvus.py \
# -v -s --tb=short --disable-warnings --asyncio-mode=auto
MILVUS_PROVIDER = "milvus"
@pytest.fixture
async def mock_milvus_client() -> MagicMock:
"""Create a mock Milvus client with common method behaviors."""
client = MagicMock()
# Mock collection operations
client.has_collection.return_value = False # Initially no collection
client.create_collection.return_value = None
client.drop_collection.return_value = None
# Mock insert operation
client.insert.return_value = {"insert_count": 10}
# Mock search operation - return mock results (data should be dict, not JSON string)
client.search.return_value = [
[
{
"id": 0,
"distance": 0.1,
"entity": {"chunk_content": {"content": "mock chunk 1", "metadata": {"document_id": "doc1"}}},
},
{
"id": 1,
"distance": 0.2,
"entity": {"chunk_content": {"content": "mock chunk 2", "metadata": {"document_id": "doc2"}}},
},
]
]
# Mock query operation for keyword search (data should be dict, not JSON string)
client.query.return_value = [
{
"chunk_id": "chunk1",
"chunk_content": {"content": "mock chunk 1", "metadata": {"document_id": "doc1"}},
"score": 0.9,
},
{
"chunk_id": "chunk2",
"chunk_content": {"content": "mock chunk 2", "metadata": {"document_id": "doc2"}},
"score": 0.8,
},
{
"chunk_id": "chunk3",
"chunk_content": {"content": "mock chunk 3", "metadata": {"document_id": "doc3"}},
"score": 0.7,
},
]
return client
@pytest.fixture
async def milvus_index(mock_milvus_client):
"""Create a MilvusIndex with mocked client."""
index = MilvusIndex(client=mock_milvus_client, collection_name="test_collection")
yield index
# No real cleanup needed since we're using mocks
async def test_add_chunks(milvus_index, sample_chunks, sample_embeddings, mock_milvus_client):
# Setup: collection doesn't exist initially, then exists after creation
mock_milvus_client.has_collection.side_effect = [False, True]
await milvus_index.add_chunks(sample_chunks, sample_embeddings)
# Verify collection was created and data was inserted
mock_milvus_client.create_collection.assert_called_once()
mock_milvus_client.insert.assert_called_once()
# Verify the insert call had the right number of chunks
insert_call = mock_milvus_client.insert.call_args
assert len(insert_call[1]["data"]) == len(sample_chunks)
async def test_query_chunks_vector(
milvus_index, sample_chunks, sample_embeddings, embedding_dimension, mock_milvus_client
):
# Setup: Add chunks first
mock_milvus_client.has_collection.return_value = True
await milvus_index.add_chunks(sample_chunks, sample_embeddings)
# Test vector search
query_embedding = np.random.rand(embedding_dimension).astype(np.float32)
response = await milvus_index.query_vector(query_embedding, k=2, score_threshold=0.0)
assert isinstance(response, QueryChunksResponse)
assert len(response.chunks) == 2
mock_milvus_client.search.assert_called_once()
async def test_query_chunks_keyword_search(milvus_index, sample_chunks, sample_embeddings, mock_milvus_client):
mock_milvus_client.has_collection.return_value = True
await milvus_index.add_chunks(sample_chunks, sample_embeddings)
# Test keyword search
query_string = "Sentence 5"
response = await milvus_index.query_keyword(query_string=query_string, k=2, score_threshold=0.0)
assert isinstance(response, QueryChunksResponse)
assert len(response.chunks) == 2
async def test_bm25_fallback_to_simple_search(milvus_index, sample_chunks, sample_embeddings, mock_milvus_client):
"""Test that when BM25 search fails, the system falls back to simple text search."""
mock_milvus_client.has_collection.return_value = True
await milvus_index.add_chunks(sample_chunks, sample_embeddings)
# Force BM25 search to fail
mock_milvus_client.search.side_effect = Exception("BM25 search not available")
# Mock simple text search results
mock_milvus_client.query.return_value = [
{
"chunk_id": "chunk1",
"chunk_content": {"content": "Python programming language", "metadata": {"document_id": "doc1"}},
},
{
"chunk_id": "chunk2",
"chunk_content": {"content": "Machine learning algorithms", "metadata": {"document_id": "doc2"}},
},
]
# Test keyword search that should fall back to simple text search
query_string = "Python"
response = await milvus_index.query_keyword(query_string=query_string, k=3, score_threshold=0.0)
# Verify response structure
assert isinstance(response, QueryChunksResponse)
assert len(response.chunks) > 0, "Fallback search should return results"
# Verify that simple text search was used (query method called instead of search)
mock_milvus_client.query.assert_called_once()
mock_milvus_client.search.assert_called_once() # Called once but failed
# Verify the query uses parameterized filter with filter_params
query_call_args = mock_milvus_client.query.call_args
assert "filter" in query_call_args[1], "Query should include filter for text search"
assert "filter_params" in query_call_args[1], "Query should use parameterized filter"
assert query_call_args[1]["filter_params"]["content"] == "Python", "Filter params should contain the search term"
# Verify all returned chunks have score 1.0 (simple binary scoring)
assert all(score == 1.0 for score in response.scores), "Simple text search should use binary scoring"
async def test_delete_collection(milvus_index, mock_milvus_client):
# Test collection deletion
mock_milvus_client.has_collection.return_value = True
await milvus_index.delete()
mock_milvus_client.drop_collection.assert_called_once_with(collection_name=milvus_index.collection_name)
async def test_query_hybrid_search_rrf(
milvus_index, sample_chunks, sample_embeddings, embedding_dimension, mock_milvus_client
):
"""Test hybrid search with RRF reranker."""
mock_milvus_client.has_collection.return_value = True
await milvus_index.add_chunks(sample_chunks, sample_embeddings)
# Mock hybrid search results
mock_milvus_client.hybrid_search.return_value = [
[
{
"id": 0,
"distance": 0.1,
"entity": {"chunk_content": {"content": "mock chunk 1", "metadata": {"document_id": "doc1"}}},
},
{
"id": 1,
"distance": 0.2,
"entity": {"chunk_content": {"content": "mock chunk 2", "metadata": {"document_id": "doc2"}}},
},
]
]
# Test hybrid search with RRF reranker
query_embedding = np.random.rand(embedding_dimension).astype(np.float32)
query_string = "test query"
response = await milvus_index.query_hybrid(
embedding=query_embedding,
query_string=query_string,
k=2,
score_threshold=0.0,
reranker_type="rrf",
reranker_params={"impact_factor": 60.0},
)
assert isinstance(response, QueryChunksResponse)
assert len(response.chunks) == 2
assert len(response.scores) == 2
# Verify hybrid search was called with correct parameters
mock_milvus_client.hybrid_search.assert_called_once()
call_args = mock_milvus_client.hybrid_search.call_args
# Check that the request contains both vector and BM25 search requests
reqs = call_args[1]["reqs"]
assert len(reqs) == 2
assert reqs[0].anns_field == "vector"
assert reqs[1].anns_field == "sparse"
ranker = call_args[1]["ranker"]
assert ranker is not None
async def test_query_hybrid_search_weighted(
milvus_index, sample_chunks, sample_embeddings, embedding_dimension, mock_milvus_client
):
"""Test hybrid search with weighted reranker."""
mock_milvus_client.has_collection.return_value = True
await milvus_index.add_chunks(sample_chunks, sample_embeddings)
# Mock hybrid search results
mock_milvus_client.hybrid_search.return_value = [
[
{
"id": 0,
"distance": 0.1,
"entity": {"chunk_content": {"content": "mock chunk 1", "metadata": {"document_id": "doc1"}}},
},
{
"id": 1,
"distance": 0.2,
"entity": {"chunk_content": {"content": "mock chunk 2", "metadata": {"document_id": "doc2"}}},
},
]
]
# Test hybrid search with weighted reranker
query_embedding = np.random.rand(embedding_dimension).astype(np.float32)
query_string = "test query"
response = await milvus_index.query_hybrid(
embedding=query_embedding,
query_string=query_string,
k=2,
score_threshold=0.0,
reranker_type="weighted",
reranker_params={"alpha": 0.7},
)
assert isinstance(response, QueryChunksResponse)
assert len(response.chunks) == 2
assert len(response.scores) == 2
# Verify hybrid search was called with correct parameters
mock_milvus_client.hybrid_search.assert_called_once()
call_args = mock_milvus_client.hybrid_search.call_args
ranker = call_args[1]["ranker"]
assert ranker is not None
async def test_query_hybrid_search_default_rrf(
milvus_index, sample_chunks, sample_embeddings, embedding_dimension, mock_milvus_client
):
"""Test hybrid search with default RRF reranker (no reranker_type specified)."""
mock_milvus_client.has_collection.return_value = True
await milvus_index.add_chunks(sample_chunks, sample_embeddings)
# Mock hybrid search results
mock_milvus_client.hybrid_search.return_value = [
[
{
"id": 0,
"distance": 0.1,
"entity": {"chunk_content": {"content": "mock chunk 1", "metadata": {"document_id": "doc1"}}},
},
]
]
# Test hybrid search with default reranker (should be RRF)
query_embedding = np.random.rand(embedding_dimension).astype(np.float32)
query_string = "test query"
response = await milvus_index.query_hybrid(
embedding=query_embedding,
query_string=query_string,
k=1,
score_threshold=0.0,
reranker_type="unknown_type", # Should default to RRF
reranker_params=None, # Should use default impact_factor
)
assert isinstance(response, QueryChunksResponse)
assert len(response.chunks) == 1
# Verify hybrid search was called with RRF reranker
mock_milvus_client.hybrid_search.assert_called_once()
call_args = mock_milvus_client.hybrid_search.call_args
ranker = call_args[1]["ranker"]
assert ranker is not None

View file

@ -1,138 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import asyncio
from unittest.mock import patch
import pytest
from llama_stack.apis.vector_dbs import VectorDB
from llama_stack.providers.remote.vector_io.pgvector.pgvector import PGVectorIndex
PGVECTOR_PROVIDER = "pgvector"
@pytest.fixture(scope="session")
def loop():
return asyncio.new_event_loop()
@pytest.fixture
def embedding_dimension():
"""Default embedding dimension for tests."""
return 384
@pytest.fixture
async def pgvector_index(embedding_dimension, mock_psycopg2_connection):
"""Create a PGVectorIndex instance with mocked database connection."""
connection, cursor = mock_psycopg2_connection
vector_db = VectorDB(
identifier="test-vector-db",
embedding_model="test-model",
embedding_dimension=embedding_dimension,
provider_id=PGVECTOR_PROVIDER,
provider_resource_id=f"{PGVECTOR_PROVIDER}:test-vector-db",
)
with patch("llama_stack.providers.remote.vector_io.pgvector.pgvector.psycopg2"):
# Use explicit COSINE distance metric for consistent testing
index = PGVectorIndex(vector_db, embedding_dimension, connection, distance_metric="COSINE")
return index, cursor
class TestPGVectorIndex:
def test_distance_metric_validation(self, embedding_dimension, mock_psycopg2_connection):
connection, cursor = mock_psycopg2_connection
vector_db = VectorDB(
identifier="test-vector-db",
embedding_model="test-model",
embedding_dimension=embedding_dimension,
provider_id=PGVECTOR_PROVIDER,
provider_resource_id=f"{PGVECTOR_PROVIDER}:test-vector-db",
)
with patch("llama_stack.providers.remote.vector_io.pgvector.pgvector.psycopg2"):
index = PGVectorIndex(vector_db, embedding_dimension, connection, distance_metric="L2")
assert index.distance_metric == "L2"
with pytest.raises(ValueError, match="Distance metric 'INVALID' is not supported"):
PGVectorIndex(vector_db, embedding_dimension, connection, distance_metric="INVALID")
def test_get_pgvector_search_function(self, pgvector_index):
index, cursor = pgvector_index
supported_metrics = index.PGVECTOR_DISTANCE_METRIC_TO_SEARCH_FUNCTION
for metric, function in supported_metrics.items():
index.distance_metric = metric
assert index.get_pgvector_search_function() == function
def test_check_distance_metric_availability(self, pgvector_index):
index, cursor = pgvector_index
supported_metrics = index.PGVECTOR_DISTANCE_METRIC_TO_SEARCH_FUNCTION
for metric in supported_metrics:
index.check_distance_metric_availability(metric)
with pytest.raises(ValueError, match="Distance metric 'INVALID' is not supported"):
index.check_distance_metric_availability("INVALID")
def test_constructor_invalid_distance_metric(self, embedding_dimension, mock_psycopg2_connection):
connection, cursor = mock_psycopg2_connection
vector_db = VectorDB(
identifier="test-vector-db",
embedding_model="test-model",
embedding_dimension=embedding_dimension,
provider_id=PGVECTOR_PROVIDER,
provider_resource_id=f"{PGVECTOR_PROVIDER}:test-vector-db",
)
with patch("llama_stack.providers.remote.vector_io.pgvector.pgvector.psycopg2"):
with pytest.raises(ValueError, match="Distance metric 'INVALID_METRIC' is not supported by PGVector"):
PGVectorIndex(vector_db, embedding_dimension, connection, distance_metric="INVALID_METRIC")
with pytest.raises(ValueError, match="Supported metrics are:"):
PGVectorIndex(vector_db, embedding_dimension, connection, distance_metric="UNKNOWN")
try:
index = PGVectorIndex(vector_db, embedding_dimension, connection, distance_metric="COSINE")
assert index.distance_metric == "COSINE"
except ValueError:
pytest.fail("Valid distance metric 'COSINE' should not raise ValueError")
def test_constructor_all_supported_distance_metrics(self, embedding_dimension, mock_psycopg2_connection):
connection, cursor = mock_psycopg2_connection
vector_db = VectorDB(
identifier="test-vector-db",
embedding_model="test-model",
embedding_dimension=embedding_dimension,
provider_id=PGVECTOR_PROVIDER,
provider_resource_id=f"{PGVECTOR_PROVIDER}:test-vector-db",
)
supported_metrics = ["L2", "L1", "COSINE", "INNER_PRODUCT", "HAMMING", "JACCARD"]
with patch("llama_stack.providers.remote.vector_io.pgvector.pgvector.psycopg2"):
for metric in supported_metrics:
try:
index = PGVectorIndex(vector_db, embedding_dimension, connection, distance_metric=metric)
assert index.distance_metric == metric
expected_operators = {
"L2": "<->",
"L1": "<+>",
"COSINE": "<=>",
"INNER_PRODUCT": "<#>",
"HAMMING": "<~>",
"JACCARD": "<%>",
}
assert index.get_pgvector_search_function() == expected_operators[metric]
except Exception as e:
pytest.fail(f"Valid distance metric '{metric}' should not raise exception: {e}")

View file

@ -1,147 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import asyncio
import os
from typing import Any
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from llama_stack.apis.inference import Inference
from llama_stack.apis.inference.inference import OpenAIEmbeddingData, OpenAIEmbeddingsResponse, OpenAIEmbeddingUsage
from llama_stack.apis.vector_io import (
QueryChunksResponse,
VectorDB,
VectorDBStore,
)
from llama_stack.providers.inline.vector_io.qdrant.config import (
QdrantVectorIOConfig as InlineQdrantVectorIOConfig,
)
from llama_stack.providers.remote.vector_io.qdrant.qdrant import (
QdrantVectorIOAdapter,
)
from llama_stack.providers.utils.kvstore.config import SqliteKVStoreConfig
# This test is a unit test for the QdrantVectorIOAdapter class. This should only contain
# tests which are specific to this class. More general (API-level) tests should be placed in
# tests/integration/vector_io/
#
# How to run this test:
#
# pytest tests/unit/providers/vector_io/test_qdrant.py \
# -v -s --tb=short --disable-warnings --asyncio-mode=auto
@pytest.fixture
def qdrant_config(tmp_path) -> InlineQdrantVectorIOConfig:
kvstore_config = SqliteKVStoreConfig(db_name=os.path.join(tmp_path, "test_kvstore.db"))
return InlineQdrantVectorIOConfig(path=os.path.join(tmp_path, "qdrant.db"), kvstore=kvstore_config)
@pytest.fixture(scope="session")
def loop():
return asyncio.new_event_loop()
@pytest.fixture
def mock_vector_db(vector_db_id) -> MagicMock:
mock_vector_db = MagicMock(spec=VectorDB)
mock_vector_db.embedding_model = "embedding_model"
mock_vector_db.identifier = vector_db_id
mock_vector_db.embedding_dimension = 384
mock_vector_db.model_dump_json.return_value = (
'{"identifier": "'
+ vector_db_id
+ '", "provider_id": "qdrant", "embedding_model": "embedding_model", "embedding_dimension": 384}'
)
return mock_vector_db
@pytest.fixture
def mock_vector_db_store(mock_vector_db) -> MagicMock:
mock_store = MagicMock(spec=VectorDBStore)
mock_store.get_vector_db = AsyncMock(return_value=mock_vector_db)
return mock_store
@pytest.fixture
def mock_api_service(sample_embeddings):
mock_api_service = MagicMock(spec=Inference)
mock_api_service.openai_embeddings = AsyncMock(
return_value=OpenAIEmbeddingsResponse(
model="mock-embedding-model",
data=[OpenAIEmbeddingData(embedding=sample, index=i) for i, sample in enumerate(sample_embeddings)],
usage=OpenAIEmbeddingUsage(prompt_tokens=10, total_tokens=10),
)
)
return mock_api_service
@pytest.fixture
async def qdrant_adapter(qdrant_config, mock_vector_db_store, mock_api_service, loop) -> QdrantVectorIOAdapter:
adapter = QdrantVectorIOAdapter(config=qdrant_config, inference_api=mock_api_service, files_api=None)
adapter.vector_db_store = mock_vector_db_store
await adapter.initialize()
yield adapter
await adapter.shutdown()
__QUERY = "Sample query"
@pytest.mark.parametrize("max_query_chunks, expected_chunks", [(2, 2), (100, 60)])
async def test_qdrant_adapter_returns_expected_chunks(
qdrant_adapter: QdrantVectorIOAdapter,
vector_db_id,
sample_chunks,
sample_embeddings,
max_query_chunks,
expected_chunks,
) -> None:
assert qdrant_adapter is not None
await qdrant_adapter.insert_chunks(vector_db_id, sample_chunks)
index = await qdrant_adapter._get_and_cache_vector_db_index(vector_db_id=vector_db_id)
assert index is not None
response = await qdrant_adapter.query_chunks(
query=__QUERY,
vector_db_id=vector_db_id,
params={"max_chunks": max_query_chunks, "mode": "vector"},
)
assert isinstance(response, QueryChunksResponse)
assert len(response.chunks) == expected_chunks
# To by-pass attempt to convert a Mock to JSON
def _prepare_for_json(value: Any) -> str:
return str(value)
@patch("llama_stack.providers.utils.telemetry.trace_protocol._prepare_for_json", new=_prepare_for_json)
async def test_qdrant_register_and_unregister_vector_db(
qdrant_adapter: QdrantVectorIOAdapter,
mock_vector_db,
sample_chunks,
) -> None:
# Initially, no collections
vector_db_id = mock_vector_db.identifier
assert len((await qdrant_adapter.client.get_collections()).collections) == 0
# Register does not create a collection
assert not (await qdrant_adapter.client.collection_exists(vector_db_id))
await qdrant_adapter.register_vector_db(mock_vector_db)
assert not (await qdrant_adapter.client.collection_exists(vector_db_id))
# First insert creates the collection
await qdrant_adapter.insert_chunks(vector_db_id, sample_chunks)
assert await qdrant_adapter.client.collection_exists(vector_db_id)
# Unregister deletes the collection
await qdrant_adapter.unregister_vector_db(vector_db_id)
assert not (await qdrant_adapter.client.collection_exists(vector_db_id))
assert len((await qdrant_adapter.client.get_collections()).collections) == 0

View file

@ -6,16 +6,23 @@
import json
import time
from unittest.mock import AsyncMock
from unittest.mock import AsyncMock, patch
import numpy as np
import pytest
from llama_stack.apis.common.errors import VectorStoreNotFoundError
from llama_stack.apis.vector_dbs import VectorDB
from llama_stack.apis.vector_io import Chunk, QueryChunksResponse
from llama_stack.providers.remote.vector_io.milvus.milvus import VECTOR_DBS_PREFIX
from llama_stack.apis.vector_io import (
Chunk,
OpenAICreateVectorStoreFileBatchRequestWithExtraBody,
QueryChunksResponse,
VectorStoreChunkingStrategyAuto,
VectorStoreFileObject,
)
from llama_stack.providers.inline.vector_io.sqlite_vec.sqlite_vec import VECTOR_DBS_PREFIX
# This test is a unit test for the inline VectoerIO providers. This should only contain
# This test is a unit test for the inline VectorIO providers. This should only contain
# tests which are specific to this class. More general (API-level) tests should be placed in
# tests/integration/vector_io/
#
@ -25,6 +32,16 @@ from llama_stack.providers.remote.vector_io.milvus.milvus import VECTOR_DBS_PREF
# -v -s --tb=short --disable-warnings --asyncio-mode=auto
@pytest.fixture(autouse=True)
def mock_resume_file_batches(request):
"""Mock the resume functionality to prevent stale file batches from being processed during tests."""
with patch(
"llama_stack.providers.utils.memory.openai_vector_store_mixin.OpenAIVectorStoreMixin._resume_incomplete_batches",
new_callable=AsyncMock,
):
yield
async def test_initialize_index(vector_index):
await vector_index.initialize()
@ -88,12 +105,8 @@ async def test_register_and_unregister_vector_db(vector_io_adapter):
async def test_query_unregistered_raises(vector_io_adapter, vector_provider):
fake_emb = np.zeros(8, dtype=np.float32)
if vector_provider == "chroma":
with pytest.raises(AttributeError):
await vector_io_adapter.query_chunks("no_such_db", fake_emb)
else:
with pytest.raises(ValueError):
await vector_io_adapter.query_chunks("no_such_db", fake_emb)
with pytest.raises(ValueError):
await vector_io_adapter.query_chunks("no_such_db", fake_emb)
async def test_insert_chunks_calls_underlying_index(vector_io_adapter):
@ -294,3 +307,657 @@ async def test_delete_openai_vector_store_file_from_storage(vector_io_adapter, t
assert loaded_file_info == {}
loaded_contents = await vector_io_adapter._load_openai_vector_store_file_contents(store_id, file_id)
assert loaded_contents == []
async def test_create_vector_store_file_batch(vector_io_adapter):
"""Test creating a file batch."""
store_id = "vs_1234"
file_ids = ["file_1", "file_2", "file_3"]
# Setup vector store
vector_io_adapter.openai_vector_stores[store_id] = {
"id": store_id,
"name": "Test Store",
"files": {},
"file_ids": [],
}
# Mock attach method and batch processing to avoid actual processing
vector_io_adapter.openai_attach_file_to_vector_store = AsyncMock()
vector_io_adapter._process_file_batch_async = AsyncMock()
batch = await vector_io_adapter.openai_create_vector_store_file_batch(
vector_store_id=store_id, params=OpenAICreateVectorStoreFileBatchRequestWithExtraBody(file_ids=file_ids)
)
assert batch.vector_store_id == store_id
assert batch.status == "in_progress"
assert batch.file_counts.total == len(file_ids)
assert batch.file_counts.in_progress == len(file_ids)
assert batch.id in vector_io_adapter.openai_file_batches
async def test_retrieve_vector_store_file_batch(vector_io_adapter):
"""Test retrieving a file batch."""
store_id = "vs_1234"
file_ids = ["file_1", "file_2"]
# Setup vector store
vector_io_adapter.openai_vector_stores[store_id] = {
"id": store_id,
"name": "Test Store",
"files": {},
"file_ids": [],
}
vector_io_adapter.openai_attach_file_to_vector_store = AsyncMock()
# Create batch first
created_batch = await vector_io_adapter.openai_create_vector_store_file_batch(
vector_store_id=store_id, params=OpenAICreateVectorStoreFileBatchRequestWithExtraBody(file_ids=file_ids)
)
# Retrieve batch
retrieved_batch = await vector_io_adapter.openai_retrieve_vector_store_file_batch(
batch_id=created_batch.id,
vector_store_id=store_id,
)
assert retrieved_batch.id == created_batch.id
assert retrieved_batch.vector_store_id == store_id
assert retrieved_batch.status == "in_progress"
async def test_cancel_vector_store_file_batch(vector_io_adapter):
"""Test cancelling a file batch."""
store_id = "vs_1234"
file_ids = ["file_1"]
# Setup vector store
vector_io_adapter.openai_vector_stores[store_id] = {
"id": store_id,
"name": "Test Store",
"files": {},
"file_ids": [],
}
# Mock both file attachment and batch processing to prevent automatic completion
vector_io_adapter.openai_attach_file_to_vector_store = AsyncMock()
vector_io_adapter._process_file_batch_async = AsyncMock()
# Create batch
batch = await vector_io_adapter.openai_create_vector_store_file_batch(
vector_store_id=store_id, params=OpenAICreateVectorStoreFileBatchRequestWithExtraBody(file_ids=file_ids)
)
# Cancel batch
cancelled_batch = await vector_io_adapter.openai_cancel_vector_store_file_batch(
batch_id=batch.id,
vector_store_id=store_id,
)
assert cancelled_batch.status == "cancelled"
async def test_list_files_in_vector_store_file_batch(vector_io_adapter):
"""Test listing files in a batch."""
store_id = "vs_1234"
file_ids = ["file_1", "file_2"]
# Setup vector store with files
files = {}
for i, file_id in enumerate(file_ids):
files[file_id] = VectorStoreFileObject(
id=file_id,
object="vector_store.file",
usage_bytes=1000,
created_at=int(time.time()) + i,
vector_store_id=store_id,
status="completed",
chunking_strategy=VectorStoreChunkingStrategyAuto(),
)
vector_io_adapter.openai_vector_stores[store_id] = {
"id": store_id,
"name": "Test Store",
"files": files,
"file_ids": file_ids,
}
# Mock file loading
vector_io_adapter._load_openai_vector_store_file = AsyncMock(
side_effect=lambda vs_id, f_id: files[f_id].model_dump()
)
vector_io_adapter.openai_attach_file_to_vector_store = AsyncMock()
# Create batch
batch = await vector_io_adapter.openai_create_vector_store_file_batch(
vector_store_id=store_id, params=OpenAICreateVectorStoreFileBatchRequestWithExtraBody(file_ids=file_ids)
)
# List files
response = await vector_io_adapter.openai_list_files_in_vector_store_file_batch(
batch_id=batch.id,
vector_store_id=store_id,
)
assert len(response.data) == len(file_ids)
assert response.first_id is not None
assert response.last_id is not None
async def test_file_batch_validation_errors(vector_io_adapter):
"""Test file batch validation errors."""
# Test nonexistent vector store
with pytest.raises(VectorStoreNotFoundError):
await vector_io_adapter.openai_create_vector_store_file_batch(
vector_store_id="nonexistent",
params=OpenAICreateVectorStoreFileBatchRequestWithExtraBody(file_ids=["file_1"]),
)
# Setup store for remaining tests
store_id = "vs_test"
vector_io_adapter.openai_vector_stores[store_id] = {"id": store_id, "files": {}, "file_ids": []}
# Test nonexistent batch
with pytest.raises(ValueError, match="File batch .* not found"):
await vector_io_adapter.openai_retrieve_vector_store_file_batch(
batch_id="nonexistent_batch",
vector_store_id=store_id,
)
# Test wrong vector store for batch
vector_io_adapter.openai_attach_file_to_vector_store = AsyncMock()
batch = await vector_io_adapter.openai_create_vector_store_file_batch(
vector_store_id=store_id, params=OpenAICreateVectorStoreFileBatchRequestWithExtraBody(file_ids=["file_1"])
)
# Create wrong_store so it exists but the batch doesn't belong to it
wrong_store_id = "wrong_store"
vector_io_adapter.openai_vector_stores[wrong_store_id] = {"id": wrong_store_id, "files": {}, "file_ids": []}
with pytest.raises(ValueError, match="does not belong to vector store"):
await vector_io_adapter.openai_retrieve_vector_store_file_batch(
batch_id=batch.id,
vector_store_id=wrong_store_id,
)
async def test_file_batch_pagination(vector_io_adapter):
"""Test file batch pagination."""
store_id = "vs_1234"
file_ids = ["file_1", "file_2", "file_3", "file_4", "file_5"]
# Setup vector store with multiple files
files = {}
for i, file_id in enumerate(file_ids):
files[file_id] = VectorStoreFileObject(
id=file_id,
object="vector_store.file",
usage_bytes=1000,
created_at=int(time.time()) + i,
vector_store_id=store_id,
status="completed",
chunking_strategy=VectorStoreChunkingStrategyAuto(),
)
vector_io_adapter.openai_vector_stores[store_id] = {
"id": store_id,
"name": "Test Store",
"files": files,
"file_ids": file_ids,
}
# Mock file loading
vector_io_adapter._load_openai_vector_store_file = AsyncMock(
side_effect=lambda vs_id, f_id: files[f_id].model_dump()
)
vector_io_adapter.openai_attach_file_to_vector_store = AsyncMock()
# Create batch
batch = await vector_io_adapter.openai_create_vector_store_file_batch(
vector_store_id=store_id, params=OpenAICreateVectorStoreFileBatchRequestWithExtraBody(file_ids=file_ids)
)
# Test pagination with limit
response = await vector_io_adapter.openai_list_files_in_vector_store_file_batch(
batch_id=batch.id,
vector_store_id=store_id,
limit=3,
)
assert len(response.data) == 3
assert response.has_more is True
# Test pagination with after cursor
first_page = await vector_io_adapter.openai_list_files_in_vector_store_file_batch(
batch_id=batch.id,
vector_store_id=store_id,
limit=2,
)
second_page = await vector_io_adapter.openai_list_files_in_vector_store_file_batch(
batch_id=batch.id,
vector_store_id=store_id,
limit=2,
after=first_page.last_id,
)
assert len(first_page.data) == 2
assert len(second_page.data) == 2
# Ensure no overlap between pages
first_page_ids = {file_obj.id for file_obj in first_page.data}
second_page_ids = {file_obj.id for file_obj in second_page.data}
assert first_page_ids.isdisjoint(second_page_ids)
# Verify we got all expected files across both pages (in desc order: file_5, file_4, file_3, file_2, file_1)
all_returned_ids = first_page_ids | second_page_ids
assert all_returned_ids == {"file_2", "file_3", "file_4", "file_5"}
async def test_file_batch_status_filtering(vector_io_adapter):
"""Test file batch status filtering."""
store_id = "vs_1234"
file_ids = ["file_1", "file_2", "file_3"]
# Setup vector store with files having different statuses
files = {}
statuses = ["completed", "in_progress", "completed"]
for i, (file_id, status) in enumerate(zip(file_ids, statuses, strict=False)):
files[file_id] = VectorStoreFileObject(
id=file_id,
object="vector_store.file",
usage_bytes=1000,
created_at=int(time.time()) + i,
vector_store_id=store_id,
status=status,
chunking_strategy=VectorStoreChunkingStrategyAuto(),
)
vector_io_adapter.openai_vector_stores[store_id] = {
"id": store_id,
"name": "Test Store",
"files": files,
"file_ids": file_ids,
}
# Mock file loading
vector_io_adapter._load_openai_vector_store_file = AsyncMock(
side_effect=lambda vs_id, f_id: files[f_id].model_dump()
)
vector_io_adapter.openai_attach_file_to_vector_store = AsyncMock()
# Create batch
batch = await vector_io_adapter.openai_create_vector_store_file_batch(
vector_store_id=store_id, params=OpenAICreateVectorStoreFileBatchRequestWithExtraBody(file_ids=file_ids)
)
# Test filtering by completed status
response = await vector_io_adapter.openai_list_files_in_vector_store_file_batch(
batch_id=batch.id,
vector_store_id=store_id,
filter="completed",
)
assert len(response.data) == 2 # Only 2 completed files
for file_obj in response.data:
assert file_obj.status == "completed"
# Test filtering by in_progress status
response = await vector_io_adapter.openai_list_files_in_vector_store_file_batch(
batch_id=batch.id,
vector_store_id=store_id,
filter="in_progress",
)
assert len(response.data) == 1 # Only 1 in_progress file
assert response.data[0].status == "in_progress"
async def test_cancel_completed_batch_fails(vector_io_adapter):
"""Test that cancelling completed batch fails."""
store_id = "vs_1234"
file_ids = ["file_1"]
# Setup vector store
vector_io_adapter.openai_vector_stores[store_id] = {
"id": store_id,
"name": "Test Store",
"files": {},
"file_ids": [],
}
vector_io_adapter.openai_attach_file_to_vector_store = AsyncMock()
# Create batch
batch = await vector_io_adapter.openai_create_vector_store_file_batch(
vector_store_id=store_id, params=OpenAICreateVectorStoreFileBatchRequestWithExtraBody(file_ids=file_ids)
)
# Manually update status to completed
batch_info = vector_io_adapter.openai_file_batches[batch.id]
batch_info["status"] = "completed"
# Try to cancel - should fail
with pytest.raises(ValueError, match="Cannot cancel batch .* with status completed"):
await vector_io_adapter.openai_cancel_vector_store_file_batch(
batch_id=batch.id,
vector_store_id=store_id,
)
async def test_file_batch_persistence_across_restarts(vector_io_adapter):
"""Test that in-progress file batches are persisted and resumed after restart."""
store_id = "vs_1234"
file_ids = ["file_1", "file_2"]
# Setup vector store
vector_io_adapter.openai_vector_stores[store_id] = {
"id": store_id,
"name": "Test Store",
"files": {},
"file_ids": [],
}
# Mock attach method and batch processing to avoid actual processing
vector_io_adapter.openai_attach_file_to_vector_store = AsyncMock()
vector_io_adapter._process_file_batch_async = AsyncMock()
# Create batch
batch = await vector_io_adapter.openai_create_vector_store_file_batch(
vector_store_id=store_id, params=OpenAICreateVectorStoreFileBatchRequestWithExtraBody(file_ids=file_ids)
)
batch_id = batch.id
# Verify batch is saved to persistent storage
assert batch_id in vector_io_adapter.openai_file_batches
saved_batch_key = f"openai_vector_stores_file_batches:v3::{batch_id}"
saved_batch = await vector_io_adapter.kvstore.get(saved_batch_key)
assert saved_batch is not None
# Verify the saved batch data contains all necessary information
saved_data = json.loads(saved_batch)
assert saved_data["id"] == batch_id
assert saved_data["status"] == "in_progress"
assert saved_data["file_ids"] == file_ids
# Simulate restart - clear in-memory cache and reload from persistence
vector_io_adapter.openai_file_batches.clear()
# Temporarily restore the real initialize_openai_vector_stores method
from llama_stack.providers.utils.memory.openai_vector_store_mixin import OpenAIVectorStoreMixin
real_method = OpenAIVectorStoreMixin.initialize_openai_vector_stores
await real_method(vector_io_adapter)
# Re-mock the processing method to prevent any resumed batches from processing
vector_io_adapter._process_file_batch_async = AsyncMock()
# Verify batch was restored
assert batch_id in vector_io_adapter.openai_file_batches
restored_batch = vector_io_adapter.openai_file_batches[batch_id]
assert restored_batch["status"] == "in_progress"
assert restored_batch["id"] == batch_id
assert vector_io_adapter.openai_file_batches[batch_id]["file_ids"] == file_ids
async def test_cancelled_batch_persists_in_storage(vector_io_adapter):
"""Test that cancelled batches persist in storage with updated status."""
store_id = "vs_1234"
file_ids = ["file_1", "file_2"]
# Setup vector store
vector_io_adapter.openai_vector_stores[store_id] = {
"id": store_id,
"name": "Test Store",
"files": {},
"file_ids": [],
}
# Mock attach method and batch processing to avoid actual processing
vector_io_adapter.openai_attach_file_to_vector_store = AsyncMock()
vector_io_adapter._process_file_batch_async = AsyncMock()
# Create batch
batch = await vector_io_adapter.openai_create_vector_store_file_batch(
vector_store_id=store_id, params=OpenAICreateVectorStoreFileBatchRequestWithExtraBody(file_ids=file_ids)
)
batch_id = batch.id
# Verify batch is initially saved to persistent storage
saved_batch_key = f"openai_vector_stores_file_batches:v3::{batch_id}"
saved_batch = await vector_io_adapter.kvstore.get(saved_batch_key)
assert saved_batch is not None
# Cancel the batch
cancelled_batch = await vector_io_adapter.openai_cancel_vector_store_file_batch(
batch_id=batch_id,
vector_store_id=store_id,
)
# Verify batch status is cancelled
assert cancelled_batch.status == "cancelled"
# Verify batch persists in storage with cancelled status
updated_batch = await vector_io_adapter.kvstore.get(saved_batch_key)
assert updated_batch is not None
batch_data = json.loads(updated_batch)
assert batch_data["status"] == "cancelled"
# Batch should remain in memory cache (matches vector store pattern)
assert batch_id in vector_io_adapter.openai_file_batches
assert vector_io_adapter.openai_file_batches[batch_id]["status"] == "cancelled"
async def test_only_in_progress_batches_resumed(vector_io_adapter):
"""Test that only in-progress batches are resumed for processing, but all batches are persisted."""
store_id = "vs_1234"
# Setup vector store
vector_io_adapter.openai_vector_stores[store_id] = {
"id": store_id,
"name": "Test Store",
"files": {},
"file_ids": [],
}
# Mock attach method and batch processing to prevent automatic completion
vector_io_adapter.openai_attach_file_to_vector_store = AsyncMock()
vector_io_adapter._process_file_batch_async = AsyncMock()
# Create multiple batches
batch1 = await vector_io_adapter.openai_create_vector_store_file_batch(
vector_store_id=store_id, params=OpenAICreateVectorStoreFileBatchRequestWithExtraBody(file_ids=["file_1"])
)
batch2 = await vector_io_adapter.openai_create_vector_store_file_batch(
vector_store_id=store_id, params=OpenAICreateVectorStoreFileBatchRequestWithExtraBody(file_ids=["file_2"])
)
# Complete one batch (should persist with completed status)
batch1_info = vector_io_adapter.openai_file_batches[batch1.id]
batch1_info["status"] = "completed"
await vector_io_adapter._save_openai_vector_store_file_batch(batch1.id, batch1_info)
# Cancel the other batch (should persist with cancelled status)
await vector_io_adapter.openai_cancel_vector_store_file_batch(batch_id=batch2.id, vector_store_id=store_id)
# Create a third batch that stays in progress
batch3 = await vector_io_adapter.openai_create_vector_store_file_batch(
vector_store_id=store_id, params=OpenAICreateVectorStoreFileBatchRequestWithExtraBody(file_ids=["file_3"])
)
# Simulate restart - clear memory and reload from persistence
vector_io_adapter.openai_file_batches.clear()
# Temporarily restore the real initialize_openai_vector_stores method
from llama_stack.providers.utils.memory.openai_vector_store_mixin import OpenAIVectorStoreMixin
real_method = OpenAIVectorStoreMixin.initialize_openai_vector_stores
await real_method(vector_io_adapter)
# All batches should be restored from persistence
assert batch1.id in vector_io_adapter.openai_file_batches # completed, persisted
assert batch2.id in vector_io_adapter.openai_file_batches # cancelled, persisted
assert batch3.id in vector_io_adapter.openai_file_batches # in-progress, restored
# Check their statuses
assert vector_io_adapter.openai_file_batches[batch1.id]["status"] == "completed"
assert vector_io_adapter.openai_file_batches[batch2.id]["status"] == "cancelled"
assert vector_io_adapter.openai_file_batches[batch3.id]["status"] == "in_progress"
# Resume functionality is mocked, so we're only testing persistence
async def test_cleanup_expired_file_batches(vector_io_adapter):
"""Test that expired file batches are cleaned up properly."""
store_id = "vs_1234"
# Setup vector store
vector_io_adapter.openai_vector_stores[store_id] = {
"id": store_id,
"name": "Test Store",
"files": {},
"file_ids": [],
}
# Mock processing to prevent automatic completion
vector_io_adapter.openai_attach_file_to_vector_store = AsyncMock()
vector_io_adapter._process_file_batch_async = AsyncMock()
# Create batches with different ages
import time
current_time = int(time.time())
# Create an old expired batch (10 days old)
old_batch_info = {
"id": "batch_old",
"vector_store_id": store_id,
"status": "completed",
"created_at": current_time - (10 * 24 * 60 * 60), # 10 days ago
"expires_at": current_time - (3 * 24 * 60 * 60), # Expired 3 days ago
"file_ids": ["file_1"],
}
# Create a recent valid batch
new_batch_info = {
"id": "batch_new",
"vector_store_id": store_id,
"status": "completed",
"created_at": current_time - (1 * 24 * 60 * 60), # 1 day ago
"expires_at": current_time + (6 * 24 * 60 * 60), # Expires in 6 days
"file_ids": ["file_2"],
}
# Store both batches in persistent storage
await vector_io_adapter._save_openai_vector_store_file_batch("batch_old", old_batch_info)
await vector_io_adapter._save_openai_vector_store_file_batch("batch_new", new_batch_info)
# Add to in-memory cache
vector_io_adapter.openai_file_batches["batch_old"] = old_batch_info
vector_io_adapter.openai_file_batches["batch_new"] = new_batch_info
# Verify both batches exist before cleanup
assert "batch_old" in vector_io_adapter.openai_file_batches
assert "batch_new" in vector_io_adapter.openai_file_batches
# Run cleanup
await vector_io_adapter._cleanup_expired_file_batches()
# Verify expired batch was removed from memory
assert "batch_old" not in vector_io_adapter.openai_file_batches
assert "batch_new" in vector_io_adapter.openai_file_batches
# Verify expired batch was removed from storage
old_batch_key = "openai_vector_stores_file_batches:v3::batch_old"
new_batch_key = "openai_vector_stores_file_batches:v3::batch_new"
old_stored = await vector_io_adapter.kvstore.get(old_batch_key)
new_stored = await vector_io_adapter.kvstore.get(new_batch_key)
assert old_stored is None # Expired batch should be deleted
assert new_stored is not None # Valid batch should remain
async def test_expired_batch_access_error(vector_io_adapter):
"""Test that accessing expired batches returns clear error message."""
store_id = "vs_1234"
# Setup vector store
vector_io_adapter.openai_vector_stores[store_id] = {
"id": store_id,
"name": "Test Store",
"files": {},
"file_ids": [],
}
# Create an expired batch
import time
current_time = int(time.time())
expired_batch_info = {
"id": "batch_expired",
"vector_store_id": store_id,
"status": "completed",
"created_at": current_time - (10 * 24 * 60 * 60), # 10 days ago
"expires_at": current_time - (3 * 24 * 60 * 60), # Expired 3 days ago
"file_ids": ["file_1"],
}
# Add to in-memory cache (simulating it was loaded before expiration)
vector_io_adapter.openai_file_batches["batch_expired"] = expired_batch_info
# Try to access expired batch
with pytest.raises(ValueError, match="File batch batch_expired has expired after 7 days from creation"):
vector_io_adapter._get_and_validate_batch("batch_expired", store_id)
async def test_max_concurrent_files_per_batch(vector_io_adapter):
"""Test that file batch processing respects MAX_CONCURRENT_FILES_PER_BATCH limit."""
import asyncio
store_id = "vs_1234"
# Setup vector store
vector_io_adapter.openai_vector_stores[store_id] = {
"id": store_id,
"name": "Test Store",
"files": {},
"file_ids": [],
}
active_files = 0
async def mock_attach_file_with_delay(vector_store_id: str, file_id: str, **kwargs):
"""Mock that tracks concurrency and blocks indefinitely to test concurrency limit."""
nonlocal active_files
active_files += 1
# Block indefinitely to test concurrency limit
await asyncio.sleep(float("inf"))
# Replace the attachment method
vector_io_adapter.openai_attach_file_to_vector_store = mock_attach_file_with_delay
# Create a batch with more files than the concurrency limit
file_ids = [f"file_{i}" for i in range(8)] # 8 files, but limit should be 5
batch = await vector_io_adapter.openai_create_vector_store_file_batch(
vector_store_id=store_id, params=OpenAICreateVectorStoreFileBatchRequestWithExtraBody(file_ids=file_ids)
)
# Give time for the semaphore logic to start processing files
await asyncio.sleep(0.2)
# Verify that only MAX_CONCURRENT_FILES_PER_BATCH files are processing concurrently
# The semaphore in _process_files_with_concurrency should limit this
from llama_stack.providers.utils.memory.openai_vector_store_mixin import MAX_CONCURRENT_FILES_PER_BATCH
assert active_files == MAX_CONCURRENT_FILES_PER_BATCH, (
f"Expected {MAX_CONCURRENT_FILES_PER_BATCH} active files, got {active_files}"
)
# Verify batch is in progress
assert batch.status == "in_progress"
assert batch.file_counts.total == 8
assert batch.file_counts.in_progress == 8

View file

@ -13,7 +13,10 @@ from unittest.mock import AsyncMock, MagicMock
import numpy as np
import pytest
from llama_stack.apis.inference.inference import OpenAIEmbeddingData
from llama_stack.apis.inference.inference import (
OpenAIEmbeddingData,
OpenAIEmbeddingsRequestWithExtraBody,
)
from llama_stack.apis.tools import RAGDocument
from llama_stack.apis.vector_io import Chunk
from llama_stack.providers.utils.memory.vector_store import (
@ -226,9 +229,14 @@ class TestVectorDBWithIndex:
await vector_db_with_index.insert_chunks(chunks)
mock_inference_api.openai_embeddings.assert_called_once_with(
"test-model without embeddings", ["Test 1", "Test 2"]
)
# Verify openai_embeddings was called with correct params
mock_inference_api.openai_embeddings.assert_called_once()
call_args = mock_inference_api.openai_embeddings.call_args[0]
assert len(call_args) == 1
params = call_args[0]
assert isinstance(params, OpenAIEmbeddingsRequestWithExtraBody)
assert params.model == "test-model without embeddings"
assert params.input == ["Test 1", "Test 2"]
mock_index.add_chunks.assert_called_once()
args = mock_index.add_chunks.call_args[0]
assert args[0] == chunks
@ -321,9 +329,14 @@ class TestVectorDBWithIndex:
await vector_db_with_index.insert_chunks(chunks)
mock_inference_api.openai_embeddings.assert_called_once_with(
"test-model with partial embeddings", ["Test 1", "Test 3"]
)
# Verify openai_embeddings was called with correct params
mock_inference_api.openai_embeddings.assert_called_once()
call_args = mock_inference_api.openai_embeddings.call_args[0]
assert len(call_args) == 1
params = call_args[0]
assert isinstance(params, OpenAIEmbeddingsRequestWithExtraBody)
assert params.model == "test-model with partial embeddings"
assert params.input == ["Test 1", "Test 3"]
mock_index.add_chunks.assert_called_once()
args = mock_index.add_chunks.call_args[0]
assert len(args[0]) == 3

View file

@ -9,6 +9,7 @@ import pytest
from llama_stack.apis.inference import Model
from llama_stack.apis.vector_dbs import VectorDB
from llama_stack.core.datatypes import VectorDBWithOwner
from llama_stack.core.store.registry import (
KEY_FORMAT,
CachedDiskDistributionRegistry,
@ -116,7 +117,7 @@ async def test_duplicate_provider_registration(cached_disk_dist_registry):
provider_resource_id="test_vector_db_2",
provider_id="baz",
)
await cached_disk_dist_registry.register(original_vector_db)
assert await cached_disk_dist_registry.register(original_vector_db)
duplicate_vector_db = VectorDB(
identifier="test_vector_db_2",
@ -125,7 +126,8 @@ async def test_duplicate_provider_registration(cached_disk_dist_registry):
provider_resource_id="test_vector_db_2",
provider_id="baz", # Same provider_id
)
await cached_disk_dist_registry.register(duplicate_vector_db)
with pytest.raises(ValueError, match="Object of type 'vector_db' and identifier 'test_vector_db_2' already exists"):
await cached_disk_dist_registry.register(duplicate_vector_db)
result = await cached_disk_dist_registry.get("vector_db", "test_vector_db_2")
assert result is not None
@ -229,3 +231,98 @@ async def test_cached_registry_error_handling(sqlite_kvstore):
invalid_obj = await cached_registry.get("vector_db", "invalid_cached_db")
assert invalid_obj is None
async def test_double_registration_identical_objects(disk_dist_registry):
"""Test that registering identical objects succeeds (idempotent)."""
vector_db = VectorDBWithOwner(
identifier="test_vector_db",
embedding_model="all-MiniLM-L6-v2",
embedding_dimension=384,
provider_resource_id="test_vector_db",
provider_id="test-provider",
)
# First registration should succeed
result1 = await disk_dist_registry.register(vector_db)
assert result1 is True
# Second registration of identical object should also succeed (idempotent)
result2 = await disk_dist_registry.register(vector_db)
assert result2 is True
# Verify object exists and is unchanged
retrieved = await disk_dist_registry.get("vector_db", "test_vector_db")
assert retrieved is not None
assert retrieved.identifier == vector_db.identifier
assert retrieved.embedding_model == vector_db.embedding_model
async def test_double_registration_different_objects(disk_dist_registry):
"""Test that registering different objects with same identifier fails."""
vector_db1 = VectorDBWithOwner(
identifier="test_vector_db",
embedding_model="all-MiniLM-L6-v2",
embedding_dimension=384,
provider_resource_id="test_vector_db",
provider_id="test-provider",
)
vector_db2 = VectorDBWithOwner(
identifier="test_vector_db", # Same identifier
embedding_model="different-model", # Different embedding model
embedding_dimension=384,
provider_resource_id="test_vector_db",
provider_id="test-provider",
)
# First registration should succeed
result1 = await disk_dist_registry.register(vector_db1)
assert result1 is True
# Second registration with different data should fail
with pytest.raises(ValueError, match="Object of type 'vector_db' and identifier 'test_vector_db' already exists"):
await disk_dist_registry.register(vector_db2)
# Verify original object is unchanged
retrieved = await disk_dist_registry.get("vector_db", "test_vector_db")
assert retrieved is not None
assert retrieved.embedding_model == "all-MiniLM-L6-v2" # Original value
async def test_double_registration_with_cache(cached_disk_dist_registry):
"""Test double registration behavior with caching enabled."""
from llama_stack.apis.models import ModelType
from llama_stack.core.datatypes import ModelWithOwner
model1 = ModelWithOwner(
identifier="test_model",
provider_resource_id="test_model",
provider_id="test-provider",
model_type=ModelType.llm,
)
model2 = ModelWithOwner(
identifier="test_model", # Same identifier
provider_resource_id="test_model",
provider_id="test-provider",
model_type=ModelType.embedding, # Different type
)
# First registration should succeed and populate cache
result1 = await cached_disk_dist_registry.register(model1)
assert result1 is True
# Verify in cache
cached_model = cached_disk_dist_registry.get_cached("model", "test_model")
assert cached_model is not None
assert cached_model.model_type == ModelType.llm
# Second registration with different data should fail
with pytest.raises(ValueError, match="Object of type 'model' and identifier 'test_model' already exists"):
await cached_disk_dist_registry.register(model2)
# Cache should still contain original model
cached_model_after = cached_disk_dist_registry.get_cached("model", "test_model")
assert cached_model_after is not None
assert cached_model_after.model_type == ModelType.llm

View file

@ -122,7 +122,7 @@ def mock_impls():
@pytest.fixture
def scope_middleware_with_mocks(mock_auth_endpoint):
def middleware_with_mocks(mock_auth_endpoint):
"""Create AuthenticationMiddleware with mocked route implementations"""
mock_app = AsyncMock()
auth_config = AuthenticationConfig(
@ -137,18 +137,20 @@ def scope_middleware_with_mocks(mock_auth_endpoint):
# Mock the route_impls to simulate finding routes with required scopes
from llama_stack.schema_utils import WebMethod
scoped_webmethod = WebMethod(route="/test/scoped", method="POST", required_scope="test.read")
public_webmethod = WebMethod(route="/test/public", method="GET")
routes = {
("POST", "/test/scoped"): WebMethod(route="/test/scoped", method="POST", required_scope="test.read"),
("GET", "/test/public"): WebMethod(route="/test/public", method="GET"),
("GET", "/health"): WebMethod(route="/health", method="GET", require_authentication=False),
("GET", "/version"): WebMethod(route="/version", method="GET", require_authentication=False),
("GET", "/models/list"): WebMethod(route="/models/list", method="GET", require_authentication=True),
}
# Mock the route finding logic
def mock_find_matching_route(method, path, route_impls):
if method == "POST" and path == "/test/scoped":
return None, {}, "/test/scoped", scoped_webmethod
elif method == "GET" and path == "/test/public":
return None, {}, "/test/public", public_webmethod
else:
raise ValueError("No matching route")
webmethod = routes.get((method, path))
if webmethod:
return None, {}, path, webmethod
raise ValueError("No matching route")
import llama_stack.core.server.auth
@ -659,9 +661,9 @@ def test_valid_introspection_with_custom_mapping_authentication(
# Scope-based authorization tests
@patch("httpx.AsyncClient.post", new=mock_post_success_with_scope)
async def test_scope_authorization_success(scope_middleware_with_mocks, valid_api_key):
async def test_scope_authorization_success(middleware_with_mocks, valid_api_key):
"""Test that user with required scope can access protected endpoint"""
middleware, mock_app = scope_middleware_with_mocks
middleware, mock_app = middleware_with_mocks
mock_receive = AsyncMock()
mock_send = AsyncMock()
@ -680,9 +682,9 @@ async def test_scope_authorization_success(scope_middleware_with_mocks, valid_ap
@patch("httpx.AsyncClient.post", new=mock_post_success_no_scope)
async def test_scope_authorization_denied(scope_middleware_with_mocks, valid_api_key):
async def test_scope_authorization_denied(middleware_with_mocks, valid_api_key):
"""Test that user without required scope gets 403 access denied"""
middleware, mock_app = scope_middleware_with_mocks
middleware, mock_app = middleware_with_mocks
mock_receive = AsyncMock()
mock_send = AsyncMock()
@ -710,9 +712,9 @@ async def test_scope_authorization_denied(scope_middleware_with_mocks, valid_api
@patch("httpx.AsyncClient.post", new=mock_post_success_no_scope)
async def test_public_endpoint_no_scope_required(scope_middleware_with_mocks, valid_api_key):
async def test_public_endpoint_no_scope_required(middleware_with_mocks, valid_api_key):
"""Test that public endpoints work without specific scopes"""
middleware, mock_app = scope_middleware_with_mocks
middleware, mock_app = middleware_with_mocks
mock_receive = AsyncMock()
mock_send = AsyncMock()
@ -730,9 +732,9 @@ async def test_public_endpoint_no_scope_required(scope_middleware_with_mocks, va
mock_send.assert_not_called()
async def test_scope_authorization_no_auth_disabled(scope_middleware_with_mocks):
async def test_scope_authorization_no_auth_disabled(middleware_with_mocks):
"""Test that when auth is disabled (no user), scope checks are bypassed"""
middleware, mock_app = scope_middleware_with_mocks
middleware, mock_app = middleware_with_mocks
mock_receive = AsyncMock()
mock_send = AsyncMock()
@ -907,3 +909,41 @@ def test_kubernetes_auth_request_payload(kubernetes_auth_client, valid_token, mo
request_body = call_args[1]["json"]
assert request_body["apiVersion"] == "authentication.k8s.io/v1"
assert request_body["kind"] == "SelfSubjectReview"
async def test_unauthenticated_endpoint_access_health(middleware_with_mocks):
"""Test that /health endpoints can be accessed without authentication"""
middleware, mock_app = middleware_with_mocks
# Test request to /health without auth header (level prefix v1 is added by router)
scope = {"type": "http", "path": "/health", "headers": [], "method": "GET"}
receive = AsyncMock()
send = AsyncMock()
# Should allow the request to proceed without authentication
await middleware(scope, receive, send)
# Verify that the request was passed to the app
mock_app.assert_called_once_with(scope, receive, send)
# Verify that no error response was sent
assert not any(call[0][0].get("status") == 401 for call in send.call_args_list)
async def test_unauthenticated_endpoint_denied_for_other_paths(middleware_with_mocks):
"""Test that endpoints other than /health and /version require authentication"""
middleware, mock_app = middleware_with_mocks
# Test request to /models/list without auth header
scope = {"type": "http", "path": "/models/list", "headers": [], "method": "GET"}
receive = AsyncMock()
send = AsyncMock()
# Should return 401 error
await middleware(scope, receive, send)
# Verify that the app was NOT called
mock_app.assert_not_called()
# Verify that a 401 error response was sent
assert any(call[0][0].get("status") == 401 for call in send.call_args_list)

View file

@ -0,0 +1,30 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from llama_stack.providers.utils.kvstore.config import SqliteKVStoreConfig
from llama_stack.providers.utils.kvstore.sqlite.sqlite import SqliteKVStoreImpl
async def test_memory_kvstore_persistence_behavior():
"""Test that :memory: database doesn't persist across instances."""
config = SqliteKVStoreConfig(db_path=":memory:")
# First instance
store1 = SqliteKVStoreImpl(config)
await store1.initialize()
await store1.set("persist_test", "should_not_persist")
await store1.shutdown()
# Second instance with same config
store2 = SqliteKVStoreImpl(config)
await store2.initialize()
# Data should not be present
result = await store2.get("persist_test")
assert result is None
await store2.shutdown()