mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-10-25 09:05:37 +00:00
Merge remote-tracking branch 'origin/main' into stores
This commit is contained in:
commit
b72154ce5e
1161 changed files with 609896 additions and 42960 deletions
57
tests/unit/core/routers/test_vector_io.py
Normal file
57
tests/unit/core/routers/test_vector_io.py
Normal file
|
|
@ -0,0 +1,57 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from unittest.mock import AsyncMock, Mock
|
||||
|
||||
import pytest
|
||||
|
||||
from llama_stack.apis.vector_io import OpenAICreateVectorStoreRequestWithExtraBody
|
||||
from llama_stack.core.routers.vector_io import VectorIORouter
|
||||
|
||||
|
||||
async def test_single_provider_auto_selection():
|
||||
# provider_id automatically selected during vector store create() when only one provider available
|
||||
mock_routing_table = Mock()
|
||||
mock_routing_table.impls_by_provider_id = {"inline::faiss": "mock_provider"}
|
||||
mock_routing_table.get_all_with_type = AsyncMock(
|
||||
return_value=[
|
||||
Mock(identifier="all-MiniLM-L6-v2", model_type="embedding", metadata={"embedding_dimension": 384})
|
||||
]
|
||||
)
|
||||
mock_routing_table.register_vector_db = AsyncMock(
|
||||
return_value=Mock(identifier="vs_123", provider_id="inline::faiss", provider_resource_id="vs_123")
|
||||
)
|
||||
mock_routing_table.get_provider_impl = AsyncMock(
|
||||
return_value=Mock(openai_create_vector_store=AsyncMock(return_value=Mock(id="vs_123")))
|
||||
)
|
||||
router = VectorIORouter(mock_routing_table)
|
||||
request = OpenAICreateVectorStoreRequestWithExtraBody.model_validate(
|
||||
{"name": "test_store", "embedding_model": "all-MiniLM-L6-v2"}
|
||||
)
|
||||
|
||||
result = await router.openai_create_vector_store(request)
|
||||
assert result.id == "vs_123"
|
||||
|
||||
|
||||
async def test_create_vector_stores_multiple_providers_missing_provider_id_error():
|
||||
# if multiple providers are available, vector store create will error without provider_id
|
||||
mock_routing_table = Mock()
|
||||
mock_routing_table.impls_by_provider_id = {
|
||||
"inline::faiss": "mock_provider_1",
|
||||
"inline::sqlite-vec": "mock_provider_2",
|
||||
}
|
||||
mock_routing_table.get_all_with_type = AsyncMock(
|
||||
return_value=[
|
||||
Mock(identifier="all-MiniLM-L6-v2", model_type="embedding", metadata={"embedding_dimension": 384})
|
||||
]
|
||||
)
|
||||
router = VectorIORouter(mock_routing_table)
|
||||
request = OpenAICreateVectorStoreRequestWithExtraBody.model_validate(
|
||||
{"name": "test_store", "embedding_model": "all-MiniLM-L6-v2"}
|
||||
)
|
||||
|
||||
with pytest.raises(ValueError, match="Multiple vector_io providers available"):
|
||||
await router.openai_create_vector_store(request)
|
||||
|
|
@ -17,7 +17,6 @@ from llama_stack.apis.datatypes import Api
|
|||
from llama_stack.apis.models import Model, ModelType
|
||||
from llama_stack.apis.shields.shields import Shield
|
||||
from llama_stack.apis.tools import ListToolDefsResponse, ToolDef, ToolGroup
|
||||
from llama_stack.apis.vector_dbs import VectorDB
|
||||
from llama_stack.core.datatypes import RegistryEntrySource
|
||||
from llama_stack.core.routing_tables.benchmarks import BenchmarksRoutingTable
|
||||
from llama_stack.core.routing_tables.datasets import DatasetsRoutingTable
|
||||
|
|
@ -25,7 +24,6 @@ from llama_stack.core.routing_tables.models import ModelsRoutingTable
|
|||
from llama_stack.core.routing_tables.scoring_functions import ScoringFunctionsRoutingTable
|
||||
from llama_stack.core.routing_tables.shields import ShieldsRoutingTable
|
||||
from llama_stack.core.routing_tables.toolgroups import ToolGroupsRoutingTable
|
||||
from llama_stack.core.routing_tables.vector_dbs import VectorDBsRoutingTable
|
||||
|
||||
|
||||
class Impl:
|
||||
|
|
@ -146,31 +144,6 @@ class ToolGroupsImpl(Impl):
|
|||
)
|
||||
|
||||
|
||||
class VectorDBImpl(Impl):
|
||||
def __init__(self):
|
||||
super().__init__(Api.vector_io)
|
||||
|
||||
async def register_vector_db(self, vector_db: VectorDB):
|
||||
return vector_db
|
||||
|
||||
async def unregister_vector_db(self, vector_db_id: str):
|
||||
return vector_db_id
|
||||
|
||||
async def openai_create_vector_store(self, **kwargs):
|
||||
import time
|
||||
import uuid
|
||||
|
||||
from llama_stack.apis.vector_io.vector_io import VectorStoreFileCounts, VectorStoreObject
|
||||
|
||||
vector_store_id = kwargs.get("provider_vector_db_id") or f"vs_{uuid.uuid4()}"
|
||||
return VectorStoreObject(
|
||||
id=vector_store_id,
|
||||
name=kwargs.get("name", vector_store_id),
|
||||
created_at=int(time.time()),
|
||||
file_counts=VectorStoreFileCounts(completed=0, cancelled=0, failed=0, in_progress=0, total=0),
|
||||
)
|
||||
|
||||
|
||||
async def test_models_routing_table(cached_disk_dist_registry):
|
||||
table = ModelsRoutingTable({"test_provider": InferenceImpl()}, cached_disk_dist_registry, {})
|
||||
await table.initialize()
|
||||
|
|
@ -201,6 +174,12 @@ async def test_models_routing_table(cached_disk_dist_registry):
|
|||
non_existent = await table.get_object_by_identifier("model", "non-existent-model")
|
||||
assert non_existent is None
|
||||
|
||||
# Test has_model
|
||||
assert await table.has_model("test_provider/test-model")
|
||||
assert await table.has_model("test_provider/test-model-2")
|
||||
assert not await table.has_model("non-existent-model")
|
||||
assert not await table.has_model("test_provider/non-existent-model")
|
||||
|
||||
await table.unregister_model(model_id="test_provider/test-model")
|
||||
await table.unregister_model(model_id="test_provider/test-model-2")
|
||||
|
||||
|
|
@ -257,40 +236,6 @@ async def test_shields_routing_table(cached_disk_dist_registry):
|
|||
await table.unregister_shield(identifier="non-existent")
|
||||
|
||||
|
||||
async def test_vectordbs_routing_table(cached_disk_dist_registry):
|
||||
table = VectorDBsRoutingTable({"test_provider": VectorDBImpl()}, cached_disk_dist_registry, {})
|
||||
await table.initialize()
|
||||
|
||||
m_table = ModelsRoutingTable({"test_provider": InferenceImpl()}, cached_disk_dist_registry, {})
|
||||
await m_table.initialize()
|
||||
await m_table.register_model(
|
||||
model_id="test-model",
|
||||
provider_id="test_provider",
|
||||
metadata={"embedding_dimension": 128},
|
||||
model_type=ModelType.embedding,
|
||||
)
|
||||
|
||||
# Register multiple vector databases and verify listing
|
||||
vdb1 = await table.register_vector_db(vector_db_id="test-vectordb", embedding_model="test_provider/test-model")
|
||||
vdb2 = await table.register_vector_db(vector_db_id="test-vectordb-2", embedding_model="test_provider/test-model")
|
||||
vector_dbs = await table.list_vector_dbs()
|
||||
|
||||
assert len(vector_dbs.data) == 2
|
||||
vector_db_ids = {v.identifier for v in vector_dbs.data}
|
||||
assert vdb1.identifier in vector_db_ids
|
||||
assert vdb2.identifier in vector_db_ids
|
||||
|
||||
# Verify they have UUID-based identifiers
|
||||
assert vdb1.identifier.startswith("vs_")
|
||||
assert vdb2.identifier.startswith("vs_")
|
||||
|
||||
await table.unregister_vector_db(vector_db_id=vdb1.identifier)
|
||||
await table.unregister_vector_db(vector_db_id=vdb2.identifier)
|
||||
|
||||
vector_dbs = await table.list_vector_dbs()
|
||||
assert len(vector_dbs.data) == 0
|
||||
|
||||
|
||||
async def test_datasets_routing_table(cached_disk_dist_registry):
|
||||
table = DatasetsRoutingTable({"localfs": DatasetsImpl()}, cached_disk_dist_registry, {})
|
||||
await table.initialize()
|
||||
|
|
@ -348,6 +293,111 @@ async def test_scoring_functions_routing_table(cached_disk_dist_registry):
|
|||
assert len(scoring_functions_list_after_deletion.data) == 0
|
||||
|
||||
|
||||
async def test_double_registration_models_positive(cached_disk_dist_registry):
|
||||
"""Test that registering the same model twice with identical data succeeds."""
|
||||
table = ModelsRoutingTable({"test_provider": InferenceImpl()}, cached_disk_dist_registry, {})
|
||||
await table.initialize()
|
||||
|
||||
# Register a model
|
||||
await table.register_model(model_id="test-model", provider_id="test_provider", metadata={"param1": "value1"})
|
||||
|
||||
# Register the exact same model again - should succeed (idempotent)
|
||||
await table.register_model(model_id="test-model", provider_id="test_provider", metadata={"param1": "value1"})
|
||||
|
||||
# Verify only one model exists
|
||||
models = await table.list_models()
|
||||
assert len(models.data) == 1
|
||||
assert models.data[0].identifier == "test_provider/test-model"
|
||||
|
||||
|
||||
async def test_double_registration_models_negative(cached_disk_dist_registry):
|
||||
"""Test that registering the same model with different data fails."""
|
||||
table = ModelsRoutingTable({"test_provider": InferenceImpl()}, cached_disk_dist_registry, {})
|
||||
await table.initialize()
|
||||
|
||||
# Register a model with specific metadata
|
||||
await table.register_model(model_id="test-model", provider_id="test_provider", metadata={"param1": "value1"})
|
||||
|
||||
# Try to register the same model with different metadata - should fail
|
||||
with pytest.raises(
|
||||
ValueError, match="Object of type 'model' and identifier 'test_provider/test-model' already exists"
|
||||
):
|
||||
await table.register_model(
|
||||
model_id="test-model", provider_id="test_provider", metadata={"param1": "different_value"}
|
||||
)
|
||||
|
||||
|
||||
async def test_double_registration_scoring_functions_positive(cached_disk_dist_registry):
|
||||
"""Test that registering the same scoring function twice with identical data succeeds."""
|
||||
table = ScoringFunctionsRoutingTable({"test_provider": ScoringFunctionsImpl()}, cached_disk_dist_registry, {})
|
||||
await table.initialize()
|
||||
|
||||
# Register a scoring function
|
||||
await table.register_scoring_function(
|
||||
scoring_fn_id="test-scoring-fn",
|
||||
provider_id="test_provider",
|
||||
description="Test scoring function",
|
||||
return_type=NumberType(),
|
||||
)
|
||||
|
||||
# Register the exact same scoring function again - should succeed (idempotent)
|
||||
await table.register_scoring_function(
|
||||
scoring_fn_id="test-scoring-fn",
|
||||
provider_id="test_provider",
|
||||
description="Test scoring function",
|
||||
return_type=NumberType(),
|
||||
)
|
||||
|
||||
# Verify only one scoring function exists
|
||||
scoring_functions = await table.list_scoring_functions()
|
||||
assert len(scoring_functions.data) == 1
|
||||
assert scoring_functions.data[0].identifier == "test-scoring-fn"
|
||||
|
||||
|
||||
async def test_double_registration_scoring_functions_negative(cached_disk_dist_registry):
|
||||
"""Test that registering the same scoring function with different data fails."""
|
||||
table = ScoringFunctionsRoutingTable({"test_provider": ScoringFunctionsImpl()}, cached_disk_dist_registry, {})
|
||||
await table.initialize()
|
||||
|
||||
# Register a scoring function
|
||||
await table.register_scoring_function(
|
||||
scoring_fn_id="test-scoring-fn",
|
||||
provider_id="test_provider",
|
||||
description="Test scoring function",
|
||||
return_type=NumberType(),
|
||||
)
|
||||
|
||||
# Try to register the same scoring function with different description - should fail
|
||||
with pytest.raises(
|
||||
ValueError, match="Object of type 'scoring_function' and identifier 'test-scoring-fn' already exists"
|
||||
):
|
||||
await table.register_scoring_function(
|
||||
scoring_fn_id="test-scoring-fn",
|
||||
provider_id="test_provider",
|
||||
description="Different description",
|
||||
return_type=NumberType(),
|
||||
)
|
||||
|
||||
|
||||
async def test_double_registration_different_providers(cached_disk_dist_registry):
|
||||
"""Test that registering objects with same ID but different providers succeeds."""
|
||||
impl1 = InferenceImpl()
|
||||
impl2 = InferenceImpl()
|
||||
table = ModelsRoutingTable({"provider1": impl1, "provider2": impl2}, cached_disk_dist_registry, {})
|
||||
await table.initialize()
|
||||
|
||||
# Register same model ID with different providers - should succeed
|
||||
await table.register_model(model_id="shared-model", provider_id="provider1")
|
||||
await table.register_model(model_id="shared-model", provider_id="provider2")
|
||||
|
||||
# Verify both models exist with different identifiers
|
||||
models = await table.list_models()
|
||||
assert len(models.data) == 2
|
||||
model_ids = {m.identifier for m in models.data}
|
||||
assert "provider1/shared-model" in model_ids
|
||||
assert "provider2/shared-model" in model_ids
|
||||
|
||||
|
||||
async def test_benchmarks_routing_table(cached_disk_dist_registry):
|
||||
table = BenchmarksRoutingTable({"test_provider": BenchmarksImpl()}, cached_disk_dist_registry, {})
|
||||
await table.initialize()
|
||||
|
|
|
|||
|
|
@ -1,381 +0,0 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
# Unit tests for the routing tables vector_dbs
|
||||
|
||||
import time
|
||||
import uuid
|
||||
from unittest.mock import AsyncMock
|
||||
|
||||
import pytest
|
||||
|
||||
from llama_stack.apis.datatypes import Api
|
||||
from llama_stack.apis.models import ModelType
|
||||
from llama_stack.apis.vector_dbs.vector_dbs import VectorDB
|
||||
from llama_stack.apis.vector_io.vector_io import (
|
||||
VectorStoreContent,
|
||||
VectorStoreDeleteResponse,
|
||||
VectorStoreFileContentsResponse,
|
||||
VectorStoreFileCounts,
|
||||
VectorStoreFileDeleteResponse,
|
||||
VectorStoreFileObject,
|
||||
VectorStoreObject,
|
||||
VectorStoreSearchResponsePage,
|
||||
)
|
||||
from llama_stack.core.access_control.datatypes import AccessRule, Scope
|
||||
from llama_stack.core.datatypes import User
|
||||
from llama_stack.core.request_headers import request_provider_data_context
|
||||
from llama_stack.core.routing_tables.vector_dbs import VectorDBsRoutingTable
|
||||
from tests.unit.distribution.routers.test_routing_tables import Impl, InferenceImpl, ModelsRoutingTable
|
||||
|
||||
|
||||
class VectorDBImpl(Impl):
|
||||
def __init__(self):
|
||||
super().__init__(Api.vector_io)
|
||||
self.vector_stores = {}
|
||||
|
||||
async def register_vector_db(self, vector_db: VectorDB):
|
||||
return vector_db
|
||||
|
||||
async def unregister_vector_db(self, vector_db_id: str):
|
||||
return vector_db_id
|
||||
|
||||
async def openai_retrieve_vector_store(self, vector_store_id):
|
||||
return VectorStoreObject(
|
||||
id=vector_store_id,
|
||||
name="Test Store",
|
||||
created_at=int(time.time()),
|
||||
file_counts=VectorStoreFileCounts(completed=0, cancelled=0, failed=0, in_progress=0, total=0),
|
||||
)
|
||||
|
||||
async def openai_update_vector_store(self, vector_store_id, **kwargs):
|
||||
return VectorStoreObject(
|
||||
id=vector_store_id,
|
||||
name="Updated Store",
|
||||
created_at=int(time.time()),
|
||||
file_counts=VectorStoreFileCounts(completed=0, cancelled=0, failed=0, in_progress=0, total=0),
|
||||
)
|
||||
|
||||
async def openai_delete_vector_store(self, vector_store_id):
|
||||
return VectorStoreDeleteResponse(id=vector_store_id, object="vector_store.deleted", deleted=True)
|
||||
|
||||
async def openai_search_vector_store(self, vector_store_id, query, **kwargs):
|
||||
return VectorStoreSearchResponsePage(
|
||||
object="vector_store.search_results.page", search_query="query", data=[], has_more=False, next_page=None
|
||||
)
|
||||
|
||||
async def openai_attach_file_to_vector_store(self, vector_store_id, file_id, **kwargs):
|
||||
return VectorStoreFileObject(
|
||||
id=file_id,
|
||||
status="completed",
|
||||
chunking_strategy={"type": "auto"},
|
||||
created_at=int(time.time()),
|
||||
vector_store_id=vector_store_id,
|
||||
)
|
||||
|
||||
async def openai_list_files_in_vector_store(self, vector_store_id, **kwargs):
|
||||
return [
|
||||
VectorStoreFileObject(
|
||||
id="1",
|
||||
status="completed",
|
||||
chunking_strategy={"type": "auto"},
|
||||
created_at=int(time.time()),
|
||||
vector_store_id=vector_store_id,
|
||||
)
|
||||
]
|
||||
|
||||
async def openai_retrieve_vector_store_file(self, vector_store_id, file_id):
|
||||
return VectorStoreFileObject(
|
||||
id=file_id,
|
||||
status="completed",
|
||||
chunking_strategy={"type": "auto"},
|
||||
created_at=int(time.time()),
|
||||
vector_store_id=vector_store_id,
|
||||
)
|
||||
|
||||
async def openai_retrieve_vector_store_file_contents(self, vector_store_id, file_id):
|
||||
return VectorStoreFileContentsResponse(
|
||||
file_id=file_id,
|
||||
filename="Sample File name",
|
||||
attributes={"key": "value"},
|
||||
content=[VectorStoreContent(type="text", text="Sample content")],
|
||||
)
|
||||
|
||||
async def openai_update_vector_store_file(self, vector_store_id, file_id, **kwargs):
|
||||
return VectorStoreFileObject(
|
||||
id=file_id,
|
||||
status="completed",
|
||||
chunking_strategy={"type": "auto"},
|
||||
created_at=int(time.time()),
|
||||
vector_store_id=vector_store_id,
|
||||
)
|
||||
|
||||
async def openai_delete_vector_store_file(self, vector_store_id, file_id):
|
||||
return VectorStoreFileDeleteResponse(id=file_id, deleted=True)
|
||||
|
||||
async def openai_create_vector_store(
|
||||
self,
|
||||
name=None,
|
||||
embedding_model=None,
|
||||
embedding_dimension=None,
|
||||
provider_id=None,
|
||||
provider_vector_db_id=None,
|
||||
**kwargs,
|
||||
):
|
||||
vector_store_id = provider_vector_db_id or f"vs_{uuid.uuid4()}"
|
||||
vector_store = VectorStoreObject(
|
||||
id=vector_store_id,
|
||||
name=name or vector_store_id,
|
||||
created_at=int(time.time()),
|
||||
file_counts=VectorStoreFileCounts(completed=0, cancelled=0, failed=0, in_progress=0, total=0),
|
||||
)
|
||||
self.vector_stores[vector_store_id] = vector_store
|
||||
return vector_store
|
||||
|
||||
async def openai_list_vector_stores(self, **kwargs):
|
||||
from llama_stack.apis.vector_io.vector_io import VectorStoreListResponse
|
||||
|
||||
return VectorStoreListResponse(
|
||||
data=list(self.vector_stores.values()), has_more=False, first_id=None, last_id=None
|
||||
)
|
||||
|
||||
|
||||
async def test_vectordbs_routing_table(cached_disk_dist_registry):
|
||||
n = 10
|
||||
table = VectorDBsRoutingTable({"test_provider": VectorDBImpl()}, cached_disk_dist_registry, {})
|
||||
await table.initialize()
|
||||
|
||||
m_table = ModelsRoutingTable({"test_provider": InferenceImpl()}, cached_disk_dist_registry, {})
|
||||
await m_table.initialize()
|
||||
await m_table.register_model(
|
||||
model_id="test-model",
|
||||
provider_id="test_provider",
|
||||
metadata={"embedding_dimension": 128},
|
||||
model_type=ModelType.embedding,
|
||||
)
|
||||
|
||||
# Register multiple vector databases and verify listing
|
||||
vdb_dict = {}
|
||||
for i in range(n):
|
||||
vdb_dict[i] = await table.register_vector_db(vector_db_id=f"test-vectordb-{i}", embedding_model="test-model")
|
||||
|
||||
vector_dbs = await table.list_vector_dbs()
|
||||
|
||||
assert len(vector_dbs.data) == len(vdb_dict)
|
||||
vector_db_ids = {v.identifier for v in vector_dbs.data}
|
||||
for k in vdb_dict:
|
||||
assert vdb_dict[k].identifier in vector_db_ids
|
||||
for k in vdb_dict:
|
||||
await table.unregister_vector_db(vector_db_id=vdb_dict[k].identifier)
|
||||
|
||||
vector_dbs = await table.list_vector_dbs()
|
||||
assert len(vector_dbs.data) == 0
|
||||
|
||||
|
||||
async def test_vector_db_and_vector_store_id_mapping(cached_disk_dist_registry):
|
||||
n = 10
|
||||
impl = VectorDBImpl()
|
||||
table = VectorDBsRoutingTable({"test_provider": impl}, cached_disk_dist_registry, {})
|
||||
await table.initialize()
|
||||
|
||||
m_table = ModelsRoutingTable({"test_provider": InferenceImpl()}, cached_disk_dist_registry, {})
|
||||
await m_table.initialize()
|
||||
await m_table.register_model(
|
||||
model_id="test-model",
|
||||
provider_id="test_provider",
|
||||
metadata={"embedding_dimension": 128},
|
||||
model_type=ModelType.embedding,
|
||||
)
|
||||
|
||||
vdb_dict = {}
|
||||
for i in range(n):
|
||||
vdb_dict[i] = await table.register_vector_db(vector_db_id=f"test-vectordb-{i}", embedding_model="test-model")
|
||||
|
||||
vector_dbs = await table.list_vector_dbs()
|
||||
vector_db_ids = {v.identifier for v in vector_dbs.data}
|
||||
|
||||
vector_stores = await impl.openai_list_vector_stores()
|
||||
vector_store_ids = {v.id for v in vector_stores.data}
|
||||
|
||||
assert vector_db_ids == vector_store_ids, (
|
||||
f"Vector DB IDs {vector_db_ids} don't match vector store IDs {vector_store_ids}"
|
||||
)
|
||||
|
||||
for vector_store in vector_stores.data:
|
||||
vector_db = await table.get_vector_db(vector_store.id)
|
||||
assert vector_store.name == vector_db.vector_db_name, (
|
||||
f"Vector store name {vector_store.name} doesn't match vector store ID {vector_store.id}"
|
||||
)
|
||||
|
||||
for vector_db_id in vector_db_ids:
|
||||
await table.unregister_vector_db(vector_db_id)
|
||||
|
||||
assert len((await table.list_vector_dbs()).data) == 0
|
||||
|
||||
|
||||
async def test_vector_db_id_becomes_vector_store_name(cached_disk_dist_registry):
|
||||
impl = VectorDBImpl()
|
||||
table = VectorDBsRoutingTable({"test_provider": impl}, cached_disk_dist_registry, {})
|
||||
await table.initialize()
|
||||
|
||||
m_table = ModelsRoutingTable({"test_provider": InferenceImpl()}, cached_disk_dist_registry, {})
|
||||
await m_table.initialize()
|
||||
await m_table.register_model(
|
||||
model_id="test-model",
|
||||
provider_id="test_provider",
|
||||
metadata={"embedding_dimension": 128},
|
||||
model_type=ModelType.embedding,
|
||||
)
|
||||
|
||||
user_provided_id = "my-custom-vector-db"
|
||||
await table.register_vector_db(vector_db_id=user_provided_id, embedding_model="test-model")
|
||||
|
||||
vector_stores = await impl.openai_list_vector_stores()
|
||||
assert len(vector_stores.data) == 1
|
||||
|
||||
vector_store = vector_stores.data[0]
|
||||
|
||||
assert vector_store.name == user_provided_id
|
||||
|
||||
assert vector_store.id.startswith("vs_")
|
||||
assert vector_store.id != user_provided_id
|
||||
|
||||
vector_dbs = await table.list_vector_dbs()
|
||||
assert len(vector_dbs.data) == 1
|
||||
assert vector_dbs.data[0].identifier == vector_store.id
|
||||
|
||||
await table.unregister_vector_db(vector_store.id)
|
||||
|
||||
|
||||
async def test_openai_vector_stores_routing_table_roles(cached_disk_dist_registry):
|
||||
impl = VectorDBImpl()
|
||||
impl.openai_retrieve_vector_store = AsyncMock(return_value="OK")
|
||||
table = VectorDBsRoutingTable({"test_provider": impl}, cached_disk_dist_registry, policy=[])
|
||||
m_table = ModelsRoutingTable({"test_provider": InferenceImpl()}, cached_disk_dist_registry, policy=[])
|
||||
authorized_table = "vs1"
|
||||
authorized_team = "team1"
|
||||
unauthorized_team = "team2"
|
||||
|
||||
await m_table.initialize()
|
||||
await m_table.register_model(
|
||||
model_id="test-model",
|
||||
provider_id="test_provider",
|
||||
metadata={"embedding_dimension": 128},
|
||||
model_type=ModelType.embedding,
|
||||
)
|
||||
|
||||
authorized_user = User(principal="alice", attributes={"roles": [authorized_team]})
|
||||
with request_provider_data_context({}, authorized_user):
|
||||
registered_vdb = await table.register_vector_db(vector_db_id="vs1", embedding_model="test-model")
|
||||
authorized_table = registered_vdb.identifier # Use the actual generated ID
|
||||
|
||||
# Authorized reader
|
||||
with request_provider_data_context({}, authorized_user):
|
||||
res = await table.openai_retrieve_vector_store(authorized_table)
|
||||
assert res == "OK"
|
||||
|
||||
# Authorized updater
|
||||
impl.openai_update_vector_store_file = AsyncMock(return_value="UPDATED")
|
||||
with request_provider_data_context({}, authorized_user):
|
||||
res = await table.openai_update_vector_store_file(authorized_table, file_id="file1", attributes={"foo": "bar"})
|
||||
assert res == "UPDATED"
|
||||
|
||||
# Unauthorized reader
|
||||
unauthorized_user = User(principal="eve", attributes={"roles": [unauthorized_team]})
|
||||
with request_provider_data_context({}, unauthorized_user):
|
||||
with pytest.raises(ValueError):
|
||||
await table.openai_retrieve_vector_store(authorized_table)
|
||||
|
||||
# Unauthorized updater
|
||||
with request_provider_data_context({}, unauthorized_user):
|
||||
with pytest.raises(ValueError):
|
||||
await table.openai_update_vector_store_file(authorized_table, file_id="file1", attributes={"foo": "bar"})
|
||||
|
||||
# Authorized deleter
|
||||
impl.openai_delete_vector_store_file = AsyncMock(return_value="DELETED")
|
||||
with request_provider_data_context({}, authorized_user):
|
||||
res = await table.openai_delete_vector_store_file(authorized_table, file_id="file1")
|
||||
assert res == "DELETED"
|
||||
|
||||
# Unauthorized deleter
|
||||
with request_provider_data_context({}, unauthorized_user):
|
||||
with pytest.raises(ValueError):
|
||||
await table.openai_delete_vector_store_file(authorized_table, file_id="file1")
|
||||
|
||||
|
||||
async def test_openai_vector_stores_routing_table_actions(cached_disk_dist_registry):
|
||||
impl = VectorDBImpl()
|
||||
|
||||
policy = [
|
||||
AccessRule(permit=Scope(actions=["create", "read", "update", "delete"]), when="user with admin in roles"),
|
||||
AccessRule(permit=Scope(actions=["read"]), when="user with reader in roles"),
|
||||
]
|
||||
|
||||
table = VectorDBsRoutingTable({"test_provider": impl}, cached_disk_dist_registry, policy=policy)
|
||||
m_table = ModelsRoutingTable({"test_provider": InferenceImpl()}, cached_disk_dist_registry, policy=[])
|
||||
|
||||
vector_db_id = "vs1"
|
||||
file_id = "file-1"
|
||||
|
||||
admin_user = User(principal="admin", attributes={"roles": ["admin"]})
|
||||
read_only_user = User(principal="reader", attributes={"roles": ["reader"]})
|
||||
no_access_user = User(principal="outsider", attributes={"roles": ["no_access"]})
|
||||
|
||||
await m_table.initialize()
|
||||
await m_table.register_model(
|
||||
model_id="test-model",
|
||||
provider_id="test_provider",
|
||||
metadata={"embedding_dimension": 128},
|
||||
model_type=ModelType.embedding,
|
||||
)
|
||||
|
||||
with request_provider_data_context({}, admin_user):
|
||||
registered_vdb = await table.register_vector_db(vector_db_id=vector_db_id, embedding_model="test-model")
|
||||
vector_db_id = registered_vdb.identifier # Use the actual generated ID
|
||||
|
||||
read_methods = [
|
||||
(table.openai_retrieve_vector_store, (vector_db_id,), {}),
|
||||
(table.openai_search_vector_store, (vector_db_id, "query"), {}),
|
||||
(table.openai_list_files_in_vector_store, (vector_db_id,), {}),
|
||||
(table.openai_retrieve_vector_store_file, (vector_db_id, file_id), {}),
|
||||
(table.openai_retrieve_vector_store_file_contents, (vector_db_id, file_id), {}),
|
||||
]
|
||||
update_methods = [
|
||||
(table.openai_update_vector_store, (vector_db_id,), {"name": "Updated DB"}),
|
||||
(table.openai_attach_file_to_vector_store, (vector_db_id, file_id), {}),
|
||||
(table.openai_update_vector_store_file, (vector_db_id, file_id), {"attributes": {"key": "value"}}),
|
||||
]
|
||||
delete_methods = [
|
||||
(table.openai_delete_vector_store_file, (vector_db_id, file_id), {}),
|
||||
(table.openai_delete_vector_store, (vector_db_id,), {}),
|
||||
]
|
||||
|
||||
for user in [admin_user, read_only_user]:
|
||||
with request_provider_data_context({}, user):
|
||||
for method, args, kwargs in read_methods:
|
||||
result = await method(*args, **kwargs)
|
||||
assert result is not None, f"Read operation failed with user {user.principal}"
|
||||
|
||||
with request_provider_data_context({}, no_access_user):
|
||||
for method, args, kwargs in read_methods:
|
||||
with pytest.raises(ValueError):
|
||||
await method(*args, **kwargs)
|
||||
|
||||
with request_provider_data_context({}, admin_user):
|
||||
for method, args, kwargs in update_methods:
|
||||
result = await method(*args, **kwargs)
|
||||
assert result is not None, "Update operation failed with admin user"
|
||||
|
||||
with request_provider_data_context({}, admin_user):
|
||||
for method, args, kwargs in delete_methods:
|
||||
result = await method(*args, **kwargs)
|
||||
assert result is not None, "Delete operation failed with admin user"
|
||||
|
||||
for user in [read_only_user, no_access_user]:
|
||||
with request_provider_data_context({}, user):
|
||||
for method, args, kwargs in delete_methods:
|
||||
with pytest.raises(ValueError):
|
||||
await method(*args, **kwargs)
|
||||
318
tests/unit/distribution/test_api_recordings.py
Normal file
318
tests/unit/distribution/test_api_recordings.py
Normal file
|
|
@ -0,0 +1,318 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import tempfile
|
||||
from pathlib import Path
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
from openai import AsyncOpenAI
|
||||
|
||||
# Import the real Pydantic response types instead of using Mocks
|
||||
from llama_stack.apis.inference import (
|
||||
OpenAIAssistantMessageParam,
|
||||
OpenAIChatCompletion,
|
||||
OpenAIChoice,
|
||||
OpenAIEmbeddingData,
|
||||
OpenAIEmbeddingsResponse,
|
||||
OpenAIEmbeddingUsage,
|
||||
)
|
||||
from llama_stack.testing.api_recorder import (
|
||||
APIRecordingMode,
|
||||
ResponseStorage,
|
||||
api_recording,
|
||||
normalize_inference_request,
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def temp_storage_dir():
|
||||
"""Create a temporary directory for test recordings."""
|
||||
with tempfile.TemporaryDirectory() as temp_dir:
|
||||
yield Path(temp_dir)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def real_openai_chat_response():
|
||||
"""Real OpenAI chat completion response using proper Pydantic objects."""
|
||||
return OpenAIChatCompletion(
|
||||
id="chatcmpl-test123",
|
||||
choices=[
|
||||
OpenAIChoice(
|
||||
index=0,
|
||||
message=OpenAIAssistantMessageParam(
|
||||
role="assistant", content="Hello! I'm doing well, thank you for asking."
|
||||
),
|
||||
finish_reason="stop",
|
||||
)
|
||||
],
|
||||
created=1234567890,
|
||||
model="llama3.2:3b",
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def real_embeddings_response():
|
||||
"""Real OpenAI embeddings response using proper Pydantic objects."""
|
||||
return OpenAIEmbeddingsResponse(
|
||||
object="list",
|
||||
data=[
|
||||
OpenAIEmbeddingData(object="embedding", embedding=[0.1, 0.2, 0.3], index=0),
|
||||
OpenAIEmbeddingData(object="embedding", embedding=[0.4, 0.5, 0.6], index=1),
|
||||
],
|
||||
model="nomic-embed-text",
|
||||
usage=OpenAIEmbeddingUsage(prompt_tokens=6, total_tokens=6),
|
||||
)
|
||||
|
||||
|
||||
class TestInferenceRecording:
|
||||
"""Test the inference recording system."""
|
||||
|
||||
def test_request_normalization(self):
|
||||
"""Test that request normalization produces consistent hashes."""
|
||||
# Test basic normalization
|
||||
hash1 = normalize_inference_request(
|
||||
"POST",
|
||||
"http://localhost:11434/v1/chat/completions",
|
||||
{},
|
||||
{"model": "llama3.2:3b", "messages": [{"role": "user", "content": "Hello world"}], "temperature": 0.7},
|
||||
)
|
||||
|
||||
# Same request should produce same hash
|
||||
hash2 = normalize_inference_request(
|
||||
"POST",
|
||||
"http://localhost:11434/v1/chat/completions",
|
||||
{},
|
||||
{"model": "llama3.2:3b", "messages": [{"role": "user", "content": "Hello world"}], "temperature": 0.7},
|
||||
)
|
||||
|
||||
assert hash1 == hash2
|
||||
|
||||
# Different content should produce different hash
|
||||
hash3 = normalize_inference_request(
|
||||
"POST",
|
||||
"http://localhost:11434/v1/chat/completions",
|
||||
{},
|
||||
{
|
||||
"model": "llama3.2:3b",
|
||||
"messages": [{"role": "user", "content": "Different message"}],
|
||||
"temperature": 0.7,
|
||||
},
|
||||
)
|
||||
|
||||
assert hash1 != hash3
|
||||
|
||||
def test_request_normalization_edge_cases(self):
|
||||
"""Test request normalization is precise about request content."""
|
||||
# Test that different whitespace produces different hashes (no normalization)
|
||||
hash1 = normalize_inference_request(
|
||||
"POST",
|
||||
"http://test/v1/chat/completions",
|
||||
{},
|
||||
{"messages": [{"role": "user", "content": "Hello world\n\n"}]},
|
||||
)
|
||||
hash2 = normalize_inference_request(
|
||||
"POST", "http://test/v1/chat/completions", {}, {"messages": [{"role": "user", "content": "Hello world"}]}
|
||||
)
|
||||
assert hash1 != hash2 # Different whitespace should produce different hashes
|
||||
|
||||
# Test that different float precision produces different hashes (no rounding)
|
||||
hash3 = normalize_inference_request("POST", "http://test/v1/chat/completions", {}, {"temperature": 0.7000001})
|
||||
hash4 = normalize_inference_request("POST", "http://test/v1/chat/completions", {}, {"temperature": 0.7})
|
||||
assert hash3 == hash4 # Small float precision differences should normalize to the same hash
|
||||
|
||||
# String-embedded decimals with excessive precision should also normalize.
|
||||
body_with_precise_scores = {
|
||||
"messages": [
|
||||
{
|
||||
"role": "tool",
|
||||
"content": "score: 0.7472640164649847",
|
||||
}
|
||||
]
|
||||
}
|
||||
body_with_precise_scores_variation = {
|
||||
"messages": [
|
||||
{
|
||||
"role": "tool",
|
||||
"content": "score: 0.74726414959878",
|
||||
}
|
||||
]
|
||||
}
|
||||
hash5 = normalize_inference_request("POST", "http://test/v1/chat/completions", {}, body_with_precise_scores)
|
||||
hash6 = normalize_inference_request(
|
||||
"POST", "http://test/v1/chat/completions", {}, body_with_precise_scores_variation
|
||||
)
|
||||
assert hash5 == hash6
|
||||
|
||||
body_with_close_scores = {
|
||||
"messages": [
|
||||
{
|
||||
"role": "tool",
|
||||
"content": "score: 0.662477492560699",
|
||||
}
|
||||
]
|
||||
}
|
||||
body_with_close_scores_variation = {
|
||||
"messages": [
|
||||
{
|
||||
"role": "tool",
|
||||
"content": "score: 0.6624775971970099",
|
||||
}
|
||||
]
|
||||
}
|
||||
hash7 = normalize_inference_request("POST", "http://test/v1/chat/completions", {}, body_with_close_scores)
|
||||
hash8 = normalize_inference_request(
|
||||
"POST", "http://test/v1/chat/completions", {}, body_with_close_scores_variation
|
||||
)
|
||||
assert hash7 == hash8
|
||||
|
||||
def test_response_storage(self, temp_storage_dir):
|
||||
"""Test the ResponseStorage class."""
|
||||
temp_storage_dir = temp_storage_dir / "test_response_storage"
|
||||
storage = ResponseStorage(temp_storage_dir)
|
||||
|
||||
# Test storing and retrieving a recording
|
||||
request_hash = "test_hash_123"
|
||||
request_data = {
|
||||
"method": "POST",
|
||||
"url": "http://localhost:11434/v1/chat/completions",
|
||||
"endpoint": "/v1/chat/completions",
|
||||
"model": "llama3.2:3b",
|
||||
}
|
||||
response_data = {"body": {"content": "test response"}, "is_streaming": False}
|
||||
|
||||
storage.store_recording(request_hash, request_data, response_data)
|
||||
|
||||
# Verify file storage and retrieval
|
||||
retrieved = storage.find_recording(request_hash)
|
||||
assert retrieved is not None
|
||||
assert retrieved["request"]["model"] == "llama3.2:3b"
|
||||
assert retrieved["response"]["body"]["content"] == "test response"
|
||||
|
||||
async def test_recording_mode(self, temp_storage_dir, real_openai_chat_response):
|
||||
"""Test that recording mode captures and stores responses."""
|
||||
|
||||
async def mock_create(*args, **kwargs):
|
||||
return real_openai_chat_response
|
||||
|
||||
temp_storage_dir = temp_storage_dir / "test_recording_mode"
|
||||
with patch("openai.resources.chat.completions.AsyncCompletions.create", side_effect=mock_create):
|
||||
with api_recording(mode=APIRecordingMode.RECORD, storage_dir=str(temp_storage_dir)):
|
||||
client = AsyncOpenAI(base_url="http://localhost:11434/v1", api_key="test")
|
||||
|
||||
response = await client.chat.completions.create(
|
||||
model="llama3.2:3b",
|
||||
messages=[{"role": "user", "content": "Hello, how are you?"}],
|
||||
temperature=0.7,
|
||||
max_tokens=50,
|
||||
)
|
||||
|
||||
# Verify the response was returned correctly
|
||||
assert response.choices[0].message.content == "Hello! I'm doing well, thank you for asking."
|
||||
|
||||
# Verify recording was stored
|
||||
storage = ResponseStorage(temp_storage_dir)
|
||||
assert storage._get_test_dir().exists()
|
||||
|
||||
async def test_replay_mode(self, temp_storage_dir, real_openai_chat_response):
|
||||
"""Test that replay mode returns stored responses without making real calls."""
|
||||
|
||||
async def mock_create(*args, **kwargs):
|
||||
return real_openai_chat_response
|
||||
|
||||
temp_storage_dir = temp_storage_dir / "test_replay_mode"
|
||||
# First, record a response
|
||||
with patch("openai.resources.chat.completions.AsyncCompletions.create", side_effect=mock_create):
|
||||
with api_recording(mode=APIRecordingMode.RECORD, storage_dir=str(temp_storage_dir)):
|
||||
client = AsyncOpenAI(base_url="http://localhost:11434/v1", api_key="test")
|
||||
|
||||
response = await client.chat.completions.create(
|
||||
model="llama3.2:3b",
|
||||
messages=[{"role": "user", "content": "Hello, how are you?"}],
|
||||
temperature=0.7,
|
||||
max_tokens=50,
|
||||
)
|
||||
|
||||
# Now test replay mode - should not call the original method
|
||||
with patch("openai.resources.chat.completions.AsyncCompletions.create") as mock_create_patch:
|
||||
with api_recording(mode=APIRecordingMode.REPLAY, storage_dir=str(temp_storage_dir)):
|
||||
client = AsyncOpenAI(base_url="http://localhost:11434/v1", api_key="test")
|
||||
|
||||
response = await client.chat.completions.create(
|
||||
model="llama3.2:3b",
|
||||
messages=[{"role": "user", "content": "Hello, how are you?"}],
|
||||
temperature=0.7,
|
||||
max_tokens=50,
|
||||
)
|
||||
|
||||
# Verify we got the recorded response
|
||||
assert response.choices[0].message.content == "Hello! I'm doing well, thank you for asking."
|
||||
|
||||
# Verify the original method was NOT called
|
||||
mock_create_patch.assert_not_called()
|
||||
|
||||
async def test_replay_missing_recording(self, temp_storage_dir):
|
||||
"""Test that replay mode fails when no recording is found."""
|
||||
temp_storage_dir = temp_storage_dir / "test_replay_missing_recording"
|
||||
with patch("openai.resources.chat.completions.AsyncCompletions.create"):
|
||||
with api_recording(mode=APIRecordingMode.REPLAY, storage_dir=str(temp_storage_dir)):
|
||||
client = AsyncOpenAI(base_url="http://localhost:11434/v1", api_key="test")
|
||||
|
||||
with pytest.raises(RuntimeError, match="Recording not found"):
|
||||
await client.chat.completions.create(
|
||||
model="llama3.2:3b", messages=[{"role": "user", "content": "This was never recorded"}]
|
||||
)
|
||||
|
||||
async def test_embeddings_recording(self, temp_storage_dir, real_embeddings_response):
|
||||
"""Test recording and replay of embeddings calls."""
|
||||
|
||||
async def mock_create(*args, **kwargs):
|
||||
return real_embeddings_response
|
||||
|
||||
temp_storage_dir = temp_storage_dir / "test_embeddings_recording"
|
||||
# Record
|
||||
with patch("openai.resources.embeddings.AsyncEmbeddings.create", side_effect=mock_create):
|
||||
with api_recording(mode=APIRecordingMode.RECORD, storage_dir=str(temp_storage_dir)):
|
||||
client = AsyncOpenAI(base_url="http://localhost:11434/v1", api_key="test")
|
||||
|
||||
response = await client.embeddings.create(
|
||||
model="nomic-embed-text", input=["Hello world", "Test embedding"]
|
||||
)
|
||||
|
||||
assert len(response.data) == 2
|
||||
|
||||
# Replay
|
||||
with patch("openai.resources.embeddings.AsyncEmbeddings.create") as mock_create_patch:
|
||||
with api_recording(mode=APIRecordingMode.REPLAY, storage_dir=str(temp_storage_dir)):
|
||||
client = AsyncOpenAI(base_url="http://localhost:11434/v1", api_key="test")
|
||||
|
||||
response = await client.embeddings.create(
|
||||
model="nomic-embed-text", input=["Hello world", "Test embedding"]
|
||||
)
|
||||
|
||||
# Verify we got the recorded response
|
||||
assert len(response.data) == 2
|
||||
assert response.data[0].embedding == [0.1, 0.2, 0.3]
|
||||
|
||||
# Verify original method was not called
|
||||
mock_create_patch.assert_not_called()
|
||||
|
||||
async def test_live_mode(self, real_openai_chat_response):
|
||||
"""Test that live mode passes through to original methods."""
|
||||
|
||||
async def mock_create(*args, **kwargs):
|
||||
return real_openai_chat_response
|
||||
|
||||
with patch("openai.resources.chat.completions.AsyncCompletions.create", side_effect=mock_create):
|
||||
with api_recording(mode=APIRecordingMode.LIVE, storage_dir="foo"):
|
||||
client = AsyncOpenAI(base_url="http://localhost:11434/v1", api_key="test")
|
||||
|
||||
response = await client.chat.completions.create(
|
||||
model="llama3.2:3b", messages=[{"role": "user", "content": "Hello"}]
|
||||
)
|
||||
|
||||
# Verify the response was returned
|
||||
assert response.choices[0].message.content == "Hello! I'm doing well, thank you for asking."
|
||||
|
|
@ -390,3 +390,467 @@ pip_packages:
|
|||
assert provider.is_external is True
|
||||
# config_class is empty string in partial spec
|
||||
assert provider.config_class == ""
|
||||
|
||||
|
||||
class TestGetExternalProvidersFromModule:
|
||||
"""Test suite for installing external providers from module."""
|
||||
|
||||
def test_stackrunconfig_provider_without_module(self, mock_providers):
|
||||
"""Test that providers without module attribute are skipped."""
|
||||
from llama_stack.core.datatypes import Provider, StackRunConfig
|
||||
from llama_stack.core.distribution import get_external_providers_from_module
|
||||
|
||||
import_module_side_effect = make_import_module_side_effect()
|
||||
|
||||
with patch("importlib.import_module", side_effect=import_module_side_effect):
|
||||
config = StackRunConfig(
|
||||
image_name="test_image",
|
||||
providers={
|
||||
"inference": [
|
||||
Provider(
|
||||
provider_id="no_module",
|
||||
provider_type="no_module",
|
||||
config={},
|
||||
)
|
||||
]
|
||||
},
|
||||
)
|
||||
registry = {Api.inference: {}}
|
||||
result = get_external_providers_from_module(registry, config, building=False)
|
||||
# Should not add anything to registry
|
||||
assert len(result[Api.inference]) == 0
|
||||
|
||||
def test_stackrunconfig_with_version_spec(self, mock_providers):
|
||||
"""Test provider with module containing version spec (e.g., package==1.0.0)."""
|
||||
from types import SimpleNamespace
|
||||
|
||||
from llama_stack.core.datatypes import Provider, StackRunConfig
|
||||
from llama_stack.core.distribution import get_external_providers_from_module
|
||||
from llama_stack.providers.datatypes import ProviderSpec
|
||||
|
||||
fake_spec = ProviderSpec(
|
||||
api=Api.inference,
|
||||
provider_type="versioned_test",
|
||||
config_class="versioned_test.config.VersionedTestConfig",
|
||||
module="versioned_test==1.0.0",
|
||||
)
|
||||
fake_module = SimpleNamespace(get_provider_spec=lambda: fake_spec)
|
||||
|
||||
def import_side_effect(name):
|
||||
if name == "versioned_test.provider":
|
||||
return fake_module
|
||||
raise ModuleNotFoundError(name)
|
||||
|
||||
with patch("importlib.import_module", side_effect=import_side_effect):
|
||||
config = StackRunConfig(
|
||||
image_name="test_image",
|
||||
providers={
|
||||
"inference": [
|
||||
Provider(
|
||||
provider_id="versioned",
|
||||
provider_type="versioned_test",
|
||||
config={},
|
||||
module="versioned_test==1.0.0",
|
||||
)
|
||||
]
|
||||
},
|
||||
)
|
||||
registry = {Api.inference: {}}
|
||||
result = get_external_providers_from_module(registry, config, building=False)
|
||||
assert "versioned_test" in result[Api.inference]
|
||||
assert result[Api.inference]["versioned_test"].module == "versioned_test==1.0.0"
|
||||
|
||||
def test_buildconfig_does_not_import_module(self, mock_providers):
|
||||
"""Test that BuildConfig does not import the module (building=True)."""
|
||||
from llama_stack.core.datatypes import BuildConfig, BuildProvider, DistributionSpec
|
||||
from llama_stack.core.distribution import get_external_providers_from_module
|
||||
|
||||
build_config = BuildConfig(
|
||||
version=2,
|
||||
image_type="container",
|
||||
image_name="test_image",
|
||||
distribution_spec=DistributionSpec(
|
||||
description="test",
|
||||
providers={
|
||||
"inference": [
|
||||
BuildProvider(
|
||||
provider_type="build_test",
|
||||
module="build_test==1.0.0",
|
||||
)
|
||||
]
|
||||
},
|
||||
),
|
||||
)
|
||||
|
||||
# Should not call import_module at all when building
|
||||
with patch("importlib.import_module") as mock_import:
|
||||
registry = {Api.inference: {}}
|
||||
result = get_external_providers_from_module(registry, build_config, building=True)
|
||||
|
||||
# Verify module was NOT imported
|
||||
mock_import.assert_not_called()
|
||||
|
||||
# Verify partial spec was created
|
||||
assert "build_test" in result[Api.inference]
|
||||
provider = result[Api.inference]["build_test"]
|
||||
assert provider.module == "build_test==1.0.0"
|
||||
assert provider.is_external is True
|
||||
assert provider.config_class == ""
|
||||
assert provider.api == Api.inference
|
||||
|
||||
def test_buildconfig_multiple_providers(self, mock_providers):
|
||||
"""Test BuildConfig with multiple providers for the same API."""
|
||||
from llama_stack.core.datatypes import BuildConfig, BuildProvider, DistributionSpec
|
||||
from llama_stack.core.distribution import get_external_providers_from_module
|
||||
|
||||
build_config = BuildConfig(
|
||||
version=2,
|
||||
image_type="container",
|
||||
image_name="test_image",
|
||||
distribution_spec=DistributionSpec(
|
||||
description="test",
|
||||
providers={
|
||||
"inference": [
|
||||
BuildProvider(provider_type="provider1", module="provider1"),
|
||||
BuildProvider(provider_type="provider2", module="provider2"),
|
||||
]
|
||||
},
|
||||
),
|
||||
)
|
||||
|
||||
with patch("importlib.import_module") as mock_import:
|
||||
registry = {Api.inference: {}}
|
||||
result = get_external_providers_from_module(registry, build_config, building=True)
|
||||
|
||||
mock_import.assert_not_called()
|
||||
assert "provider1" in result[Api.inference]
|
||||
assert "provider2" in result[Api.inference]
|
||||
|
||||
def test_distributionspec_does_not_import_module(self, mock_providers):
|
||||
"""Test that DistributionSpec does not import the module (building=True)."""
|
||||
from llama_stack.core.datatypes import BuildProvider, DistributionSpec
|
||||
from llama_stack.core.distribution import get_external_providers_from_module
|
||||
|
||||
dist_spec = DistributionSpec(
|
||||
description="test distribution",
|
||||
providers={
|
||||
"inference": [
|
||||
BuildProvider(
|
||||
provider_type="dist_test",
|
||||
module="dist_test==2.0.0",
|
||||
)
|
||||
]
|
||||
},
|
||||
)
|
||||
|
||||
# Should not call import_module at all when building
|
||||
with patch("importlib.import_module") as mock_import:
|
||||
registry = {Api.inference: {}}
|
||||
result = get_external_providers_from_module(registry, dist_spec, building=True)
|
||||
|
||||
# Verify module was NOT imported
|
||||
mock_import.assert_not_called()
|
||||
|
||||
# Verify partial spec was created
|
||||
assert "dist_test" in result[Api.inference]
|
||||
provider = result[Api.inference]["dist_test"]
|
||||
assert provider.module == "dist_test==2.0.0"
|
||||
assert provider.is_external is True
|
||||
assert provider.config_class == ""
|
||||
|
||||
def test_list_return_from_get_provider_spec(self, mock_providers):
|
||||
"""Test when get_provider_spec returns a list of specs."""
|
||||
from types import SimpleNamespace
|
||||
|
||||
from llama_stack.core.datatypes import Provider, StackRunConfig
|
||||
from llama_stack.core.distribution import get_external_providers_from_module
|
||||
from llama_stack.providers.datatypes import ProviderSpec
|
||||
|
||||
spec1 = ProviderSpec(
|
||||
api=Api.inference,
|
||||
provider_type="list_test",
|
||||
config_class="list_test.config.Config1",
|
||||
module="list_test",
|
||||
)
|
||||
spec2 = ProviderSpec(
|
||||
api=Api.inference,
|
||||
provider_type="list_test_remote",
|
||||
config_class="list_test.config.Config2",
|
||||
module="list_test",
|
||||
)
|
||||
|
||||
fake_module = SimpleNamespace(get_provider_spec=lambda: [spec1, spec2])
|
||||
|
||||
def import_side_effect(name):
|
||||
if name == "list_test.provider":
|
||||
return fake_module
|
||||
raise ModuleNotFoundError(name)
|
||||
|
||||
with patch("importlib.import_module", side_effect=import_side_effect):
|
||||
config = StackRunConfig(
|
||||
image_name="test_image",
|
||||
providers={
|
||||
"inference": [
|
||||
Provider(
|
||||
provider_id="list_test",
|
||||
provider_type="list_test",
|
||||
config={},
|
||||
module="list_test",
|
||||
)
|
||||
]
|
||||
},
|
||||
)
|
||||
registry = {Api.inference: {}}
|
||||
result = get_external_providers_from_module(registry, config, building=False)
|
||||
|
||||
# Only the matching provider_type should be added
|
||||
assert "list_test" in result[Api.inference]
|
||||
assert result[Api.inference]["list_test"].config_class == "list_test.config.Config1"
|
||||
|
||||
def test_list_return_filters_by_provider_type(self, mock_providers):
|
||||
"""Test that list return filters specs by provider_type."""
|
||||
from types import SimpleNamespace
|
||||
|
||||
from llama_stack.core.datatypes import Provider, StackRunConfig
|
||||
from llama_stack.core.distribution import get_external_providers_from_module
|
||||
from llama_stack.providers.datatypes import ProviderSpec
|
||||
|
||||
spec1 = ProviderSpec(
|
||||
api=Api.inference,
|
||||
provider_type="wanted",
|
||||
config_class="test.Config1",
|
||||
module="test",
|
||||
)
|
||||
spec2 = ProviderSpec(
|
||||
api=Api.inference,
|
||||
provider_type="unwanted",
|
||||
config_class="test.Config2",
|
||||
module="test",
|
||||
)
|
||||
|
||||
fake_module = SimpleNamespace(get_provider_spec=lambda: [spec1, spec2])
|
||||
|
||||
def import_side_effect(name):
|
||||
if name == "test.provider":
|
||||
return fake_module
|
||||
raise ModuleNotFoundError(name)
|
||||
|
||||
with patch("importlib.import_module", side_effect=import_side_effect):
|
||||
config = StackRunConfig(
|
||||
image_name="test_image",
|
||||
providers={
|
||||
"inference": [
|
||||
Provider(
|
||||
provider_id="wanted",
|
||||
provider_type="wanted",
|
||||
config={},
|
||||
module="test",
|
||||
)
|
||||
]
|
||||
},
|
||||
)
|
||||
registry = {Api.inference: {}}
|
||||
result = get_external_providers_from_module(registry, config, building=False)
|
||||
|
||||
# Only the matching provider_type should be added
|
||||
assert "wanted" in result[Api.inference]
|
||||
assert "unwanted" not in result[Api.inference]
|
||||
|
||||
def test_list_return_adds_multiple_provider_types(self, mock_providers):
|
||||
"""Test that list return adds multiple different provider_types when config requests them."""
|
||||
from types import SimpleNamespace
|
||||
|
||||
from llama_stack.core.datatypes import Provider, StackRunConfig
|
||||
from llama_stack.core.distribution import get_external_providers_from_module
|
||||
from llama_stack.providers.datatypes import ProviderSpec
|
||||
|
||||
# Module returns both inline and remote variants
|
||||
spec1 = ProviderSpec(
|
||||
api=Api.inference,
|
||||
provider_type="remote::ollama",
|
||||
config_class="test.RemoteConfig",
|
||||
module="test",
|
||||
)
|
||||
spec2 = ProviderSpec(
|
||||
api=Api.inference,
|
||||
provider_type="inline::ollama",
|
||||
config_class="test.InlineConfig",
|
||||
module="test",
|
||||
)
|
||||
|
||||
fake_module = SimpleNamespace(get_provider_spec=lambda: [spec1, spec2])
|
||||
|
||||
def import_side_effect(name):
|
||||
if name == "test.provider":
|
||||
return fake_module
|
||||
raise ModuleNotFoundError(name)
|
||||
|
||||
with patch("importlib.import_module", side_effect=import_side_effect):
|
||||
config = StackRunConfig(
|
||||
image_name="test_image",
|
||||
providers={
|
||||
"inference": [
|
||||
Provider(
|
||||
provider_id="remote_ollama",
|
||||
provider_type="remote::ollama",
|
||||
config={},
|
||||
module="test",
|
||||
),
|
||||
Provider(
|
||||
provider_id="inline_ollama",
|
||||
provider_type="inline::ollama",
|
||||
config={},
|
||||
module="test",
|
||||
),
|
||||
]
|
||||
},
|
||||
)
|
||||
registry = {Api.inference: {}}
|
||||
result = get_external_providers_from_module(registry, config, building=False)
|
||||
|
||||
# Both provider types should be added to registry
|
||||
assert "remote::ollama" in result[Api.inference]
|
||||
assert "inline::ollama" in result[Api.inference]
|
||||
assert result[Api.inference]["remote::ollama"].config_class == "test.RemoteConfig"
|
||||
assert result[Api.inference]["inline::ollama"].config_class == "test.InlineConfig"
|
||||
|
||||
def test_module_not_found_raises_value_error(self, mock_providers):
|
||||
"""Test that ModuleNotFoundError raises ValueError with helpful message."""
|
||||
from llama_stack.core.datatypes import Provider, StackRunConfig
|
||||
from llama_stack.core.distribution import get_external_providers_from_module
|
||||
|
||||
def import_side_effect(name):
|
||||
if name == "missing_module.provider":
|
||||
raise ModuleNotFoundError(name)
|
||||
raise ModuleNotFoundError(name)
|
||||
|
||||
with patch("importlib.import_module", side_effect=import_side_effect):
|
||||
config = StackRunConfig(
|
||||
image_name="test_image",
|
||||
providers={
|
||||
"inference": [
|
||||
Provider(
|
||||
provider_id="missing",
|
||||
provider_type="missing",
|
||||
config={},
|
||||
module="missing_module",
|
||||
)
|
||||
]
|
||||
},
|
||||
)
|
||||
registry = {Api.inference: {}}
|
||||
|
||||
with pytest.raises(ValueError) as exc_info:
|
||||
get_external_providers_from_module(registry, config, building=False)
|
||||
|
||||
assert "get_provider_spec not found" in str(exc_info.value)
|
||||
|
||||
def test_generic_exception_is_raised(self, mock_providers):
|
||||
"""Test that generic exceptions are properly raised."""
|
||||
from types import SimpleNamespace
|
||||
|
||||
from llama_stack.core.datatypes import Provider, StackRunConfig
|
||||
from llama_stack.core.distribution import get_external_providers_from_module
|
||||
|
||||
def bad_spec():
|
||||
raise RuntimeError("Something went wrong")
|
||||
|
||||
fake_module = SimpleNamespace(get_provider_spec=bad_spec)
|
||||
|
||||
def import_side_effect(name):
|
||||
if name == "error_module.provider":
|
||||
return fake_module
|
||||
raise ModuleNotFoundError(name)
|
||||
|
||||
with patch("importlib.import_module", side_effect=import_side_effect):
|
||||
config = StackRunConfig(
|
||||
image_name="test_image",
|
||||
providers={
|
||||
"inference": [
|
||||
Provider(
|
||||
provider_id="error",
|
||||
provider_type="error",
|
||||
config={},
|
||||
module="error_module",
|
||||
)
|
||||
]
|
||||
},
|
||||
)
|
||||
registry = {Api.inference: {}}
|
||||
|
||||
with pytest.raises(RuntimeError) as exc_info:
|
||||
get_external_providers_from_module(registry, config, building=False)
|
||||
|
||||
assert "Something went wrong" in str(exc_info.value)
|
||||
|
||||
def test_empty_provider_list(self, mock_providers):
|
||||
"""Test with empty provider list."""
|
||||
from llama_stack.core.datatypes import StackRunConfig
|
||||
from llama_stack.core.distribution import get_external_providers_from_module
|
||||
|
||||
config = StackRunConfig(
|
||||
image_name="test_image",
|
||||
providers={},
|
||||
)
|
||||
registry = {Api.inference: {}}
|
||||
result = get_external_providers_from_module(registry, config, building=False)
|
||||
|
||||
# Should return registry unchanged
|
||||
assert result == registry
|
||||
assert len(result[Api.inference]) == 0
|
||||
|
||||
def test_multiple_apis_with_providers(self, mock_providers):
|
||||
"""Test multiple APIs with providers."""
|
||||
from types import SimpleNamespace
|
||||
|
||||
from llama_stack.core.datatypes import Provider, StackRunConfig
|
||||
from llama_stack.core.distribution import get_external_providers_from_module
|
||||
from llama_stack.providers.datatypes import ProviderSpec
|
||||
|
||||
inference_spec = ProviderSpec(
|
||||
api=Api.inference,
|
||||
provider_type="inf_test",
|
||||
config_class="inf.Config",
|
||||
module="inf_test",
|
||||
)
|
||||
safety_spec = ProviderSpec(
|
||||
api=Api.safety,
|
||||
provider_type="safe_test",
|
||||
config_class="safe.Config",
|
||||
module="safe_test",
|
||||
)
|
||||
|
||||
def import_side_effect(name):
|
||||
if name == "inf_test.provider":
|
||||
return SimpleNamespace(get_provider_spec=lambda: inference_spec)
|
||||
elif name == "safe_test.provider":
|
||||
return SimpleNamespace(get_provider_spec=lambda: safety_spec)
|
||||
raise ModuleNotFoundError(name)
|
||||
|
||||
with patch("importlib.import_module", side_effect=import_side_effect):
|
||||
config = StackRunConfig(
|
||||
image_name="test_image",
|
||||
providers={
|
||||
"inference": [
|
||||
Provider(
|
||||
provider_id="inf",
|
||||
provider_type="inf_test",
|
||||
config={},
|
||||
module="inf_test",
|
||||
)
|
||||
],
|
||||
"safety": [
|
||||
Provider(
|
||||
provider_id="safe",
|
||||
provider_type="safe_test",
|
||||
config={},
|
||||
module="safe_test",
|
||||
)
|
||||
],
|
||||
},
|
||||
)
|
||||
registry = {Api.inference: {}, Api.safety: {}}
|
||||
result = get_external_providers_from_module(registry, config, building=False)
|
||||
|
||||
assert "inf_test" in result[Api.inference]
|
||||
assert "safe_test" in result[Api.safety]
|
||||
|
|
|
|||
|
|
@ -1,382 +0,0 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import tempfile
|
||||
from pathlib import Path
|
||||
from unittest.mock import AsyncMock, Mock, patch
|
||||
|
||||
import pytest
|
||||
from openai import NOT_GIVEN, AsyncOpenAI
|
||||
from openai.types.model import Model as OpenAIModel
|
||||
|
||||
# Import the real Pydantic response types instead of using Mocks
|
||||
from llama_stack.apis.inference import (
|
||||
OpenAIAssistantMessageParam,
|
||||
OpenAIChatCompletion,
|
||||
OpenAIChoice,
|
||||
OpenAICompletion,
|
||||
OpenAIEmbeddingData,
|
||||
OpenAIEmbeddingsResponse,
|
||||
OpenAIEmbeddingUsage,
|
||||
)
|
||||
from llama_stack.testing.inference_recorder import (
|
||||
InferenceMode,
|
||||
ResponseStorage,
|
||||
inference_recording,
|
||||
normalize_request,
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def temp_storage_dir():
|
||||
"""Create a temporary directory for test recordings."""
|
||||
with tempfile.TemporaryDirectory() as temp_dir:
|
||||
yield Path(temp_dir)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def real_openai_chat_response():
|
||||
"""Real OpenAI chat completion response using proper Pydantic objects."""
|
||||
return OpenAIChatCompletion(
|
||||
id="chatcmpl-test123",
|
||||
choices=[
|
||||
OpenAIChoice(
|
||||
index=0,
|
||||
message=OpenAIAssistantMessageParam(
|
||||
role="assistant", content="Hello! I'm doing well, thank you for asking."
|
||||
),
|
||||
finish_reason="stop",
|
||||
)
|
||||
],
|
||||
created=1234567890,
|
||||
model="llama3.2:3b",
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def real_embeddings_response():
|
||||
"""Real OpenAI embeddings response using proper Pydantic objects."""
|
||||
return OpenAIEmbeddingsResponse(
|
||||
object="list",
|
||||
data=[
|
||||
OpenAIEmbeddingData(object="embedding", embedding=[0.1, 0.2, 0.3], index=0),
|
||||
OpenAIEmbeddingData(object="embedding", embedding=[0.4, 0.5, 0.6], index=1),
|
||||
],
|
||||
model="nomic-embed-text",
|
||||
usage=OpenAIEmbeddingUsage(prompt_tokens=6, total_tokens=6),
|
||||
)
|
||||
|
||||
|
||||
class TestInferenceRecording:
|
||||
"""Test the inference recording system."""
|
||||
|
||||
def test_request_normalization(self):
|
||||
"""Test that request normalization produces consistent hashes."""
|
||||
# Test basic normalization
|
||||
hash1 = normalize_request(
|
||||
"POST",
|
||||
"http://localhost:11434/v1/chat/completions",
|
||||
{},
|
||||
{"model": "llama3.2:3b", "messages": [{"role": "user", "content": "Hello world"}], "temperature": 0.7},
|
||||
)
|
||||
|
||||
# Same request should produce same hash
|
||||
hash2 = normalize_request(
|
||||
"POST",
|
||||
"http://localhost:11434/v1/chat/completions",
|
||||
{},
|
||||
{"model": "llama3.2:3b", "messages": [{"role": "user", "content": "Hello world"}], "temperature": 0.7},
|
||||
)
|
||||
|
||||
assert hash1 == hash2
|
||||
|
||||
# Different content should produce different hash
|
||||
hash3 = normalize_request(
|
||||
"POST",
|
||||
"http://localhost:11434/v1/chat/completions",
|
||||
{},
|
||||
{
|
||||
"model": "llama3.2:3b",
|
||||
"messages": [{"role": "user", "content": "Different message"}],
|
||||
"temperature": 0.7,
|
||||
},
|
||||
)
|
||||
|
||||
assert hash1 != hash3
|
||||
|
||||
def test_request_normalization_edge_cases(self):
|
||||
"""Test request normalization is precise about request content."""
|
||||
# Test that different whitespace produces different hashes (no normalization)
|
||||
hash1 = normalize_request(
|
||||
"POST",
|
||||
"http://test/v1/chat/completions",
|
||||
{},
|
||||
{"messages": [{"role": "user", "content": "Hello world\n\n"}]},
|
||||
)
|
||||
hash2 = normalize_request(
|
||||
"POST", "http://test/v1/chat/completions", {}, {"messages": [{"role": "user", "content": "Hello world"}]}
|
||||
)
|
||||
assert hash1 != hash2 # Different whitespace should produce different hashes
|
||||
|
||||
# Test that different float precision produces different hashes (no rounding)
|
||||
hash3 = normalize_request("POST", "http://test/v1/chat/completions", {}, {"temperature": 0.7000001})
|
||||
hash4 = normalize_request("POST", "http://test/v1/chat/completions", {}, {"temperature": 0.7})
|
||||
assert hash3 != hash4 # Different precision should produce different hashes
|
||||
|
||||
def test_response_storage(self, temp_storage_dir):
|
||||
"""Test the ResponseStorage class."""
|
||||
temp_storage_dir = temp_storage_dir / "test_response_storage"
|
||||
storage = ResponseStorage(temp_storage_dir)
|
||||
|
||||
# Test storing and retrieving a recording
|
||||
request_hash = "test_hash_123"
|
||||
request_data = {
|
||||
"method": "POST",
|
||||
"url": "http://localhost:11434/v1/chat/completions",
|
||||
"endpoint": "/v1/chat/completions",
|
||||
"model": "llama3.2:3b",
|
||||
}
|
||||
response_data = {"body": {"content": "test response"}, "is_streaming": False}
|
||||
|
||||
storage.store_recording(request_hash, request_data, response_data)
|
||||
|
||||
# Verify file storage and retrieval
|
||||
retrieved = storage.find_recording(request_hash)
|
||||
assert retrieved is not None
|
||||
assert retrieved["request"]["model"] == "llama3.2:3b"
|
||||
assert retrieved["response"]["body"]["content"] == "test response"
|
||||
|
||||
async def test_recording_mode(self, temp_storage_dir, real_openai_chat_response):
|
||||
"""Test that recording mode captures and stores responses."""
|
||||
temp_storage_dir = temp_storage_dir / "test_recording_mode"
|
||||
with inference_recording(mode=InferenceMode.RECORD, storage_dir=str(temp_storage_dir)):
|
||||
client = AsyncOpenAI(base_url="http://localhost:11434/v1", api_key="test")
|
||||
client.chat.completions._post = AsyncMock(return_value=real_openai_chat_response)
|
||||
|
||||
response = await client.chat.completions.create(
|
||||
model="llama3.2:3b",
|
||||
messages=[{"role": "user", "content": "Hello, how are you?"}],
|
||||
temperature=0.7,
|
||||
max_tokens=50,
|
||||
user=NOT_GIVEN,
|
||||
)
|
||||
|
||||
# Verify the response was returned correctly
|
||||
assert response.choices[0].message.content == "Hello! I'm doing well, thank you for asking."
|
||||
client.chat.completions._post.assert_called_once()
|
||||
|
||||
# Verify recording was stored
|
||||
storage = ResponseStorage(temp_storage_dir)
|
||||
dir = storage._get_test_dir()
|
||||
assert dir.exists()
|
||||
|
||||
async def test_replay_mode(self, temp_storage_dir, real_openai_chat_response):
|
||||
"""Test that replay mode returns stored responses without making real calls."""
|
||||
temp_storage_dir = temp_storage_dir / "test_replay_mode"
|
||||
# First, record a response
|
||||
with inference_recording(mode=InferenceMode.RECORD, storage_dir=str(temp_storage_dir)):
|
||||
client = AsyncOpenAI(base_url="http://localhost:11434/v1", api_key="test")
|
||||
client.chat.completions._post = AsyncMock(return_value=real_openai_chat_response)
|
||||
|
||||
response = await client.chat.completions.create(
|
||||
model="llama3.2:3b",
|
||||
messages=[{"role": "user", "content": "Hello, how are you?"}],
|
||||
temperature=0.7,
|
||||
max_tokens=50,
|
||||
user=NOT_GIVEN,
|
||||
)
|
||||
client.chat.completions._post.assert_called_once()
|
||||
|
||||
# Now test replay mode - should not call the original method
|
||||
with inference_recording(mode=InferenceMode.REPLAY, storage_dir=str(temp_storage_dir)):
|
||||
client = AsyncOpenAI(base_url="http://localhost:11434/v1", api_key="test")
|
||||
client.chat.completions._post = AsyncMock(return_value=real_openai_chat_response)
|
||||
|
||||
response = await client.chat.completions.create(
|
||||
model="llama3.2:3b",
|
||||
messages=[{"role": "user", "content": "Hello, how are you?"}],
|
||||
temperature=0.7,
|
||||
max_tokens=50,
|
||||
)
|
||||
|
||||
# Verify we got the recorded response
|
||||
assert response.choices[0].message.content == "Hello! I'm doing well, thank you for asking."
|
||||
|
||||
# Verify the original method was NOT called
|
||||
client.chat.completions._post.assert_not_called()
|
||||
|
||||
async def test_replay_mode_models(self, temp_storage_dir):
|
||||
"""Test that replay mode returns stored responses without making real model listing calls."""
|
||||
|
||||
async def _async_iterator(models):
|
||||
for model in models:
|
||||
yield model
|
||||
|
||||
models = [
|
||||
OpenAIModel(id="foo", created=1, object="model", owned_by="test"),
|
||||
OpenAIModel(id="bar", created=2, object="model", owned_by="test"),
|
||||
]
|
||||
|
||||
expected_ids = {m.id for m in models}
|
||||
|
||||
temp_storage_dir = temp_storage_dir / "test_replay_mode_models"
|
||||
|
||||
# baseline - mock works without recording
|
||||
client = AsyncOpenAI(base_url="http://localhost:11434/v1", api_key="test")
|
||||
client.models._get_api_list = Mock(return_value=_async_iterator(models))
|
||||
assert {m.id async for m in client.models.list()} == expected_ids
|
||||
client.models._get_api_list.assert_called_once()
|
||||
|
||||
# record the call
|
||||
with inference_recording(mode=InferenceMode.RECORD, storage_dir=temp_storage_dir):
|
||||
client = AsyncOpenAI(base_url="http://localhost:11434/v1", api_key="test")
|
||||
client.models._get_api_list = Mock(return_value=_async_iterator(models))
|
||||
assert {m.id async for m in client.models.list()} == expected_ids
|
||||
client.models._get_api_list.assert_called_once()
|
||||
|
||||
# replay the call
|
||||
with inference_recording(mode=InferenceMode.REPLAY, storage_dir=temp_storage_dir):
|
||||
client = AsyncOpenAI(base_url="http://localhost:11434/v1", api_key="test")
|
||||
client.models._get_api_list = Mock(return_value=_async_iterator(models))
|
||||
assert {m.id async for m in client.models.list()} == expected_ids
|
||||
client.models._get_api_list.assert_not_called()
|
||||
|
||||
async def test_replay_missing_recording(self, temp_storage_dir):
|
||||
"""Test that replay mode fails when no recording is found."""
|
||||
temp_storage_dir = temp_storage_dir / "test_replay_missing_recording"
|
||||
with patch("openai.resources.chat.completions.AsyncCompletions.create"):
|
||||
with inference_recording(mode=InferenceMode.REPLAY, storage_dir=str(temp_storage_dir)):
|
||||
client = AsyncOpenAI(base_url="http://localhost:11434/v1", api_key="test")
|
||||
|
||||
with pytest.raises(RuntimeError, match="No recorded response found"):
|
||||
await client.chat.completions.create(
|
||||
model="llama3.2:3b", messages=[{"role": "user", "content": "This was never recorded"}]
|
||||
)
|
||||
|
||||
async def test_embeddings_recording(self, temp_storage_dir, real_embeddings_response):
|
||||
"""Test recording and replay of embeddings calls."""
|
||||
|
||||
# baseline - mock works without recording
|
||||
client = AsyncOpenAI(base_url="http://localhost:11434/v1", api_key="test")
|
||||
client.embeddings._post = AsyncMock(return_value=real_embeddings_response)
|
||||
response = await client.embeddings.create(
|
||||
model=real_embeddings_response.model,
|
||||
input=["Hello world", "Test embedding"],
|
||||
encoding_format=NOT_GIVEN,
|
||||
)
|
||||
assert len(response.data) == 2
|
||||
assert response.data[0].embedding == [0.1, 0.2, 0.3]
|
||||
client.embeddings._post.assert_called_once()
|
||||
|
||||
temp_storage_dir = temp_storage_dir / "test_embeddings_recording"
|
||||
# Record
|
||||
with inference_recording(mode=InferenceMode.RECORD, storage_dir=str(temp_storage_dir)):
|
||||
client = AsyncOpenAI(base_url="http://localhost:11434/v1", api_key="test")
|
||||
client.embeddings._post = AsyncMock(return_value=real_embeddings_response)
|
||||
|
||||
response = await client.embeddings.create(
|
||||
model=real_embeddings_response.model,
|
||||
input=["Hello world", "Test embedding"],
|
||||
encoding_format=NOT_GIVEN,
|
||||
dimensions=NOT_GIVEN,
|
||||
user=NOT_GIVEN,
|
||||
)
|
||||
|
||||
assert len(response.data) == 2
|
||||
|
||||
# Replay
|
||||
with inference_recording(mode=InferenceMode.REPLAY, storage_dir=str(temp_storage_dir)):
|
||||
client = AsyncOpenAI(base_url="http://localhost:11434/v1", api_key="test")
|
||||
client.embeddings._post = AsyncMock(return_value=real_embeddings_response)
|
||||
|
||||
response = await client.embeddings.create(
|
||||
model=real_embeddings_response.model,
|
||||
input=["Hello world", "Test embedding"],
|
||||
)
|
||||
|
||||
# Verify we got the recorded response
|
||||
assert len(response.data) == 2
|
||||
assert response.data[0].embedding == [0.1, 0.2, 0.3]
|
||||
|
||||
# Verify original method was not called
|
||||
client.embeddings._post.assert_not_called()
|
||||
|
||||
async def test_completions_recording(self, temp_storage_dir):
|
||||
real_completions_response = OpenAICompletion(
|
||||
id="test_completion",
|
||||
object="text_completion",
|
||||
created=1234567890,
|
||||
model="llama3.2:3b",
|
||||
choices=[
|
||||
{
|
||||
"text": "Hello! I'm doing well, thank you for asking.",
|
||||
"index": 0,
|
||||
"logprobs": None,
|
||||
"finish_reason": "stop",
|
||||
}
|
||||
],
|
||||
)
|
||||
|
||||
temp_storage_dir = temp_storage_dir / "test_completions_recording"
|
||||
|
||||
# baseline - mock works without recording
|
||||
client = AsyncOpenAI(base_url="http://localhost:11434/v1", api_key="test")
|
||||
client.completions._post = AsyncMock(return_value=real_completions_response)
|
||||
response = await client.completions.create(
|
||||
model=real_completions_response.model,
|
||||
prompt="Hello, how are you?",
|
||||
temperature=0.7,
|
||||
max_tokens=50,
|
||||
user=NOT_GIVEN,
|
||||
)
|
||||
assert response.choices[0].text == real_completions_response.choices[0].text
|
||||
client.completions._post.assert_called_once()
|
||||
|
||||
# Record
|
||||
with inference_recording(mode=InferenceMode.RECORD, storage_dir=str(temp_storage_dir)):
|
||||
client = AsyncOpenAI(base_url="http://localhost:11434/v1", api_key="test")
|
||||
client.completions._post = AsyncMock(return_value=real_completions_response)
|
||||
|
||||
response = await client.completions.create(
|
||||
model=real_completions_response.model,
|
||||
prompt="Hello, how are you?",
|
||||
temperature=0.7,
|
||||
max_tokens=50,
|
||||
user=NOT_GIVEN,
|
||||
)
|
||||
|
||||
assert response.choices[0].text == real_completions_response.choices[0].text
|
||||
client.completions._post.assert_called_once()
|
||||
|
||||
# Replay
|
||||
with inference_recording(mode=InferenceMode.REPLAY, storage_dir=str(temp_storage_dir)):
|
||||
client = AsyncOpenAI(base_url="http://localhost:11434/v1", api_key="test")
|
||||
client.completions._post = AsyncMock(return_value=real_completions_response)
|
||||
response = await client.completions.create(
|
||||
model=real_completions_response.model,
|
||||
prompt="Hello, how are you?",
|
||||
temperature=0.7,
|
||||
max_tokens=50,
|
||||
)
|
||||
assert response.choices[0].text == real_completions_response.choices[0].text
|
||||
client.completions._post.assert_not_called()
|
||||
|
||||
async def test_live_mode(self, real_openai_chat_response):
|
||||
"""Test that live mode passes through to original methods."""
|
||||
|
||||
async def mock_create(*args, **kwargs):
|
||||
return real_openai_chat_response
|
||||
|
||||
with patch("openai.resources.chat.completions.AsyncCompletions.create", side_effect=mock_create):
|
||||
with inference_recording(mode=InferenceMode.LIVE, storage_dir="foo"):
|
||||
client = AsyncOpenAI(base_url="http://localhost:11434/v1", api_key="test")
|
||||
|
||||
response = await client.chat.completions.create(
|
||||
model="llama3.2:3b", messages=[{"role": "user", "content": "Hello"}]
|
||||
)
|
||||
|
||||
# Verify the response was returned
|
||||
assert response.choices[0].message.content == "Hello! I'm doing well, thank you for asking."
|
||||
|
|
@ -15,6 +15,7 @@ from llama_stack.apis.agents import (
|
|||
AgentCreateResponse,
|
||||
)
|
||||
from llama_stack.apis.common.responses import PaginatedResponse
|
||||
from llama_stack.apis.conversations import Conversations
|
||||
from llama_stack.apis.inference import Inference
|
||||
from llama_stack.apis.safety import Safety
|
||||
from llama_stack.apis.tools import ListToolDefsResponse, ToolDef, ToolGroups, ToolRuntime
|
||||
|
|
@ -33,6 +34,7 @@ def mock_apis():
|
|||
"safety_api": AsyncMock(spec=Safety),
|
||||
"tool_runtime_api": AsyncMock(spec=ToolRuntime),
|
||||
"tool_groups_api": AsyncMock(spec=ToolGroups),
|
||||
"conversations_api": AsyncMock(spec=Conversations),
|
||||
}
|
||||
|
||||
|
||||
|
|
@ -59,7 +61,8 @@ async def agents_impl(config, mock_apis):
|
|||
mock_apis["safety_api"],
|
||||
mock_apis["tool_runtime_api"],
|
||||
mock_apis["tool_groups_api"],
|
||||
{},
|
||||
mock_apis["conversations_api"],
|
||||
[],
|
||||
)
|
||||
await impl.initialize()
|
||||
yield impl
|
||||
|
|
|
|||
|
|
@ -4,7 +4,7 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from unittest.mock import AsyncMock
|
||||
from unittest.mock import AsyncMock, patch
|
||||
|
||||
import pytest
|
||||
from openai.types.chat.chat_completion_chunk import (
|
||||
|
|
@ -20,6 +20,7 @@ from llama_stack.apis.agents.openai_responses import (
|
|||
ListOpenAIResponseInputItem,
|
||||
OpenAIResponseInputMessageContentText,
|
||||
OpenAIResponseInputToolFunction,
|
||||
OpenAIResponseInputToolMCP,
|
||||
OpenAIResponseInputToolWebSearch,
|
||||
OpenAIResponseMessage,
|
||||
OpenAIResponseOutputMessageContentOutputText,
|
||||
|
|
@ -32,13 +33,14 @@ from llama_stack.apis.agents.openai_responses import (
|
|||
from llama_stack.apis.inference import (
|
||||
OpenAIAssistantMessageParam,
|
||||
OpenAIChatCompletionContentPartTextParam,
|
||||
OpenAIChatCompletionRequestWithExtraBody,
|
||||
OpenAIDeveloperMessageParam,
|
||||
OpenAIJSONSchema,
|
||||
OpenAIResponseFormatJSONObject,
|
||||
OpenAIResponseFormatJSONSchema,
|
||||
OpenAIUserMessageParam,
|
||||
)
|
||||
from llama_stack.apis.tools.tools import ToolDef, ToolGroups, ToolInvocationResult, ToolRuntime
|
||||
from llama_stack.apis.tools.tools import ListToolDefsResponse, ToolDef, ToolGroups, ToolInvocationResult, ToolRuntime
|
||||
from llama_stack.core.access_control.access_control import default_policy
|
||||
from llama_stack.core.datatypes import ResponsesStoreConfig
|
||||
from llama_stack.providers.inline.agents.meta_reference.responses.openai_responses import (
|
||||
|
|
@ -82,9 +84,21 @@ def mock_vector_io_api():
|
|||
return vector_io_api
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_conversations_api():
|
||||
"""Mock conversations API for testing."""
|
||||
mock_api = AsyncMock()
|
||||
return mock_api
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def openai_responses_impl(
|
||||
mock_inference_api, mock_tool_groups_api, mock_tool_runtime_api, mock_responses_store, mock_vector_io_api
|
||||
mock_inference_api,
|
||||
mock_tool_groups_api,
|
||||
mock_tool_runtime_api,
|
||||
mock_responses_store,
|
||||
mock_vector_io_api,
|
||||
mock_conversations_api,
|
||||
):
|
||||
return OpenAIResponsesImpl(
|
||||
inference_api=mock_inference_api,
|
||||
|
|
@ -92,6 +106,7 @@ def openai_responses_impl(
|
|||
tool_runtime_api=mock_tool_runtime_api,
|
||||
responses_store=mock_responses_store,
|
||||
vector_io_api=mock_vector_io_api,
|
||||
conversations_api=mock_conversations_api,
|
||||
)
|
||||
|
||||
|
||||
|
|
@ -147,18 +162,24 @@ async def test_create_openai_response_with_string_input(openai_responses_impl, m
|
|||
chunks = [chunk async for chunk in result]
|
||||
|
||||
mock_inference_api.openai_chat_completion.assert_called_once_with(
|
||||
model=model,
|
||||
messages=[OpenAIUserMessageParam(role="user", content="What is the capital of Ireland?", name=None)],
|
||||
response_format=None,
|
||||
tools=None,
|
||||
stream=True,
|
||||
temperature=0.1,
|
||||
OpenAIChatCompletionRequestWithExtraBody(
|
||||
model=model,
|
||||
messages=[OpenAIUserMessageParam(role="user", content="What is the capital of Ireland?", name=None)],
|
||||
response_format=None,
|
||||
tools=None,
|
||||
stream=True,
|
||||
temperature=0.1,
|
||||
stream_options={
|
||||
"include_usage": True,
|
||||
},
|
||||
)
|
||||
)
|
||||
|
||||
# Should have content part events for text streaming
|
||||
# Expected: response.created, content_part.added, output_text.delta, content_part.done, response.completed
|
||||
assert len(chunks) >= 4
|
||||
# Expected: response.created, response.in_progress, content_part.added, output_text.delta, content_part.done, response.completed
|
||||
assert len(chunks) >= 5
|
||||
assert chunks[0].type == "response.created"
|
||||
assert any(chunk.type == "response.in_progress" for chunk in chunks)
|
||||
|
||||
# Check for content part events
|
||||
content_part_added_events = [c for c in chunks if c.type == "response.content_part.added"]
|
||||
|
|
@ -169,6 +190,14 @@ async def test_create_openai_response_with_string_input(openai_responses_impl, m
|
|||
assert len(content_part_done_events) >= 1, "Should have content_part.done event for text"
|
||||
assert len(text_delta_events) >= 1, "Should have text delta events"
|
||||
|
||||
added_event = content_part_added_events[0]
|
||||
done_event = content_part_done_events[0]
|
||||
assert added_event.content_index == 0
|
||||
assert done_event.content_index == 0
|
||||
assert added_event.output_index == done_event.output_index == 0
|
||||
assert added_event.item_id == done_event.item_id
|
||||
assert added_event.response_id == done_event.response_id
|
||||
|
||||
# Verify final event is completion
|
||||
assert chunks[-1].type == "response.completed"
|
||||
|
||||
|
|
@ -177,6 +206,8 @@ async def test_create_openai_response_with_string_input(openai_responses_impl, m
|
|||
assert final_response.model == model
|
||||
assert len(final_response.output) == 1
|
||||
assert isinstance(final_response.output[0], OpenAIResponseMessage)
|
||||
assert final_response.output[0].id == added_event.item_id
|
||||
assert final_response.id == added_event.response_id
|
||||
|
||||
openai_responses_impl.responses_store.store_response_object.assert_called_once()
|
||||
assert final_response.output[0].content[0].text == "Dublin"
|
||||
|
|
@ -228,13 +259,15 @@ async def test_create_openai_response_with_string_input_with_tools(openai_respon
|
|||
|
||||
# Verify
|
||||
first_call = mock_inference_api.openai_chat_completion.call_args_list[0]
|
||||
assert first_call.kwargs["messages"][0].content == "What is the capital of Ireland?"
|
||||
assert first_call.kwargs["tools"] is not None
|
||||
assert first_call.kwargs["temperature"] == 0.1
|
||||
first_params = first_call.args[0]
|
||||
assert first_params.messages[0].content == "What is the capital of Ireland?"
|
||||
assert first_params.tools is not None
|
||||
assert first_params.temperature == 0.1
|
||||
|
||||
second_call = mock_inference_api.openai_chat_completion.call_args_list[1]
|
||||
assert second_call.kwargs["messages"][-1].content == "Dublin"
|
||||
assert second_call.kwargs["temperature"] == 0.1
|
||||
second_params = second_call.args[0]
|
||||
assert second_params.messages[-1].content == "Dublin"
|
||||
assert second_params.temperature == 0.1
|
||||
|
||||
openai_responses_impl.tool_groups_api.get_tool.assert_called_once_with("web_search")
|
||||
openai_responses_impl.tool_runtime_api.invoke_tool.assert_called_once_with(
|
||||
|
|
@ -303,36 +336,42 @@ async def test_create_openai_response_with_tool_call_type_none(openai_responses_
|
|||
chunks = [chunk async for chunk in result]
|
||||
|
||||
# Verify event types
|
||||
# Should have: response.created, output_item.added, function_call_arguments.delta,
|
||||
# function_call_arguments.done, output_item.done, response.completed
|
||||
assert len(chunks) == 6
|
||||
# Should have: response.created, response.in_progress, output_item.added,
|
||||
# function_call_arguments.delta, function_call_arguments.done, output_item.done, response.completed
|
||||
assert len(chunks) == 7
|
||||
|
||||
event_types = [chunk.type for chunk in chunks]
|
||||
assert event_types == [
|
||||
"response.created",
|
||||
"response.in_progress",
|
||||
"response.output_item.added",
|
||||
"response.function_call_arguments.delta",
|
||||
"response.function_call_arguments.done",
|
||||
"response.output_item.done",
|
||||
"response.completed",
|
||||
]
|
||||
|
||||
# Verify inference API was called correctly (after iterating over result)
|
||||
first_call = mock_inference_api.openai_chat_completion.call_args_list[0]
|
||||
assert first_call.kwargs["messages"][0].content == input_text
|
||||
assert first_call.kwargs["tools"] is not None
|
||||
assert first_call.kwargs["temperature"] == 0.1
|
||||
first_params = first_call.args[0]
|
||||
assert first_params.messages[0].content == input_text
|
||||
assert first_params.tools is not None
|
||||
assert first_params.temperature == 0.1
|
||||
|
||||
# Check response.created event (should have empty output)
|
||||
assert chunks[0].type == "response.created"
|
||||
assert len(chunks[0].response.output) == 0
|
||||
|
||||
# Check streaming events
|
||||
assert chunks[1].type == "response.output_item.added"
|
||||
assert chunks[2].type == "response.function_call_arguments.delta"
|
||||
assert chunks[3].type == "response.function_call_arguments.done"
|
||||
assert chunks[4].type == "response.output_item.done"
|
||||
|
||||
# Check response.completed event (should have the tool call)
|
||||
assert chunks[5].type == "response.completed"
|
||||
assert len(chunks[5].response.output) == 1
|
||||
assert chunks[5].response.output[0].type == "function_call"
|
||||
assert chunks[5].response.output[0].name == "get_weather"
|
||||
completed_chunk = chunks[-1]
|
||||
assert completed_chunk.type == "response.completed"
|
||||
assert len(completed_chunk.response.output) == 1
|
||||
assert completed_chunk.response.output[0].type == "function_call"
|
||||
assert completed_chunk.response.output[0].name == "get_weather"
|
||||
|
||||
|
||||
async def test_create_openai_response_with_tool_call_function_arguments_none(openai_responses_impl, mock_inference_api):
|
||||
"""Test creating an OpenAI response with a tool call response that has a function that does not accept arguments, or arguments set to None when they are not mandatory."""
|
||||
# Setup
|
||||
"""Test creating an OpenAI response with tool calls that omit arguments."""
|
||||
|
||||
input_text = "What is the time right now?"
|
||||
model = "meta-llama/Llama-3.1-8B-Instruct"
|
||||
|
||||
|
|
@ -359,9 +398,22 @@ async def test_create_openai_response_with_tool_call_function_arguments_none(ope
|
|||
object="chat.completion.chunk",
|
||||
)
|
||||
|
||||
mock_inference_api.openai_chat_completion.return_value = fake_stream_toolcall()
|
||||
def assert_common_expectations(chunks) -> None:
|
||||
first_call = mock_inference_api.openai_chat_completion.call_args_list[0]
|
||||
first_params = first_call.args[0]
|
||||
assert first_params.messages[0].content == input_text
|
||||
assert first_params.tools is not None
|
||||
assert first_params.temperature == 0.1
|
||||
assert len(chunks[0].response.output) == 0
|
||||
completed_chunk = chunks[-1]
|
||||
assert completed_chunk.type == "response.completed"
|
||||
assert len(completed_chunk.response.output) == 1
|
||||
assert completed_chunk.response.output[0].type == "function_call"
|
||||
assert completed_chunk.response.output[0].name == "get_current_time"
|
||||
assert completed_chunk.response.output[0].arguments == "{}"
|
||||
|
||||
# Function does not accept arguments
|
||||
mock_inference_api.openai_chat_completion.return_value = fake_stream_toolcall()
|
||||
result = await openai_responses_impl.create_openai_response(
|
||||
input=input_text,
|
||||
model=model,
|
||||
|
|
@ -369,46 +421,23 @@ async def test_create_openai_response_with_tool_call_function_arguments_none(ope
|
|||
temperature=0.1,
|
||||
tools=[
|
||||
OpenAIResponseInputToolFunction(
|
||||
name="get_current_time",
|
||||
description="Get current time for system's timezone",
|
||||
parameters={},
|
||||
name="get_current_time", description="Get current time for system's timezone", parameters={}
|
||||
)
|
||||
],
|
||||
)
|
||||
|
||||
# Check that we got the content from our mocked tool execution result
|
||||
chunks = [chunk async for chunk in result]
|
||||
|
||||
# Verify event types
|
||||
# Should have: response.created, output_item.added, function_call_arguments.delta,
|
||||
# function_call_arguments.done, output_item.done, response.completed
|
||||
assert len(chunks) == 5
|
||||
|
||||
# Verify inference API was called correctly (after iterating over result)
|
||||
first_call = mock_inference_api.openai_chat_completion.call_args_list[0]
|
||||
assert first_call.kwargs["messages"][0].content == input_text
|
||||
assert first_call.kwargs["tools"] is not None
|
||||
assert first_call.kwargs["temperature"] == 0.1
|
||||
|
||||
# Check response.created event (should have empty output)
|
||||
assert chunks[0].type == "response.created"
|
||||
assert len(chunks[0].response.output) == 0
|
||||
|
||||
# Check streaming events
|
||||
assert chunks[1].type == "response.output_item.added"
|
||||
assert chunks[2].type == "response.function_call_arguments.done"
|
||||
assert chunks[3].type == "response.output_item.done"
|
||||
|
||||
# Check response.completed event (should have the tool call with arguments set to "{}")
|
||||
assert chunks[4].type == "response.completed"
|
||||
assert len(chunks[4].response.output) == 1
|
||||
assert chunks[4].response.output[0].type == "function_call"
|
||||
assert chunks[4].response.output[0].name == "get_current_time"
|
||||
assert chunks[4].response.output[0].arguments == "{}"
|
||||
|
||||
mock_inference_api.openai_chat_completion.return_value = fake_stream_toolcall()
|
||||
assert [chunk.type for chunk in chunks] == [
|
||||
"response.created",
|
||||
"response.in_progress",
|
||||
"response.output_item.added",
|
||||
"response.function_call_arguments.done",
|
||||
"response.output_item.done",
|
||||
"response.completed",
|
||||
]
|
||||
assert_common_expectations(chunks)
|
||||
|
||||
# Function accepts optional arguments
|
||||
mock_inference_api.openai_chat_completion.return_value = fake_stream_toolcall()
|
||||
result = await openai_responses_impl.create_openai_response(
|
||||
input=input_text,
|
||||
model=model,
|
||||
|
|
@ -418,42 +447,47 @@ async def test_create_openai_response_with_tool_call_function_arguments_none(ope
|
|||
OpenAIResponseInputToolFunction(
|
||||
name="get_current_time",
|
||||
description="Get current time for system's timezone",
|
||||
parameters={
|
||||
"timezone": "string",
|
||||
},
|
||||
parameters={"timezone": "string"},
|
||||
)
|
||||
],
|
||||
)
|
||||
|
||||
# Check that we got the content from our mocked tool execution result
|
||||
chunks = [chunk async for chunk in result]
|
||||
assert [chunk.type for chunk in chunks] == [
|
||||
"response.created",
|
||||
"response.in_progress",
|
||||
"response.output_item.added",
|
||||
"response.function_call_arguments.done",
|
||||
"response.output_item.done",
|
||||
"response.completed",
|
||||
]
|
||||
assert_common_expectations(chunks)
|
||||
|
||||
# Verify event types
|
||||
# Should have: response.created, output_item.added, function_call_arguments.delta,
|
||||
# function_call_arguments.done, output_item.done, response.completed
|
||||
assert len(chunks) == 5
|
||||
|
||||
# Verify inference API was called correctly (after iterating over result)
|
||||
first_call = mock_inference_api.openai_chat_completion.call_args_list[0]
|
||||
assert first_call.kwargs["messages"][0].content == input_text
|
||||
assert first_call.kwargs["tools"] is not None
|
||||
assert first_call.kwargs["temperature"] == 0.1
|
||||
|
||||
# Check response.created event (should have empty output)
|
||||
assert chunks[0].type == "response.created"
|
||||
assert len(chunks[0].response.output) == 0
|
||||
|
||||
# Check streaming events
|
||||
assert chunks[1].type == "response.output_item.added"
|
||||
assert chunks[2].type == "response.function_call_arguments.done"
|
||||
assert chunks[3].type == "response.output_item.done"
|
||||
|
||||
# Check response.completed event (should have the tool call with arguments set to "{}")
|
||||
assert chunks[4].type == "response.completed"
|
||||
assert len(chunks[4].response.output) == 1
|
||||
assert chunks[4].response.output[0].type == "function_call"
|
||||
assert chunks[4].response.output[0].name == "get_current_time"
|
||||
assert chunks[4].response.output[0].arguments == "{}"
|
||||
# Function accepts optional arguments with additional optional fields
|
||||
mock_inference_api.openai_chat_completion.return_value = fake_stream_toolcall()
|
||||
result = await openai_responses_impl.create_openai_response(
|
||||
input=input_text,
|
||||
model=model,
|
||||
stream=True,
|
||||
temperature=0.1,
|
||||
tools=[
|
||||
OpenAIResponseInputToolFunction(
|
||||
name="get_current_time",
|
||||
description="Get current time for system's timezone",
|
||||
parameters={"timezone": "string", "location": "string"},
|
||||
)
|
||||
],
|
||||
)
|
||||
chunks = [chunk async for chunk in result]
|
||||
assert [chunk.type for chunk in chunks] == [
|
||||
"response.created",
|
||||
"response.in_progress",
|
||||
"response.output_item.added",
|
||||
"response.function_call_arguments.done",
|
||||
"response.output_item.done",
|
||||
"response.completed",
|
||||
]
|
||||
assert_common_expectations(chunks)
|
||||
mock_inference_api.openai_chat_completion.return_value = fake_stream_toolcall()
|
||||
|
||||
|
||||
async def test_create_openai_response_with_multiple_messages(openai_responses_impl, mock_inference_api):
|
||||
|
|
@ -485,7 +519,9 @@ async def test_create_openai_response_with_multiple_messages(openai_responses_im
|
|||
|
||||
# Verify the the correct messages were sent to the inference API i.e.
|
||||
# All of the responses message were convered to the chat completion message objects
|
||||
inference_messages = mock_inference_api.openai_chat_completion.call_args_list[0].kwargs["messages"]
|
||||
call_args = mock_inference_api.openai_chat_completion.call_args_list[0]
|
||||
params = call_args.args[0]
|
||||
inference_messages = params.messages
|
||||
for i, m in enumerate(input_messages):
|
||||
if isinstance(m.content, str):
|
||||
assert inference_messages[i].content == m.content
|
||||
|
|
@ -653,7 +689,8 @@ async def test_create_openai_response_with_instructions(openai_responses_impl, m
|
|||
# Verify
|
||||
mock_inference_api.openai_chat_completion.assert_called_once()
|
||||
call_args = mock_inference_api.openai_chat_completion.call_args
|
||||
sent_messages = call_args.kwargs["messages"]
|
||||
params = call_args.args[0]
|
||||
sent_messages = params.messages
|
||||
|
||||
# Check that instructions were prepended as a system message
|
||||
assert len(sent_messages) == 2
|
||||
|
|
@ -691,7 +728,8 @@ async def test_create_openai_response_with_instructions_and_multiple_messages(
|
|||
# Verify
|
||||
mock_inference_api.openai_chat_completion.assert_called_once()
|
||||
call_args = mock_inference_api.openai_chat_completion.call_args
|
||||
sent_messages = call_args.kwargs["messages"]
|
||||
params = call_args.args[0]
|
||||
sent_messages = params.messages
|
||||
|
||||
# Check that instructions were prepended as a system message
|
||||
assert len(sent_messages) == 4 # 1 system + 3 input messages
|
||||
|
|
@ -751,7 +789,8 @@ async def test_create_openai_response_with_instructions_and_previous_response(
|
|||
# Verify
|
||||
mock_inference_api.openai_chat_completion.assert_called_once()
|
||||
call_args = mock_inference_api.openai_chat_completion.call_args
|
||||
sent_messages = call_args.kwargs["messages"]
|
||||
params = call_args.args[0]
|
||||
sent_messages = params.messages
|
||||
|
||||
# Check that instructions were prepended as a system message
|
||||
assert len(sent_messages) == 4, sent_messages
|
||||
|
|
@ -953,6 +992,58 @@ async def test_store_response_uses_rehydrated_input_with_previous_response(
|
|||
assert result.status == "completed"
|
||||
|
||||
|
||||
@patch("llama_stack.providers.utils.tools.mcp.list_mcp_tools")
|
||||
async def test_reuse_mcp_tool_list(
|
||||
mock_list_mcp_tools, openai_responses_impl, mock_responses_store, mock_inference_api
|
||||
):
|
||||
"""Test that mcp_list_tools can be reused where appropriate."""
|
||||
|
||||
mock_inference_api.openai_chat_completion.return_value = fake_stream()
|
||||
mock_list_mcp_tools.return_value = ListToolDefsResponse(
|
||||
data=[ToolDef(name="test_tool", description="a test tool", input_schema={}, output_schema={})]
|
||||
)
|
||||
|
||||
res1 = await openai_responses_impl.create_openai_response(
|
||||
input="What is 2+2?",
|
||||
model="meta-llama/Llama-3.1-8B-Instruct",
|
||||
store=True,
|
||||
tools=[
|
||||
OpenAIResponseInputToolFunction(name="fake", parameters=None),
|
||||
OpenAIResponseInputToolMCP(server_label="alabel", server_url="aurl"),
|
||||
],
|
||||
)
|
||||
args = mock_responses_store.store_response_object.call_args
|
||||
data = args.kwargs["response_object"].model_dump()
|
||||
data["input"] = [input_item.model_dump() for input_item in args.kwargs["input"]]
|
||||
data["messages"] = [msg.model_dump() for msg in args.kwargs["messages"]]
|
||||
stored = _OpenAIResponseObjectWithInputAndMessages(**data)
|
||||
mock_responses_store.get_response_object.return_value = stored
|
||||
|
||||
res2 = await openai_responses_impl.create_openai_response(
|
||||
previous_response_id=res1.id,
|
||||
input="Now what is 3+3?",
|
||||
model="meta-llama/Llama-3.1-8B-Instruct",
|
||||
store=True,
|
||||
tools=[
|
||||
OpenAIResponseInputToolMCP(server_label="alabel", server_url="aurl"),
|
||||
],
|
||||
)
|
||||
assert len(mock_inference_api.openai_chat_completion.call_args_list) == 2
|
||||
second_call = mock_inference_api.openai_chat_completion.call_args_list[1]
|
||||
second_params = second_call.args[0]
|
||||
tools_seen = second_params.tools
|
||||
assert len(tools_seen) == 1
|
||||
assert tools_seen[0]["function"]["name"] == "test_tool"
|
||||
assert tools_seen[0]["function"]["description"] == "a test tool"
|
||||
|
||||
assert mock_list_mcp_tools.call_count == 1
|
||||
listings = [obj for obj in res2.output if obj.type == "mcp_list_tools"]
|
||||
assert len(listings) == 1
|
||||
assert listings[0].server_label == "alabel"
|
||||
assert len(listings[0].tools) == 1
|
||||
assert listings[0].tools[0].name == "test_tool"
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"text_format, response_format",
|
||||
[
|
||||
|
|
@ -987,8 +1078,9 @@ async def test_create_openai_response_with_text_format(
|
|||
|
||||
# Verify
|
||||
first_call = mock_inference_api.openai_chat_completion.call_args_list[0]
|
||||
assert first_call.kwargs["messages"][0].content == input_text
|
||||
assert first_call.kwargs["response_format"] == response_format
|
||||
first_params = first_call.args[0]
|
||||
assert first_params.messages[0].content == input_text
|
||||
assert first_params.response_format == response_format
|
||||
|
||||
|
||||
async def test_create_openai_response_with_invalid_text_format(openai_responses_impl, mock_inference_api):
|
||||
|
|
|
|||
|
|
@ -0,0 +1,331 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
|
||||
import pytest
|
||||
|
||||
from llama_stack.apis.agents.openai_responses import (
|
||||
OpenAIResponseMessage,
|
||||
OpenAIResponseObject,
|
||||
OpenAIResponseObjectStreamResponseCompleted,
|
||||
OpenAIResponseOutputMessageContentOutputText,
|
||||
)
|
||||
from llama_stack.apis.common.errors import (
|
||||
ConversationNotFoundError,
|
||||
InvalidConversationIdError,
|
||||
)
|
||||
from llama_stack.apis.conversations.conversations import (
|
||||
ConversationItemList,
|
||||
)
|
||||
|
||||
# Import existing fixtures from the main responses test file
|
||||
pytest_plugins = ["tests.unit.providers.agents.meta_reference.test_openai_responses"]
|
||||
|
||||
from llama_stack.providers.inline.agents.meta_reference.responses.openai_responses import (
|
||||
OpenAIResponsesImpl,
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def responses_impl_with_conversations(
|
||||
mock_inference_api,
|
||||
mock_tool_groups_api,
|
||||
mock_tool_runtime_api,
|
||||
mock_responses_store,
|
||||
mock_vector_io_api,
|
||||
mock_conversations_api,
|
||||
):
|
||||
"""Create OpenAIResponsesImpl instance with conversations API."""
|
||||
return OpenAIResponsesImpl(
|
||||
inference_api=mock_inference_api,
|
||||
tool_groups_api=mock_tool_groups_api,
|
||||
tool_runtime_api=mock_tool_runtime_api,
|
||||
responses_store=mock_responses_store,
|
||||
vector_io_api=mock_vector_io_api,
|
||||
conversations_api=mock_conversations_api,
|
||||
)
|
||||
|
||||
|
||||
class TestConversationValidation:
|
||||
"""Test conversation ID validation logic."""
|
||||
|
||||
async def test_nonexistent_conversation_raises_error(
|
||||
self, responses_impl_with_conversations, mock_conversations_api
|
||||
):
|
||||
"""Test that ConversationNotFoundError is raised for non-existent conversation."""
|
||||
conv_id = "conv_nonexistent"
|
||||
|
||||
# Mock conversation not found
|
||||
mock_conversations_api.list.side_effect = ConversationNotFoundError("conv_nonexistent")
|
||||
|
||||
with pytest.raises(ConversationNotFoundError):
|
||||
await responses_impl_with_conversations.create_openai_response(
|
||||
input="Hello", model="test-model", conversation=conv_id, stream=False
|
||||
)
|
||||
|
||||
|
||||
class TestConversationContextLoading:
|
||||
"""Test conversation context loading functionality."""
|
||||
|
||||
async def test_load_conversation_context_simple_input(
|
||||
self, responses_impl_with_conversations, mock_conversations_api
|
||||
):
|
||||
"""Test loading conversation context with simple string input."""
|
||||
conv_id = "conv_test123"
|
||||
input_text = "Hello, how are you?"
|
||||
|
||||
# mock items in chronological order (a consequence of order="asc")
|
||||
mock_conversation_items = ConversationItemList(
|
||||
data=[
|
||||
OpenAIResponseMessage(
|
||||
id="msg_1",
|
||||
content=[{"type": "input_text", "text": "Previous user message"}],
|
||||
role="user",
|
||||
status="completed",
|
||||
type="message",
|
||||
),
|
||||
OpenAIResponseMessage(
|
||||
id="msg_2",
|
||||
content=[{"type": "output_text", "text": "Previous assistant response"}],
|
||||
role="assistant",
|
||||
status="completed",
|
||||
type="message",
|
||||
),
|
||||
],
|
||||
first_id="msg_1",
|
||||
has_more=False,
|
||||
last_id="msg_2",
|
||||
object="list",
|
||||
)
|
||||
|
||||
mock_conversations_api.list.return_value = mock_conversation_items
|
||||
|
||||
result = await responses_impl_with_conversations._load_conversation_context(conv_id, input_text)
|
||||
|
||||
# should have conversation history + new input
|
||||
assert len(result) == 3
|
||||
assert isinstance(result[0], OpenAIResponseMessage)
|
||||
assert result[0].role == "user"
|
||||
assert isinstance(result[1], OpenAIResponseMessage)
|
||||
assert result[1].role == "assistant"
|
||||
assert isinstance(result[2], OpenAIResponseMessage)
|
||||
assert result[2].role == "user"
|
||||
assert result[2].content == input_text
|
||||
|
||||
async def test_load_conversation_context_api_error(self, responses_impl_with_conversations, mock_conversations_api):
|
||||
"""Test loading conversation context when API call fails."""
|
||||
conv_id = "conv_test123"
|
||||
input_text = "Hello"
|
||||
|
||||
mock_conversations_api.list.side_effect = Exception("API Error")
|
||||
|
||||
with pytest.raises(Exception, match="API Error"):
|
||||
await responses_impl_with_conversations._load_conversation_context(conv_id, input_text)
|
||||
|
||||
async def test_load_conversation_context_with_list_input(
|
||||
self, responses_impl_with_conversations, mock_conversations_api
|
||||
):
|
||||
"""Test loading conversation context with list input."""
|
||||
conv_id = "conv_test123"
|
||||
input_messages = [
|
||||
OpenAIResponseMessage(role="user", content="First message"),
|
||||
OpenAIResponseMessage(role="user", content="Second message"),
|
||||
]
|
||||
|
||||
mock_conversations_api.list.return_value = ConversationItemList(
|
||||
data=[], first_id=None, has_more=False, last_id=None, object="list"
|
||||
)
|
||||
|
||||
result = await responses_impl_with_conversations._load_conversation_context(conv_id, input_messages)
|
||||
|
||||
assert len(result) == 2
|
||||
assert result == input_messages
|
||||
|
||||
async def test_load_conversation_context_empty_conversation(
|
||||
self, responses_impl_with_conversations, mock_conversations_api
|
||||
):
|
||||
"""Test loading context from empty conversation."""
|
||||
conv_id = "conv_empty"
|
||||
input_text = "Hello"
|
||||
|
||||
mock_conversations_api.list.return_value = ConversationItemList(
|
||||
data=[], first_id=None, has_more=False, last_id=None, object="list"
|
||||
)
|
||||
|
||||
result = await responses_impl_with_conversations._load_conversation_context(conv_id, input_text)
|
||||
|
||||
assert len(result) == 1
|
||||
assert result[0].role == "user"
|
||||
assert result[0].content == input_text
|
||||
|
||||
|
||||
class TestMessageSyncing:
|
||||
"""Test message syncing to conversations."""
|
||||
|
||||
async def test_sync_response_to_conversation_simple(
|
||||
self, responses_impl_with_conversations, mock_conversations_api
|
||||
):
|
||||
"""Test syncing simple response to conversation."""
|
||||
conv_id = "conv_test123"
|
||||
input_text = "What are the 5 Ds of dodgeball?"
|
||||
|
||||
# mock response
|
||||
mock_response = OpenAIResponseObject(
|
||||
id="resp_123",
|
||||
created_at=1234567890,
|
||||
model="test-model",
|
||||
object="response",
|
||||
output=[
|
||||
OpenAIResponseMessage(
|
||||
id="msg_response",
|
||||
content=[
|
||||
OpenAIResponseOutputMessageContentOutputText(
|
||||
text="The 5 Ds are: Dodge, Duck, Dip, Dive, and Dodge.", type="output_text", annotations=[]
|
||||
)
|
||||
],
|
||||
role="assistant",
|
||||
status="completed",
|
||||
type="message",
|
||||
)
|
||||
],
|
||||
status="completed",
|
||||
)
|
||||
|
||||
await responses_impl_with_conversations._sync_response_to_conversation(conv_id, input_text, mock_response)
|
||||
|
||||
# should call add_items with user input and assistant response
|
||||
mock_conversations_api.add_items.assert_called_once()
|
||||
call_args = mock_conversations_api.add_items.call_args
|
||||
|
||||
assert call_args[0][0] == conv_id # conversation_id
|
||||
items = call_args[0][1] # conversation_items
|
||||
|
||||
assert len(items) == 2
|
||||
# User message
|
||||
assert items[0].type == "message"
|
||||
assert items[0].role == "user"
|
||||
assert items[0].content[0].type == "input_text"
|
||||
assert items[0].content[0].text == input_text
|
||||
|
||||
# Assistant message
|
||||
assert items[1].type == "message"
|
||||
assert items[1].role == "assistant"
|
||||
|
||||
async def test_sync_response_to_conversation_api_error(
|
||||
self, responses_impl_with_conversations, mock_conversations_api
|
||||
):
|
||||
mock_conversations_api.add_items.side_effect = Exception("API Error")
|
||||
mock_response = OpenAIResponseObject(
|
||||
id="resp_123", created_at=1234567890, model="test-model", object="response", output=[], status="completed"
|
||||
)
|
||||
|
||||
# matching the behavior of OpenAI here
|
||||
with pytest.raises(Exception, match="API Error"):
|
||||
await responses_impl_with_conversations._sync_response_to_conversation(
|
||||
"conv_test123", "Hello", mock_response
|
||||
)
|
||||
|
||||
async def test_sync_unsupported_types(self, responses_impl_with_conversations):
|
||||
mock_response = OpenAIResponseObject(
|
||||
id="resp_123", created_at=1234567890, model="test-model", object="response", output=[], status="completed"
|
||||
)
|
||||
|
||||
with pytest.raises(NotImplementedError, match="Unsupported input item type"):
|
||||
await responses_impl_with_conversations._sync_response_to_conversation(
|
||||
"conv_123", [{"not": "message"}], mock_response
|
||||
)
|
||||
|
||||
with pytest.raises(NotImplementedError, match="Unsupported message role: system"):
|
||||
await responses_impl_with_conversations._sync_response_to_conversation(
|
||||
"conv_123", [OpenAIResponseMessage(role="system", content="test")], mock_response
|
||||
)
|
||||
|
||||
|
||||
class TestIntegrationWorkflow:
|
||||
"""Integration tests for the full conversation workflow."""
|
||||
|
||||
async def test_create_response_with_valid_conversation(
|
||||
self, responses_impl_with_conversations, mock_conversations_api
|
||||
):
|
||||
"""Test creating a response with a valid conversation parameter."""
|
||||
mock_conversations_api.list.return_value = ConversationItemList(
|
||||
data=[], first_id=None, has_more=False, last_id=None, object="list"
|
||||
)
|
||||
|
||||
async def mock_streaming_response(*args, **kwargs):
|
||||
mock_response = OpenAIResponseObject(
|
||||
id="resp_test123",
|
||||
created_at=1234567890,
|
||||
model="test-model",
|
||||
object="response",
|
||||
output=[
|
||||
OpenAIResponseMessage(
|
||||
id="msg_response",
|
||||
content=[
|
||||
OpenAIResponseOutputMessageContentOutputText(
|
||||
text="Test response", type="output_text", annotations=[]
|
||||
)
|
||||
],
|
||||
role="assistant",
|
||||
status="completed",
|
||||
type="message",
|
||||
)
|
||||
],
|
||||
status="completed",
|
||||
)
|
||||
|
||||
yield OpenAIResponseObjectStreamResponseCompleted(response=mock_response, type="response.completed")
|
||||
|
||||
responses_impl_with_conversations._create_streaming_response = mock_streaming_response
|
||||
|
||||
input_text = "Hello, how are you?"
|
||||
conversation_id = "conv_test123"
|
||||
|
||||
response = await responses_impl_with_conversations.create_openai_response(
|
||||
input=input_text, model="test-model", conversation=conversation_id, stream=False
|
||||
)
|
||||
|
||||
assert response is not None
|
||||
assert response.id == "resp_test123"
|
||||
|
||||
mock_conversations_api.list.assert_called_once_with(conversation_id, order="asc")
|
||||
|
||||
# Note: conversation sync happens in the streaming response flow,
|
||||
# which is complex to mock fully in this unit test
|
||||
|
||||
async def test_create_response_with_invalid_conversation_id(self, responses_impl_with_conversations):
|
||||
"""Test creating a response with an invalid conversation ID."""
|
||||
with pytest.raises(InvalidConversationIdError) as exc_info:
|
||||
await responses_impl_with_conversations.create_openai_response(
|
||||
input="Hello", model="test-model", conversation="invalid_id", stream=False
|
||||
)
|
||||
|
||||
assert "Expected an ID that begins with 'conv_'" in str(exc_info.value)
|
||||
|
||||
async def test_create_response_with_nonexistent_conversation(
|
||||
self, responses_impl_with_conversations, mock_conversations_api
|
||||
):
|
||||
"""Test creating a response with a non-existent conversation."""
|
||||
mock_conversations_api.list.side_effect = ConversationNotFoundError("conv_nonexistent")
|
||||
|
||||
with pytest.raises(ConversationNotFoundError) as exc_info:
|
||||
await responses_impl_with_conversations.create_openai_response(
|
||||
input="Hello", model="test-model", conversation="conv_nonexistent", stream=False
|
||||
)
|
||||
|
||||
assert "not found" in str(exc_info.value)
|
||||
|
||||
async def test_conversation_and_previous_response_id(
|
||||
self, responses_impl_with_conversations, mock_conversations_api, mock_responses_store
|
||||
):
|
||||
with pytest.raises(ValueError) as exc_info:
|
||||
await responses_impl_with_conversations.create_openai_response(
|
||||
input="test", model="test", conversation="conv_123", previous_response_id="resp_123"
|
||||
)
|
||||
|
||||
assert "Mutually exclusive parameters" in str(exc_info.value)
|
||||
assert "previous_response_id" in str(exc_info.value)
|
||||
assert "conversation" in str(exc_info.value)
|
||||
|
|
@ -8,6 +8,7 @@
|
|||
import pytest
|
||||
|
||||
from llama_stack.apis.agents.openai_responses import (
|
||||
OpenAIResponseAnnotationFileCitation,
|
||||
OpenAIResponseInputFunctionToolCallOutput,
|
||||
OpenAIResponseInputMessageContentImage,
|
||||
OpenAIResponseInputMessageContentText,
|
||||
|
|
@ -35,6 +36,7 @@ from llama_stack.apis.inference import (
|
|||
OpenAIUserMessageParam,
|
||||
)
|
||||
from llama_stack.providers.inline.agents.meta_reference.responses.utils import (
|
||||
_extract_citations_from_text,
|
||||
convert_chat_choice_to_response_message,
|
||||
convert_response_content_to_chat_content,
|
||||
convert_response_input_to_chat_messages,
|
||||
|
|
@ -340,3 +342,26 @@ class TestIsFunctionToolCall:
|
|||
|
||||
result = is_function_tool_call(tool_call, tools)
|
||||
assert result is False
|
||||
|
||||
|
||||
class TestExtractCitationsFromText:
|
||||
def test_extract_citations_and_annotations(self):
|
||||
text = "Start [not-a-file]. New source <|file-abc123|>. "
|
||||
text += "Other source <|file-def456|>? Repeat source <|file-abc123|>! No citation."
|
||||
file_mapping = {"file-abc123": "doc1.pdf", "file-def456": "doc2.txt"}
|
||||
|
||||
annotations, cleaned_text = _extract_citations_from_text(text, file_mapping)
|
||||
|
||||
expected_annotations = [
|
||||
OpenAIResponseAnnotationFileCitation(file_id="file-abc123", filename="doc1.pdf", index=30),
|
||||
OpenAIResponseAnnotationFileCitation(file_id="file-def456", filename="doc2.txt", index=44),
|
||||
OpenAIResponseAnnotationFileCitation(file_id="file-abc123", filename="doc1.pdf", index=59),
|
||||
]
|
||||
expected_clean_text = "Start [not-a-file]. New source. Other source? Repeat source! No citation."
|
||||
|
||||
assert cleaned_text == expected_clean_text
|
||||
assert annotations == expected_annotations
|
||||
# OpenAI cites at the end of the sentence
|
||||
assert cleaned_text[expected_annotations[0].index] == "."
|
||||
assert cleaned_text[expected_annotations[1].index] == "?"
|
||||
assert cleaned_text[expected_annotations[2].index] == "!"
|
||||
|
|
|
|||
|
|
@ -0,0 +1,183 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
|
||||
from llama_stack.apis.agents.openai_responses import (
|
||||
MCPListToolsTool,
|
||||
OpenAIResponseInputToolFileSearch,
|
||||
OpenAIResponseInputToolFunction,
|
||||
OpenAIResponseInputToolMCP,
|
||||
OpenAIResponseInputToolWebSearch,
|
||||
OpenAIResponseObject,
|
||||
OpenAIResponseOutputMessageMCPListTools,
|
||||
OpenAIResponseToolMCP,
|
||||
)
|
||||
from llama_stack.providers.inline.agents.meta_reference.responses.types import ToolContext
|
||||
|
||||
|
||||
class TestToolContext:
|
||||
def test_no_tools(self):
|
||||
tools = []
|
||||
context = ToolContext(tools)
|
||||
previous_response = OpenAIResponseObject(created_at=1234, id="test", model="mymodel", output=[], status="")
|
||||
context.recover_tools_from_previous_response(previous_response)
|
||||
|
||||
assert len(context.tools_to_process) == 0
|
||||
assert len(context.previous_tools) == 0
|
||||
assert len(context.previous_tool_listings) == 0
|
||||
|
||||
def test_no_previous_tools(self):
|
||||
tools = [
|
||||
OpenAIResponseInputToolFileSearch(vector_store_ids=["fake"]),
|
||||
OpenAIResponseInputToolMCP(server_label="label", server_url="url"),
|
||||
]
|
||||
context = ToolContext(tools)
|
||||
previous_response = OpenAIResponseObject(created_at=1234, id="test", model="mymodel", output=[], status="")
|
||||
context.recover_tools_from_previous_response(previous_response)
|
||||
|
||||
assert len(context.tools_to_process) == 2
|
||||
assert len(context.previous_tools) == 0
|
||||
assert len(context.previous_tool_listings) == 0
|
||||
|
||||
def test_reusable_server(self):
|
||||
tools = [
|
||||
OpenAIResponseInputToolFileSearch(vector_store_ids=["fake"]),
|
||||
OpenAIResponseInputToolMCP(server_label="alabel", server_url="aurl"),
|
||||
]
|
||||
context = ToolContext(tools)
|
||||
output = [
|
||||
OpenAIResponseOutputMessageMCPListTools(
|
||||
id="test", server_label="alabel", tools=[MCPListToolsTool(name="test_tool", input_schema={})]
|
||||
)
|
||||
]
|
||||
previous_response = OpenAIResponseObject(created_at=1234, id="test", model="fake", output=output, status="")
|
||||
previous_response.tools = [
|
||||
OpenAIResponseInputToolFileSearch(vector_store_ids=["fake"]),
|
||||
OpenAIResponseToolMCP(server_label="alabel"),
|
||||
]
|
||||
context.recover_tools_from_previous_response(previous_response)
|
||||
|
||||
assert len(context.tools_to_process) == 1
|
||||
assert context.tools_to_process[0].type == "file_search"
|
||||
assert len(context.previous_tools) == 1
|
||||
assert context.previous_tools["test_tool"].server_label == "alabel"
|
||||
assert context.previous_tools["test_tool"].server_url == "aurl"
|
||||
assert len(context.previous_tool_listings) == 1
|
||||
assert len(context.previous_tool_listings[0].tools) == 1
|
||||
assert context.previous_tool_listings[0].server_label == "alabel"
|
||||
|
||||
def test_multiple_reusable_servers(self):
|
||||
tools = [
|
||||
OpenAIResponseInputToolFunction(name="fake", parameters=None),
|
||||
OpenAIResponseInputToolMCP(server_label="anotherlabel", server_url="anotherurl"),
|
||||
OpenAIResponseInputToolWebSearch(),
|
||||
OpenAIResponseInputToolMCP(server_label="alabel", server_url="aurl"),
|
||||
]
|
||||
context = ToolContext(tools)
|
||||
output = [
|
||||
OpenAIResponseOutputMessageMCPListTools(
|
||||
id="test1", server_label="alabel", tools=[MCPListToolsTool(name="test_tool", input_schema={})]
|
||||
),
|
||||
OpenAIResponseOutputMessageMCPListTools(
|
||||
id="test2",
|
||||
server_label="anotherlabel",
|
||||
tools=[MCPListToolsTool(name="some_other_tool", input_schema={})],
|
||||
),
|
||||
]
|
||||
previous_response = OpenAIResponseObject(created_at=1234, id="test", model="fake", output=output, status="")
|
||||
previous_response.tools = [
|
||||
OpenAIResponseInputToolFunction(name="fake", parameters=None),
|
||||
OpenAIResponseToolMCP(server_label="anotherlabel", server_url="anotherurl"),
|
||||
OpenAIResponseInputToolWebSearch(type="web_search"),
|
||||
OpenAIResponseToolMCP(server_label="alabel", server_url="aurl"),
|
||||
]
|
||||
context.recover_tools_from_previous_response(previous_response)
|
||||
|
||||
assert len(context.tools_to_process) == 2
|
||||
assert context.tools_to_process[0].type == "function"
|
||||
assert context.tools_to_process[1].type == "web_search"
|
||||
assert len(context.previous_tools) == 2
|
||||
assert context.previous_tools["test_tool"].server_label == "alabel"
|
||||
assert context.previous_tools["test_tool"].server_url == "aurl"
|
||||
assert context.previous_tools["some_other_tool"].server_label == "anotherlabel"
|
||||
assert context.previous_tools["some_other_tool"].server_url == "anotherurl"
|
||||
assert len(context.previous_tool_listings) == 2
|
||||
assert len(context.previous_tool_listings[0].tools) == 1
|
||||
assert context.previous_tool_listings[0].server_label == "alabel"
|
||||
assert len(context.previous_tool_listings[1].tools) == 1
|
||||
assert context.previous_tool_listings[1].server_label == "anotherlabel"
|
||||
|
||||
def test_multiple_servers_only_one_reusable(self):
|
||||
tools = [
|
||||
OpenAIResponseInputToolFunction(name="fake", parameters=None),
|
||||
OpenAIResponseInputToolMCP(server_label="anotherlabel", server_url="anotherurl"),
|
||||
OpenAIResponseInputToolWebSearch(type="web_search"),
|
||||
OpenAIResponseInputToolMCP(server_label="alabel", server_url="aurl"),
|
||||
]
|
||||
context = ToolContext(tools)
|
||||
output = [
|
||||
OpenAIResponseOutputMessageMCPListTools(
|
||||
id="test2",
|
||||
server_label="anotherlabel",
|
||||
tools=[MCPListToolsTool(name="some_other_tool", input_schema={})],
|
||||
)
|
||||
]
|
||||
previous_response = OpenAIResponseObject(created_at=1234, id="test", model="fake", output=output, status="")
|
||||
previous_response.tools = [
|
||||
OpenAIResponseInputToolFunction(name="fake", parameters=None),
|
||||
OpenAIResponseToolMCP(server_label="anotherlabel", server_url="anotherurl"),
|
||||
OpenAIResponseInputToolWebSearch(type="web_search"),
|
||||
]
|
||||
context.recover_tools_from_previous_response(previous_response)
|
||||
|
||||
assert len(context.tools_to_process) == 3
|
||||
assert context.tools_to_process[0].type == "function"
|
||||
assert context.tools_to_process[1].type == "web_search"
|
||||
assert context.tools_to_process[2].type == "mcp"
|
||||
assert len(context.previous_tools) == 1
|
||||
assert context.previous_tools["some_other_tool"].server_label == "anotherlabel"
|
||||
assert context.previous_tools["some_other_tool"].server_url == "anotherurl"
|
||||
assert len(context.previous_tool_listings) == 1
|
||||
assert len(context.previous_tool_listings[0].tools) == 1
|
||||
assert context.previous_tool_listings[0].server_label == "anotherlabel"
|
||||
|
||||
def test_mismatched_allowed_tools(self):
|
||||
tools = [
|
||||
OpenAIResponseInputToolFunction(name="fake", parameters=None),
|
||||
OpenAIResponseInputToolMCP(server_label="anotherlabel", server_url="anotherurl"),
|
||||
OpenAIResponseInputToolWebSearch(type="web_search"),
|
||||
OpenAIResponseInputToolMCP(server_label="alabel", server_url="aurl", allowed_tools=["test_tool_2"]),
|
||||
]
|
||||
context = ToolContext(tools)
|
||||
output = [
|
||||
OpenAIResponseOutputMessageMCPListTools(
|
||||
id="test1", server_label="alabel", tools=[MCPListToolsTool(name="test_tool_1", input_schema={})]
|
||||
),
|
||||
OpenAIResponseOutputMessageMCPListTools(
|
||||
id="test2",
|
||||
server_label="anotherlabel",
|
||||
tools=[MCPListToolsTool(name="some_other_tool", input_schema={})],
|
||||
),
|
||||
]
|
||||
previous_response = OpenAIResponseObject(created_at=1234, id="test", model="fake", output=output, status="")
|
||||
previous_response.tools = [
|
||||
OpenAIResponseInputToolFunction(name="fake", parameters=None),
|
||||
OpenAIResponseToolMCP(server_label="anotherlabel", server_url="anotherurl"),
|
||||
OpenAIResponseInputToolWebSearch(type="web_search"),
|
||||
OpenAIResponseToolMCP(server_label="alabel", server_url="aurl"),
|
||||
]
|
||||
context.recover_tools_from_previous_response(previous_response)
|
||||
|
||||
assert len(context.tools_to_process) == 3
|
||||
assert context.tools_to_process[0].type == "function"
|
||||
assert context.tools_to_process[1].type == "web_search"
|
||||
assert context.tools_to_process[2].type == "mcp"
|
||||
assert len(context.previous_tools) == 1
|
||||
assert context.previous_tools["some_other_tool"].server_label == "anotherlabel"
|
||||
assert context.previous_tools["some_other_tool"].server_url == "anotherurl"
|
||||
assert len(context.previous_tool_listings) == 1
|
||||
assert len(context.previous_tool_listings[0].tools) == 1
|
||||
assert context.previous_tool_listings[0].server_label == "anotherlabel"
|
||||
|
|
@ -213,7 +213,6 @@ class TestReferenceBatchesImpl:
|
|||
@pytest.mark.parametrize(
|
||||
"endpoint",
|
||||
[
|
||||
"/v1/embeddings",
|
||||
"/v1/invalid/endpoint",
|
||||
"",
|
||||
],
|
||||
|
|
@ -765,3 +764,12 @@ class TestReferenceBatchesImpl:
|
|||
await asyncio.sleep(0.042) # let tasks start
|
||||
|
||||
assert active_batches == 2, f"Expected 2 active batches, got {active_batches}"
|
||||
|
||||
async def test_create_batch_embeddings_endpoint(self, provider):
|
||||
"""Test that batch creation succeeds with embeddings endpoint."""
|
||||
batch = await provider.create_batch(
|
||||
input_file_id="file_123",
|
||||
endpoint="/v1/embeddings",
|
||||
completion_window="24h",
|
||||
)
|
||||
assert batch.endpoint == "/v1/embeddings"
|
||||
|
|
|
|||
|
|
@ -7,6 +7,8 @@
|
|||
import json
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
from llama_stack.core.request_headers import request_provider_data_context
|
||||
from llama_stack.providers.remote.inference.groq.config import GroqConfig
|
||||
from llama_stack.providers.remote.inference.groq.groq import GroqInferenceAdapter
|
||||
|
|
@ -16,74 +18,71 @@ from llama_stack.providers.remote.inference.openai.config import OpenAIConfig
|
|||
from llama_stack.providers.remote.inference.openai.openai import OpenAIInferenceAdapter
|
||||
from llama_stack.providers.remote.inference.together.config import TogetherImplConfig
|
||||
from llama_stack.providers.remote.inference.together.together import TogetherInferenceAdapter
|
||||
from llama_stack.providers.remote.inference.watsonx.config import WatsonXConfig
|
||||
from llama_stack.providers.remote.inference.watsonx.watsonx import WatsonXInferenceAdapter
|
||||
|
||||
|
||||
def test_groq_provider_openai_client_caching():
|
||||
"""Ensure the Groq provider does not cache api keys across client requests"""
|
||||
|
||||
config = GroqConfig()
|
||||
inference_adapter = GroqInferenceAdapter(config)
|
||||
|
||||
inference_adapter.__provider_spec__ = MagicMock()
|
||||
inference_adapter.__provider_spec__.provider_data_validator = (
|
||||
"llama_stack.providers.remote.inference.groq.config.GroqProviderDataValidator"
|
||||
)
|
||||
|
||||
for api_key in ["test1", "test2"]:
|
||||
with request_provider_data_context(
|
||||
{"x-llamastack-provider-data": json.dumps({inference_adapter.provider_data_api_key_field: api_key})}
|
||||
):
|
||||
assert inference_adapter.client.api_key == api_key
|
||||
|
||||
|
||||
def test_openai_provider_openai_client_caching():
|
||||
@pytest.mark.parametrize(
|
||||
"config_cls,adapter_cls,provider_data_validator",
|
||||
[
|
||||
(
|
||||
GroqConfig,
|
||||
GroqInferenceAdapter,
|
||||
"llama_stack.providers.remote.inference.groq.config.GroqProviderDataValidator",
|
||||
),
|
||||
(
|
||||
OpenAIConfig,
|
||||
OpenAIInferenceAdapter,
|
||||
"llama_stack.providers.remote.inference.openai.config.OpenAIProviderDataValidator",
|
||||
),
|
||||
(
|
||||
TogetherImplConfig,
|
||||
TogetherInferenceAdapter,
|
||||
"llama_stack.providers.remote.inference.together.TogetherProviderDataValidator",
|
||||
),
|
||||
(
|
||||
LlamaCompatConfig,
|
||||
LlamaCompatInferenceAdapter,
|
||||
"llama_stack.providers.remote.inference.llama_openai_compat.config.LlamaProviderDataValidator",
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_openai_provider_data_used(config_cls, adapter_cls, provider_data_validator: str):
|
||||
"""Ensure the OpenAI provider does not cache api keys across client requests"""
|
||||
|
||||
config = OpenAIConfig()
|
||||
inference_adapter = OpenAIInferenceAdapter(config)
|
||||
inference_adapter = adapter_cls(config=config_cls())
|
||||
|
||||
inference_adapter.__provider_spec__ = MagicMock()
|
||||
inference_adapter.__provider_spec__.provider_data_validator = (
|
||||
"llama_stack.providers.remote.inference.openai.config.OpenAIProviderDataValidator"
|
||||
)
|
||||
inference_adapter.__provider_spec__.provider_data_validator = provider_data_validator
|
||||
|
||||
for api_key in ["test1", "test2"]:
|
||||
with request_provider_data_context(
|
||||
{"x-llamastack-provider-data": json.dumps({inference_adapter.provider_data_api_key_field: api_key})}
|
||||
):
|
||||
openai_client = inference_adapter.client
|
||||
assert openai_client.api_key == api_key
|
||||
|
||||
|
||||
def test_together_provider_openai_client_caching():
|
||||
"""Ensure the Together provider does not cache api keys across client requests"""
|
||||
|
||||
config = TogetherImplConfig()
|
||||
inference_adapter = TogetherInferenceAdapter(config)
|
||||
|
||||
inference_adapter.__provider_spec__ = MagicMock()
|
||||
inference_adapter.__provider_spec__.provider_data_validator = (
|
||||
"llama_stack.providers.remote.inference.together.TogetherProviderDataValidator"
|
||||
)
|
||||
|
||||
for api_key in ["test1", "test2"]:
|
||||
with request_provider_data_context({"x-llamastack-provider-data": json.dumps({"together_api_key": api_key})}):
|
||||
together_client = inference_adapter._get_client()
|
||||
assert together_client.client.api_key == api_key
|
||||
openai_client = inference_adapter._get_openai_client()
|
||||
assert openai_client.api_key == api_key
|
||||
|
||||
|
||||
def test_llama_compat_provider_openai_client_caching():
|
||||
"""Ensure the LlamaCompat provider does not cache api keys across client requests"""
|
||||
config = LlamaCompatConfig()
|
||||
inference_adapter = LlamaCompatInferenceAdapter(config)
|
||||
|
||||
inference_adapter.__provider_spec__ = MagicMock()
|
||||
inference_adapter.__provider_spec__.provider_data_validator = (
|
||||
"llama_stack.providers.remote.inference.llama_openai_compat.config.LlamaProviderDataValidator"
|
||||
)
|
||||
|
||||
for api_key in ["test1", "test2"]:
|
||||
with request_provider_data_context({"x-llamastack-provider-data": json.dumps({"llama_api_key": api_key})}):
|
||||
assert inference_adapter.client.api_key == api_key
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"config_cls,adapter_cls,provider_data_validator",
|
||||
[
|
||||
(
|
||||
WatsonXConfig,
|
||||
WatsonXInferenceAdapter,
|
||||
"llama_stack.providers.remote.inference.watsonx.config.WatsonXProviderDataValidator",
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_litellm_provider_data_used(config_cls, adapter_cls, provider_data_validator: str):
|
||||
"""Validate data for LiteLLM-based providers. Similar to test_openai_provider_data_used, but without the
|
||||
assumption that there is an OpenAI-compatible client object."""
|
||||
|
||||
inference_adapter = adapter_cls(config=config_cls())
|
||||
|
||||
inference_adapter.__provider_spec__ = MagicMock()
|
||||
inference_adapter.__provider_spec__.provider_data_validator = provider_data_validator
|
||||
|
||||
for api_key in ["test1", "test2"]:
|
||||
with request_provider_data_context(
|
||||
{"x-llamastack-provider-data": json.dumps({inference_adapter.provider_data_api_key_field: api_key})}
|
||||
):
|
||||
assert inference_adapter.get_api_key() == api_key
|
||||
|
|
|
|||
|
|
@ -18,7 +18,7 @@ class TestOpenAIBaseURLConfig:
|
|||
def test_default_base_url_without_env_var(self):
|
||||
"""Test that the adapter uses the default OpenAI base URL when no environment variable is set."""
|
||||
config = OpenAIConfig(api_key="test-key")
|
||||
adapter = OpenAIInferenceAdapter(config)
|
||||
adapter = OpenAIInferenceAdapter(config=config)
|
||||
adapter.provider_data_api_key_field = None # Disable provider data for this test
|
||||
|
||||
assert adapter.get_base_url() == "https://api.openai.com/v1"
|
||||
|
|
@ -27,7 +27,7 @@ class TestOpenAIBaseURLConfig:
|
|||
"""Test that the adapter uses a custom base URL when provided in config."""
|
||||
custom_url = "https://custom.openai.com/v1"
|
||||
config = OpenAIConfig(api_key="test-key", base_url=custom_url)
|
||||
adapter = OpenAIInferenceAdapter(config)
|
||||
adapter = OpenAIInferenceAdapter(config=config)
|
||||
adapter.provider_data_api_key_field = None # Disable provider data for this test
|
||||
|
||||
assert adapter.get_base_url() == custom_url
|
||||
|
|
@ -39,7 +39,7 @@ class TestOpenAIBaseURLConfig:
|
|||
config_data = OpenAIConfig.sample_run_config(api_key="test-key")
|
||||
processed_config = replace_env_vars(config_data)
|
||||
config = OpenAIConfig.model_validate(processed_config)
|
||||
adapter = OpenAIInferenceAdapter(config)
|
||||
adapter = OpenAIInferenceAdapter(config=config)
|
||||
adapter.provider_data_api_key_field = None # Disable provider data for this test
|
||||
|
||||
assert adapter.get_base_url() == "https://env.openai.com/v1"
|
||||
|
|
@ -49,7 +49,7 @@ class TestOpenAIBaseURLConfig:
|
|||
"""Test that explicit config value overrides environment variable."""
|
||||
custom_url = "https://config.openai.com/v1"
|
||||
config = OpenAIConfig(api_key="test-key", base_url=custom_url)
|
||||
adapter = OpenAIInferenceAdapter(config)
|
||||
adapter = OpenAIInferenceAdapter(config=config)
|
||||
adapter.provider_data_api_key_field = None # Disable provider data for this test
|
||||
|
||||
# Config should take precedence over environment variable
|
||||
|
|
@ -60,7 +60,7 @@ class TestOpenAIBaseURLConfig:
|
|||
"""Test that the OpenAI client is initialized with the configured base URL."""
|
||||
custom_url = "https://test.openai.com/v1"
|
||||
config = OpenAIConfig(api_key="test-key", base_url=custom_url)
|
||||
adapter = OpenAIInferenceAdapter(config)
|
||||
adapter = OpenAIInferenceAdapter(config=config)
|
||||
adapter.provider_data_api_key_field = None # Disable provider data for this test
|
||||
|
||||
# Mock the get_api_key method since it's delegated to LiteLLMOpenAIMixin
|
||||
|
|
@ -80,7 +80,7 @@ class TestOpenAIBaseURLConfig:
|
|||
"""Test that check_model_availability uses the configured base URL."""
|
||||
custom_url = "https://test.openai.com/v1"
|
||||
config = OpenAIConfig(api_key="test-key", base_url=custom_url)
|
||||
adapter = OpenAIInferenceAdapter(config)
|
||||
adapter = OpenAIInferenceAdapter(config=config)
|
||||
adapter.provider_data_api_key_field = None # Disable provider data for this test
|
||||
|
||||
# Mock the get_api_key method
|
||||
|
|
@ -122,7 +122,7 @@ class TestOpenAIBaseURLConfig:
|
|||
config_data = OpenAIConfig.sample_run_config(api_key="test-key")
|
||||
processed_config = replace_env_vars(config_data)
|
||||
config = OpenAIConfig.model_validate(processed_config)
|
||||
adapter = OpenAIInferenceAdapter(config)
|
||||
adapter = OpenAIInferenceAdapter(config=config)
|
||||
adapter.provider_data_api_key_field = None # Disable provider data for this test
|
||||
|
||||
# Mock the get_api_key method
|
||||
|
|
|
|||
|
|
@ -5,45 +5,27 @@
|
|||
# the root directory of this source tree.
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import time
|
||||
from unittest.mock import AsyncMock, MagicMock, PropertyMock, patch
|
||||
|
||||
import pytest
|
||||
from openai.types.chat.chat_completion_chunk import (
|
||||
ChatCompletionChunk as OpenAIChatCompletionChunk,
|
||||
)
|
||||
from openai.types.chat.chat_completion_chunk import (
|
||||
Choice as OpenAIChoiceChunk,
|
||||
)
|
||||
from openai.types.chat.chat_completion_chunk import (
|
||||
ChoiceDelta as OpenAIChoiceDelta,
|
||||
)
|
||||
from openai.types.chat.chat_completion_chunk import (
|
||||
ChoiceDeltaToolCall as OpenAIChoiceDeltaToolCall,
|
||||
)
|
||||
from openai.types.chat.chat_completion_chunk import (
|
||||
ChoiceDeltaToolCallFunction as OpenAIChoiceDeltaToolCallFunction,
|
||||
)
|
||||
from openai.types.model import Model as OpenAIModel
|
||||
|
||||
from llama_stack.apis.inference import (
|
||||
ChatCompletionRequest,
|
||||
ChatCompletionResponseEventType,
|
||||
OpenAIAssistantMessageParam,
|
||||
OpenAIChatCompletion,
|
||||
OpenAIChatCompletionRequestWithExtraBody,
|
||||
OpenAIChoice,
|
||||
OpenAICompletion,
|
||||
OpenAICompletionChoice,
|
||||
OpenAICompletionRequestWithExtraBody,
|
||||
ToolChoice,
|
||||
UserMessage,
|
||||
)
|
||||
from llama_stack.apis.models import Model
|
||||
from llama_stack.models.llama.datatypes import StopReason
|
||||
from llama_stack.core.routers.inference import InferenceRouter
|
||||
from llama_stack.core.routing_tables.models import ModelsRoutingTable
|
||||
from llama_stack.providers.datatypes import HealthStatus
|
||||
from llama_stack.providers.remote.inference.vllm.config import VLLMInferenceAdapterConfig
|
||||
from llama_stack.providers.remote.inference.vllm.vllm import (
|
||||
VLLMInferenceAdapter,
|
||||
_process_vllm_chat_completion_stream_response,
|
||||
)
|
||||
from llama_stack.providers.remote.inference.vllm.vllm import VLLMInferenceAdapter
|
||||
|
||||
# These are unit test for the remote vllm provider
|
||||
# implementation. This should only contain tests which are specific to
|
||||
|
|
@ -56,37 +38,15 @@ from llama_stack.providers.remote.inference.vllm.vllm import (
|
|||
# -v -s --tb=short --disable-warnings
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def mock_openai_models_list():
|
||||
with patch("openai.resources.models.AsyncModels.list") as mock_list:
|
||||
yield mock_list
|
||||
|
||||
|
||||
@pytest.fixture(scope="function")
|
||||
async def vllm_inference_adapter():
|
||||
config = VLLMInferenceAdapterConfig(url="http://mocked.localhost:12345")
|
||||
inference_adapter = VLLMInferenceAdapter(config)
|
||||
inference_adapter = VLLMInferenceAdapter(config=config)
|
||||
inference_adapter.model_store = AsyncMock()
|
||||
# Mock the __provider_spec__ attribute that would normally be set by the resolver
|
||||
inference_adapter.__provider_spec__ = MagicMock()
|
||||
inference_adapter.__provider_spec__.provider_type = "vllm-inference"
|
||||
inference_adapter.__provider_spec__.provider_data_validator = MagicMock()
|
||||
await inference_adapter.initialize()
|
||||
return inference_adapter
|
||||
|
||||
|
||||
async def test_register_model_checks_vllm(mock_openai_models_list, vllm_inference_adapter):
|
||||
async def mock_openai_models():
|
||||
yield OpenAIModel(id="foo", created=1, object="model", owned_by="test")
|
||||
|
||||
mock_openai_models_list.return_value = mock_openai_models()
|
||||
|
||||
foo_model = Model(identifier="foo", provider_resource_id="foo", provider_id="vllm-inference")
|
||||
|
||||
await vllm_inference_adapter.register_model(foo_model)
|
||||
mock_openai_models_list.assert_called()
|
||||
|
||||
|
||||
async def test_old_vllm_tool_choice(vllm_inference_adapter):
|
||||
"""
|
||||
Test that we set tool_choice to none when no tools are in use
|
||||
|
|
@ -102,416 +62,20 @@ async def test_old_vllm_tool_choice(vllm_inference_adapter):
|
|||
mock_client_property.return_value = mock_client
|
||||
|
||||
# No tools but auto tool choice
|
||||
await vllm_inference_adapter.openai_chat_completion(
|
||||
"mock-model",
|
||||
[],
|
||||
params = OpenAIChatCompletionRequestWithExtraBody(
|
||||
model="mock-model",
|
||||
messages=[{"role": "user", "content": "test"}],
|
||||
stream=False,
|
||||
tools=None,
|
||||
tool_choice=ToolChoice.auto.value,
|
||||
)
|
||||
await vllm_inference_adapter.openai_chat_completion(params)
|
||||
mock_client.chat.completions.create.assert_called()
|
||||
call_args = mock_client.chat.completions.create.call_args
|
||||
# Ensure tool_choice gets converted to none for older vLLM versions
|
||||
assert call_args.kwargs["tool_choice"] == ToolChoice.none.value
|
||||
|
||||
|
||||
async def test_tool_call_delta_empty_tool_call_buf():
|
||||
"""
|
||||
Test that we don't generate extra chunks when processing a
|
||||
tool call response that didn't call any tools. Previously we would
|
||||
emit chunks with spurious ToolCallParseStatus.succeeded or
|
||||
ToolCallParseStatus.failed when processing chunks that didn't
|
||||
actually make any tool calls.
|
||||
"""
|
||||
|
||||
async def mock_stream():
|
||||
delta = OpenAIChoiceDelta(content="", tool_calls=None)
|
||||
choices = [OpenAIChoiceChunk(delta=delta, finish_reason="stop", index=0)]
|
||||
mock_chunk = OpenAIChatCompletionChunk(
|
||||
id="chunk-1",
|
||||
created=1,
|
||||
model="foo",
|
||||
object="chat.completion.chunk",
|
||||
choices=choices,
|
||||
)
|
||||
for chunk in [mock_chunk]:
|
||||
yield chunk
|
||||
|
||||
chunks = [chunk async for chunk in _process_vllm_chat_completion_stream_response(mock_stream())]
|
||||
assert len(chunks) == 2
|
||||
assert chunks[0].event.event_type.value == "start"
|
||||
assert chunks[1].event.event_type.value == "complete"
|
||||
assert chunks[1].event.stop_reason == StopReason.end_of_turn
|
||||
|
||||
|
||||
async def test_tool_call_delta_streaming_arguments_dict():
|
||||
async def mock_stream():
|
||||
mock_chunk_1 = OpenAIChatCompletionChunk(
|
||||
id="chunk-1",
|
||||
created=1,
|
||||
model="foo",
|
||||
object="chat.completion.chunk",
|
||||
choices=[
|
||||
OpenAIChoiceChunk(
|
||||
delta=OpenAIChoiceDelta(
|
||||
content="",
|
||||
tool_calls=[
|
||||
OpenAIChoiceDeltaToolCall(
|
||||
id="tc_1",
|
||||
index=1,
|
||||
function=OpenAIChoiceDeltaToolCallFunction(
|
||||
name="power",
|
||||
arguments="",
|
||||
),
|
||||
)
|
||||
],
|
||||
),
|
||||
finish_reason=None,
|
||||
index=0,
|
||||
)
|
||||
],
|
||||
)
|
||||
mock_chunk_2 = OpenAIChatCompletionChunk(
|
||||
id="chunk-2",
|
||||
created=1,
|
||||
model="foo",
|
||||
object="chat.completion.chunk",
|
||||
choices=[
|
||||
OpenAIChoiceChunk(
|
||||
delta=OpenAIChoiceDelta(
|
||||
content="",
|
||||
tool_calls=[
|
||||
OpenAIChoiceDeltaToolCall(
|
||||
id="tc_1",
|
||||
index=1,
|
||||
function=OpenAIChoiceDeltaToolCallFunction(
|
||||
name="power",
|
||||
arguments='{"number": 28, "power": 3}',
|
||||
),
|
||||
)
|
||||
],
|
||||
),
|
||||
finish_reason=None,
|
||||
index=0,
|
||||
)
|
||||
],
|
||||
)
|
||||
mock_chunk_3 = OpenAIChatCompletionChunk(
|
||||
id="chunk-3",
|
||||
created=1,
|
||||
model="foo",
|
||||
object="chat.completion.chunk",
|
||||
choices=[
|
||||
OpenAIChoiceChunk(
|
||||
delta=OpenAIChoiceDelta(content="", tool_calls=None), finish_reason="tool_calls", index=0
|
||||
)
|
||||
],
|
||||
)
|
||||
for chunk in [mock_chunk_1, mock_chunk_2, mock_chunk_3]:
|
||||
yield chunk
|
||||
|
||||
chunks = [chunk async for chunk in _process_vllm_chat_completion_stream_response(mock_stream())]
|
||||
assert len(chunks) == 3
|
||||
assert chunks[0].event.event_type.value == "start"
|
||||
assert chunks[1].event.event_type.value == "progress"
|
||||
assert chunks[1].event.delta.type == "tool_call"
|
||||
assert chunks[1].event.delta.parse_status.value == "succeeded"
|
||||
assert chunks[1].event.delta.tool_call.arguments == '{"number": 28, "power": 3}'
|
||||
assert chunks[2].event.event_type.value == "complete"
|
||||
|
||||
|
||||
async def test_multiple_tool_calls():
|
||||
async def mock_stream():
|
||||
mock_chunk_1 = OpenAIChatCompletionChunk(
|
||||
id="chunk-1",
|
||||
created=1,
|
||||
model="foo",
|
||||
object="chat.completion.chunk",
|
||||
choices=[
|
||||
OpenAIChoiceChunk(
|
||||
delta=OpenAIChoiceDelta(
|
||||
content="",
|
||||
tool_calls=[
|
||||
OpenAIChoiceDeltaToolCall(
|
||||
id="",
|
||||
index=1,
|
||||
function=OpenAIChoiceDeltaToolCallFunction(
|
||||
name="power",
|
||||
arguments='{"number": 28, "power": 3}',
|
||||
),
|
||||
),
|
||||
],
|
||||
),
|
||||
finish_reason=None,
|
||||
index=0,
|
||||
)
|
||||
],
|
||||
)
|
||||
mock_chunk_2 = OpenAIChatCompletionChunk(
|
||||
id="chunk-2",
|
||||
created=1,
|
||||
model="foo",
|
||||
object="chat.completion.chunk",
|
||||
choices=[
|
||||
OpenAIChoiceChunk(
|
||||
delta=OpenAIChoiceDelta(
|
||||
content="",
|
||||
tool_calls=[
|
||||
OpenAIChoiceDeltaToolCall(
|
||||
id="",
|
||||
index=2,
|
||||
function=OpenAIChoiceDeltaToolCallFunction(
|
||||
name="multiple",
|
||||
arguments='{"first_number": 4, "second_number": 7}',
|
||||
),
|
||||
),
|
||||
],
|
||||
),
|
||||
finish_reason=None,
|
||||
index=0,
|
||||
)
|
||||
],
|
||||
)
|
||||
mock_chunk_3 = OpenAIChatCompletionChunk(
|
||||
id="chunk-3",
|
||||
created=1,
|
||||
model="foo",
|
||||
object="chat.completion.chunk",
|
||||
choices=[
|
||||
OpenAIChoiceChunk(
|
||||
delta=OpenAIChoiceDelta(content="", tool_calls=None), finish_reason="tool_calls", index=0
|
||||
)
|
||||
],
|
||||
)
|
||||
for chunk in [mock_chunk_1, mock_chunk_2, mock_chunk_3]:
|
||||
yield chunk
|
||||
|
||||
chunks = [chunk async for chunk in _process_vllm_chat_completion_stream_response(mock_stream())]
|
||||
assert len(chunks) == 4
|
||||
assert chunks[0].event.event_type.value == "start"
|
||||
assert chunks[1].event.event_type.value == "progress"
|
||||
assert chunks[1].event.delta.type == "tool_call"
|
||||
assert chunks[1].event.delta.parse_status.value == "succeeded"
|
||||
assert chunks[1].event.delta.tool_call.arguments == '{"number": 28, "power": 3}'
|
||||
assert chunks[2].event.event_type.value == "progress"
|
||||
assert chunks[2].event.delta.type == "tool_call"
|
||||
assert chunks[2].event.delta.parse_status.value == "succeeded"
|
||||
assert chunks[2].event.delta.tool_call.arguments == '{"first_number": 4, "second_number": 7}'
|
||||
assert chunks[3].event.event_type.value == "complete"
|
||||
|
||||
|
||||
async def test_process_vllm_chat_completion_stream_response_no_choices():
|
||||
"""
|
||||
Test that we don't error out when vLLM returns no choices for a
|
||||
completion request. This can happen when there's an error thrown
|
||||
in vLLM for example.
|
||||
"""
|
||||
|
||||
async def mock_stream():
|
||||
choices = []
|
||||
mock_chunk = OpenAIChatCompletionChunk(
|
||||
id="chunk-1",
|
||||
created=1,
|
||||
model="foo",
|
||||
object="chat.completion.chunk",
|
||||
choices=choices,
|
||||
)
|
||||
for chunk in [mock_chunk]:
|
||||
yield chunk
|
||||
|
||||
chunks = [chunk async for chunk in _process_vllm_chat_completion_stream_response(mock_stream())]
|
||||
assert len(chunks) == 1
|
||||
assert chunks[0].event.event_type.value == "start"
|
||||
|
||||
|
||||
async def test_get_params_empty_tools(vllm_inference_adapter):
|
||||
request = ChatCompletionRequest(
|
||||
tools=[],
|
||||
model="test_model",
|
||||
messages=[UserMessage(content="test")],
|
||||
)
|
||||
params = await vllm_inference_adapter._get_params(request)
|
||||
assert "tools" not in params
|
||||
|
||||
|
||||
async def test_process_vllm_chat_completion_stream_response_tool_call_args_last_chunk():
|
||||
"""
|
||||
Tests the edge case where the model returns the arguments for the tool call in the same chunk that
|
||||
contains the finish reason (i.e., the last one).
|
||||
We want to make sure the tool call is executed in this case, and the parameters are passed correctly.
|
||||
"""
|
||||
|
||||
mock_tool_name = "mock_tool"
|
||||
mock_tool_arguments = {"arg1": 0, "arg2": 100}
|
||||
mock_tool_arguments_str = json.dumps(mock_tool_arguments)
|
||||
|
||||
async def mock_stream():
|
||||
mock_chunks = [
|
||||
OpenAIChatCompletionChunk(
|
||||
id="chunk-1",
|
||||
created=1,
|
||||
model="foo",
|
||||
object="chat.completion.chunk",
|
||||
choices=[
|
||||
{
|
||||
"delta": {
|
||||
"content": None,
|
||||
"tool_calls": [
|
||||
{
|
||||
"index": 0,
|
||||
"id": "mock_id",
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": mock_tool_name,
|
||||
"arguments": None,
|
||||
},
|
||||
}
|
||||
],
|
||||
},
|
||||
"finish_reason": None,
|
||||
"logprobs": None,
|
||||
"index": 0,
|
||||
}
|
||||
],
|
||||
),
|
||||
OpenAIChatCompletionChunk(
|
||||
id="chunk-1",
|
||||
created=1,
|
||||
model="foo",
|
||||
object="chat.completion.chunk",
|
||||
choices=[
|
||||
{
|
||||
"delta": {
|
||||
"content": None,
|
||||
"tool_calls": [
|
||||
{
|
||||
"index": 0,
|
||||
"id": None,
|
||||
"function": {
|
||||
"name": None,
|
||||
"arguments": mock_tool_arguments_str,
|
||||
},
|
||||
}
|
||||
],
|
||||
},
|
||||
"finish_reason": "tool_calls",
|
||||
"logprobs": None,
|
||||
"index": 0,
|
||||
}
|
||||
],
|
||||
),
|
||||
]
|
||||
for chunk in mock_chunks:
|
||||
yield chunk
|
||||
|
||||
chunks = [chunk async for chunk in _process_vllm_chat_completion_stream_response(mock_stream())]
|
||||
assert len(chunks) == 3
|
||||
assert chunks[-1].event.event_type == ChatCompletionResponseEventType.complete
|
||||
assert chunks[-2].event.delta.type == "tool_call"
|
||||
assert chunks[-2].event.delta.tool_call.tool_name == mock_tool_name
|
||||
assert chunks[-2].event.delta.tool_call.arguments == mock_tool_arguments_str
|
||||
|
||||
|
||||
async def test_process_vllm_chat_completion_stream_response_no_finish_reason():
|
||||
"""
|
||||
Tests the edge case where the model requests a tool call and stays idle without explicitly providing the
|
||||
finish reason.
|
||||
We want to make sure that this case is recognized and handled correctly, i.e., as a valid end of message.
|
||||
"""
|
||||
|
||||
mock_tool_name = "mock_tool"
|
||||
mock_tool_arguments = {"arg1": 0, "arg2": 100}
|
||||
mock_tool_arguments_str = json.dumps(mock_tool_arguments)
|
||||
|
||||
async def mock_stream():
|
||||
mock_chunks = [
|
||||
OpenAIChatCompletionChunk(
|
||||
id="chunk-1",
|
||||
created=1,
|
||||
model="foo",
|
||||
object="chat.completion.chunk",
|
||||
choices=[
|
||||
{
|
||||
"delta": {
|
||||
"content": None,
|
||||
"tool_calls": [
|
||||
{
|
||||
"index": 0,
|
||||
"id": "mock_id",
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": mock_tool_name,
|
||||
"arguments": mock_tool_arguments_str,
|
||||
},
|
||||
}
|
||||
],
|
||||
},
|
||||
"finish_reason": None,
|
||||
"logprobs": None,
|
||||
"index": 0,
|
||||
}
|
||||
],
|
||||
),
|
||||
]
|
||||
for chunk in mock_chunks:
|
||||
yield chunk
|
||||
|
||||
chunks = [chunk async for chunk in _process_vllm_chat_completion_stream_response(mock_stream())]
|
||||
assert len(chunks) == 3
|
||||
assert chunks[-1].event.event_type == ChatCompletionResponseEventType.complete
|
||||
assert chunks[-2].event.delta.type == "tool_call"
|
||||
assert chunks[-2].event.delta.tool_call.tool_name == mock_tool_name
|
||||
assert chunks[-2].event.delta.tool_call.arguments == mock_tool_arguments_str
|
||||
|
||||
|
||||
async def test_process_vllm_chat_completion_stream_response_tool_without_args():
|
||||
"""
|
||||
Tests the edge case where no arguments are provided for the tool call.
|
||||
Tool calls with no arguments should be treated as regular tool calls, which was not the case until now.
|
||||
"""
|
||||
mock_tool_name = "mock_tool"
|
||||
|
||||
async def mock_stream():
|
||||
mock_chunks = [
|
||||
OpenAIChatCompletionChunk(
|
||||
id="chunk-1",
|
||||
created=1,
|
||||
model="foo",
|
||||
object="chat.completion.chunk",
|
||||
choices=[
|
||||
{
|
||||
"delta": {
|
||||
"content": None,
|
||||
"tool_calls": [
|
||||
{
|
||||
"index": 0,
|
||||
"id": "mock_id",
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": mock_tool_name,
|
||||
"arguments": "",
|
||||
},
|
||||
}
|
||||
],
|
||||
},
|
||||
"finish_reason": None,
|
||||
"logprobs": None,
|
||||
"index": 0,
|
||||
}
|
||||
],
|
||||
),
|
||||
]
|
||||
for chunk in mock_chunks:
|
||||
yield chunk
|
||||
|
||||
chunks = [chunk async for chunk in _process_vllm_chat_completion_stream_response(mock_stream())]
|
||||
assert len(chunks) == 3
|
||||
assert chunks[-1].event.event_type == ChatCompletionResponseEventType.complete
|
||||
assert chunks[-2].event.delta.type == "tool_call"
|
||||
assert chunks[-2].event.delta.tool_call.tool_name == mock_tool_name
|
||||
assert chunks[-2].event.delta.tool_call.arguments == "{}"
|
||||
|
||||
|
||||
async def test_health_status_success(vllm_inference_adapter):
|
||||
"""
|
||||
Test the health method of VLLM InferenceAdapter when the connection is successful.
|
||||
|
|
@ -614,9 +178,12 @@ async def test_openai_chat_completion_is_async(vllm_inference_adapter):
|
|||
)
|
||||
|
||||
async def do_inference():
|
||||
await vllm_inference_adapter.openai_chat_completion(
|
||||
"mock-model", messages=["one fish", "two fish"], stream=False
|
||||
params = OpenAIChatCompletionRequestWithExtraBody(
|
||||
model="mock-model",
|
||||
messages=[{"role": "user", "content": "one fish two fish"}],
|
||||
stream=False,
|
||||
)
|
||||
await vllm_inference_adapter.openai_chat_completion(params)
|
||||
|
||||
with patch.object(VLLMInferenceAdapter, "client", new_callable=PropertyMock) as mock_create_client:
|
||||
mock_client = MagicMock()
|
||||
|
|
@ -631,105 +198,146 @@ async def test_openai_chat_completion_is_async(vllm_inference_adapter):
|
|||
assert total_time < (sleep_time * 2), f"Total time taken: {total_time}s exceeded expected max"
|
||||
|
||||
|
||||
async def test_should_refresh_models():
|
||||
async def test_vllm_completion_extra_body():
|
||||
"""
|
||||
Test the should_refresh_models method with different refresh_models configurations.
|
||||
|
||||
This test verifies that:
|
||||
1. When refresh_models is True, should_refresh_models returns True regardless of api_token
|
||||
2. When refresh_models is False, should_refresh_models returns False regardless of api_token
|
||||
Test that vLLM-specific guided_choice and prompt_logprobs parameters are correctly forwarded
|
||||
via extra_body to the underlying OpenAI client through the InferenceRouter.
|
||||
"""
|
||||
# Set up the vLLM adapter
|
||||
config = VLLMInferenceAdapterConfig(url="http://mocked.localhost:12345")
|
||||
vllm_adapter = VLLMInferenceAdapter(config=config)
|
||||
vllm_adapter.__provider_id__ = "vllm"
|
||||
await vllm_adapter.initialize()
|
||||
|
||||
# Test case 1: refresh_models is True, api_token is None
|
||||
config1 = VLLMInferenceAdapterConfig(url="http://test.localhost", api_token=None, refresh_models=True)
|
||||
adapter1 = VLLMInferenceAdapter(config1)
|
||||
result1 = await adapter1.should_refresh_models()
|
||||
assert result1 is True, "should_refresh_models should return True when refresh_models is True"
|
||||
# Create a mock model store
|
||||
mock_model_store = AsyncMock()
|
||||
mock_model = Model(identifier="mock-model", provider_resource_id="mock-model", provider_id="vllm")
|
||||
mock_model_store.get_model.return_value = mock_model
|
||||
mock_model_store.has_model.return_value = True
|
||||
|
||||
# Test case 2: refresh_models is True, api_token is empty string
|
||||
config2 = VLLMInferenceAdapterConfig(url="http://test.localhost", api_token="", refresh_models=True)
|
||||
adapter2 = VLLMInferenceAdapter(config2)
|
||||
result2 = await adapter2.should_refresh_models()
|
||||
assert result2 is True, "should_refresh_models should return True when refresh_models is True"
|
||||
# Create a mock dist_registry
|
||||
mock_dist_registry = MagicMock()
|
||||
mock_dist_registry.get = AsyncMock(return_value=mock_model)
|
||||
mock_dist_registry.set = AsyncMock()
|
||||
|
||||
# Test case 3: refresh_models is True, api_token is "fake" (default)
|
||||
config3 = VLLMInferenceAdapterConfig(url="http://test.localhost", api_token="fake", refresh_models=True)
|
||||
adapter3 = VLLMInferenceAdapter(config3)
|
||||
result3 = await adapter3.should_refresh_models()
|
||||
assert result3 is True, "should_refresh_models should return True when refresh_models is True"
|
||||
# Set up the routing table
|
||||
routing_table = ModelsRoutingTable(
|
||||
impls_by_provider_id={"vllm": vllm_adapter},
|
||||
dist_registry=mock_dist_registry,
|
||||
policy=[],
|
||||
)
|
||||
# Inject the model store into the adapter
|
||||
vllm_adapter.model_store = routing_table
|
||||
|
||||
# Test case 4: refresh_models is True, api_token is real token
|
||||
config4 = VLLMInferenceAdapterConfig(url="http://test.localhost", api_token="real-token-123", refresh_models=True)
|
||||
adapter4 = VLLMInferenceAdapter(config4)
|
||||
result4 = await adapter4.should_refresh_models()
|
||||
assert result4 is True, "should_refresh_models should return True when refresh_models is True"
|
||||
# Create the InferenceRouter
|
||||
router = InferenceRouter(routing_table=routing_table)
|
||||
|
||||
# Test case 5: refresh_models is False, api_token is real token
|
||||
config5 = VLLMInferenceAdapterConfig(url="http://test.localhost", api_token="real-token-456", refresh_models=False)
|
||||
adapter5 = VLLMInferenceAdapter(config5)
|
||||
result5 = await adapter5.should_refresh_models()
|
||||
assert result5 is False, "should_refresh_models should return False when refresh_models is False"
|
||||
|
||||
|
||||
async def test_provider_data_var_context_propagation(vllm_inference_adapter):
|
||||
"""
|
||||
Test that PROVIDER_DATA_VAR context is properly propagated through the vLLM inference adapter.
|
||||
This ensures that dynamic provider data (like API tokens) can be passed through context.
|
||||
Note: The base URL is always taken from config.url, not from provider data.
|
||||
"""
|
||||
# Mock the AsyncOpenAI class to capture provider data
|
||||
with (
|
||||
patch("llama_stack.providers.utils.inference.openai_mixin.AsyncOpenAI") as mock_openai_class,
|
||||
patch.object(vllm_inference_adapter, "get_request_provider_data") as mock_get_provider_data,
|
||||
):
|
||||
mock_client = AsyncMock()
|
||||
mock_client.chat.completions.create = AsyncMock()
|
||||
mock_openai_class.return_value = mock_client
|
||||
|
||||
# Mock provider data to return test data
|
||||
mock_provider_data = MagicMock()
|
||||
mock_provider_data.vllm_api_token = "test-token-123"
|
||||
mock_provider_data.vllm_url = "http://test-server:8000/v1"
|
||||
mock_get_provider_data.return_value = mock_provider_data
|
||||
|
||||
# Mock the model
|
||||
mock_model = Model(identifier="test-model", provider_resource_id="test-model", provider_id="vllm-inference")
|
||||
vllm_inference_adapter.model_store.get_model.return_value = mock_model
|
||||
|
||||
try:
|
||||
# Execute chat completion
|
||||
await vllm_inference_adapter.openai_chat_completion(
|
||||
model="test-model",
|
||||
messages=[UserMessage(content="Hello")],
|
||||
stream=False,
|
||||
# Patch the OpenAI client
|
||||
with patch.object(VLLMInferenceAdapter, "client", new_callable=PropertyMock) as mock_client_property:
|
||||
mock_client = MagicMock()
|
||||
mock_client.completions.create = AsyncMock(
|
||||
return_value=OpenAICompletion(
|
||||
id="cmpl-abc123",
|
||||
created=1,
|
||||
model="mock-model",
|
||||
choices=[
|
||||
OpenAICompletionChoice(
|
||||
text="joy",
|
||||
finish_reason="stop",
|
||||
index=0,
|
||||
)
|
||||
],
|
||||
)
|
||||
)
|
||||
mock_client_property.return_value = mock_client
|
||||
|
||||
# Verify that ALL client calls were made with the correct parameters
|
||||
calls = mock_openai_class.call_args_list
|
||||
incorrect_calls = []
|
||||
# Test with guided_choice and prompt_logprobs as extra fields
|
||||
params = OpenAICompletionRequestWithExtraBody(
|
||||
model="mock-model",
|
||||
prompt="I am feeling happy",
|
||||
stream=False,
|
||||
guided_choice=["joy", "sadness"],
|
||||
prompt_logprobs=5,
|
||||
)
|
||||
await router.openai_completion(params)
|
||||
|
||||
for i, call in enumerate(calls):
|
||||
api_key = call[1]["api_key"]
|
||||
base_url = call[1]["base_url"]
|
||||
# Verify that the client was called with extra_body containing both parameters
|
||||
mock_client.completions.create.assert_called_once()
|
||||
call_kwargs = mock_client.completions.create.call_args.kwargs
|
||||
assert "extra_body" in call_kwargs
|
||||
assert "guided_choice" in call_kwargs["extra_body"]
|
||||
assert call_kwargs["extra_body"]["guided_choice"] == ["joy", "sadness"]
|
||||
assert "prompt_logprobs" in call_kwargs["extra_body"]
|
||||
assert call_kwargs["extra_body"]["prompt_logprobs"] == 5
|
||||
|
||||
if api_key != "test-token-123" or base_url != "http://mocked.localhost:12345":
|
||||
incorrect_calls.append({"call_index": i, "api_key": api_key, "base_url": base_url})
|
||||
|
||||
if incorrect_calls:
|
||||
error_msg = (
|
||||
f"Found {len(incorrect_calls)} calls with incorrect parameters out of {len(calls)} total calls:\n"
|
||||
)
|
||||
for incorrect_call in incorrect_calls:
|
||||
error_msg += f" Call {incorrect_call['call_index']}: api_key='{incorrect_call['api_key']}', base_url='{incorrect_call['base_url']}'\n"
|
||||
error_msg += "Expected: api_key='test-token-123', base_url='http://mocked.localhost:12345'"
|
||||
raise AssertionError(error_msg)
|
||||
async def test_vllm_chat_completion_extra_body():
|
||||
"""
|
||||
Test that vLLM-specific parameters (e.g., chat_template_kwargs) are correctly forwarded
|
||||
via extra_body to the underlying OpenAI client through the InferenceRouter for chat completion.
|
||||
"""
|
||||
# Set up the vLLM adapter
|
||||
config = VLLMInferenceAdapterConfig(url="http://mocked.localhost:12345")
|
||||
vllm_adapter = VLLMInferenceAdapter(config=config)
|
||||
vllm_adapter.__provider_id__ = "vllm"
|
||||
await vllm_adapter.initialize()
|
||||
|
||||
# Ensure at least one call was made
|
||||
assert len(calls) >= 1, "No AsyncOpenAI client calls were made"
|
||||
# Create a mock model store
|
||||
mock_model_store = AsyncMock()
|
||||
mock_model = Model(identifier="mock-model", provider_resource_id="mock-model", provider_id="vllm")
|
||||
mock_model_store.get_model.return_value = mock_model
|
||||
mock_model_store.has_model.return_value = True
|
||||
|
||||
# Verify that chat completion was called
|
||||
mock_client.chat.completions.create.assert_called_once()
|
||||
# Create a mock dist_registry
|
||||
mock_dist_registry = MagicMock()
|
||||
mock_dist_registry.get = AsyncMock(return_value=mock_model)
|
||||
mock_dist_registry.set = AsyncMock()
|
||||
|
||||
finally:
|
||||
# Clean up context
|
||||
pass
|
||||
# Set up the routing table
|
||||
routing_table = ModelsRoutingTable(
|
||||
impls_by_provider_id={"vllm": vllm_adapter},
|
||||
dist_registry=mock_dist_registry,
|
||||
policy=[],
|
||||
)
|
||||
# Inject the model store into the adapter
|
||||
vllm_adapter.model_store = routing_table
|
||||
|
||||
# Create the InferenceRouter
|
||||
router = InferenceRouter(routing_table=routing_table)
|
||||
|
||||
# Patch the OpenAI client
|
||||
with patch.object(VLLMInferenceAdapter, "client", new_callable=PropertyMock) as mock_client_property:
|
||||
mock_client = MagicMock()
|
||||
mock_client.chat.completions.create = AsyncMock(
|
||||
return_value=OpenAIChatCompletion(
|
||||
id="chatcmpl-abc123",
|
||||
created=1,
|
||||
model="mock-model",
|
||||
choices=[
|
||||
OpenAIChoice(
|
||||
message=OpenAIAssistantMessageParam(
|
||||
content="test response",
|
||||
),
|
||||
finish_reason="stop",
|
||||
index=0,
|
||||
)
|
||||
],
|
||||
)
|
||||
)
|
||||
mock_client_property.return_value = mock_client
|
||||
|
||||
# Test with chat_template_kwargs as extra field
|
||||
params = OpenAIChatCompletionRequestWithExtraBody(
|
||||
model="mock-model",
|
||||
messages=[{"role": "user", "content": "test"}],
|
||||
stream=False,
|
||||
chat_template_kwargs={"thinking": True},
|
||||
)
|
||||
await router.openai_chat_completion(params)
|
||||
|
||||
# Verify that the client was called with extra_body containing chat_template_kwargs
|
||||
mock_client.chat.completions.create.assert_called_once()
|
||||
call_kwargs = mock_client.chat.completions.create.call_args.kwargs
|
||||
assert "extra_body" in call_kwargs
|
||||
assert "chat_template_kwargs" in call_kwargs["extra_body"]
|
||||
assert call_kwargs["extra_body"]["chat_template_kwargs"] == {"thinking": True}
|
||||
|
|
|
|||
|
|
@ -5,14 +5,17 @@
|
|||
# the root directory of this source tree.
|
||||
|
||||
import json
|
||||
from collections.abc import Iterable
|
||||
from typing import Any
|
||||
from unittest.mock import AsyncMock, MagicMock, Mock, PropertyMock, patch
|
||||
|
||||
import pytest
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from llama_stack.apis.inference import Model, OpenAIUserMessageParam
|
||||
from llama_stack.apis.inference import Model, OpenAIChatCompletionRequestWithExtraBody, OpenAIUserMessageParam
|
||||
from llama_stack.apis.models import ModelType
|
||||
from llama_stack.core.request_headers import request_provider_data_context
|
||||
from llama_stack.providers.utils.inference.model_registry import RemoteInferenceProviderConfig
|
||||
from llama_stack.providers.utils.inference.openai_mixin import OpenAIMixin
|
||||
|
||||
|
||||
|
|
@ -29,7 +32,7 @@ class OpenAIMixinImpl(OpenAIMixin):
|
|||
class OpenAIMixinWithEmbeddingsImpl(OpenAIMixinImpl):
|
||||
"""Test implementation with embedding model metadata"""
|
||||
|
||||
embedding_model_metadata = {
|
||||
embedding_model_metadata: dict[str, dict[str, int]] = {
|
||||
"text-embedding-3-small": {"embedding_dimension": 1536, "context_length": 8192},
|
||||
"text-embedding-ada-002": {"embedding_dimension": 1536, "context_length": 8192},
|
||||
}
|
||||
|
|
@ -38,13 +41,15 @@ class OpenAIMixinWithEmbeddingsImpl(OpenAIMixinImpl):
|
|||
@pytest.fixture
|
||||
def mixin():
|
||||
"""Create a test instance of OpenAIMixin with mocked model_store"""
|
||||
mixin_instance = OpenAIMixinImpl()
|
||||
config = RemoteInferenceProviderConfig()
|
||||
mixin_instance = OpenAIMixinImpl(config=config)
|
||||
|
||||
# just enough to satisfy _get_provider_model_id calls
|
||||
mock_model_store = MagicMock()
|
||||
# Mock model_store with async methods
|
||||
mock_model_store = AsyncMock()
|
||||
mock_model = MagicMock()
|
||||
mock_model.provider_resource_id = "test-provider-resource-id"
|
||||
mock_model_store.get_model = AsyncMock(return_value=mock_model)
|
||||
mock_model_store.has_model = AsyncMock(return_value=False) # Default to False, tests can override
|
||||
mixin_instance.model_store = mock_model_store
|
||||
|
||||
return mixin_instance
|
||||
|
|
@ -53,7 +58,8 @@ def mixin():
|
|||
@pytest.fixture
|
||||
def mixin_with_embeddings():
|
||||
"""Create a test instance of OpenAIMixin with embedding model metadata"""
|
||||
return OpenAIMixinWithEmbeddingsImpl()
|
||||
config = RemoteInferenceProviderConfig()
|
||||
return OpenAIMixinWithEmbeddingsImpl(config=config)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
|
|
@ -184,6 +190,40 @@ class TestOpenAIMixinCheckModelAvailability:
|
|||
|
||||
assert len(mixin._model_cache) == 3
|
||||
|
||||
async def test_check_model_availability_with_pre_registered_model(
|
||||
self, mixin, mock_client_with_models, mock_client_context
|
||||
):
|
||||
"""Test that check_model_availability returns True for pre-registered models in model_store"""
|
||||
# Mock model_store.has_model to return True for a specific model
|
||||
mock_model_store = AsyncMock()
|
||||
mock_model_store.has_model = AsyncMock(return_value=True)
|
||||
mixin.model_store = mock_model_store
|
||||
|
||||
# Test that pre-registered model is found without calling the provider's API
|
||||
with mock_client_context(mixin, mock_client_with_models):
|
||||
mock_client_with_models.models.list.assert_not_called()
|
||||
assert await mixin.check_model_availability("pre-registered-model")
|
||||
# Should not call the provider's list_models since model was found in store
|
||||
mock_client_with_models.models.list.assert_not_called()
|
||||
mock_model_store.has_model.assert_called_once_with("pre-registered-model")
|
||||
|
||||
async def test_check_model_availability_fallback_to_provider_when_not_in_store(
|
||||
self, mixin, mock_client_with_models, mock_client_context
|
||||
):
|
||||
"""Test that check_model_availability falls back to provider when model not in store"""
|
||||
# Mock model_store.has_model to return False
|
||||
mock_model_store = AsyncMock()
|
||||
mock_model_store.has_model = AsyncMock(return_value=False)
|
||||
mixin.model_store = mock_model_store
|
||||
|
||||
# Test that it falls back to provider's model cache
|
||||
with mock_client_context(mixin, mock_client_with_models):
|
||||
mock_client_with_models.models.list.assert_not_called()
|
||||
assert await mixin.check_model_availability("some-mock-model-id")
|
||||
# Should call the provider's list_models since model was not found in store
|
||||
mock_client_with_models.models.list.assert_called_once()
|
||||
mock_model_store.has_model.assert_called_once_with("some-mock-model-id")
|
||||
|
||||
|
||||
class TestOpenAIMixinCacheBehavior:
|
||||
"""Test cases for cache behavior and edge cases"""
|
||||
|
|
@ -231,7 +271,8 @@ class TestOpenAIMixinImagePreprocessing:
|
|||
with patch("llama_stack.providers.utils.inference.openai_mixin.localize_image_content") as mock_localize:
|
||||
mock_localize.return_value = (b"fake_image_data", "jpeg")
|
||||
|
||||
await mixin.openai_chat_completion(model="test-model", messages=[message])
|
||||
params = OpenAIChatCompletionRequestWithExtraBody(model="test-model", messages=[message])
|
||||
await mixin.openai_chat_completion(params)
|
||||
|
||||
mock_localize.assert_called_once_with("http://example.com/image.jpg")
|
||||
|
||||
|
|
@ -263,7 +304,8 @@ class TestOpenAIMixinImagePreprocessing:
|
|||
|
||||
with patch.object(type(mixin), "client", new_callable=PropertyMock, return_value=mock_client):
|
||||
with patch("llama_stack.providers.utils.inference.openai_mixin.localize_image_content") as mock_localize:
|
||||
await mixin.openai_chat_completion(model="test-model", messages=[message])
|
||||
params = OpenAIChatCompletionRequestWithExtraBody(model="test-model", messages=[message])
|
||||
await mixin.openai_chat_completion(params)
|
||||
|
||||
mock_localize.assert_not_called()
|
||||
|
||||
|
|
@ -461,10 +503,16 @@ class TestOpenAIMixinModelRegistration:
|
|||
assert result is None
|
||||
|
||||
async def test_should_refresh_models(self, mixin):
|
||||
"""Test should_refresh_models method (should always return False)"""
|
||||
"""Test should_refresh_models method returns config value"""
|
||||
# Default config has refresh_models=False
|
||||
result = await mixin.should_refresh_models()
|
||||
assert result is False
|
||||
|
||||
config_with_refresh = RemoteInferenceProviderConfig(refresh_models=True)
|
||||
mixin_with_refresh = OpenAIMixinImpl(config=config_with_refresh)
|
||||
result_with_refresh = await mixin_with_refresh.should_refresh_models()
|
||||
assert result_with_refresh is True
|
||||
|
||||
async def test_register_model_error_propagation(self, mixin, mock_client_with_exception, mock_client_context):
|
||||
"""Test that errors from provider API are properly propagated during registration"""
|
||||
model = Model(
|
||||
|
|
@ -498,13 +546,145 @@ class OpenAIMixinWithProviderData(OpenAIMixinImpl):
|
|||
return "default-base-url"
|
||||
|
||||
|
||||
class CustomListProviderModelIdsImplementation(OpenAIMixinImpl):
|
||||
"""Test implementation with custom list_provider_model_ids override"""
|
||||
|
||||
custom_model_ids: Any
|
||||
|
||||
async def list_provider_model_ids(self) -> Iterable[str]:
|
||||
"""Return custom model IDs list"""
|
||||
return self.custom_model_ids
|
||||
|
||||
|
||||
class TestOpenAIMixinCustomListProviderModelIds:
|
||||
"""Test cases for custom list_provider_model_ids() implementation functionality"""
|
||||
|
||||
@pytest.fixture
|
||||
def custom_model_ids_list(self):
|
||||
"""Create a list of custom model ID strings"""
|
||||
return ["custom-model-1", "custom-model-2", "custom-embedding"]
|
||||
|
||||
@pytest.fixture
|
||||
def config(self):
|
||||
"""Create RemoteInferenceProviderConfig instance"""
|
||||
return RemoteInferenceProviderConfig()
|
||||
|
||||
@pytest.fixture
|
||||
def adapter(self, custom_model_ids_list, config):
|
||||
"""Create mixin instance with custom list_provider_model_ids implementation"""
|
||||
mixin = CustomListProviderModelIdsImplementation(config=config, custom_model_ids=custom_model_ids_list)
|
||||
mixin.embedding_model_metadata = {"custom-embedding": {"embedding_dimension": 768, "context_length": 512}}
|
||||
return mixin
|
||||
|
||||
async def test_is_used(self, adapter, custom_model_ids_list):
|
||||
"""Test that custom list_provider_model_ids() implementation is used instead of client.models.list()"""
|
||||
result = await adapter.list_models()
|
||||
|
||||
assert result is not None
|
||||
assert len(result) == 3
|
||||
|
||||
assert set(custom_model_ids_list) == {m.identifier for m in result}
|
||||
|
||||
async def test_populates_cache(self, adapter, custom_model_ids_list):
|
||||
"""Test that custom list_provider_model_ids() results are cached"""
|
||||
assert len(adapter._model_cache) == 0
|
||||
|
||||
await adapter.list_models()
|
||||
|
||||
assert set(custom_model_ids_list) == set(adapter._model_cache.keys())
|
||||
|
||||
async def test_respects_allowed_models(self, config):
|
||||
"""Test that custom list_provider_model_ids() respects allowed_models filtering"""
|
||||
mixin = CustomListProviderModelIdsImplementation(
|
||||
config=config, custom_model_ids=["model-1", "model-2", "model-3"]
|
||||
)
|
||||
mixin.allowed_models = ["model-1"]
|
||||
|
||||
result = await mixin.list_models()
|
||||
|
||||
assert result is not None
|
||||
assert len(result) == 1
|
||||
assert result[0].identifier == "model-1"
|
||||
|
||||
async def test_with_empty_list(self, config):
|
||||
"""Test that custom list_provider_model_ids() handles empty list correctly"""
|
||||
mixin = CustomListProviderModelIdsImplementation(config=config, custom_model_ids=[])
|
||||
|
||||
result = await mixin.list_models()
|
||||
|
||||
assert result is not None
|
||||
assert len(result) == 0
|
||||
assert len(mixin._model_cache) == 0
|
||||
|
||||
async def test_wrong_type_raises_error(self, config):
|
||||
"""Test that list_provider_model_ids() returning unhashable items results in an error"""
|
||||
mixin = CustomListProviderModelIdsImplementation(
|
||||
config=config, custom_model_ids=["valid-model", ["nested", "list"]]
|
||||
)
|
||||
with pytest.raises(Exception, match="is not a string"):
|
||||
await mixin.list_models()
|
||||
|
||||
mixin = CustomListProviderModelIdsImplementation(
|
||||
config=config, custom_model_ids=[{"key": "value"}, "valid-model"]
|
||||
)
|
||||
with pytest.raises(Exception, match="is not a string"):
|
||||
await mixin.list_models()
|
||||
|
||||
mixin = CustomListProviderModelIdsImplementation(config=config, custom_model_ids=["valid-model", 42.0])
|
||||
with pytest.raises(Exception, match="is not a string"):
|
||||
await mixin.list_models()
|
||||
|
||||
mixin = CustomListProviderModelIdsImplementation(config=config, custom_model_ids=[None])
|
||||
with pytest.raises(Exception, match="is not a string"):
|
||||
await mixin.list_models()
|
||||
|
||||
async def test_non_iterable_raises_error(self, config):
|
||||
"""Test that list_provider_model_ids() returning non-iterable type raises error"""
|
||||
mixin = CustomListProviderModelIdsImplementation(config=config, custom_model_ids=42)
|
||||
|
||||
with pytest.raises(
|
||||
TypeError,
|
||||
match=r"Failed to list models: CustomListProviderModelIdsImplementation\.list_provider_model_ids\(\) must return an iterable.*but returned int",
|
||||
):
|
||||
await mixin.list_models()
|
||||
|
||||
async def test_accepts_various_iterables(self, config):
|
||||
"""Test that list_provider_model_ids() accepts tuples, sets, generators, etc."""
|
||||
|
||||
tuples = CustomListProviderModelIdsImplementation(
|
||||
config=config, custom_model_ids=("model-1", "model-2", "model-3")
|
||||
)
|
||||
result = await tuples.list_models()
|
||||
assert result is not None
|
||||
assert len(result) == 3
|
||||
|
||||
class GeneratorAdapter(OpenAIMixinImpl):
|
||||
async def list_provider_model_ids(self) -> Iterable[str]:
|
||||
def gen():
|
||||
yield "gen-model-1"
|
||||
yield "gen-model-2"
|
||||
|
||||
return gen()
|
||||
|
||||
mixin = GeneratorAdapter(config=config)
|
||||
result = await mixin.list_models()
|
||||
assert result is not None
|
||||
assert len(result) == 2
|
||||
|
||||
sets = CustomListProviderModelIdsImplementation(config=config, custom_model_ids={"set-model-1", "set-model-2"})
|
||||
result = await sets.list_models()
|
||||
assert result is not None
|
||||
assert len(result) == 2
|
||||
|
||||
|
||||
class TestOpenAIMixinProviderDataApiKey:
|
||||
"""Test cases for provider_data_api_key_field functionality"""
|
||||
|
||||
@pytest.fixture
|
||||
def mixin_with_provider_data_field(self):
|
||||
"""Mixin instance with provider_data_api_key_field set"""
|
||||
mixin_instance = OpenAIMixinWithProviderData()
|
||||
config = RemoteInferenceProviderConfig()
|
||||
mixin_instance = OpenAIMixinWithProviderData(config=config)
|
||||
|
||||
# Mock provider_spec for provider data validation
|
||||
mock_provider_spec = MagicMock()
|
||||
|
|
@ -542,7 +722,7 @@ class TestOpenAIMixinProviderDataApiKey:
|
|||
):
|
||||
"""Test that ValueError is raised when provider data exists but doesn't have required key"""
|
||||
with request_provider_data_context({"x-llamastack-provider-data": json.dumps({"wrong_key": "some-value"})}):
|
||||
with pytest.raises(ValueError, match="API key is not set"):
|
||||
with pytest.raises(ValueError, match="API key not provided"):
|
||||
_ = mixin_with_provider_data_field_and_none_api_key.client
|
||||
|
||||
def test_error_message_includes_correct_field_names(self, mixin_with_provider_data_field_and_none_api_key):
|
||||
|
|
|
|||
|
|
@ -0,0 +1,77 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import pytest
|
||||
|
||||
from llama_stack.core.stack import replace_env_vars
|
||||
from llama_stack.providers.remote.inference.anthropic import AnthropicConfig
|
||||
from llama_stack.providers.remote.inference.azure import AzureConfig
|
||||
from llama_stack.providers.remote.inference.bedrock import BedrockConfig
|
||||
from llama_stack.providers.remote.inference.cerebras import CerebrasImplConfig
|
||||
from llama_stack.providers.remote.inference.databricks import DatabricksImplConfig
|
||||
from llama_stack.providers.remote.inference.fireworks import FireworksImplConfig
|
||||
from llama_stack.providers.remote.inference.gemini import GeminiConfig
|
||||
from llama_stack.providers.remote.inference.groq import GroqConfig
|
||||
from llama_stack.providers.remote.inference.llama_openai_compat import LlamaCompatConfig
|
||||
from llama_stack.providers.remote.inference.nvidia import NVIDIAConfig
|
||||
from llama_stack.providers.remote.inference.ollama import OllamaImplConfig
|
||||
from llama_stack.providers.remote.inference.openai import OpenAIConfig
|
||||
from llama_stack.providers.remote.inference.runpod import RunpodImplConfig
|
||||
from llama_stack.providers.remote.inference.sambanova import SambaNovaImplConfig
|
||||
from llama_stack.providers.remote.inference.tgi import TGIImplConfig
|
||||
from llama_stack.providers.remote.inference.together import TogetherImplConfig
|
||||
from llama_stack.providers.remote.inference.vertexai import VertexAIConfig
|
||||
from llama_stack.providers.remote.inference.vllm import VLLMInferenceAdapterConfig
|
||||
from llama_stack.providers.remote.inference.watsonx import WatsonXConfig
|
||||
|
||||
|
||||
class TestRemoteInferenceProviderConfig:
|
||||
@pytest.mark.parametrize(
|
||||
"config_cls,alias_name,env_name,extra_config",
|
||||
[
|
||||
(AnthropicConfig, "api_key", "ANTHROPIC_API_KEY", {}),
|
||||
(AzureConfig, "api_key", "AZURE_API_KEY", {"api_base": "HTTP://FAKE"}),
|
||||
(BedrockConfig, None, None, {}),
|
||||
(CerebrasImplConfig, "api_key", "CEREBRAS_API_KEY", {}),
|
||||
(DatabricksImplConfig, "api_token", "DATABRICKS_TOKEN", {}),
|
||||
(FireworksImplConfig, "api_key", "FIREWORKS_API_KEY", {}),
|
||||
(GeminiConfig, "api_key", "GEMINI_API_KEY", {}),
|
||||
(GroqConfig, "api_key", "GROQ_API_KEY", {}),
|
||||
(LlamaCompatConfig, "api_key", "LLAMA_API_KEY", {}),
|
||||
(NVIDIAConfig, "api_key", "NVIDIA_API_KEY", {}),
|
||||
(OllamaImplConfig, None, None, {}),
|
||||
(OpenAIConfig, "api_key", "OPENAI_API_KEY", {}),
|
||||
(RunpodImplConfig, "api_token", "RUNPOD_API_TOKEN", {}),
|
||||
(SambaNovaImplConfig, "api_key", "SAMBANOVA_API_KEY", {}),
|
||||
(TGIImplConfig, None, None, {"url": "FAKE"}),
|
||||
(TogetherImplConfig, "api_key", "TOGETHER_API_KEY", {}),
|
||||
(VertexAIConfig, None, None, {"project": "FAKE", "location": "FAKE"}),
|
||||
(VLLMInferenceAdapterConfig, "api_token", "VLLM_API_TOKEN", {}),
|
||||
(WatsonXConfig, "api_key", "WATSONX_API_KEY", {}),
|
||||
],
|
||||
)
|
||||
def test_provider_config_auth_credentials(self, monkeypatch, config_cls, alias_name, env_name, extra_config):
|
||||
"""Test that the config class correctly maps the alias to auth_credential."""
|
||||
secret_value = config_cls.__name__
|
||||
|
||||
if alias_name is None:
|
||||
pytest.skip("No alias name provided for this config class.")
|
||||
|
||||
config = config_cls(**{alias_name: secret_value, **extra_config})
|
||||
assert config.auth_credential is not None
|
||||
assert config.auth_credential.get_secret_value() == secret_value
|
||||
|
||||
schema = config_cls.model_json_schema()
|
||||
assert alias_name in schema["properties"]
|
||||
assert "auth_credential" not in schema["properties"]
|
||||
|
||||
if env_name:
|
||||
monkeypatch.setenv(env_name, secret_value)
|
||||
sample_config = config_cls.sample_run_config()
|
||||
expanded_config = replace_env_vars(sample_config)
|
||||
config_from_sample = config_cls(**{**expanded_config, **extra_config})
|
||||
assert config_from_sample.auth_credential is not None
|
||||
assert config_from_sample.auth_credential.get_secret_value() == secret_value
|
||||
|
|
@ -9,32 +9,22 @@ from unittest.mock import AsyncMock, MagicMock, patch
|
|||
|
||||
import numpy as np
|
||||
import pytest
|
||||
from chromadb import PersistentClient
|
||||
from pymilvus import MilvusClient, connections
|
||||
|
||||
from llama_stack.apis.vector_dbs import VectorDB
|
||||
from llama_stack.apis.vector_io import Chunk, ChunkMetadata, QueryChunksResponse
|
||||
from llama_stack.providers.inline.vector_io.chroma.config import ChromaVectorIOConfig
|
||||
from llama_stack.providers.inline.vector_io.faiss.config import FaissVectorIOConfig
|
||||
from llama_stack.providers.inline.vector_io.faiss.faiss import FaissIndex, FaissVectorIOAdapter
|
||||
from llama_stack.providers.inline.vector_io.milvus.config import MilvusVectorIOConfig, SqliteKVStoreConfig
|
||||
from llama_stack.providers.inline.vector_io.qdrant import QdrantVectorIOConfig
|
||||
from llama_stack.providers.inline.vector_io.sqlite_vec import SQLiteVectorIOConfig
|
||||
from llama_stack.providers.inline.vector_io.sqlite_vec.sqlite_vec import SQLiteVecIndex, SQLiteVecVectorIOAdapter
|
||||
from llama_stack.providers.remote.vector_io.chroma.chroma import ChromaIndex, ChromaVectorIOAdapter, maybe_await
|
||||
from llama_stack.providers.remote.vector_io.milvus.milvus import MilvusIndex, MilvusVectorIOAdapter
|
||||
from llama_stack.providers.remote.vector_io.pgvector.config import PGVectorVectorIOConfig
|
||||
from llama_stack.providers.remote.vector_io.pgvector.pgvector import PGVectorIndex, PGVectorVectorIOAdapter
|
||||
from llama_stack.providers.remote.vector_io.qdrant.qdrant import QdrantVectorIOAdapter
|
||||
from llama_stack.providers.remote.vector_io.weaviate.config import WeaviateVectorIOConfig
|
||||
from llama_stack.providers.remote.vector_io.weaviate.weaviate import WeaviateIndex, WeaviateVectorIOAdapter
|
||||
from llama_stack.providers.utils.kvstore.config import SqliteKVStoreConfig
|
||||
|
||||
EMBEDDING_DIMENSION = 384
|
||||
COLLECTION_PREFIX = "test_collection"
|
||||
MILVUS_ALIAS = "test_milvus"
|
||||
|
||||
|
||||
@pytest.fixture(params=["milvus", "sqlite_vec", "faiss", "chroma", "pgvector", "weaviate"])
|
||||
@pytest.fixture(params=["sqlite_vec", "faiss", "pgvector"])
|
||||
def vector_provider(request):
|
||||
return request.param
|
||||
|
||||
|
|
@ -145,10 +135,10 @@ async def sqlite_vec_vec_index(embedding_dimension, tmp_path_factory):
|
|||
|
||||
|
||||
@pytest.fixture
|
||||
async def sqlite_vec_adapter(sqlite_vec_db_path, mock_inference_api, embedding_dimension):
|
||||
async def sqlite_vec_adapter(sqlite_vec_db_path, unique_kvstore_config, mock_inference_api, embedding_dimension):
|
||||
config = SQLiteVectorIOConfig(
|
||||
db_path=sqlite_vec_db_path,
|
||||
kvstore=SqliteKVStoreConfig(),
|
||||
kvstore=unique_kvstore_config,
|
||||
)
|
||||
adapter = SQLiteVecVectorIOAdapter(
|
||||
config=config,
|
||||
|
|
@ -170,46 +160,6 @@ async def sqlite_vec_adapter(sqlite_vec_db_path, mock_inference_api, embedding_d
|
|||
await adapter.shutdown()
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def milvus_vec_db_path(tmp_path_factory):
|
||||
db_path = str(tmp_path_factory.getbasetemp() / "test_milvus.db")
|
||||
return db_path
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def milvus_vec_index(milvus_vec_db_path, embedding_dimension):
|
||||
client = MilvusClient(milvus_vec_db_path)
|
||||
name = f"{COLLECTION_PREFIX}_{np.random.randint(1e6)}"
|
||||
connections.connect(alias=MILVUS_ALIAS, uri=milvus_vec_db_path)
|
||||
index = MilvusIndex(client, name, consistency_level="Strong")
|
||||
index.db_path = milvus_vec_db_path
|
||||
yield index
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def milvus_vec_adapter(milvus_vec_db_path, mock_inference_api):
|
||||
config = MilvusVectorIOConfig(
|
||||
db_path=milvus_vec_db_path,
|
||||
kvstore=SqliteKVStoreConfig(),
|
||||
)
|
||||
adapter = MilvusVectorIOAdapter(
|
||||
config=config,
|
||||
inference_api=mock_inference_api,
|
||||
files_api=None,
|
||||
)
|
||||
await adapter.initialize()
|
||||
await adapter.register_vector_db(
|
||||
VectorDB(
|
||||
identifier=adapter.metadata_collection_name,
|
||||
provider_id="test_provider",
|
||||
embedding_model="test_model",
|
||||
embedding_dimension=128,
|
||||
)
|
||||
)
|
||||
yield adapter
|
||||
await adapter.shutdown()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def faiss_vec_db_path(tmp_path_factory):
|
||||
db_path = str(tmp_path_factory.getbasetemp() / "test_faiss.db")
|
||||
|
|
@ -246,98 +196,6 @@ async def faiss_vec_adapter(unique_kvstore_config, mock_inference_api, embedding
|
|||
await adapter.shutdown()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def chroma_vec_db_path(tmp_path_factory):
|
||||
persist_dir = tmp_path_factory.mktemp(f"chroma_{np.random.randint(1e6)}")
|
||||
return str(persist_dir)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def chroma_vec_index(chroma_vec_db_path, embedding_dimension):
|
||||
client = PersistentClient(path=chroma_vec_db_path)
|
||||
name = f"{COLLECTION_PREFIX}_{np.random.randint(1e6)}"
|
||||
collection = await maybe_await(client.get_or_create_collection(name))
|
||||
index = ChromaIndex(client=client, collection=collection)
|
||||
await index.initialize()
|
||||
yield index
|
||||
await index.delete()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def chroma_vec_adapter(chroma_vec_db_path, mock_inference_api, embedding_dimension):
|
||||
config = ChromaVectorIOConfig(
|
||||
db_path=chroma_vec_db_path,
|
||||
kvstore=SqliteKVStoreConfig(),
|
||||
)
|
||||
adapter = ChromaVectorIOAdapter(
|
||||
config=config,
|
||||
inference_api=mock_inference_api,
|
||||
files_api=None,
|
||||
)
|
||||
await adapter.initialize()
|
||||
await adapter.register_vector_db(
|
||||
VectorDB(
|
||||
identifier=f"chroma_test_collection_{random.randint(1, 1_000_000)}",
|
||||
provider_id="test_provider",
|
||||
embedding_model="test_model",
|
||||
embedding_dimension=embedding_dimension,
|
||||
)
|
||||
)
|
||||
yield adapter
|
||||
await adapter.shutdown()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def qdrant_vec_db_path(tmp_path_factory):
|
||||
import uuid
|
||||
|
||||
db_path = str(tmp_path_factory.getbasetemp() / f"test_qdrant_{uuid.uuid4()}.db")
|
||||
return db_path
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def qdrant_vec_adapter(qdrant_vec_db_path, mock_inference_api, embedding_dimension):
|
||||
import uuid
|
||||
|
||||
config = QdrantVectorIOConfig(
|
||||
db_path=qdrant_vec_db_path,
|
||||
kvstore=SqliteKVStoreConfig(),
|
||||
)
|
||||
adapter = QdrantVectorIOAdapter(
|
||||
config=config,
|
||||
inference_api=mock_inference_api,
|
||||
files_api=None,
|
||||
)
|
||||
collection_id = f"qdrant_test_collection_{uuid.uuid4()}"
|
||||
await adapter.initialize()
|
||||
await adapter.register_vector_db(
|
||||
VectorDB(
|
||||
identifier=collection_id,
|
||||
provider_id="test_provider",
|
||||
embedding_model="test_model",
|
||||
embedding_dimension=embedding_dimension,
|
||||
)
|
||||
)
|
||||
adapter.test_collection_id = collection_id
|
||||
yield adapter
|
||||
await adapter.shutdown()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def qdrant_vec_index(qdrant_vec_db_path, embedding_dimension):
|
||||
import uuid
|
||||
|
||||
from qdrant_client import AsyncQdrantClient
|
||||
|
||||
from llama_stack.providers.remote.vector_io.qdrant.qdrant import QdrantIndex
|
||||
|
||||
client = AsyncQdrantClient(path=qdrant_vec_db_path)
|
||||
collection_name = f"qdrant_test_collection_{uuid.uuid4()}"
|
||||
index = QdrantIndex(client, collection_name)
|
||||
yield index
|
||||
await index.delete()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_psycopg2_connection():
|
||||
connection = MagicMock()
|
||||
|
|
@ -386,14 +244,14 @@ async def pgvector_vec_index(embedding_dimension, mock_psycopg2_connection):
|
|||
|
||||
|
||||
@pytest.fixture
|
||||
async def pgvector_vec_adapter(mock_inference_api, embedding_dimension):
|
||||
async def pgvector_vec_adapter(unique_kvstore_config, mock_inference_api, embedding_dimension):
|
||||
config = PGVectorVectorIOConfig(
|
||||
host="localhost",
|
||||
port=5432,
|
||||
db="test_db",
|
||||
user="test_user",
|
||||
password="test_password",
|
||||
kvstore=SqliteKVStoreConfig(),
|
||||
kvstore=unique_kvstore_config,
|
||||
)
|
||||
|
||||
adapter = PGVectorVectorIOAdapter(config, mock_inference_api, None)
|
||||
|
|
@ -450,81 +308,12 @@ async def pgvector_vec_adapter(mock_inference_api, embedding_dimension):
|
|||
await adapter.shutdown()
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def weaviate_vec_db_path(tmp_path_factory):
|
||||
db_path = str(tmp_path_factory.getbasetemp() / "test_weaviate.db")
|
||||
return db_path
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def weaviate_vec_index(weaviate_vec_db_path):
|
||||
import pytest_socket
|
||||
import weaviate
|
||||
|
||||
pytest_socket.enable_socket()
|
||||
client = weaviate.connect_to_embedded(
|
||||
hostname="localhost",
|
||||
port=8080,
|
||||
grpc_port=50051,
|
||||
persistence_data_path=weaviate_vec_db_path,
|
||||
)
|
||||
index = WeaviateIndex(client=client, collection_name="Testcollection")
|
||||
await index.initialize()
|
||||
yield index
|
||||
await index.delete()
|
||||
client.close()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def weaviate_vec_adapter(weaviate_vec_db_path, mock_inference_api, embedding_dimension):
|
||||
import pytest_socket
|
||||
import weaviate
|
||||
|
||||
pytest_socket.enable_socket()
|
||||
|
||||
client = weaviate.connect_to_embedded(
|
||||
hostname="localhost",
|
||||
port=8080,
|
||||
grpc_port=50051,
|
||||
persistence_data_path=weaviate_vec_db_path,
|
||||
)
|
||||
|
||||
config = WeaviateVectorIOConfig(
|
||||
weaviate_cluster_url="localhost:8080",
|
||||
weaviate_api_key=None,
|
||||
kvstore=SqliteKVStoreConfig(),
|
||||
)
|
||||
adapter = WeaviateVectorIOAdapter(
|
||||
config=config,
|
||||
inference_api=mock_inference_api,
|
||||
files_api=None,
|
||||
)
|
||||
collection_id = f"weaviate_test_collection_{random.randint(1, 1_000_000)}"
|
||||
await adapter.initialize()
|
||||
await adapter.register_vector_db(
|
||||
VectorDB(
|
||||
identifier=collection_id,
|
||||
provider_id="test_provider",
|
||||
embedding_model="test_model",
|
||||
embedding_dimension=embedding_dimension,
|
||||
)
|
||||
)
|
||||
adapter.test_collection_id = collection_id
|
||||
yield adapter
|
||||
await adapter.shutdown()
|
||||
client.close()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def vector_io_adapter(vector_provider, request):
|
||||
vector_provider_dict = {
|
||||
"milvus": "milvus_vec_adapter",
|
||||
"faiss": "faiss_vec_adapter",
|
||||
"sqlite_vec": "sqlite_vec_adapter",
|
||||
"chroma": "chroma_vec_adapter",
|
||||
"qdrant": "qdrant_vec_adapter",
|
||||
"pgvector": "pgvector_vec_adapter",
|
||||
"weaviate": "weaviate_vec_adapter",
|
||||
}
|
||||
return request.getfixturevalue(vector_provider_dict[vector_provider])
|
||||
|
||||
|
|
|
|||
|
|
@ -1,326 +0,0 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
from llama_stack.apis.vector_io import QueryChunksResponse
|
||||
|
||||
# Mock the entire pymilvus module
|
||||
pymilvus_mock = MagicMock()
|
||||
pymilvus_mock.DataType = MagicMock()
|
||||
pymilvus_mock.MilvusClient = MagicMock
|
||||
pymilvus_mock.RRFRanker = MagicMock
|
||||
pymilvus_mock.WeightedRanker = MagicMock
|
||||
pymilvus_mock.AnnSearchRequest = MagicMock
|
||||
|
||||
# Apply the mock before importing MilvusIndex
|
||||
with patch.dict("sys.modules", {"pymilvus": pymilvus_mock}):
|
||||
from llama_stack.providers.remote.vector_io.milvus.milvus import MilvusIndex
|
||||
|
||||
# This test is a unit test for the MilvusVectorIOAdapter class. This should only contain
|
||||
# tests which are specific to this class. More general (API-level) tests should be placed in
|
||||
# tests/integration/vector_io/
|
||||
#
|
||||
# How to run this test:
|
||||
#
|
||||
# pytest tests/unit/providers/vector_io/test_milvus.py \
|
||||
# -v -s --tb=short --disable-warnings --asyncio-mode=auto
|
||||
|
||||
MILVUS_PROVIDER = "milvus"
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def mock_milvus_client() -> MagicMock:
|
||||
"""Create a mock Milvus client with common method behaviors."""
|
||||
client = MagicMock()
|
||||
|
||||
# Mock collection operations
|
||||
client.has_collection.return_value = False # Initially no collection
|
||||
client.create_collection.return_value = None
|
||||
client.drop_collection.return_value = None
|
||||
|
||||
# Mock insert operation
|
||||
client.insert.return_value = {"insert_count": 10}
|
||||
|
||||
# Mock search operation - return mock results (data should be dict, not JSON string)
|
||||
client.search.return_value = [
|
||||
[
|
||||
{
|
||||
"id": 0,
|
||||
"distance": 0.1,
|
||||
"entity": {"chunk_content": {"content": "mock chunk 1", "metadata": {"document_id": "doc1"}}},
|
||||
},
|
||||
{
|
||||
"id": 1,
|
||||
"distance": 0.2,
|
||||
"entity": {"chunk_content": {"content": "mock chunk 2", "metadata": {"document_id": "doc2"}}},
|
||||
},
|
||||
]
|
||||
]
|
||||
|
||||
# Mock query operation for keyword search (data should be dict, not JSON string)
|
||||
client.query.return_value = [
|
||||
{
|
||||
"chunk_id": "chunk1",
|
||||
"chunk_content": {"content": "mock chunk 1", "metadata": {"document_id": "doc1"}},
|
||||
"score": 0.9,
|
||||
},
|
||||
{
|
||||
"chunk_id": "chunk2",
|
||||
"chunk_content": {"content": "mock chunk 2", "metadata": {"document_id": "doc2"}},
|
||||
"score": 0.8,
|
||||
},
|
||||
{
|
||||
"chunk_id": "chunk3",
|
||||
"chunk_content": {"content": "mock chunk 3", "metadata": {"document_id": "doc3"}},
|
||||
"score": 0.7,
|
||||
},
|
||||
]
|
||||
|
||||
return client
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def milvus_index(mock_milvus_client):
|
||||
"""Create a MilvusIndex with mocked client."""
|
||||
index = MilvusIndex(client=mock_milvus_client, collection_name="test_collection")
|
||||
yield index
|
||||
# No real cleanup needed since we're using mocks
|
||||
|
||||
|
||||
async def test_add_chunks(milvus_index, sample_chunks, sample_embeddings, mock_milvus_client):
|
||||
# Setup: collection doesn't exist initially, then exists after creation
|
||||
mock_milvus_client.has_collection.side_effect = [False, True]
|
||||
|
||||
await milvus_index.add_chunks(sample_chunks, sample_embeddings)
|
||||
|
||||
# Verify collection was created and data was inserted
|
||||
mock_milvus_client.create_collection.assert_called_once()
|
||||
mock_milvus_client.insert.assert_called_once()
|
||||
|
||||
# Verify the insert call had the right number of chunks
|
||||
insert_call = mock_milvus_client.insert.call_args
|
||||
assert len(insert_call[1]["data"]) == len(sample_chunks)
|
||||
|
||||
|
||||
async def test_query_chunks_vector(
|
||||
milvus_index, sample_chunks, sample_embeddings, embedding_dimension, mock_milvus_client
|
||||
):
|
||||
# Setup: Add chunks first
|
||||
mock_milvus_client.has_collection.return_value = True
|
||||
await milvus_index.add_chunks(sample_chunks, sample_embeddings)
|
||||
|
||||
# Test vector search
|
||||
query_embedding = np.random.rand(embedding_dimension).astype(np.float32)
|
||||
response = await milvus_index.query_vector(query_embedding, k=2, score_threshold=0.0)
|
||||
|
||||
assert isinstance(response, QueryChunksResponse)
|
||||
assert len(response.chunks) == 2
|
||||
mock_milvus_client.search.assert_called_once()
|
||||
|
||||
|
||||
async def test_query_chunks_keyword_search(milvus_index, sample_chunks, sample_embeddings, mock_milvus_client):
|
||||
mock_milvus_client.has_collection.return_value = True
|
||||
await milvus_index.add_chunks(sample_chunks, sample_embeddings)
|
||||
|
||||
# Test keyword search
|
||||
query_string = "Sentence 5"
|
||||
response = await milvus_index.query_keyword(query_string=query_string, k=2, score_threshold=0.0)
|
||||
|
||||
assert isinstance(response, QueryChunksResponse)
|
||||
assert len(response.chunks) == 2
|
||||
|
||||
|
||||
async def test_bm25_fallback_to_simple_search(milvus_index, sample_chunks, sample_embeddings, mock_milvus_client):
|
||||
"""Test that when BM25 search fails, the system falls back to simple text search."""
|
||||
mock_milvus_client.has_collection.return_value = True
|
||||
await milvus_index.add_chunks(sample_chunks, sample_embeddings)
|
||||
|
||||
# Force BM25 search to fail
|
||||
mock_milvus_client.search.side_effect = Exception("BM25 search not available")
|
||||
|
||||
# Mock simple text search results
|
||||
mock_milvus_client.query.return_value = [
|
||||
{
|
||||
"chunk_id": "chunk1",
|
||||
"chunk_content": {"content": "Python programming language", "metadata": {"document_id": "doc1"}},
|
||||
},
|
||||
{
|
||||
"chunk_id": "chunk2",
|
||||
"chunk_content": {"content": "Machine learning algorithms", "metadata": {"document_id": "doc2"}},
|
||||
},
|
||||
]
|
||||
|
||||
# Test keyword search that should fall back to simple text search
|
||||
query_string = "Python"
|
||||
response = await milvus_index.query_keyword(query_string=query_string, k=3, score_threshold=0.0)
|
||||
|
||||
# Verify response structure
|
||||
assert isinstance(response, QueryChunksResponse)
|
||||
assert len(response.chunks) > 0, "Fallback search should return results"
|
||||
|
||||
# Verify that simple text search was used (query method called instead of search)
|
||||
mock_milvus_client.query.assert_called_once()
|
||||
mock_milvus_client.search.assert_called_once() # Called once but failed
|
||||
|
||||
# Verify the query uses parameterized filter with filter_params
|
||||
query_call_args = mock_milvus_client.query.call_args
|
||||
assert "filter" in query_call_args[1], "Query should include filter for text search"
|
||||
assert "filter_params" in query_call_args[1], "Query should use parameterized filter"
|
||||
assert query_call_args[1]["filter_params"]["content"] == "Python", "Filter params should contain the search term"
|
||||
|
||||
# Verify all returned chunks have score 1.0 (simple binary scoring)
|
||||
assert all(score == 1.0 for score in response.scores), "Simple text search should use binary scoring"
|
||||
|
||||
|
||||
async def test_delete_collection(milvus_index, mock_milvus_client):
|
||||
# Test collection deletion
|
||||
mock_milvus_client.has_collection.return_value = True
|
||||
|
||||
await milvus_index.delete()
|
||||
|
||||
mock_milvus_client.drop_collection.assert_called_once_with(collection_name=milvus_index.collection_name)
|
||||
|
||||
|
||||
async def test_query_hybrid_search_rrf(
|
||||
milvus_index, sample_chunks, sample_embeddings, embedding_dimension, mock_milvus_client
|
||||
):
|
||||
"""Test hybrid search with RRF reranker."""
|
||||
mock_milvus_client.has_collection.return_value = True
|
||||
await milvus_index.add_chunks(sample_chunks, sample_embeddings)
|
||||
|
||||
# Mock hybrid search results
|
||||
mock_milvus_client.hybrid_search.return_value = [
|
||||
[
|
||||
{
|
||||
"id": 0,
|
||||
"distance": 0.1,
|
||||
"entity": {"chunk_content": {"content": "mock chunk 1", "metadata": {"document_id": "doc1"}}},
|
||||
},
|
||||
{
|
||||
"id": 1,
|
||||
"distance": 0.2,
|
||||
"entity": {"chunk_content": {"content": "mock chunk 2", "metadata": {"document_id": "doc2"}}},
|
||||
},
|
||||
]
|
||||
]
|
||||
|
||||
# Test hybrid search with RRF reranker
|
||||
query_embedding = np.random.rand(embedding_dimension).astype(np.float32)
|
||||
query_string = "test query"
|
||||
response = await milvus_index.query_hybrid(
|
||||
embedding=query_embedding,
|
||||
query_string=query_string,
|
||||
k=2,
|
||||
score_threshold=0.0,
|
||||
reranker_type="rrf",
|
||||
reranker_params={"impact_factor": 60.0},
|
||||
)
|
||||
|
||||
assert isinstance(response, QueryChunksResponse)
|
||||
assert len(response.chunks) == 2
|
||||
assert len(response.scores) == 2
|
||||
|
||||
# Verify hybrid search was called with correct parameters
|
||||
mock_milvus_client.hybrid_search.assert_called_once()
|
||||
call_args = mock_milvus_client.hybrid_search.call_args
|
||||
|
||||
# Check that the request contains both vector and BM25 search requests
|
||||
reqs = call_args[1]["reqs"]
|
||||
assert len(reqs) == 2
|
||||
assert reqs[0].anns_field == "vector"
|
||||
assert reqs[1].anns_field == "sparse"
|
||||
ranker = call_args[1]["ranker"]
|
||||
assert ranker is not None
|
||||
|
||||
|
||||
async def test_query_hybrid_search_weighted(
|
||||
milvus_index, sample_chunks, sample_embeddings, embedding_dimension, mock_milvus_client
|
||||
):
|
||||
"""Test hybrid search with weighted reranker."""
|
||||
mock_milvus_client.has_collection.return_value = True
|
||||
await milvus_index.add_chunks(sample_chunks, sample_embeddings)
|
||||
|
||||
# Mock hybrid search results
|
||||
mock_milvus_client.hybrid_search.return_value = [
|
||||
[
|
||||
{
|
||||
"id": 0,
|
||||
"distance": 0.1,
|
||||
"entity": {"chunk_content": {"content": "mock chunk 1", "metadata": {"document_id": "doc1"}}},
|
||||
},
|
||||
{
|
||||
"id": 1,
|
||||
"distance": 0.2,
|
||||
"entity": {"chunk_content": {"content": "mock chunk 2", "metadata": {"document_id": "doc2"}}},
|
||||
},
|
||||
]
|
||||
]
|
||||
|
||||
# Test hybrid search with weighted reranker
|
||||
query_embedding = np.random.rand(embedding_dimension).astype(np.float32)
|
||||
query_string = "test query"
|
||||
response = await milvus_index.query_hybrid(
|
||||
embedding=query_embedding,
|
||||
query_string=query_string,
|
||||
k=2,
|
||||
score_threshold=0.0,
|
||||
reranker_type="weighted",
|
||||
reranker_params={"alpha": 0.7},
|
||||
)
|
||||
|
||||
assert isinstance(response, QueryChunksResponse)
|
||||
assert len(response.chunks) == 2
|
||||
assert len(response.scores) == 2
|
||||
|
||||
# Verify hybrid search was called with correct parameters
|
||||
mock_milvus_client.hybrid_search.assert_called_once()
|
||||
call_args = mock_milvus_client.hybrid_search.call_args
|
||||
ranker = call_args[1]["ranker"]
|
||||
assert ranker is not None
|
||||
|
||||
|
||||
async def test_query_hybrid_search_default_rrf(
|
||||
milvus_index, sample_chunks, sample_embeddings, embedding_dimension, mock_milvus_client
|
||||
):
|
||||
"""Test hybrid search with default RRF reranker (no reranker_type specified)."""
|
||||
mock_milvus_client.has_collection.return_value = True
|
||||
await milvus_index.add_chunks(sample_chunks, sample_embeddings)
|
||||
|
||||
# Mock hybrid search results
|
||||
mock_milvus_client.hybrid_search.return_value = [
|
||||
[
|
||||
{
|
||||
"id": 0,
|
||||
"distance": 0.1,
|
||||
"entity": {"chunk_content": {"content": "mock chunk 1", "metadata": {"document_id": "doc1"}}},
|
||||
},
|
||||
]
|
||||
]
|
||||
|
||||
# Test hybrid search with default reranker (should be RRF)
|
||||
query_embedding = np.random.rand(embedding_dimension).astype(np.float32)
|
||||
query_string = "test query"
|
||||
response = await milvus_index.query_hybrid(
|
||||
embedding=query_embedding,
|
||||
query_string=query_string,
|
||||
k=1,
|
||||
score_threshold=0.0,
|
||||
reranker_type="unknown_type", # Should default to RRF
|
||||
reranker_params=None, # Should use default impact_factor
|
||||
)
|
||||
|
||||
assert isinstance(response, QueryChunksResponse)
|
||||
assert len(response.chunks) == 1
|
||||
|
||||
# Verify hybrid search was called with RRF reranker
|
||||
mock_milvus_client.hybrid_search.assert_called_once()
|
||||
call_args = mock_milvus_client.hybrid_search.call_args
|
||||
ranker = call_args[1]["ranker"]
|
||||
assert ranker is not None
|
||||
|
|
@ -1,138 +0,0 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import asyncio
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
|
||||
from llama_stack.apis.vector_dbs import VectorDB
|
||||
from llama_stack.providers.remote.vector_io.pgvector.pgvector import PGVectorIndex
|
||||
|
||||
PGVECTOR_PROVIDER = "pgvector"
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def loop():
|
||||
return asyncio.new_event_loop()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def embedding_dimension():
|
||||
"""Default embedding dimension for tests."""
|
||||
return 384
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def pgvector_index(embedding_dimension, mock_psycopg2_connection):
|
||||
"""Create a PGVectorIndex instance with mocked database connection."""
|
||||
connection, cursor = mock_psycopg2_connection
|
||||
|
||||
vector_db = VectorDB(
|
||||
identifier="test-vector-db",
|
||||
embedding_model="test-model",
|
||||
embedding_dimension=embedding_dimension,
|
||||
provider_id=PGVECTOR_PROVIDER,
|
||||
provider_resource_id=f"{PGVECTOR_PROVIDER}:test-vector-db",
|
||||
)
|
||||
|
||||
with patch("llama_stack.providers.remote.vector_io.pgvector.pgvector.psycopg2"):
|
||||
# Use explicit COSINE distance metric for consistent testing
|
||||
index = PGVectorIndex(vector_db, embedding_dimension, connection, distance_metric="COSINE")
|
||||
|
||||
return index, cursor
|
||||
|
||||
|
||||
class TestPGVectorIndex:
|
||||
def test_distance_metric_validation(self, embedding_dimension, mock_psycopg2_connection):
|
||||
connection, cursor = mock_psycopg2_connection
|
||||
|
||||
vector_db = VectorDB(
|
||||
identifier="test-vector-db",
|
||||
embedding_model="test-model",
|
||||
embedding_dimension=embedding_dimension,
|
||||
provider_id=PGVECTOR_PROVIDER,
|
||||
provider_resource_id=f"{PGVECTOR_PROVIDER}:test-vector-db",
|
||||
)
|
||||
|
||||
with patch("llama_stack.providers.remote.vector_io.pgvector.pgvector.psycopg2"):
|
||||
index = PGVectorIndex(vector_db, embedding_dimension, connection, distance_metric="L2")
|
||||
assert index.distance_metric == "L2"
|
||||
with pytest.raises(ValueError, match="Distance metric 'INVALID' is not supported"):
|
||||
PGVectorIndex(vector_db, embedding_dimension, connection, distance_metric="INVALID")
|
||||
|
||||
def test_get_pgvector_search_function(self, pgvector_index):
|
||||
index, cursor = pgvector_index
|
||||
supported_metrics = index.PGVECTOR_DISTANCE_METRIC_TO_SEARCH_FUNCTION
|
||||
|
||||
for metric, function in supported_metrics.items():
|
||||
index.distance_metric = metric
|
||||
assert index.get_pgvector_search_function() == function
|
||||
|
||||
def test_check_distance_metric_availability(self, pgvector_index):
|
||||
index, cursor = pgvector_index
|
||||
supported_metrics = index.PGVECTOR_DISTANCE_METRIC_TO_SEARCH_FUNCTION
|
||||
|
||||
for metric in supported_metrics:
|
||||
index.check_distance_metric_availability(metric)
|
||||
|
||||
with pytest.raises(ValueError, match="Distance metric 'INVALID' is not supported"):
|
||||
index.check_distance_metric_availability("INVALID")
|
||||
|
||||
def test_constructor_invalid_distance_metric(self, embedding_dimension, mock_psycopg2_connection):
|
||||
connection, cursor = mock_psycopg2_connection
|
||||
|
||||
vector_db = VectorDB(
|
||||
identifier="test-vector-db",
|
||||
embedding_model="test-model",
|
||||
embedding_dimension=embedding_dimension,
|
||||
provider_id=PGVECTOR_PROVIDER,
|
||||
provider_resource_id=f"{PGVECTOR_PROVIDER}:test-vector-db",
|
||||
)
|
||||
|
||||
with patch("llama_stack.providers.remote.vector_io.pgvector.pgvector.psycopg2"):
|
||||
with pytest.raises(ValueError, match="Distance metric 'INVALID_METRIC' is not supported by PGVector"):
|
||||
PGVectorIndex(vector_db, embedding_dimension, connection, distance_metric="INVALID_METRIC")
|
||||
|
||||
with pytest.raises(ValueError, match="Supported metrics are:"):
|
||||
PGVectorIndex(vector_db, embedding_dimension, connection, distance_metric="UNKNOWN")
|
||||
|
||||
try:
|
||||
index = PGVectorIndex(vector_db, embedding_dimension, connection, distance_metric="COSINE")
|
||||
assert index.distance_metric == "COSINE"
|
||||
except ValueError:
|
||||
pytest.fail("Valid distance metric 'COSINE' should not raise ValueError")
|
||||
|
||||
def test_constructor_all_supported_distance_metrics(self, embedding_dimension, mock_psycopg2_connection):
|
||||
connection, cursor = mock_psycopg2_connection
|
||||
|
||||
vector_db = VectorDB(
|
||||
identifier="test-vector-db",
|
||||
embedding_model="test-model",
|
||||
embedding_dimension=embedding_dimension,
|
||||
provider_id=PGVECTOR_PROVIDER,
|
||||
provider_resource_id=f"{PGVECTOR_PROVIDER}:test-vector-db",
|
||||
)
|
||||
|
||||
supported_metrics = ["L2", "L1", "COSINE", "INNER_PRODUCT", "HAMMING", "JACCARD"]
|
||||
|
||||
with patch("llama_stack.providers.remote.vector_io.pgvector.pgvector.psycopg2"):
|
||||
for metric in supported_metrics:
|
||||
try:
|
||||
index = PGVectorIndex(vector_db, embedding_dimension, connection, distance_metric=metric)
|
||||
assert index.distance_metric == metric
|
||||
|
||||
expected_operators = {
|
||||
"L2": "<->",
|
||||
"L1": "<+>",
|
||||
"COSINE": "<=>",
|
||||
"INNER_PRODUCT": "<#>",
|
||||
"HAMMING": "<~>",
|
||||
"JACCARD": "<%>",
|
||||
}
|
||||
assert index.get_pgvector_search_function() == expected_operators[metric]
|
||||
except Exception as e:
|
||||
pytest.fail(f"Valid distance metric '{metric}' should not raise exception: {e}")
|
||||
|
|
@ -1,147 +0,0 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import asyncio
|
||||
import os
|
||||
from typing import Any
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from llama_stack.apis.inference import Inference
|
||||
from llama_stack.apis.inference.inference import OpenAIEmbeddingData, OpenAIEmbeddingsResponse, OpenAIEmbeddingUsage
|
||||
from llama_stack.apis.vector_io import (
|
||||
QueryChunksResponse,
|
||||
VectorDB,
|
||||
VectorDBStore,
|
||||
)
|
||||
from llama_stack.providers.inline.vector_io.qdrant.config import (
|
||||
QdrantVectorIOConfig as InlineQdrantVectorIOConfig,
|
||||
)
|
||||
from llama_stack.providers.remote.vector_io.qdrant.qdrant import (
|
||||
QdrantVectorIOAdapter,
|
||||
)
|
||||
from llama_stack.providers.utils.kvstore.config import SqliteKVStoreConfig
|
||||
|
||||
# This test is a unit test for the QdrantVectorIOAdapter class. This should only contain
|
||||
# tests which are specific to this class. More general (API-level) tests should be placed in
|
||||
# tests/integration/vector_io/
|
||||
#
|
||||
# How to run this test:
|
||||
#
|
||||
# pytest tests/unit/providers/vector_io/test_qdrant.py \
|
||||
# -v -s --tb=short --disable-warnings --asyncio-mode=auto
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def qdrant_config(tmp_path) -> InlineQdrantVectorIOConfig:
|
||||
kvstore_config = SqliteKVStoreConfig(db_name=os.path.join(tmp_path, "test_kvstore.db"))
|
||||
return InlineQdrantVectorIOConfig(path=os.path.join(tmp_path, "qdrant.db"), kvstore=kvstore_config)
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def loop():
|
||||
return asyncio.new_event_loop()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_vector_db(vector_db_id) -> MagicMock:
|
||||
mock_vector_db = MagicMock(spec=VectorDB)
|
||||
mock_vector_db.embedding_model = "embedding_model"
|
||||
mock_vector_db.identifier = vector_db_id
|
||||
mock_vector_db.embedding_dimension = 384
|
||||
mock_vector_db.model_dump_json.return_value = (
|
||||
'{"identifier": "'
|
||||
+ vector_db_id
|
||||
+ '", "provider_id": "qdrant", "embedding_model": "embedding_model", "embedding_dimension": 384}'
|
||||
)
|
||||
return mock_vector_db
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_vector_db_store(mock_vector_db) -> MagicMock:
|
||||
mock_store = MagicMock(spec=VectorDBStore)
|
||||
mock_store.get_vector_db = AsyncMock(return_value=mock_vector_db)
|
||||
return mock_store
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_api_service(sample_embeddings):
|
||||
mock_api_service = MagicMock(spec=Inference)
|
||||
mock_api_service.openai_embeddings = AsyncMock(
|
||||
return_value=OpenAIEmbeddingsResponse(
|
||||
model="mock-embedding-model",
|
||||
data=[OpenAIEmbeddingData(embedding=sample, index=i) for i, sample in enumerate(sample_embeddings)],
|
||||
usage=OpenAIEmbeddingUsage(prompt_tokens=10, total_tokens=10),
|
||||
)
|
||||
)
|
||||
return mock_api_service
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def qdrant_adapter(qdrant_config, mock_vector_db_store, mock_api_service, loop) -> QdrantVectorIOAdapter:
|
||||
adapter = QdrantVectorIOAdapter(config=qdrant_config, inference_api=mock_api_service, files_api=None)
|
||||
adapter.vector_db_store = mock_vector_db_store
|
||||
await adapter.initialize()
|
||||
yield adapter
|
||||
await adapter.shutdown()
|
||||
|
||||
|
||||
__QUERY = "Sample query"
|
||||
|
||||
|
||||
@pytest.mark.parametrize("max_query_chunks, expected_chunks", [(2, 2), (100, 60)])
|
||||
async def test_qdrant_adapter_returns_expected_chunks(
|
||||
qdrant_adapter: QdrantVectorIOAdapter,
|
||||
vector_db_id,
|
||||
sample_chunks,
|
||||
sample_embeddings,
|
||||
max_query_chunks,
|
||||
expected_chunks,
|
||||
) -> None:
|
||||
assert qdrant_adapter is not None
|
||||
await qdrant_adapter.insert_chunks(vector_db_id, sample_chunks)
|
||||
|
||||
index = await qdrant_adapter._get_and_cache_vector_db_index(vector_db_id=vector_db_id)
|
||||
assert index is not None
|
||||
|
||||
response = await qdrant_adapter.query_chunks(
|
||||
query=__QUERY,
|
||||
vector_db_id=vector_db_id,
|
||||
params={"max_chunks": max_query_chunks, "mode": "vector"},
|
||||
)
|
||||
assert isinstance(response, QueryChunksResponse)
|
||||
assert len(response.chunks) == expected_chunks
|
||||
|
||||
|
||||
# To by-pass attempt to convert a Mock to JSON
|
||||
def _prepare_for_json(value: Any) -> str:
|
||||
return str(value)
|
||||
|
||||
|
||||
@patch("llama_stack.providers.utils.telemetry.trace_protocol._prepare_for_json", new=_prepare_for_json)
|
||||
async def test_qdrant_register_and_unregister_vector_db(
|
||||
qdrant_adapter: QdrantVectorIOAdapter,
|
||||
mock_vector_db,
|
||||
sample_chunks,
|
||||
) -> None:
|
||||
# Initially, no collections
|
||||
vector_db_id = mock_vector_db.identifier
|
||||
assert len((await qdrant_adapter.client.get_collections()).collections) == 0
|
||||
|
||||
# Register does not create a collection
|
||||
assert not (await qdrant_adapter.client.collection_exists(vector_db_id))
|
||||
await qdrant_adapter.register_vector_db(mock_vector_db)
|
||||
assert not (await qdrant_adapter.client.collection_exists(vector_db_id))
|
||||
|
||||
# First insert creates the collection
|
||||
await qdrant_adapter.insert_chunks(vector_db_id, sample_chunks)
|
||||
assert await qdrant_adapter.client.collection_exists(vector_db_id)
|
||||
|
||||
# Unregister deletes the collection
|
||||
await qdrant_adapter.unregister_vector_db(vector_db_id)
|
||||
assert not (await qdrant_adapter.client.collection_exists(vector_db_id))
|
||||
assert len((await qdrant_adapter.client.get_collections()).collections) == 0
|
||||
|
|
@ -6,16 +6,23 @@
|
|||
|
||||
import json
|
||||
import time
|
||||
from unittest.mock import AsyncMock
|
||||
from unittest.mock import AsyncMock, patch
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
from llama_stack.apis.common.errors import VectorStoreNotFoundError
|
||||
from llama_stack.apis.vector_dbs import VectorDB
|
||||
from llama_stack.apis.vector_io import Chunk, QueryChunksResponse
|
||||
from llama_stack.providers.remote.vector_io.milvus.milvus import VECTOR_DBS_PREFIX
|
||||
from llama_stack.apis.vector_io import (
|
||||
Chunk,
|
||||
OpenAICreateVectorStoreFileBatchRequestWithExtraBody,
|
||||
QueryChunksResponse,
|
||||
VectorStoreChunkingStrategyAuto,
|
||||
VectorStoreFileObject,
|
||||
)
|
||||
from llama_stack.providers.inline.vector_io.sqlite_vec.sqlite_vec import VECTOR_DBS_PREFIX
|
||||
|
||||
# This test is a unit test for the inline VectoerIO providers. This should only contain
|
||||
# This test is a unit test for the inline VectorIO providers. This should only contain
|
||||
# tests which are specific to this class. More general (API-level) tests should be placed in
|
||||
# tests/integration/vector_io/
|
||||
#
|
||||
|
|
@ -25,6 +32,16 @@ from llama_stack.providers.remote.vector_io.milvus.milvus import VECTOR_DBS_PREF
|
|||
# -v -s --tb=short --disable-warnings --asyncio-mode=auto
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def mock_resume_file_batches(request):
|
||||
"""Mock the resume functionality to prevent stale file batches from being processed during tests."""
|
||||
with patch(
|
||||
"llama_stack.providers.utils.memory.openai_vector_store_mixin.OpenAIVectorStoreMixin._resume_incomplete_batches",
|
||||
new_callable=AsyncMock,
|
||||
):
|
||||
yield
|
||||
|
||||
|
||||
async def test_initialize_index(vector_index):
|
||||
await vector_index.initialize()
|
||||
|
||||
|
|
@ -88,12 +105,8 @@ async def test_register_and_unregister_vector_db(vector_io_adapter):
|
|||
|
||||
async def test_query_unregistered_raises(vector_io_adapter, vector_provider):
|
||||
fake_emb = np.zeros(8, dtype=np.float32)
|
||||
if vector_provider == "chroma":
|
||||
with pytest.raises(AttributeError):
|
||||
await vector_io_adapter.query_chunks("no_such_db", fake_emb)
|
||||
else:
|
||||
with pytest.raises(ValueError):
|
||||
await vector_io_adapter.query_chunks("no_such_db", fake_emb)
|
||||
with pytest.raises(ValueError):
|
||||
await vector_io_adapter.query_chunks("no_such_db", fake_emb)
|
||||
|
||||
|
||||
async def test_insert_chunks_calls_underlying_index(vector_io_adapter):
|
||||
|
|
@ -294,3 +307,657 @@ async def test_delete_openai_vector_store_file_from_storage(vector_io_adapter, t
|
|||
assert loaded_file_info == {}
|
||||
loaded_contents = await vector_io_adapter._load_openai_vector_store_file_contents(store_id, file_id)
|
||||
assert loaded_contents == []
|
||||
|
||||
|
||||
async def test_create_vector_store_file_batch(vector_io_adapter):
|
||||
"""Test creating a file batch."""
|
||||
store_id = "vs_1234"
|
||||
file_ids = ["file_1", "file_2", "file_3"]
|
||||
|
||||
# Setup vector store
|
||||
vector_io_adapter.openai_vector_stores[store_id] = {
|
||||
"id": store_id,
|
||||
"name": "Test Store",
|
||||
"files": {},
|
||||
"file_ids": [],
|
||||
}
|
||||
|
||||
# Mock attach method and batch processing to avoid actual processing
|
||||
vector_io_adapter.openai_attach_file_to_vector_store = AsyncMock()
|
||||
vector_io_adapter._process_file_batch_async = AsyncMock()
|
||||
|
||||
batch = await vector_io_adapter.openai_create_vector_store_file_batch(
|
||||
vector_store_id=store_id, params=OpenAICreateVectorStoreFileBatchRequestWithExtraBody(file_ids=file_ids)
|
||||
)
|
||||
|
||||
assert batch.vector_store_id == store_id
|
||||
assert batch.status == "in_progress"
|
||||
assert batch.file_counts.total == len(file_ids)
|
||||
assert batch.file_counts.in_progress == len(file_ids)
|
||||
assert batch.id in vector_io_adapter.openai_file_batches
|
||||
|
||||
|
||||
async def test_retrieve_vector_store_file_batch(vector_io_adapter):
|
||||
"""Test retrieving a file batch."""
|
||||
store_id = "vs_1234"
|
||||
file_ids = ["file_1", "file_2"]
|
||||
|
||||
# Setup vector store
|
||||
vector_io_adapter.openai_vector_stores[store_id] = {
|
||||
"id": store_id,
|
||||
"name": "Test Store",
|
||||
"files": {},
|
||||
"file_ids": [],
|
||||
}
|
||||
|
||||
vector_io_adapter.openai_attach_file_to_vector_store = AsyncMock()
|
||||
|
||||
# Create batch first
|
||||
created_batch = await vector_io_adapter.openai_create_vector_store_file_batch(
|
||||
vector_store_id=store_id, params=OpenAICreateVectorStoreFileBatchRequestWithExtraBody(file_ids=file_ids)
|
||||
)
|
||||
|
||||
# Retrieve batch
|
||||
retrieved_batch = await vector_io_adapter.openai_retrieve_vector_store_file_batch(
|
||||
batch_id=created_batch.id,
|
||||
vector_store_id=store_id,
|
||||
)
|
||||
|
||||
assert retrieved_batch.id == created_batch.id
|
||||
assert retrieved_batch.vector_store_id == store_id
|
||||
assert retrieved_batch.status == "in_progress"
|
||||
|
||||
|
||||
async def test_cancel_vector_store_file_batch(vector_io_adapter):
|
||||
"""Test cancelling a file batch."""
|
||||
store_id = "vs_1234"
|
||||
file_ids = ["file_1"]
|
||||
|
||||
# Setup vector store
|
||||
vector_io_adapter.openai_vector_stores[store_id] = {
|
||||
"id": store_id,
|
||||
"name": "Test Store",
|
||||
"files": {},
|
||||
"file_ids": [],
|
||||
}
|
||||
|
||||
# Mock both file attachment and batch processing to prevent automatic completion
|
||||
vector_io_adapter.openai_attach_file_to_vector_store = AsyncMock()
|
||||
vector_io_adapter._process_file_batch_async = AsyncMock()
|
||||
|
||||
# Create batch
|
||||
batch = await vector_io_adapter.openai_create_vector_store_file_batch(
|
||||
vector_store_id=store_id, params=OpenAICreateVectorStoreFileBatchRequestWithExtraBody(file_ids=file_ids)
|
||||
)
|
||||
|
||||
# Cancel batch
|
||||
cancelled_batch = await vector_io_adapter.openai_cancel_vector_store_file_batch(
|
||||
batch_id=batch.id,
|
||||
vector_store_id=store_id,
|
||||
)
|
||||
|
||||
assert cancelled_batch.status == "cancelled"
|
||||
|
||||
|
||||
async def test_list_files_in_vector_store_file_batch(vector_io_adapter):
|
||||
"""Test listing files in a batch."""
|
||||
store_id = "vs_1234"
|
||||
file_ids = ["file_1", "file_2"]
|
||||
|
||||
# Setup vector store with files
|
||||
files = {}
|
||||
for i, file_id in enumerate(file_ids):
|
||||
files[file_id] = VectorStoreFileObject(
|
||||
id=file_id,
|
||||
object="vector_store.file",
|
||||
usage_bytes=1000,
|
||||
created_at=int(time.time()) + i,
|
||||
vector_store_id=store_id,
|
||||
status="completed",
|
||||
chunking_strategy=VectorStoreChunkingStrategyAuto(),
|
||||
)
|
||||
|
||||
vector_io_adapter.openai_vector_stores[store_id] = {
|
||||
"id": store_id,
|
||||
"name": "Test Store",
|
||||
"files": files,
|
||||
"file_ids": file_ids,
|
||||
}
|
||||
|
||||
# Mock file loading
|
||||
vector_io_adapter._load_openai_vector_store_file = AsyncMock(
|
||||
side_effect=lambda vs_id, f_id: files[f_id].model_dump()
|
||||
)
|
||||
vector_io_adapter.openai_attach_file_to_vector_store = AsyncMock()
|
||||
|
||||
# Create batch
|
||||
batch = await vector_io_adapter.openai_create_vector_store_file_batch(
|
||||
vector_store_id=store_id, params=OpenAICreateVectorStoreFileBatchRequestWithExtraBody(file_ids=file_ids)
|
||||
)
|
||||
|
||||
# List files
|
||||
response = await vector_io_adapter.openai_list_files_in_vector_store_file_batch(
|
||||
batch_id=batch.id,
|
||||
vector_store_id=store_id,
|
||||
)
|
||||
|
||||
assert len(response.data) == len(file_ids)
|
||||
assert response.first_id is not None
|
||||
assert response.last_id is not None
|
||||
|
||||
|
||||
async def test_file_batch_validation_errors(vector_io_adapter):
|
||||
"""Test file batch validation errors."""
|
||||
# Test nonexistent vector store
|
||||
with pytest.raises(VectorStoreNotFoundError):
|
||||
await vector_io_adapter.openai_create_vector_store_file_batch(
|
||||
vector_store_id="nonexistent",
|
||||
params=OpenAICreateVectorStoreFileBatchRequestWithExtraBody(file_ids=["file_1"]),
|
||||
)
|
||||
|
||||
# Setup store for remaining tests
|
||||
store_id = "vs_test"
|
||||
vector_io_adapter.openai_vector_stores[store_id] = {"id": store_id, "files": {}, "file_ids": []}
|
||||
|
||||
# Test nonexistent batch
|
||||
with pytest.raises(ValueError, match="File batch .* not found"):
|
||||
await vector_io_adapter.openai_retrieve_vector_store_file_batch(
|
||||
batch_id="nonexistent_batch",
|
||||
vector_store_id=store_id,
|
||||
)
|
||||
|
||||
# Test wrong vector store for batch
|
||||
vector_io_adapter.openai_attach_file_to_vector_store = AsyncMock()
|
||||
batch = await vector_io_adapter.openai_create_vector_store_file_batch(
|
||||
vector_store_id=store_id, params=OpenAICreateVectorStoreFileBatchRequestWithExtraBody(file_ids=["file_1"])
|
||||
)
|
||||
|
||||
# Create wrong_store so it exists but the batch doesn't belong to it
|
||||
wrong_store_id = "wrong_store"
|
||||
vector_io_adapter.openai_vector_stores[wrong_store_id] = {"id": wrong_store_id, "files": {}, "file_ids": []}
|
||||
|
||||
with pytest.raises(ValueError, match="does not belong to vector store"):
|
||||
await vector_io_adapter.openai_retrieve_vector_store_file_batch(
|
||||
batch_id=batch.id,
|
||||
vector_store_id=wrong_store_id,
|
||||
)
|
||||
|
||||
|
||||
async def test_file_batch_pagination(vector_io_adapter):
|
||||
"""Test file batch pagination."""
|
||||
store_id = "vs_1234"
|
||||
file_ids = ["file_1", "file_2", "file_3", "file_4", "file_5"]
|
||||
|
||||
# Setup vector store with multiple files
|
||||
files = {}
|
||||
for i, file_id in enumerate(file_ids):
|
||||
files[file_id] = VectorStoreFileObject(
|
||||
id=file_id,
|
||||
object="vector_store.file",
|
||||
usage_bytes=1000,
|
||||
created_at=int(time.time()) + i,
|
||||
vector_store_id=store_id,
|
||||
status="completed",
|
||||
chunking_strategy=VectorStoreChunkingStrategyAuto(),
|
||||
)
|
||||
|
||||
vector_io_adapter.openai_vector_stores[store_id] = {
|
||||
"id": store_id,
|
||||
"name": "Test Store",
|
||||
"files": files,
|
||||
"file_ids": file_ids,
|
||||
}
|
||||
|
||||
# Mock file loading
|
||||
vector_io_adapter._load_openai_vector_store_file = AsyncMock(
|
||||
side_effect=lambda vs_id, f_id: files[f_id].model_dump()
|
||||
)
|
||||
vector_io_adapter.openai_attach_file_to_vector_store = AsyncMock()
|
||||
|
||||
# Create batch
|
||||
batch = await vector_io_adapter.openai_create_vector_store_file_batch(
|
||||
vector_store_id=store_id, params=OpenAICreateVectorStoreFileBatchRequestWithExtraBody(file_ids=file_ids)
|
||||
)
|
||||
|
||||
# Test pagination with limit
|
||||
response = await vector_io_adapter.openai_list_files_in_vector_store_file_batch(
|
||||
batch_id=batch.id,
|
||||
vector_store_id=store_id,
|
||||
limit=3,
|
||||
)
|
||||
|
||||
assert len(response.data) == 3
|
||||
assert response.has_more is True
|
||||
|
||||
# Test pagination with after cursor
|
||||
first_page = await vector_io_adapter.openai_list_files_in_vector_store_file_batch(
|
||||
batch_id=batch.id,
|
||||
vector_store_id=store_id,
|
||||
limit=2,
|
||||
)
|
||||
|
||||
second_page = await vector_io_adapter.openai_list_files_in_vector_store_file_batch(
|
||||
batch_id=batch.id,
|
||||
vector_store_id=store_id,
|
||||
limit=2,
|
||||
after=first_page.last_id,
|
||||
)
|
||||
|
||||
assert len(first_page.data) == 2
|
||||
assert len(second_page.data) == 2
|
||||
# Ensure no overlap between pages
|
||||
first_page_ids = {file_obj.id for file_obj in first_page.data}
|
||||
second_page_ids = {file_obj.id for file_obj in second_page.data}
|
||||
assert first_page_ids.isdisjoint(second_page_ids)
|
||||
# Verify we got all expected files across both pages (in desc order: file_5, file_4, file_3, file_2, file_1)
|
||||
all_returned_ids = first_page_ids | second_page_ids
|
||||
assert all_returned_ids == {"file_2", "file_3", "file_4", "file_5"}
|
||||
|
||||
|
||||
async def test_file_batch_status_filtering(vector_io_adapter):
|
||||
"""Test file batch status filtering."""
|
||||
store_id = "vs_1234"
|
||||
file_ids = ["file_1", "file_2", "file_3"]
|
||||
|
||||
# Setup vector store with files having different statuses
|
||||
files = {}
|
||||
statuses = ["completed", "in_progress", "completed"]
|
||||
for i, (file_id, status) in enumerate(zip(file_ids, statuses, strict=False)):
|
||||
files[file_id] = VectorStoreFileObject(
|
||||
id=file_id,
|
||||
object="vector_store.file",
|
||||
usage_bytes=1000,
|
||||
created_at=int(time.time()) + i,
|
||||
vector_store_id=store_id,
|
||||
status=status,
|
||||
chunking_strategy=VectorStoreChunkingStrategyAuto(),
|
||||
)
|
||||
|
||||
vector_io_adapter.openai_vector_stores[store_id] = {
|
||||
"id": store_id,
|
||||
"name": "Test Store",
|
||||
"files": files,
|
||||
"file_ids": file_ids,
|
||||
}
|
||||
|
||||
# Mock file loading
|
||||
vector_io_adapter._load_openai_vector_store_file = AsyncMock(
|
||||
side_effect=lambda vs_id, f_id: files[f_id].model_dump()
|
||||
)
|
||||
vector_io_adapter.openai_attach_file_to_vector_store = AsyncMock()
|
||||
|
||||
# Create batch
|
||||
batch = await vector_io_adapter.openai_create_vector_store_file_batch(
|
||||
vector_store_id=store_id, params=OpenAICreateVectorStoreFileBatchRequestWithExtraBody(file_ids=file_ids)
|
||||
)
|
||||
|
||||
# Test filtering by completed status
|
||||
response = await vector_io_adapter.openai_list_files_in_vector_store_file_batch(
|
||||
batch_id=batch.id,
|
||||
vector_store_id=store_id,
|
||||
filter="completed",
|
||||
)
|
||||
|
||||
assert len(response.data) == 2 # Only 2 completed files
|
||||
for file_obj in response.data:
|
||||
assert file_obj.status == "completed"
|
||||
|
||||
# Test filtering by in_progress status
|
||||
response = await vector_io_adapter.openai_list_files_in_vector_store_file_batch(
|
||||
batch_id=batch.id,
|
||||
vector_store_id=store_id,
|
||||
filter="in_progress",
|
||||
)
|
||||
|
||||
assert len(response.data) == 1 # Only 1 in_progress file
|
||||
assert response.data[0].status == "in_progress"
|
||||
|
||||
|
||||
async def test_cancel_completed_batch_fails(vector_io_adapter):
|
||||
"""Test that cancelling completed batch fails."""
|
||||
store_id = "vs_1234"
|
||||
file_ids = ["file_1"]
|
||||
|
||||
# Setup vector store
|
||||
vector_io_adapter.openai_vector_stores[store_id] = {
|
||||
"id": store_id,
|
||||
"name": "Test Store",
|
||||
"files": {},
|
||||
"file_ids": [],
|
||||
}
|
||||
|
||||
vector_io_adapter.openai_attach_file_to_vector_store = AsyncMock()
|
||||
|
||||
# Create batch
|
||||
batch = await vector_io_adapter.openai_create_vector_store_file_batch(
|
||||
vector_store_id=store_id, params=OpenAICreateVectorStoreFileBatchRequestWithExtraBody(file_ids=file_ids)
|
||||
)
|
||||
|
||||
# Manually update status to completed
|
||||
batch_info = vector_io_adapter.openai_file_batches[batch.id]
|
||||
batch_info["status"] = "completed"
|
||||
|
||||
# Try to cancel - should fail
|
||||
with pytest.raises(ValueError, match="Cannot cancel batch .* with status completed"):
|
||||
await vector_io_adapter.openai_cancel_vector_store_file_batch(
|
||||
batch_id=batch.id,
|
||||
vector_store_id=store_id,
|
||||
)
|
||||
|
||||
|
||||
async def test_file_batch_persistence_across_restarts(vector_io_adapter):
|
||||
"""Test that in-progress file batches are persisted and resumed after restart."""
|
||||
store_id = "vs_1234"
|
||||
file_ids = ["file_1", "file_2"]
|
||||
|
||||
# Setup vector store
|
||||
vector_io_adapter.openai_vector_stores[store_id] = {
|
||||
"id": store_id,
|
||||
"name": "Test Store",
|
||||
"files": {},
|
||||
"file_ids": [],
|
||||
}
|
||||
|
||||
# Mock attach method and batch processing to avoid actual processing
|
||||
vector_io_adapter.openai_attach_file_to_vector_store = AsyncMock()
|
||||
vector_io_adapter._process_file_batch_async = AsyncMock()
|
||||
|
||||
# Create batch
|
||||
batch = await vector_io_adapter.openai_create_vector_store_file_batch(
|
||||
vector_store_id=store_id, params=OpenAICreateVectorStoreFileBatchRequestWithExtraBody(file_ids=file_ids)
|
||||
)
|
||||
batch_id = batch.id
|
||||
|
||||
# Verify batch is saved to persistent storage
|
||||
assert batch_id in vector_io_adapter.openai_file_batches
|
||||
saved_batch_key = f"openai_vector_stores_file_batches:v3::{batch_id}"
|
||||
saved_batch = await vector_io_adapter.kvstore.get(saved_batch_key)
|
||||
assert saved_batch is not None
|
||||
|
||||
# Verify the saved batch data contains all necessary information
|
||||
saved_data = json.loads(saved_batch)
|
||||
assert saved_data["id"] == batch_id
|
||||
assert saved_data["status"] == "in_progress"
|
||||
assert saved_data["file_ids"] == file_ids
|
||||
|
||||
# Simulate restart - clear in-memory cache and reload from persistence
|
||||
vector_io_adapter.openai_file_batches.clear()
|
||||
|
||||
# Temporarily restore the real initialize_openai_vector_stores method
|
||||
from llama_stack.providers.utils.memory.openai_vector_store_mixin import OpenAIVectorStoreMixin
|
||||
|
||||
real_method = OpenAIVectorStoreMixin.initialize_openai_vector_stores
|
||||
await real_method(vector_io_adapter)
|
||||
|
||||
# Re-mock the processing method to prevent any resumed batches from processing
|
||||
vector_io_adapter._process_file_batch_async = AsyncMock()
|
||||
|
||||
# Verify batch was restored
|
||||
assert batch_id in vector_io_adapter.openai_file_batches
|
||||
restored_batch = vector_io_adapter.openai_file_batches[batch_id]
|
||||
assert restored_batch["status"] == "in_progress"
|
||||
assert restored_batch["id"] == batch_id
|
||||
assert vector_io_adapter.openai_file_batches[batch_id]["file_ids"] == file_ids
|
||||
|
||||
|
||||
async def test_cancelled_batch_persists_in_storage(vector_io_adapter):
|
||||
"""Test that cancelled batches persist in storage with updated status."""
|
||||
store_id = "vs_1234"
|
||||
file_ids = ["file_1", "file_2"]
|
||||
|
||||
# Setup vector store
|
||||
vector_io_adapter.openai_vector_stores[store_id] = {
|
||||
"id": store_id,
|
||||
"name": "Test Store",
|
||||
"files": {},
|
||||
"file_ids": [],
|
||||
}
|
||||
|
||||
# Mock attach method and batch processing to avoid actual processing
|
||||
vector_io_adapter.openai_attach_file_to_vector_store = AsyncMock()
|
||||
vector_io_adapter._process_file_batch_async = AsyncMock()
|
||||
|
||||
# Create batch
|
||||
batch = await vector_io_adapter.openai_create_vector_store_file_batch(
|
||||
vector_store_id=store_id, params=OpenAICreateVectorStoreFileBatchRequestWithExtraBody(file_ids=file_ids)
|
||||
)
|
||||
batch_id = batch.id
|
||||
|
||||
# Verify batch is initially saved to persistent storage
|
||||
saved_batch_key = f"openai_vector_stores_file_batches:v3::{batch_id}"
|
||||
saved_batch = await vector_io_adapter.kvstore.get(saved_batch_key)
|
||||
assert saved_batch is not None
|
||||
|
||||
# Cancel the batch
|
||||
cancelled_batch = await vector_io_adapter.openai_cancel_vector_store_file_batch(
|
||||
batch_id=batch_id,
|
||||
vector_store_id=store_id,
|
||||
)
|
||||
|
||||
# Verify batch status is cancelled
|
||||
assert cancelled_batch.status == "cancelled"
|
||||
|
||||
# Verify batch persists in storage with cancelled status
|
||||
updated_batch = await vector_io_adapter.kvstore.get(saved_batch_key)
|
||||
assert updated_batch is not None
|
||||
batch_data = json.loads(updated_batch)
|
||||
assert batch_data["status"] == "cancelled"
|
||||
|
||||
# Batch should remain in memory cache (matches vector store pattern)
|
||||
assert batch_id in vector_io_adapter.openai_file_batches
|
||||
assert vector_io_adapter.openai_file_batches[batch_id]["status"] == "cancelled"
|
||||
|
||||
|
||||
async def test_only_in_progress_batches_resumed(vector_io_adapter):
|
||||
"""Test that only in-progress batches are resumed for processing, but all batches are persisted."""
|
||||
store_id = "vs_1234"
|
||||
|
||||
# Setup vector store
|
||||
vector_io_adapter.openai_vector_stores[store_id] = {
|
||||
"id": store_id,
|
||||
"name": "Test Store",
|
||||
"files": {},
|
||||
"file_ids": [],
|
||||
}
|
||||
|
||||
# Mock attach method and batch processing to prevent automatic completion
|
||||
vector_io_adapter.openai_attach_file_to_vector_store = AsyncMock()
|
||||
vector_io_adapter._process_file_batch_async = AsyncMock()
|
||||
|
||||
# Create multiple batches
|
||||
batch1 = await vector_io_adapter.openai_create_vector_store_file_batch(
|
||||
vector_store_id=store_id, params=OpenAICreateVectorStoreFileBatchRequestWithExtraBody(file_ids=["file_1"])
|
||||
)
|
||||
batch2 = await vector_io_adapter.openai_create_vector_store_file_batch(
|
||||
vector_store_id=store_id, params=OpenAICreateVectorStoreFileBatchRequestWithExtraBody(file_ids=["file_2"])
|
||||
)
|
||||
|
||||
# Complete one batch (should persist with completed status)
|
||||
batch1_info = vector_io_adapter.openai_file_batches[batch1.id]
|
||||
batch1_info["status"] = "completed"
|
||||
await vector_io_adapter._save_openai_vector_store_file_batch(batch1.id, batch1_info)
|
||||
|
||||
# Cancel the other batch (should persist with cancelled status)
|
||||
await vector_io_adapter.openai_cancel_vector_store_file_batch(batch_id=batch2.id, vector_store_id=store_id)
|
||||
|
||||
# Create a third batch that stays in progress
|
||||
batch3 = await vector_io_adapter.openai_create_vector_store_file_batch(
|
||||
vector_store_id=store_id, params=OpenAICreateVectorStoreFileBatchRequestWithExtraBody(file_ids=["file_3"])
|
||||
)
|
||||
|
||||
# Simulate restart - clear memory and reload from persistence
|
||||
vector_io_adapter.openai_file_batches.clear()
|
||||
|
||||
# Temporarily restore the real initialize_openai_vector_stores method
|
||||
from llama_stack.providers.utils.memory.openai_vector_store_mixin import OpenAIVectorStoreMixin
|
||||
|
||||
real_method = OpenAIVectorStoreMixin.initialize_openai_vector_stores
|
||||
await real_method(vector_io_adapter)
|
||||
|
||||
# All batches should be restored from persistence
|
||||
assert batch1.id in vector_io_adapter.openai_file_batches # completed, persisted
|
||||
assert batch2.id in vector_io_adapter.openai_file_batches # cancelled, persisted
|
||||
assert batch3.id in vector_io_adapter.openai_file_batches # in-progress, restored
|
||||
|
||||
# Check their statuses
|
||||
assert vector_io_adapter.openai_file_batches[batch1.id]["status"] == "completed"
|
||||
assert vector_io_adapter.openai_file_batches[batch2.id]["status"] == "cancelled"
|
||||
assert vector_io_adapter.openai_file_batches[batch3.id]["status"] == "in_progress"
|
||||
|
||||
# Resume functionality is mocked, so we're only testing persistence
|
||||
|
||||
|
||||
async def test_cleanup_expired_file_batches(vector_io_adapter):
|
||||
"""Test that expired file batches are cleaned up properly."""
|
||||
store_id = "vs_1234"
|
||||
|
||||
# Setup vector store
|
||||
vector_io_adapter.openai_vector_stores[store_id] = {
|
||||
"id": store_id,
|
||||
"name": "Test Store",
|
||||
"files": {},
|
||||
"file_ids": [],
|
||||
}
|
||||
|
||||
# Mock processing to prevent automatic completion
|
||||
vector_io_adapter.openai_attach_file_to_vector_store = AsyncMock()
|
||||
vector_io_adapter._process_file_batch_async = AsyncMock()
|
||||
|
||||
# Create batches with different ages
|
||||
import time
|
||||
|
||||
current_time = int(time.time())
|
||||
|
||||
# Create an old expired batch (10 days old)
|
||||
old_batch_info = {
|
||||
"id": "batch_old",
|
||||
"vector_store_id": store_id,
|
||||
"status": "completed",
|
||||
"created_at": current_time - (10 * 24 * 60 * 60), # 10 days ago
|
||||
"expires_at": current_time - (3 * 24 * 60 * 60), # Expired 3 days ago
|
||||
"file_ids": ["file_1"],
|
||||
}
|
||||
|
||||
# Create a recent valid batch
|
||||
new_batch_info = {
|
||||
"id": "batch_new",
|
||||
"vector_store_id": store_id,
|
||||
"status": "completed",
|
||||
"created_at": current_time - (1 * 24 * 60 * 60), # 1 day ago
|
||||
"expires_at": current_time + (6 * 24 * 60 * 60), # Expires in 6 days
|
||||
"file_ids": ["file_2"],
|
||||
}
|
||||
|
||||
# Store both batches in persistent storage
|
||||
await vector_io_adapter._save_openai_vector_store_file_batch("batch_old", old_batch_info)
|
||||
await vector_io_adapter._save_openai_vector_store_file_batch("batch_new", new_batch_info)
|
||||
|
||||
# Add to in-memory cache
|
||||
vector_io_adapter.openai_file_batches["batch_old"] = old_batch_info
|
||||
vector_io_adapter.openai_file_batches["batch_new"] = new_batch_info
|
||||
|
||||
# Verify both batches exist before cleanup
|
||||
assert "batch_old" in vector_io_adapter.openai_file_batches
|
||||
assert "batch_new" in vector_io_adapter.openai_file_batches
|
||||
|
||||
# Run cleanup
|
||||
await vector_io_adapter._cleanup_expired_file_batches()
|
||||
|
||||
# Verify expired batch was removed from memory
|
||||
assert "batch_old" not in vector_io_adapter.openai_file_batches
|
||||
assert "batch_new" in vector_io_adapter.openai_file_batches
|
||||
|
||||
# Verify expired batch was removed from storage
|
||||
old_batch_key = "openai_vector_stores_file_batches:v3::batch_old"
|
||||
new_batch_key = "openai_vector_stores_file_batches:v3::batch_new"
|
||||
|
||||
old_stored = await vector_io_adapter.kvstore.get(old_batch_key)
|
||||
new_stored = await vector_io_adapter.kvstore.get(new_batch_key)
|
||||
|
||||
assert old_stored is None # Expired batch should be deleted
|
||||
assert new_stored is not None # Valid batch should remain
|
||||
|
||||
|
||||
async def test_expired_batch_access_error(vector_io_adapter):
|
||||
"""Test that accessing expired batches returns clear error message."""
|
||||
store_id = "vs_1234"
|
||||
|
||||
# Setup vector store
|
||||
vector_io_adapter.openai_vector_stores[store_id] = {
|
||||
"id": store_id,
|
||||
"name": "Test Store",
|
||||
"files": {},
|
||||
"file_ids": [],
|
||||
}
|
||||
|
||||
# Create an expired batch
|
||||
import time
|
||||
|
||||
current_time = int(time.time())
|
||||
|
||||
expired_batch_info = {
|
||||
"id": "batch_expired",
|
||||
"vector_store_id": store_id,
|
||||
"status": "completed",
|
||||
"created_at": current_time - (10 * 24 * 60 * 60), # 10 days ago
|
||||
"expires_at": current_time - (3 * 24 * 60 * 60), # Expired 3 days ago
|
||||
"file_ids": ["file_1"],
|
||||
}
|
||||
|
||||
# Add to in-memory cache (simulating it was loaded before expiration)
|
||||
vector_io_adapter.openai_file_batches["batch_expired"] = expired_batch_info
|
||||
|
||||
# Try to access expired batch
|
||||
with pytest.raises(ValueError, match="File batch batch_expired has expired after 7 days from creation"):
|
||||
vector_io_adapter._get_and_validate_batch("batch_expired", store_id)
|
||||
|
||||
|
||||
async def test_max_concurrent_files_per_batch(vector_io_adapter):
|
||||
"""Test that file batch processing respects MAX_CONCURRENT_FILES_PER_BATCH limit."""
|
||||
import asyncio
|
||||
|
||||
store_id = "vs_1234"
|
||||
|
||||
# Setup vector store
|
||||
vector_io_adapter.openai_vector_stores[store_id] = {
|
||||
"id": store_id,
|
||||
"name": "Test Store",
|
||||
"files": {},
|
||||
"file_ids": [],
|
||||
}
|
||||
|
||||
active_files = 0
|
||||
|
||||
async def mock_attach_file_with_delay(vector_store_id: str, file_id: str, **kwargs):
|
||||
"""Mock that tracks concurrency and blocks indefinitely to test concurrency limit."""
|
||||
nonlocal active_files
|
||||
active_files += 1
|
||||
|
||||
# Block indefinitely to test concurrency limit
|
||||
await asyncio.sleep(float("inf"))
|
||||
|
||||
# Replace the attachment method
|
||||
vector_io_adapter.openai_attach_file_to_vector_store = mock_attach_file_with_delay
|
||||
|
||||
# Create a batch with more files than the concurrency limit
|
||||
file_ids = [f"file_{i}" for i in range(8)] # 8 files, but limit should be 5
|
||||
|
||||
batch = await vector_io_adapter.openai_create_vector_store_file_batch(
|
||||
vector_store_id=store_id, params=OpenAICreateVectorStoreFileBatchRequestWithExtraBody(file_ids=file_ids)
|
||||
)
|
||||
|
||||
# Give time for the semaphore logic to start processing files
|
||||
await asyncio.sleep(0.2)
|
||||
|
||||
# Verify that only MAX_CONCURRENT_FILES_PER_BATCH files are processing concurrently
|
||||
# The semaphore in _process_files_with_concurrency should limit this
|
||||
from llama_stack.providers.utils.memory.openai_vector_store_mixin import MAX_CONCURRENT_FILES_PER_BATCH
|
||||
|
||||
assert active_files == MAX_CONCURRENT_FILES_PER_BATCH, (
|
||||
f"Expected {MAX_CONCURRENT_FILES_PER_BATCH} active files, got {active_files}"
|
||||
)
|
||||
|
||||
# Verify batch is in progress
|
||||
assert batch.status == "in_progress"
|
||||
assert batch.file_counts.total == 8
|
||||
assert batch.file_counts.in_progress == 8
|
||||
|
|
|
|||
|
|
@ -13,7 +13,10 @@ from unittest.mock import AsyncMock, MagicMock
|
|||
import numpy as np
|
||||
import pytest
|
||||
|
||||
from llama_stack.apis.inference.inference import OpenAIEmbeddingData
|
||||
from llama_stack.apis.inference.inference import (
|
||||
OpenAIEmbeddingData,
|
||||
OpenAIEmbeddingsRequestWithExtraBody,
|
||||
)
|
||||
from llama_stack.apis.tools import RAGDocument
|
||||
from llama_stack.apis.vector_io import Chunk
|
||||
from llama_stack.providers.utils.memory.vector_store import (
|
||||
|
|
@ -226,9 +229,14 @@ class TestVectorDBWithIndex:
|
|||
|
||||
await vector_db_with_index.insert_chunks(chunks)
|
||||
|
||||
mock_inference_api.openai_embeddings.assert_called_once_with(
|
||||
"test-model without embeddings", ["Test 1", "Test 2"]
|
||||
)
|
||||
# Verify openai_embeddings was called with correct params
|
||||
mock_inference_api.openai_embeddings.assert_called_once()
|
||||
call_args = mock_inference_api.openai_embeddings.call_args[0]
|
||||
assert len(call_args) == 1
|
||||
params = call_args[0]
|
||||
assert isinstance(params, OpenAIEmbeddingsRequestWithExtraBody)
|
||||
assert params.model == "test-model without embeddings"
|
||||
assert params.input == ["Test 1", "Test 2"]
|
||||
mock_index.add_chunks.assert_called_once()
|
||||
args = mock_index.add_chunks.call_args[0]
|
||||
assert args[0] == chunks
|
||||
|
|
@ -321,9 +329,14 @@ class TestVectorDBWithIndex:
|
|||
|
||||
await vector_db_with_index.insert_chunks(chunks)
|
||||
|
||||
mock_inference_api.openai_embeddings.assert_called_once_with(
|
||||
"test-model with partial embeddings", ["Test 1", "Test 3"]
|
||||
)
|
||||
# Verify openai_embeddings was called with correct params
|
||||
mock_inference_api.openai_embeddings.assert_called_once()
|
||||
call_args = mock_inference_api.openai_embeddings.call_args[0]
|
||||
assert len(call_args) == 1
|
||||
params = call_args[0]
|
||||
assert isinstance(params, OpenAIEmbeddingsRequestWithExtraBody)
|
||||
assert params.model == "test-model with partial embeddings"
|
||||
assert params.input == ["Test 1", "Test 3"]
|
||||
mock_index.add_chunks.assert_called_once()
|
||||
args = mock_index.add_chunks.call_args[0]
|
||||
assert len(args[0]) == 3
|
||||
|
|
|
|||
|
|
@ -9,6 +9,7 @@ import pytest
|
|||
|
||||
from llama_stack.apis.inference import Model
|
||||
from llama_stack.apis.vector_dbs import VectorDB
|
||||
from llama_stack.core.datatypes import VectorDBWithOwner
|
||||
from llama_stack.core.store.registry import (
|
||||
KEY_FORMAT,
|
||||
CachedDiskDistributionRegistry,
|
||||
|
|
@ -116,7 +117,7 @@ async def test_duplicate_provider_registration(cached_disk_dist_registry):
|
|||
provider_resource_id="test_vector_db_2",
|
||||
provider_id="baz",
|
||||
)
|
||||
await cached_disk_dist_registry.register(original_vector_db)
|
||||
assert await cached_disk_dist_registry.register(original_vector_db)
|
||||
|
||||
duplicate_vector_db = VectorDB(
|
||||
identifier="test_vector_db_2",
|
||||
|
|
@ -125,7 +126,8 @@ async def test_duplicate_provider_registration(cached_disk_dist_registry):
|
|||
provider_resource_id="test_vector_db_2",
|
||||
provider_id="baz", # Same provider_id
|
||||
)
|
||||
await cached_disk_dist_registry.register(duplicate_vector_db)
|
||||
with pytest.raises(ValueError, match="Object of type 'vector_db' and identifier 'test_vector_db_2' already exists"):
|
||||
await cached_disk_dist_registry.register(duplicate_vector_db)
|
||||
|
||||
result = await cached_disk_dist_registry.get("vector_db", "test_vector_db_2")
|
||||
assert result is not None
|
||||
|
|
@ -229,3 +231,98 @@ async def test_cached_registry_error_handling(sqlite_kvstore):
|
|||
|
||||
invalid_obj = await cached_registry.get("vector_db", "invalid_cached_db")
|
||||
assert invalid_obj is None
|
||||
|
||||
|
||||
async def test_double_registration_identical_objects(disk_dist_registry):
|
||||
"""Test that registering identical objects succeeds (idempotent)."""
|
||||
vector_db = VectorDBWithOwner(
|
||||
identifier="test_vector_db",
|
||||
embedding_model="all-MiniLM-L6-v2",
|
||||
embedding_dimension=384,
|
||||
provider_resource_id="test_vector_db",
|
||||
provider_id="test-provider",
|
||||
)
|
||||
|
||||
# First registration should succeed
|
||||
result1 = await disk_dist_registry.register(vector_db)
|
||||
assert result1 is True
|
||||
|
||||
# Second registration of identical object should also succeed (idempotent)
|
||||
result2 = await disk_dist_registry.register(vector_db)
|
||||
assert result2 is True
|
||||
|
||||
# Verify object exists and is unchanged
|
||||
retrieved = await disk_dist_registry.get("vector_db", "test_vector_db")
|
||||
assert retrieved is not None
|
||||
assert retrieved.identifier == vector_db.identifier
|
||||
assert retrieved.embedding_model == vector_db.embedding_model
|
||||
|
||||
|
||||
async def test_double_registration_different_objects(disk_dist_registry):
|
||||
"""Test that registering different objects with same identifier fails."""
|
||||
vector_db1 = VectorDBWithOwner(
|
||||
identifier="test_vector_db",
|
||||
embedding_model="all-MiniLM-L6-v2",
|
||||
embedding_dimension=384,
|
||||
provider_resource_id="test_vector_db",
|
||||
provider_id="test-provider",
|
||||
)
|
||||
|
||||
vector_db2 = VectorDBWithOwner(
|
||||
identifier="test_vector_db", # Same identifier
|
||||
embedding_model="different-model", # Different embedding model
|
||||
embedding_dimension=384,
|
||||
provider_resource_id="test_vector_db",
|
||||
provider_id="test-provider",
|
||||
)
|
||||
|
||||
# First registration should succeed
|
||||
result1 = await disk_dist_registry.register(vector_db1)
|
||||
assert result1 is True
|
||||
|
||||
# Second registration with different data should fail
|
||||
with pytest.raises(ValueError, match="Object of type 'vector_db' and identifier 'test_vector_db' already exists"):
|
||||
await disk_dist_registry.register(vector_db2)
|
||||
|
||||
# Verify original object is unchanged
|
||||
retrieved = await disk_dist_registry.get("vector_db", "test_vector_db")
|
||||
assert retrieved is not None
|
||||
assert retrieved.embedding_model == "all-MiniLM-L6-v2" # Original value
|
||||
|
||||
|
||||
async def test_double_registration_with_cache(cached_disk_dist_registry):
|
||||
"""Test double registration behavior with caching enabled."""
|
||||
from llama_stack.apis.models import ModelType
|
||||
from llama_stack.core.datatypes import ModelWithOwner
|
||||
|
||||
model1 = ModelWithOwner(
|
||||
identifier="test_model",
|
||||
provider_resource_id="test_model",
|
||||
provider_id="test-provider",
|
||||
model_type=ModelType.llm,
|
||||
)
|
||||
|
||||
model2 = ModelWithOwner(
|
||||
identifier="test_model", # Same identifier
|
||||
provider_resource_id="test_model",
|
||||
provider_id="test-provider",
|
||||
model_type=ModelType.embedding, # Different type
|
||||
)
|
||||
|
||||
# First registration should succeed and populate cache
|
||||
result1 = await cached_disk_dist_registry.register(model1)
|
||||
assert result1 is True
|
||||
|
||||
# Verify in cache
|
||||
cached_model = cached_disk_dist_registry.get_cached("model", "test_model")
|
||||
assert cached_model is not None
|
||||
assert cached_model.model_type == ModelType.llm
|
||||
|
||||
# Second registration with different data should fail
|
||||
with pytest.raises(ValueError, match="Object of type 'model' and identifier 'test_model' already exists"):
|
||||
await cached_disk_dist_registry.register(model2)
|
||||
|
||||
# Cache should still contain original model
|
||||
cached_model_after = cached_disk_dist_registry.get_cached("model", "test_model")
|
||||
assert cached_model_after is not None
|
||||
assert cached_model_after.model_type == ModelType.llm
|
||||
|
|
|
|||
|
|
@ -122,7 +122,7 @@ def mock_impls():
|
|||
|
||||
|
||||
@pytest.fixture
|
||||
def scope_middleware_with_mocks(mock_auth_endpoint):
|
||||
def middleware_with_mocks(mock_auth_endpoint):
|
||||
"""Create AuthenticationMiddleware with mocked route implementations"""
|
||||
mock_app = AsyncMock()
|
||||
auth_config = AuthenticationConfig(
|
||||
|
|
@ -137,18 +137,20 @@ def scope_middleware_with_mocks(mock_auth_endpoint):
|
|||
# Mock the route_impls to simulate finding routes with required scopes
|
||||
from llama_stack.schema_utils import WebMethod
|
||||
|
||||
scoped_webmethod = WebMethod(route="/test/scoped", method="POST", required_scope="test.read")
|
||||
|
||||
public_webmethod = WebMethod(route="/test/public", method="GET")
|
||||
routes = {
|
||||
("POST", "/test/scoped"): WebMethod(route="/test/scoped", method="POST", required_scope="test.read"),
|
||||
("GET", "/test/public"): WebMethod(route="/test/public", method="GET"),
|
||||
("GET", "/health"): WebMethod(route="/health", method="GET", require_authentication=False),
|
||||
("GET", "/version"): WebMethod(route="/version", method="GET", require_authentication=False),
|
||||
("GET", "/models/list"): WebMethod(route="/models/list", method="GET", require_authentication=True),
|
||||
}
|
||||
|
||||
# Mock the route finding logic
|
||||
def mock_find_matching_route(method, path, route_impls):
|
||||
if method == "POST" and path == "/test/scoped":
|
||||
return None, {}, "/test/scoped", scoped_webmethod
|
||||
elif method == "GET" and path == "/test/public":
|
||||
return None, {}, "/test/public", public_webmethod
|
||||
else:
|
||||
raise ValueError("No matching route")
|
||||
webmethod = routes.get((method, path))
|
||||
if webmethod:
|
||||
return None, {}, path, webmethod
|
||||
raise ValueError("No matching route")
|
||||
|
||||
import llama_stack.core.server.auth
|
||||
|
||||
|
|
@ -659,9 +661,9 @@ def test_valid_introspection_with_custom_mapping_authentication(
|
|||
|
||||
# Scope-based authorization tests
|
||||
@patch("httpx.AsyncClient.post", new=mock_post_success_with_scope)
|
||||
async def test_scope_authorization_success(scope_middleware_with_mocks, valid_api_key):
|
||||
async def test_scope_authorization_success(middleware_with_mocks, valid_api_key):
|
||||
"""Test that user with required scope can access protected endpoint"""
|
||||
middleware, mock_app = scope_middleware_with_mocks
|
||||
middleware, mock_app = middleware_with_mocks
|
||||
mock_receive = AsyncMock()
|
||||
mock_send = AsyncMock()
|
||||
|
||||
|
|
@ -680,9 +682,9 @@ async def test_scope_authorization_success(scope_middleware_with_mocks, valid_ap
|
|||
|
||||
|
||||
@patch("httpx.AsyncClient.post", new=mock_post_success_no_scope)
|
||||
async def test_scope_authorization_denied(scope_middleware_with_mocks, valid_api_key):
|
||||
async def test_scope_authorization_denied(middleware_with_mocks, valid_api_key):
|
||||
"""Test that user without required scope gets 403 access denied"""
|
||||
middleware, mock_app = scope_middleware_with_mocks
|
||||
middleware, mock_app = middleware_with_mocks
|
||||
mock_receive = AsyncMock()
|
||||
mock_send = AsyncMock()
|
||||
|
||||
|
|
@ -710,9 +712,9 @@ async def test_scope_authorization_denied(scope_middleware_with_mocks, valid_api
|
|||
|
||||
|
||||
@patch("httpx.AsyncClient.post", new=mock_post_success_no_scope)
|
||||
async def test_public_endpoint_no_scope_required(scope_middleware_with_mocks, valid_api_key):
|
||||
async def test_public_endpoint_no_scope_required(middleware_with_mocks, valid_api_key):
|
||||
"""Test that public endpoints work without specific scopes"""
|
||||
middleware, mock_app = scope_middleware_with_mocks
|
||||
middleware, mock_app = middleware_with_mocks
|
||||
mock_receive = AsyncMock()
|
||||
mock_send = AsyncMock()
|
||||
|
||||
|
|
@ -730,9 +732,9 @@ async def test_public_endpoint_no_scope_required(scope_middleware_with_mocks, va
|
|||
mock_send.assert_not_called()
|
||||
|
||||
|
||||
async def test_scope_authorization_no_auth_disabled(scope_middleware_with_mocks):
|
||||
async def test_scope_authorization_no_auth_disabled(middleware_with_mocks):
|
||||
"""Test that when auth is disabled (no user), scope checks are bypassed"""
|
||||
middleware, mock_app = scope_middleware_with_mocks
|
||||
middleware, mock_app = middleware_with_mocks
|
||||
mock_receive = AsyncMock()
|
||||
mock_send = AsyncMock()
|
||||
|
||||
|
|
@ -907,3 +909,41 @@ def test_kubernetes_auth_request_payload(kubernetes_auth_client, valid_token, mo
|
|||
request_body = call_args[1]["json"]
|
||||
assert request_body["apiVersion"] == "authentication.k8s.io/v1"
|
||||
assert request_body["kind"] == "SelfSubjectReview"
|
||||
|
||||
|
||||
async def test_unauthenticated_endpoint_access_health(middleware_with_mocks):
|
||||
"""Test that /health endpoints can be accessed without authentication"""
|
||||
middleware, mock_app = middleware_with_mocks
|
||||
|
||||
# Test request to /health without auth header (level prefix v1 is added by router)
|
||||
scope = {"type": "http", "path": "/health", "headers": [], "method": "GET"}
|
||||
receive = AsyncMock()
|
||||
send = AsyncMock()
|
||||
|
||||
# Should allow the request to proceed without authentication
|
||||
await middleware(scope, receive, send)
|
||||
|
||||
# Verify that the request was passed to the app
|
||||
mock_app.assert_called_once_with(scope, receive, send)
|
||||
|
||||
# Verify that no error response was sent
|
||||
assert not any(call[0][0].get("status") == 401 for call in send.call_args_list)
|
||||
|
||||
|
||||
async def test_unauthenticated_endpoint_denied_for_other_paths(middleware_with_mocks):
|
||||
"""Test that endpoints other than /health and /version require authentication"""
|
||||
middleware, mock_app = middleware_with_mocks
|
||||
|
||||
# Test request to /models/list without auth header
|
||||
scope = {"type": "http", "path": "/models/list", "headers": [], "method": "GET"}
|
||||
receive = AsyncMock()
|
||||
send = AsyncMock()
|
||||
|
||||
# Should return 401 error
|
||||
await middleware(scope, receive, send)
|
||||
|
||||
# Verify that the app was NOT called
|
||||
mock_app.assert_not_called()
|
||||
|
||||
# Verify that a 401 error response was sent
|
||||
assert any(call[0][0].get("status") == 401 for call in send.call_args_list)
|
||||
|
|
|
|||
30
tests/unit/utils/kvstore/test_sqlite_memory.py
Normal file
30
tests/unit/utils/kvstore/test_sqlite_memory.py
Normal file
|
|
@ -0,0 +1,30 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
|
||||
from llama_stack.providers.utils.kvstore.config import SqliteKVStoreConfig
|
||||
from llama_stack.providers.utils.kvstore.sqlite.sqlite import SqliteKVStoreImpl
|
||||
|
||||
|
||||
async def test_memory_kvstore_persistence_behavior():
|
||||
"""Test that :memory: database doesn't persist across instances."""
|
||||
config = SqliteKVStoreConfig(db_path=":memory:")
|
||||
|
||||
# First instance
|
||||
store1 = SqliteKVStoreImpl(config)
|
||||
await store1.initialize()
|
||||
await store1.set("persist_test", "should_not_persist")
|
||||
await store1.shutdown()
|
||||
|
||||
# Second instance with same config
|
||||
store2 = SqliteKVStoreImpl(config)
|
||||
await store2.initialize()
|
||||
|
||||
# Data should not be present
|
||||
result = await store2.get("persist_test")
|
||||
assert result is None
|
||||
|
||||
await store2.shutdown()
|
||||
Loading…
Add table
Add a link
Reference in a new issue