diff --git a/tests/integration/conftest.py b/tests/integration/conftest.py index ec5918268..fa96688c0 100644 --- a/tests/integration/conftest.py +++ b/tests/integration/conftest.py @@ -95,6 +95,7 @@ def pytest_addoption(parser): parser.addoption( "--embedding-dimension", type=int, + default=384, help="Output dimensionality of the embedding model to use for testing. Default: 384", ) parser.addoption( diff --git a/tests/integration/tool_runtime/test_rag_tool.py b/tests/integration/tool_runtime/test_rag_tool.py index 46f4f8768..2affe2a2d 100644 --- a/tests/integration/tool_runtime/test_rag_tool.py +++ b/tests/integration/tool_runtime/test_rag_tool.py @@ -63,12 +63,14 @@ def assert_valid_text_response(response): assert all(isinstance(chunk.text, str) for chunk in response.content) -def test_vector_db_insert_inline_and_query(client_with_empty_registry, sample_documents, embedding_model_id): +def test_vector_db_insert_inline_and_query( + client_with_empty_registry, sample_documents, embedding_model_id, embedding_dimension +): vector_db_id = "test_vector_db" client_with_empty_registry.vector_dbs.register( vector_db_id=vector_db_id, embedding_model=embedding_model_id, - embedding_dimension=384, + embedding_dimension=embedding_dimension, ) client_with_empty_registry.tool_runtime.rag_tool.insert( @@ -116,7 +118,9 @@ def test_vector_db_insert_inline_and_query(client_with_empty_registry, sample_do assert all(score >= 0.01 for score in response4.scores) -def test_vector_db_insert_from_url_and_query(client_with_empty_registry, sample_documents, embedding_model_id): +def test_vector_db_insert_from_url_and_query( + client_with_empty_registry, sample_documents, embedding_model_id, embedding_dimension +): providers = [p for p in client_with_empty_registry.providers.list() if p.api == "vector_io"] assert len(providers) > 0 @@ -125,7 +129,7 @@ def test_vector_db_insert_from_url_and_query(client_with_empty_registry, sample_ client_with_empty_registry.vector_dbs.register( vector_db_id=vector_db_id, embedding_model=embedding_model_id, - embedding_dimension=384, + embedding_dimension=embedding_dimension, ) # list to check memory bank is successfully registered @@ -170,7 +174,7 @@ def test_vector_db_insert_from_url_and_query(client_with_empty_registry, sample_ assert any("llama2" in chunk.content.lower() for chunk in response2.chunks) -def test_rag_tool_insert_and_query(client_with_empty_registry, embedding_model_id): +def test_rag_tool_insert_and_query(client_with_empty_registry, embedding_model_id, embedding_dimension): providers = [p for p in client_with_empty_registry.providers.list() if p.api == "vector_io"] assert len(providers) > 0 @@ -179,7 +183,7 @@ def test_rag_tool_insert_and_query(client_with_empty_registry, embedding_model_i client_with_empty_registry.vector_dbs.register( vector_db_id=vector_db_id, embedding_model=embedding_model_id, - embedding_dimension=384, + embedding_dimension=embedding_dimension, ) available_vector_dbs = [vector_db.identifier for vector_db in client_with_empty_registry.vector_dbs.list()] diff --git a/tests/integration/vector_io/test_vector_io.py b/tests/integration/vector_io/test_vector_io.py index f550cf666..95fcb8db5 100644 --- a/tests/integration/vector_io/test_vector_io.py +++ b/tests/integration/vector_io/test_vector_io.py @@ -46,13 +46,13 @@ def client_with_empty_registry(client_with_models): clear_registry() -def test_vector_db_retrieve(client_with_empty_registry, embedding_model_id): +def test_vector_db_retrieve(client_with_empty_registry, embedding_model_id, embedding_dimension): # Register a memory bank first vector_db_id = "test_vector_db" client_with_empty_registry.vector_dbs.register( vector_db_id=vector_db_id, embedding_model=embedding_model_id, - embedding_dimension=384, + embedding_dimension=embedding_dimension, ) # Retrieve the memory bank and validate its properties @@ -63,12 +63,12 @@ def test_vector_db_retrieve(client_with_empty_registry, embedding_model_id): assert response.provider_resource_id == vector_db_id -def test_vector_db_register(client_with_empty_registry, embedding_model_id): +def test_vector_db_register(client_with_empty_registry, embedding_model_id, embedding_dimension): vector_db_id = "test_vector_db" client_with_empty_registry.vector_dbs.register( vector_db_id=vector_db_id, embedding_model=embedding_model_id, - embedding_dimension=384, + embedding_dimension=embedding_dimension, ) vector_dbs_after_register = [vector_db.identifier for vector_db in client_with_empty_registry.vector_dbs.list()] @@ -90,12 +90,12 @@ def test_vector_db_register(client_with_empty_registry, embedding_model_id): ("How does machine learning improve over time?", "doc2"), ], ) -def test_insert_chunks(client_with_empty_registry, embedding_model_id, sample_chunks, test_case): +def test_insert_chunks(client_with_empty_registry, embedding_model_id, embedding_dimension, sample_chunks, test_case): vector_db_id = "test_vector_db" client_with_empty_registry.vector_dbs.register( vector_db_id=vector_db_id, embedding_model=embedding_model_id, - embedding_dimension=384, + embedding_dimension=embedding_dimension, ) client_with_empty_registry.vector_io.insert( @@ -122,19 +122,19 @@ def test_insert_chunks(client_with_empty_registry, embedding_model_id, sample_ch assert top_match.metadata["document_id"] == expected_doc_id, f"Query '{query}' should match {expected_doc_id}" -def test_insert_chunks_with_precomputed_embeddings(client_with_empty_registry, embedding_model_id): +def test_insert_chunks_with_precomputed_embeddings(client_with_empty_registry, embedding_model_id, embedding_dimension): vector_db_id = "test_precomputed_embeddings_db" client_with_empty_registry.vector_dbs.register( vector_db_id=vector_db_id, embedding_model=embedding_model_id, - embedding_dimension=384, + embedding_dimension=embedding_dimension, ) chunks_with_embeddings = [ Chunk( content="This is a test chunk with precomputed embedding.", metadata={"document_id": "doc1", "source": "precomputed"}, - embedding=[0.1] * 384, + embedding=[0.1] * int(embedding_dimension), ), ] @@ -156,19 +156,21 @@ def test_insert_chunks_with_precomputed_embeddings(client_with_empty_registry, e assert response.chunks[0].metadata["source"] == "precomputed" -def test_query_returns_valid_object_when_identical_to_embedding_in_vdb(client_with_empty_registry, embedding_model_id): +def test_query_returns_valid_object_when_identical_to_embedding_in_vdb( + client_with_empty_registry, embedding_model_id, embedding_dimension +): vector_db_id = "test_precomputed_embeddings_db" client_with_empty_registry.vector_dbs.register( vector_db_id=vector_db_id, embedding_model=embedding_model_id, - embedding_dimension=384, + embedding_dimension=embedding_dimension, ) chunks_with_embeddings = [ Chunk( content="duplicate", metadata={"document_id": "doc1", "source": "precomputed"}, - embedding=[0.1] * 384, + embedding=[0.1] * int(embedding_dimension), ), ]