From e2fe39aee108c1796ddc1be1e44acd25c800082e Mon Sep 17 00:00:00 2001 From: Francisco Arceo Date: Fri, 5 Sep 2025 07:40:34 -0600 Subject: [PATCH 1/4] feat!: Migrate Vector DB IDs to Vector Store IDs (breaking change) (#3253) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit # What does this PR do? This change migrates the VectorDB id generation to Vector Stores. This is a breaking change for **_some users_** that may have application code using the `vector_db_id` parameter in the request of the VectorDB protocol instead of the `VectorDB.identifier` in the response. By default we will now create a Vector Store every time we register a VectorDB. The caveat with this approach is that this maps the `vector_db_id` → `vector_store.name`. This is a reasonable tradeoff to transition users towards OpenAI Vector Stores. As an added benefit, registering VectorDBs will result in them appearing in the VectorStores admin UI. ### Why? This PR makes the `POST` API call to `/v1/vector-dbs` swap the `vector_db_id` parameter in the **request body** into the VectorStore's name field and sets the `vector_db_id` to the generated vector store id (e.g., `vs_038247dd-4bbb-4dbb-a6be-d5ecfd46cfdb`). That means that users would have to do something like follows in their application code: ```python res = client.vector_dbs.register( vector_db_id='my-vector-db-id', embedding_model='ollama/all-minilm:l6-v2', embedding_dimension=384, ) vector_db_id = res.identifier ``` And then the rest of their code would behave, including `VectorIO`'s insert protocol using `vector_db_id` in the request. An alternative implementation would be to just delete the `vector_db_id` parameter in `VectorDB` but the end result would still require users having to write `vector_db_id = res.identifier` since `VectorStores.create()` generates the ID for you. So this approach felt the easiest way to migrate users towards VectorStores (subsequent PRs will be added to trigger `files.create()` and `vector_stores.files.create()`). ## Test Plan Unit tests and integration tests have been added. Signed-off-by: Francisco Javier Arceo --- llama_stack/core/routing_tables/vector_dbs.py | 26 +++- tests/integration/vector_io/test_vector_io.py | 75 +++++++---- .../routers/test_routing_tables.py | 30 ++++- .../routing_tables/test_vector_dbs.py | 127 ++++++++++++++++-- 4 files changed, 209 insertions(+), 49 deletions(-) diff --git a/llama_stack/core/routing_tables/vector_dbs.py b/llama_stack/core/routing_tables/vector_dbs.py index 00f71b4fe..497894064 100644 --- a/llama_stack/core/routing_tables/vector_dbs.py +++ b/llama_stack/core/routing_tables/vector_dbs.py @@ -52,7 +52,6 @@ class VectorDBsRoutingTable(CommonRoutingTableImpl, VectorDBs): provider_vector_db_id: str | None = None, vector_db_name: str | None = None, ) -> VectorDB: - provider_vector_db_id = provider_vector_db_id or vector_db_id if provider_id is None: if len(self.impls_by_provider_id) > 0: provider_id = list(self.impls_by_provider_id.keys())[0] @@ -69,14 +68,33 @@ class VectorDBsRoutingTable(CommonRoutingTableImpl, VectorDBs): raise ModelTypeError(embedding_model, model.model_type, ModelType.embedding) if "embedding_dimension" not in model.metadata: raise ValueError(f"Model {embedding_model} does not have an embedding dimension") + + provider = self.impls_by_provider_id[provider_id] + logger.warning( + "VectorDB is being deprecated in future releases in favor of VectorStore. Please migrate your usage accordingly." + ) + vector_store = await provider.openai_create_vector_store( + name=vector_db_name or vector_db_id, + embedding_model=embedding_model, + embedding_dimension=model.metadata["embedding_dimension"], + provider_id=provider_id, + provider_vector_db_id=provider_vector_db_id, + ) + + vector_store_id = vector_store.id + actual_provider_vector_db_id = provider_vector_db_id or vector_store_id + logger.warning( + f"Ignoring vector_db_id {vector_db_id} and using vector_store_id {vector_store_id} instead. Setting VectorDB {vector_db_id} to VectorDB.vector_db_name" + ) + vector_db_data = { - "identifier": vector_db_id, + "identifier": vector_store_id, "type": ResourceType.vector_db.value, "provider_id": provider_id, - "provider_resource_id": provider_vector_db_id, + "provider_resource_id": actual_provider_vector_db_id, "embedding_model": embedding_model, "embedding_dimension": model.metadata["embedding_dimension"], - "vector_db_name": vector_db_name, + "vector_db_name": vector_store.name, } vector_db = TypeAdapter(VectorDBWithOwner).validate_python(vector_db_data) await self.register_object(vector_db) diff --git a/tests/integration/vector_io/test_vector_io.py b/tests/integration/vector_io/test_vector_io.py index 07faa0db1..979eff6bb 100644 --- a/tests/integration/vector_io/test_vector_io.py +++ b/tests/integration/vector_io/test_vector_io.py @@ -47,34 +47,45 @@ def client_with_empty_registry(client_with_models): def test_vector_db_retrieve(client_with_empty_registry, embedding_model_id, embedding_dimension): - # Register a memory bank first - vector_db_id = "test_vector_db" - client_with_empty_registry.vector_dbs.register( - vector_db_id=vector_db_id, + vector_db_name = "test_vector_db" + register_response = client_with_empty_registry.vector_dbs.register( + vector_db_id=vector_db_name, embedding_model=embedding_model_id, embedding_dimension=embedding_dimension, ) + actual_vector_db_id = register_response.identifier + # Retrieve the memory bank and validate its properties - response = client_with_empty_registry.vector_dbs.retrieve(vector_db_id=vector_db_id) + response = client_with_empty_registry.vector_dbs.retrieve(vector_db_id=actual_vector_db_id) assert response is not None - assert response.identifier == vector_db_id + assert response.identifier == actual_vector_db_id assert response.embedding_model == embedding_model_id - assert response.provider_resource_id == vector_db_id + assert response.identifier.startswith("vs_") def test_vector_db_register(client_with_empty_registry, embedding_model_id, embedding_dimension): - vector_db_id = "test_vector_db" - client_with_empty_registry.vector_dbs.register( - vector_db_id=vector_db_id, + vector_db_name = "test_vector_db" + response = client_with_empty_registry.vector_dbs.register( + vector_db_id=vector_db_name, embedding_model=embedding_model_id, embedding_dimension=embedding_dimension, ) - vector_dbs_after_register = [vector_db.identifier for vector_db in client_with_empty_registry.vector_dbs.list()] - assert vector_dbs_after_register == [vector_db_id] + actual_vector_db_id = response.identifier + assert actual_vector_db_id.startswith("vs_") + assert actual_vector_db_id != vector_db_name - client_with_empty_registry.vector_dbs.unregister(vector_db_id=vector_db_id) + vector_dbs_after_register = [vector_db.identifier for vector_db in client_with_empty_registry.vector_dbs.list()] + assert vector_dbs_after_register == [actual_vector_db_id] + + vector_stores = client_with_empty_registry.vector_stores.list() + assert len(vector_stores.data) == 1 + vector_store = vector_stores.data[0] + assert vector_store.id == actual_vector_db_id + assert vector_store.name == vector_db_name + + client_with_empty_registry.vector_dbs.unregister(vector_db_id=actual_vector_db_id) vector_dbs = [vector_db.identifier for vector_db in client_with_empty_registry.vector_dbs.list()] assert len(vector_dbs) == 0 @@ -91,20 +102,22 @@ def test_vector_db_register(client_with_empty_registry, embedding_model_id, embe ], ) def test_insert_chunks(client_with_empty_registry, embedding_model_id, embedding_dimension, sample_chunks, test_case): - vector_db_id = "test_vector_db" - client_with_empty_registry.vector_dbs.register( - vector_db_id=vector_db_id, + vector_db_name = "test_vector_db" + register_response = client_with_empty_registry.vector_dbs.register( + vector_db_id=vector_db_name, embedding_model=embedding_model_id, embedding_dimension=embedding_dimension, ) + actual_vector_db_id = register_response.identifier + client_with_empty_registry.vector_io.insert( - vector_db_id=vector_db_id, + vector_db_id=actual_vector_db_id, chunks=sample_chunks, ) response = client_with_empty_registry.vector_io.query( - vector_db_id=vector_db_id, + vector_db_id=actual_vector_db_id, query="What is the capital of France?", ) assert response is not None @@ -113,7 +126,7 @@ def test_insert_chunks(client_with_empty_registry, embedding_model_id, embedding query, expected_doc_id = test_case response = client_with_empty_registry.vector_io.query( - vector_db_id=vector_db_id, + vector_db_id=actual_vector_db_id, query=query, ) assert response is not None @@ -128,13 +141,15 @@ def test_insert_chunks_with_precomputed_embeddings(client_with_empty_registry, e "remote::qdrant": {"score_threshold": -1.0}, "inline::qdrant": {"score_threshold": -1.0}, } - vector_db_id = "test_precomputed_embeddings_db" - client_with_empty_registry.vector_dbs.register( - vector_db_id=vector_db_id, + vector_db_name = "test_precomputed_embeddings_db" + register_response = client_with_empty_registry.vector_dbs.register( + vector_db_id=vector_db_name, embedding_model=embedding_model_id, embedding_dimension=embedding_dimension, ) + actual_vector_db_id = register_response.identifier + chunks_with_embeddings = [ Chunk( content="This is a test chunk with precomputed embedding.", @@ -144,13 +159,13 @@ def test_insert_chunks_with_precomputed_embeddings(client_with_empty_registry, e ] client_with_empty_registry.vector_io.insert( - vector_db_id=vector_db_id, + vector_db_id=actual_vector_db_id, chunks=chunks_with_embeddings, ) provider = [p.provider_id for p in client_with_empty_registry.providers.list() if p.api == "vector_io"][0] response = client_with_empty_registry.vector_io.query( - vector_db_id=vector_db_id, + vector_db_id=actual_vector_db_id, query="precomputed embedding test", params=vector_io_provider_params_dict.get(provider, None), ) @@ -173,13 +188,15 @@ def test_query_returns_valid_object_when_identical_to_embedding_in_vdb( "remote::qdrant": {"score_threshold": 0.0}, "inline::qdrant": {"score_threshold": 0.0}, } - vector_db_id = "test_precomputed_embeddings_db" - client_with_empty_registry.vector_dbs.register( - vector_db_id=vector_db_id, + vector_db_name = "test_precomputed_embeddings_db" + register_response = client_with_empty_registry.vector_dbs.register( + vector_db_id=vector_db_name, embedding_model=embedding_model_id, embedding_dimension=embedding_dimension, ) + actual_vector_db_id = register_response.identifier + chunks_with_embeddings = [ Chunk( content="duplicate", @@ -189,13 +206,13 @@ def test_query_returns_valid_object_when_identical_to_embedding_in_vdb( ] client_with_empty_registry.vector_io.insert( - vector_db_id=vector_db_id, + vector_db_id=actual_vector_db_id, chunks=chunks_with_embeddings, ) provider = [p.provider_id for p in client_with_empty_registry.providers.list() if p.api == "vector_io"][0] response = client_with_empty_registry.vector_io.query( - vector_db_id=vector_db_id, + vector_db_id=actual_vector_db_id, query="duplicate", params=vector_io_provider_params_dict.get(provider, None), ) diff --git a/tests/unit/distribution/routers/test_routing_tables.py b/tests/unit/distribution/routers/test_routing_tables.py index 2652f5c8d..1ceee81c6 100644 --- a/tests/unit/distribution/routers/test_routing_tables.py +++ b/tests/unit/distribution/routers/test_routing_tables.py @@ -146,6 +146,20 @@ class VectorDBImpl(Impl): async def unregister_vector_db(self, vector_db_id: str): return vector_db_id + async def openai_create_vector_store(self, **kwargs): + import time + import uuid + + from llama_stack.apis.vector_io.vector_io import VectorStoreFileCounts, VectorStoreObject + + vector_store_id = kwargs.get("provider_vector_db_id") or f"vs_{uuid.uuid4()}" + return VectorStoreObject( + id=vector_store_id, + name=kwargs.get("name", vector_store_id), + created_at=int(time.time()), + file_counts=VectorStoreFileCounts(completed=0, cancelled=0, failed=0, in_progress=0, total=0), + ) + async def test_models_routing_table(cached_disk_dist_registry): table = ModelsRoutingTable({"test_provider": InferenceImpl()}, cached_disk_dist_registry, {}) @@ -247,17 +261,21 @@ async def test_vectordbs_routing_table(cached_disk_dist_registry): ) # Register multiple vector databases and verify listing - await table.register_vector_db(vector_db_id="test-vectordb", embedding_model="test_provider/test-model") - await table.register_vector_db(vector_db_id="test-vectordb-2", embedding_model="test_provider/test-model") + vdb1 = await table.register_vector_db(vector_db_id="test-vectordb", embedding_model="test_provider/test-model") + vdb2 = await table.register_vector_db(vector_db_id="test-vectordb-2", embedding_model="test_provider/test-model") vector_dbs = await table.list_vector_dbs() assert len(vector_dbs.data) == 2 vector_db_ids = {v.identifier for v in vector_dbs.data} - assert "test-vectordb" in vector_db_ids - assert "test-vectordb-2" in vector_db_ids + assert vdb1.identifier in vector_db_ids + assert vdb2.identifier in vector_db_ids - await table.unregister_vector_db(vector_db_id="test-vectordb") - await table.unregister_vector_db(vector_db_id="test-vectordb-2") + # Verify they have UUID-based identifiers + assert vdb1.identifier.startswith("vs_") + assert vdb2.identifier.startswith("vs_") + + await table.unregister_vector_db(vector_db_id=vdb1.identifier) + await table.unregister_vector_db(vector_db_id=vdb2.identifier) vector_dbs = await table.list_vector_dbs() assert len(vector_dbs.data) == 0 diff --git a/tests/unit/distribution/routing_tables/test_vector_dbs.py b/tests/unit/distribution/routing_tables/test_vector_dbs.py index 789eda433..3444f64c2 100644 --- a/tests/unit/distribution/routing_tables/test_vector_dbs.py +++ b/tests/unit/distribution/routing_tables/test_vector_dbs.py @@ -7,6 +7,7 @@ # Unit tests for the routing tables vector_dbs import time +import uuid from unittest.mock import AsyncMock import pytest @@ -34,6 +35,7 @@ from tests.unit.distribution.routers.test_routing_tables import Impl, InferenceI class VectorDBImpl(Impl): def __init__(self): super().__init__(Api.vector_io) + self.vector_stores = {} async def register_vector_db(self, vector_db: VectorDB): return vector_db @@ -114,8 +116,35 @@ class VectorDBImpl(Impl): async def openai_delete_vector_store_file(self, vector_store_id, file_id): return VectorStoreFileDeleteResponse(id=file_id, deleted=True) + async def openai_create_vector_store( + self, + name=None, + embedding_model=None, + embedding_dimension=None, + provider_id=None, + provider_vector_db_id=None, + **kwargs, + ): + vector_store_id = provider_vector_db_id or f"vs_{uuid.uuid4()}" + vector_store = VectorStoreObject( + id=vector_store_id, + name=name or vector_store_id, + created_at=int(time.time()), + file_counts=VectorStoreFileCounts(completed=0, cancelled=0, failed=0, in_progress=0, total=0), + ) + self.vector_stores[vector_store_id] = vector_store + return vector_store + + async def openai_list_vector_stores(self, **kwargs): + from llama_stack.apis.vector_io.vector_io import VectorStoreListResponse + + return VectorStoreListResponse( + data=list(self.vector_stores.values()), has_more=False, first_id=None, last_id=None + ) + async def test_vectordbs_routing_table(cached_disk_dist_registry): + n = 10 table = VectorDBsRoutingTable({"test_provider": VectorDBImpl()}, cached_disk_dist_registry, {}) await table.initialize() @@ -129,22 +158,98 @@ async def test_vectordbs_routing_table(cached_disk_dist_registry): ) # Register multiple vector databases and verify listing - await table.register_vector_db(vector_db_id="test-vectordb", embedding_model="test-model") - await table.register_vector_db(vector_db_id="test-vectordb-2", embedding_model="test-model") + vdb_dict = {} + for i in range(n): + vdb_dict[i] = await table.register_vector_db(vector_db_id=f"test-vectordb-{i}", embedding_model="test-model") + vector_dbs = await table.list_vector_dbs() - assert len(vector_dbs.data) == 2 + assert len(vector_dbs.data) == len(vdb_dict) vector_db_ids = {v.identifier for v in vector_dbs.data} - assert "test-vectordb" in vector_db_ids - assert "test-vectordb-2" in vector_db_ids - - await table.unregister_vector_db(vector_db_id="test-vectordb") - await table.unregister_vector_db(vector_db_id="test-vectordb-2") + for k in vdb_dict: + assert vdb_dict[k].identifier in vector_db_ids + for k in vdb_dict: + await table.unregister_vector_db(vector_db_id=vdb_dict[k].identifier) vector_dbs = await table.list_vector_dbs() assert len(vector_dbs.data) == 0 +async def test_vector_db_and_vector_store_id_mapping(cached_disk_dist_registry): + n = 10 + impl = VectorDBImpl() + table = VectorDBsRoutingTable({"test_provider": impl}, cached_disk_dist_registry, {}) + await table.initialize() + + m_table = ModelsRoutingTable({"test_provider": InferenceImpl()}, cached_disk_dist_registry, {}) + await m_table.initialize() + await m_table.register_model( + model_id="test-model", + provider_id="test_provider", + metadata={"embedding_dimension": 128}, + model_type=ModelType.embedding, + ) + + vdb_dict = {} + for i in range(n): + vdb_dict[i] = await table.register_vector_db(vector_db_id=f"test-vectordb-{i}", embedding_model="test-model") + + vector_dbs = await table.list_vector_dbs() + vector_db_ids = {v.identifier for v in vector_dbs.data} + + vector_stores = await impl.openai_list_vector_stores() + vector_store_ids = {v.id for v in vector_stores.data} + + assert vector_db_ids == vector_store_ids, ( + f"Vector DB IDs {vector_db_ids} don't match vector store IDs {vector_store_ids}" + ) + + for vector_store in vector_stores.data: + vector_db = await table.get_vector_db(vector_store.id) + assert vector_store.name == vector_db.vector_db_name, ( + f"Vector store name {vector_store.name} doesn't match vector store ID {vector_store.id}" + ) + + for vector_db_id in vector_db_ids: + await table.unregister_vector_db(vector_db_id) + + assert len((await table.list_vector_dbs()).data) == 0 + + +async def test_vector_db_id_becomes_vector_store_name(cached_disk_dist_registry): + impl = VectorDBImpl() + table = VectorDBsRoutingTable({"test_provider": impl}, cached_disk_dist_registry, {}) + await table.initialize() + + m_table = ModelsRoutingTable({"test_provider": InferenceImpl()}, cached_disk_dist_registry, {}) + await m_table.initialize() + await m_table.register_model( + model_id="test-model", + provider_id="test_provider", + metadata={"embedding_dimension": 128}, + model_type=ModelType.embedding, + ) + + user_provided_id = "my-custom-vector-db" + await table.register_vector_db(vector_db_id=user_provided_id, embedding_model="test-model") + + vector_stores = await impl.openai_list_vector_stores() + assert len(vector_stores.data) == 1 + + vector_store = vector_stores.data[0] + + assert vector_store.name == user_provided_id + + assert vector_store.id.startswith("vs_") + assert vector_store.id != user_provided_id + + vector_dbs = await table.list_vector_dbs() + assert len(vector_dbs.data) == 1 + assert vector_dbs.data[0].identifier == vector_store.id + + await table.unregister_vector_db(vector_store.id) + + async def test_openai_vector_stores_routing_table_roles(cached_disk_dist_registry): impl = VectorDBImpl() impl.openai_retrieve_vector_store = AsyncMock(return_value="OK") @@ -164,7 +269,8 @@ async def test_openai_vector_stores_routing_table_roles(cached_disk_dist_registr authorized_user = User(principal="alice", attributes={"roles": [authorized_team]}) with request_provider_data_context({}, authorized_user): - _ = await table.register_vector_db(vector_db_id="vs1", embedding_model="test-model") + registered_vdb = await table.register_vector_db(vector_db_id="vs1", embedding_model="test-model") + authorized_table = registered_vdb.identifier # Use the actual generated ID # Authorized reader with request_provider_data_context({}, authorized_user): @@ -227,7 +333,8 @@ async def test_openai_vector_stores_routing_table_actions(cached_disk_dist_regis ) with request_provider_data_context({}, admin_user): - await table.register_vector_db(vector_db_id=vector_db_id, embedding_model="test-model") + registered_vdb = await table.register_vector_db(vector_db_id=vector_db_id, embedding_model="test-model") + vector_db_id = registered_vdb.identifier # Use the actual generated ID read_methods = [ (table.openai_retrieve_vector_store, (vector_db_id,), {}), From df1526991f6d5cfca05d3c4f1077b67f4832d93e Mon Sep 17 00:00:00 2001 From: Matthew Farrellee Date: Fri, 5 Sep 2025 14:59:57 -0400 Subject: [PATCH 2/4] feat(batches, completions): add /v1/completions support to /v1/batches (#3309) # What does this PR do? add support for /v1/completions to the /v1/batches api ## Test Plan ci --- .../inline/batches/reference/batches.py | 69 +++++++++++++------ tests/integration/batches/test_batches.py | 55 +++++++++++++++ .../recordings/responses/41e27b9b5d09.json | 42 +++++++++++ .../unit/providers/batches/test_reference.py | 65 +++++++++++++++-- 4 files changed, 205 insertions(+), 26 deletions(-) create mode 100644 tests/integration/recordings/responses/41e27b9b5d09.json diff --git a/llama_stack/providers/inline/batches/reference/batches.py b/llama_stack/providers/inline/batches/reference/batches.py index 26f0ad15a..e049518a4 100644 --- a/llama_stack/providers/inline/batches/reference/batches.py +++ b/llama_stack/providers/inline/batches/reference/batches.py @@ -178,9 +178,9 @@ class ReferenceBatchesImpl(Batches): # TODO: set expiration time for garbage collection - if endpoint not in ["/v1/chat/completions"]: + if endpoint not in ["/v1/chat/completions", "/v1/completions"]: raise ValueError( - f"Invalid endpoint: {endpoint}. Supported values: /v1/chat/completions. Code: invalid_value. Param: endpoint", + f"Invalid endpoint: {endpoint}. Supported values: /v1/chat/completions, /v1/completions. Code: invalid_value. Param: endpoint", ) if completion_window != "24h": @@ -424,13 +424,21 @@ class ReferenceBatchesImpl(Batches): ) valid = False - for param, expected_type, type_string in [ - ("model", str, "a string"), - # messages is specific to /v1/chat/completions - # we could skip validating messages here and let inference fail. however, - # that would be a very expensive way to find out messages is wrong. - ("messages", list, "an array"), # TODO: allow messages to be a string? - ]: + if batch.endpoint == "/v1/chat/completions": + required_params = [ + ("model", str, "a string"), + # messages is specific to /v1/chat/completions + # we could skip validating messages here and let inference fail. however, + # that would be a very expensive way to find out messages is wrong. + ("messages", list, "an array"), # TODO: allow messages to be a string? + ] + else: # /v1/completions + required_params = [ + ("model", str, "a string"), + ("prompt", str, "a string"), # TODO: allow prompt to be a list of strings?? + ] + + for param, expected_type, type_string in required_params: if param not in body: errors.append( BatchError( @@ -591,20 +599,37 @@ class ReferenceBatchesImpl(Batches): try: # TODO(SECURITY): review body for security issues - request.body["messages"] = [convert_to_openai_message_param(msg) for msg in request.body["messages"]] - chat_response = await self.inference_api.openai_chat_completion(**request.body) + if request.url == "/v1/chat/completions": + request.body["messages"] = [convert_to_openai_message_param(msg) for msg in request.body["messages"]] + chat_response = await self.inference_api.openai_chat_completion(**request.body) - # this is for mypy, we don't allow streaming so we'll get the right type - assert hasattr(chat_response, "model_dump_json"), "Chat response must have model_dump_json method" - return { - "id": request_id, - "custom_id": request.custom_id, - "response": { - "status_code": 200, - "request_id": request_id, # TODO: should this be different? - "body": chat_response.model_dump_json(), - }, - } + # this is for mypy, we don't allow streaming so we'll get the right type + assert hasattr(chat_response, "model_dump_json"), "Chat response must have model_dump_json method" + return { + "id": request_id, + "custom_id": request.custom_id, + "response": { + "status_code": 200, + "request_id": request_id, # TODO: should this be different? + "body": chat_response.model_dump_json(), + }, + } + else: # /v1/completions + completion_response = await self.inference_api.openai_completion(**request.body) + + # this is for mypy, we don't allow streaming so we'll get the right type + assert hasattr(completion_response, "model_dump_json"), ( + "Completion response must have model_dump_json method" + ) + return { + "id": request_id, + "custom_id": request.custom_id, + "response": { + "status_code": 200, + "request_id": request_id, + "body": completion_response.model_dump_json(), + }, + } except Exception as e: logger.info(f"Error processing request {request.custom_id} in batch {batch_id}: {e}") return { diff --git a/tests/integration/batches/test_batches.py b/tests/integration/batches/test_batches.py index 59811b7a4..d55a68bd3 100644 --- a/tests/integration/batches/test_batches.py +++ b/tests/integration/batches/test_batches.py @@ -268,3 +268,58 @@ class TestBatchesIntegration: deleted_error_file = openai_client.files.delete(final_batch.error_file_id) assert deleted_error_file.deleted, f"Error file {final_batch.error_file_id} was not deleted successfully" + + def test_batch_e2e_completions(self, openai_client, batch_helper, text_model_id): + """Run an end-to-end batch with a single successful text completion request.""" + request_body = {"model": text_model_id, "prompt": "Say completions", "max_tokens": 20} + + batch_requests = [ + { + "custom_id": "success-1", + "method": "POST", + "url": "/v1/completions", + "body": request_body, + } + ] + + with batch_helper.create_file(batch_requests) as uploaded_file: + batch = openai_client.batches.create( + input_file_id=uploaded_file.id, + endpoint="/v1/completions", + completion_window="24h", + metadata={"test": "e2e_completions_success"}, + ) + + final_batch = batch_helper.wait_for( + batch.id, + max_wait_time=3 * 60, + expected_statuses={"completed"}, + timeout_action="skip", + ) + + assert final_batch.status == "completed" + assert final_batch.request_counts is not None + assert final_batch.request_counts.total == 1 + assert final_batch.request_counts.completed == 1 + assert final_batch.output_file_id is not None + + output_content = openai_client.files.content(final_batch.output_file_id) + if isinstance(output_content, str): + output_text = output_content + else: + output_text = output_content.content.decode("utf-8") + + output_lines = output_text.strip().split("\n") + assert len(output_lines) == 1 + + result = json.loads(output_lines[0]) + assert result["custom_id"] == "success-1" + assert "response" in result + assert result["response"]["status_code"] == 200 + + deleted_output_file = openai_client.files.delete(final_batch.output_file_id) + assert deleted_output_file.deleted + + if final_batch.error_file_id is not None: + deleted_error_file = openai_client.files.delete(final_batch.error_file_id) + assert deleted_error_file.deleted diff --git a/tests/integration/recordings/responses/41e27b9b5d09.json b/tests/integration/recordings/responses/41e27b9b5d09.json new file mode 100644 index 000000000..45d140843 --- /dev/null +++ b/tests/integration/recordings/responses/41e27b9b5d09.json @@ -0,0 +1,42 @@ +{ + "request": { + "method": "POST", + "url": "http://0.0.0.0:11434/v1/v1/completions", + "headers": {}, + "body": { + "model": "llama3.2:3b-instruct-fp16", + "prompt": "Say completions", + "max_tokens": 20 + }, + "endpoint": "/v1/completions", + "model": "llama3.2:3b-instruct-fp16" + }, + "response": { + "body": { + "__type__": "openai.types.completion.Completion", + "__data__": { + "id": "cmpl-271", + "choices": [ + { + "finish_reason": "length", + "index": 0, + "logprobs": null, + "text": "You want me to respond with a completion, but you didn't specify what I should complete. Could" + } + ], + "created": 1756846620, + "model": "llama3.2:3b-instruct-fp16", + "object": "text_completion", + "system_fingerprint": "fp_ollama", + "usage": { + "completion_tokens": 20, + "prompt_tokens": 28, + "total_tokens": 48, + "completion_tokens_details": null, + "prompt_tokens_details": null + } + } + }, + "is_streaming": false + } +} diff --git a/tests/unit/providers/batches/test_reference.py b/tests/unit/providers/batches/test_reference.py index 0ca866f7b..dfef5e040 100644 --- a/tests/unit/providers/batches/test_reference.py +++ b/tests/unit/providers/batches/test_reference.py @@ -46,7 +46,8 @@ The tests are categorized and outlined below, keep this updated: * test_validate_input_url_mismatch (negative) * test_validate_input_multiple_errors_per_request (negative) * test_validate_input_invalid_request_format (negative) - * test_validate_input_missing_parameters (parametrized negative - custom_id, method, url, body, model, messages missing validation) + * test_validate_input_missing_parameters_chat_completions (parametrized negative - custom_id, method, url, body, model, messages missing validation for chat/completions) + * test_validate_input_missing_parameters_completions (parametrized negative - custom_id, method, url, body, model, prompt missing validation for completions) * test_validate_input_invalid_parameter_types (parametrized negative - custom_id, url, method, body, model, messages type validation) The tests use temporary SQLite databases for isolation and mock external @@ -213,7 +214,6 @@ class TestReferenceBatchesImpl: "endpoint", [ "/v1/embeddings", - "/v1/completions", "/v1/invalid/endpoint", "", ], @@ -499,8 +499,10 @@ class TestReferenceBatchesImpl: ("messages", "body.messages", "invalid_request", "Messages parameter is required"), ], ) - async def test_validate_input_missing_parameters(self, provider, param_name, param_path, error_code, error_message): - """Test _validate_input when file contains request with missing required parameters.""" + async def test_validate_input_missing_parameters_chat_completions( + self, provider, param_name, param_path, error_code, error_message + ): + """Test _validate_input when file contains request with missing required parameters for chat completions.""" provider.files_api.openai_retrieve_file = AsyncMock() mock_response = MagicMock() @@ -541,6 +543,61 @@ class TestReferenceBatchesImpl: assert errors[0].message == error_message assert errors[0].param == param_path + @pytest.mark.parametrize( + "param_name,param_path,error_code,error_message", + [ + ("custom_id", "custom_id", "missing_required_parameter", "Missing required parameter: custom_id"), + ("method", "method", "missing_required_parameter", "Missing required parameter: method"), + ("url", "url", "missing_required_parameter", "Missing required parameter: url"), + ("body", "body", "missing_required_parameter", "Missing required parameter: body"), + ("model", "body.model", "invalid_request", "Model parameter is required"), + ("prompt", "body.prompt", "invalid_request", "Prompt parameter is required"), + ], + ) + async def test_validate_input_missing_parameters_completions( + self, provider, param_name, param_path, error_code, error_message + ): + """Test _validate_input when file contains request with missing required parameters for text completions.""" + provider.files_api.openai_retrieve_file = AsyncMock() + mock_response = MagicMock() + + base_request = { + "custom_id": "req-1", + "method": "POST", + "url": "/v1/completions", + "body": {"model": "test-model", "prompt": "Hello"}, + } + + # Remove the specific parameter being tested + if "." in param_path: + top_level, nested_param = param_path.split(".", 1) + del base_request[top_level][nested_param] + else: + del base_request[param_name] + + mock_response.body = json.dumps(base_request).encode() + provider.files_api.openai_retrieve_file_content = AsyncMock(return_value=mock_response) + + batch = BatchObject( + id="batch_test", + object="batch", + endpoint="/v1/completions", + input_file_id=f"missing_{param_name}_file", + completion_window="24h", + status="validating", + created_at=1234567890, + ) + + errors, requests = await provider._validate_input(batch) + + assert len(errors) == 1 + assert len(requests) == 0 + + assert errors[0].code == error_code + assert errors[0].line == 1 + assert errors[0].message == error_message + assert errors[0].param == param_path + async def test_validate_input_url_mismatch(self, provider): """Test _validate_input when file contains request with URL that doesn't match batch endpoint.""" provider.files_api.openai_retrieve_file = AsyncMock() From 0c2757a05b504bafef1bc589376712b9ac9a1c52 Mon Sep 17 00:00:00 2001 From: Matthew Farrellee Date: Fri, 5 Sep 2025 15:00:09 -0400 Subject: [PATCH 3/4] chore(sambanova test): skip with_n tests for sambanova, it is not implemented server-side (#3342) # What does this PR do? skip a test that cannot pass for sambanova see https://docs-legacy.sambanova.ai/sambastudio/latest/open-ai-api.html\#_example_requests_using_openai_client ## Test Plan ci --- .../inference/test_openai_completion.py | 14 ++++++++++---- 1 file changed, 10 insertions(+), 4 deletions(-) diff --git a/tests/integration/inference/test_openai_completion.py b/tests/integration/inference/test_openai_completion.py index 62185e470..bb447b3c1 100644 --- a/tests/integration/inference/test_openai_completion.py +++ b/tests/integration/inference/test_openai_completion.py @@ -58,6 +58,15 @@ def skip_if_model_doesnt_support_suffix(client_with_models, model_id): pytest.skip(f"Provider {provider.provider_type} doesn't support suffix.") +def skip_if_doesnt_support_n(client_with_models, model_id): + provider = provider_from_model(client_with_models, model_id) + if provider.provider_type in ( + "remote::sambanova", + "remote::ollama", + ): + pytest.skip(f"Model {model_id} hosted by {provider.provider_type} doesn't support n param.") + + def skip_if_model_doesnt_support_openai_chat_completion(client_with_models, model_id): provider = provider_from_model(client_with_models, model_id) if provider.provider_type in ( @@ -262,10 +271,7 @@ def test_openai_chat_completion_streaming(compat_client, client_with_models, tex ) def test_openai_chat_completion_streaming_with_n(compat_client, client_with_models, text_model_id, test_case): skip_if_model_doesnt_support_openai_chat_completion(client_with_models, text_model_id) - - provider = provider_from_model(client_with_models, text_model_id) - if provider.provider_type == "remote::ollama": - pytest.skip(f"Model {text_model_id} hosted by {provider.provider_type} doesn't support n > 1.") + skip_if_doesnt_support_n(client_with_models, text_model_id) tc = TestCase(test_case) question = tc["question"] From 47b640370e275dd178c0bb9fcf822467e032120d Mon Sep 17 00:00:00 2001 From: Ashwin Bharambe Date: Fri, 5 Sep 2025 13:58:49 -0700 Subject: [PATCH 4/4] feat(tests): introduce a test "suite" concept to encompass dirs, options (#3339) Our integration tests need to be 'grouped' because each group often needs a specific set of models it works with. We separated vision tests due to this, and we have a separate set of tests which test "Responses" API. This PR makes this system a bit more official so it is very easy to target these groups and apply all testing infrastructure towards all the groups (for example, record-replay) uniformly. There are three suites declared: - base - vision - responses Note that our CI currently runs the "base" and "vision" suites. You can use the `--suite` option when running pytest (or any of the testing scripts or workflows.) For example: ``` OLLAMA_URL=http://localhost:11434 \ pytest -s -v tests/integration/ --stack-config starter --suite vision ``` --- .../actions/run-and-record-tests/action.yml | 30 ++-- .github/actions/setup-ollama/action.yml | 8 +- .../actions/setup-test-environment/action.yml | 8 +- .github/workflows/README.md | 2 +- .github/workflows/integration-tests.yml | 20 +-- .../workflows/record-integration-tests.yml | 32 ++--- scripts/github/schedule-record-workflow.sh | 32 +++-- scripts/integration-tests.sh | 133 +++++++----------- tests/README.md | 2 +- tests/integration/README.md | 21 +++ tests/integration/conftest.py | 75 ++++++++-- .../{non_ci => }/responses/__init__.py | 0 .../responses/fixtures/__init__.py | 0 .../responses/fixtures/fixtures.py | 0 .../fixtures/images/vision_test_1.jpg | Bin .../fixtures/images/vision_test_2.jpg | Bin .../fixtures/images/vision_test_3.jpg | Bin .../fixtures/pdfs/llama_stack_and_models.pdf | Bin .../responses/fixtures/test_cases.py | 0 .../{non_ci => }/responses/helpers.py | 0 .../responses/streaming_assertions.py | 0 .../responses/test_basic_responses.py | 0 .../responses/test_file_search.py | 0 .../responses/test_tool_responses.py | 0 tests/integration/suites.py | 53 +++++++ 25 files changed, 255 insertions(+), 161 deletions(-) rename tests/integration/{non_ci => }/responses/__init__.py (100%) rename tests/integration/{non_ci => }/responses/fixtures/__init__.py (100%) rename tests/integration/{non_ci => }/responses/fixtures/fixtures.py (100%) rename tests/integration/{non_ci => }/responses/fixtures/images/vision_test_1.jpg (100%) rename tests/integration/{non_ci => }/responses/fixtures/images/vision_test_2.jpg (100%) rename tests/integration/{non_ci => }/responses/fixtures/images/vision_test_3.jpg (100%) rename tests/integration/{non_ci => }/responses/fixtures/pdfs/llama_stack_and_models.pdf (100%) rename tests/integration/{non_ci => }/responses/fixtures/test_cases.py (100%) rename tests/integration/{non_ci => }/responses/helpers.py (100%) rename tests/integration/{non_ci => }/responses/streaming_assertions.py (100%) rename tests/integration/{non_ci => }/responses/test_basic_responses.py (100%) rename tests/integration/{non_ci => }/responses/test_file_search.py (100%) rename tests/integration/{non_ci => }/responses/test_tool_responses.py (100%) create mode 100644 tests/integration/suites.py diff --git a/.github/actions/run-and-record-tests/action.yml b/.github/actions/run-and-record-tests/action.yml index 60550cfdc..7f028b104 100644 --- a/.github/actions/run-and-record-tests/action.yml +++ b/.github/actions/run-and-record-tests/action.yml @@ -2,13 +2,6 @@ name: 'Run and Record Tests' description: 'Run integration tests and handle recording/artifact upload' inputs: - test-subdirs: - description: 'Comma-separated list of test subdirectories to run' - required: true - test-pattern: - description: 'Regex pattern to pass to pytest -k' - required: false - default: '' stack-config: description: 'Stack configuration to use' required: true @@ -18,10 +11,18 @@ inputs: inference-mode: description: 'Inference mode (record or replay)' required: true - run-vision-tests: - description: 'Whether to run vision tests' + test-suite: + description: 'Test suite to use: base, responses, vision, etc.' required: false - default: 'false' + default: '' + test-subdirs: + description: 'Comma-separated list of test subdirectories to run; overrides test-suite' + required: false + default: '' + test-pattern: + description: 'Regex pattern to pass to pytest -k' + required: false + default: '' runs: using: 'composite' @@ -42,7 +43,7 @@ runs: --test-subdirs '${{ inputs.test-subdirs }}' \ --test-pattern '${{ inputs.test-pattern }}' \ --inference-mode '${{ inputs.inference-mode }}' \ - ${{ inputs.run-vision-tests == 'true' && '--run-vision-tests' || '' }} \ + --test-suite '${{ inputs.test-suite }}' \ | tee pytest-${{ inputs.inference-mode }}.log @@ -57,12 +58,7 @@ runs: echo "New recordings detected, committing and pushing" git add tests/integration/recordings/ - if [ "${{ inputs.run-vision-tests }}" == "true" ]; then - git commit -m "Recordings update from CI (vision)" - else - git commit -m "Recordings update from CI" - fi - + git commit -m "Recordings update from CI (test-suite: ${{ inputs.test-suite }})" git fetch origin ${{ github.ref_name }} git rebase origin/${{ github.ref_name }} echo "Rebased successfully" diff --git a/.github/actions/setup-ollama/action.yml b/.github/actions/setup-ollama/action.yml index e57876cb0..dc2f87e8c 100644 --- a/.github/actions/setup-ollama/action.yml +++ b/.github/actions/setup-ollama/action.yml @@ -1,17 +1,17 @@ name: Setup Ollama description: Start Ollama inputs: - run-vision-tests: - description: 'Run vision tests: "true" or "false"' + test-suite: + description: 'Test suite to use: base, responses, vision, etc.' required: false - default: 'false' + default: '' runs: using: "composite" steps: - name: Start Ollama shell: bash run: | - if [ "${{ inputs.run-vision-tests }}" == "true" ]; then + if [ "${{ inputs.test-suite }}" == "vision" ]; then image="ollama-with-vision-model" else image="ollama-with-models" diff --git a/.github/actions/setup-test-environment/action.yml b/.github/actions/setup-test-environment/action.yml index d830e3d13..3be76f009 100644 --- a/.github/actions/setup-test-environment/action.yml +++ b/.github/actions/setup-test-environment/action.yml @@ -12,10 +12,10 @@ inputs: description: 'Provider to setup (ollama or vllm)' required: true default: 'ollama' - run-vision-tests: - description: 'Whether to setup provider for vision tests' + test-suite: + description: 'Test suite to use: base, responses, vision, etc.' required: false - default: 'false' + default: '' inference-mode: description: 'Inference mode (record or replay)' required: true @@ -33,7 +33,7 @@ runs: if: ${{ inputs.provider == 'ollama' && inputs.inference-mode == 'record' }} uses: ./.github/actions/setup-ollama with: - run-vision-tests: ${{ inputs.run-vision-tests }} + test-suite: ${{ inputs.test-suite }} - name: Setup vllm if: ${{ inputs.provider == 'vllm' && inputs.inference-mode == 'record' }} diff --git a/.github/workflows/README.md b/.github/workflows/README.md index 8344d12a4..2e0df58b8 100644 --- a/.github/workflows/README.md +++ b/.github/workflows/README.md @@ -8,7 +8,7 @@ Llama Stack uses GitHub Actions for Continuous Integration (CI). Below is a tabl | Installer CI | [install-script-ci.yml](install-script-ci.yml) | Test the installation script | | Integration Auth Tests | [integration-auth-tests.yml](integration-auth-tests.yml) | Run the integration test suite with Kubernetes authentication | | SqlStore Integration Tests | [integration-sql-store-tests.yml](integration-sql-store-tests.yml) | Run the integration test suite with SqlStore | -| Integration Tests (Replay) | [integration-tests.yml](integration-tests.yml) | Run the integration test suite from tests/integration in replay mode | +| Integration Tests (Replay) | [integration-tests.yml](integration-tests.yml) | Run the integration test suites from tests/integration in replay mode | | Vector IO Integration Tests | [integration-vector-io-tests.yml](integration-vector-io-tests.yml) | Run the integration test suite with various VectorIO providers | | Pre-commit | [pre-commit.yml](pre-commit.yml) | Run pre-commit checks | | Test Llama Stack Build | [providers-build.yml](providers-build.yml) | Test llama stack build | diff --git a/.github/workflows/integration-tests.yml b/.github/workflows/integration-tests.yml index 57e582b20..bb53eea2f 100644 --- a/.github/workflows/integration-tests.yml +++ b/.github/workflows/integration-tests.yml @@ -1,6 +1,6 @@ name: Integration Tests (Replay) -run-name: Run the integration test suite from tests/integration in replay mode +run-name: Run the integration test suites from tests/integration in replay mode on: push: @@ -32,14 +32,6 @@ on: description: 'Test against a specific provider' type: string default: 'ollama' - test-subdirs: - description: 'Comma-separated list of test subdirectories to run' - type: string - default: '' - test-pattern: - description: 'Regex pattern to pass to pytest -k' - type: string - default: '' concurrency: # Skip concurrency for pushes to main - each commit should be tested independently @@ -50,7 +42,7 @@ jobs: run-replay-mode-tests: runs-on: ubuntu-latest - name: ${{ format('Integration Tests ({0}, {1}, {2}, client={3}, vision={4})', matrix.client-type, matrix.provider, matrix.python-version, matrix.client-version, matrix.run-vision-tests) }} + name: ${{ format('Integration Tests ({0}, {1}, {2}, client={3}, {4})', matrix.client-type, matrix.provider, matrix.python-version, matrix.client-version, matrix.test-suite) }} strategy: fail-fast: false @@ -61,7 +53,7 @@ jobs: # Use Python 3.13 only on nightly schedule (daily latest client test), otherwise use 3.12 python-version: ${{ github.event.schedule == '0 0 * * *' && fromJSON('["3.12", "3.13"]') || fromJSON('["3.12"]') }} client-version: ${{ (github.event.schedule == '0 0 * * *' || github.event.inputs.test-all-client-versions == 'true') && fromJSON('["published", "latest"]') || fromJSON('["latest"]') }} - run-vision-tests: [true, false] + test-suite: [base, vision] steps: - name: Checkout repository @@ -73,15 +65,13 @@ jobs: python-version: ${{ matrix.python-version }} client-version: ${{ matrix.client-version }} provider: ${{ matrix.provider }} - run-vision-tests: ${{ matrix.run-vision-tests }} + test-suite: ${{ matrix.test-suite }} inference-mode: 'replay' - name: Run tests uses: ./.github/actions/run-and-record-tests with: - test-subdirs: ${{ inputs.test-subdirs }} - test-pattern: ${{ inputs.test-pattern }} stack-config: ${{ matrix.client-type == 'library' && 'ci-tests' || 'server:ci-tests' }} provider: ${{ matrix.provider }} inference-mode: 'replay' - run-vision-tests: ${{ matrix.run-vision-tests }} + test-suite: ${{ matrix.test-suite }} diff --git a/.github/workflows/record-integration-tests.yml b/.github/workflows/record-integration-tests.yml index d4f5586e2..01797a54b 100644 --- a/.github/workflows/record-integration-tests.yml +++ b/.github/workflows/record-integration-tests.yml @@ -10,18 +10,18 @@ run-name: Run the integration test suite from tests/integration on: workflow_dispatch: inputs: - test-subdirs: - description: 'Comma-separated list of test subdirectories to run' - type: string - default: '' test-provider: description: 'Test against a specific provider' type: string default: 'ollama' - run-vision-tests: - description: 'Whether to run vision tests' - type: boolean - default: false + test-suite: + description: 'Test suite to use: base, responses, vision, etc.' + type: string + default: '' + test-subdirs: + description: 'Comma-separated list of test subdirectories to run; overrides test-suite' + type: string + default: '' test-pattern: description: 'Regex pattern to pass to pytest -k' type: string @@ -38,11 +38,11 @@ jobs: - name: Echo workflow inputs run: | echo "::group::Workflow Inputs" - echo "test-subdirs: ${{ inputs.test-subdirs }}" - echo "test-provider: ${{ inputs.test-provider }}" - echo "run-vision-tests: ${{ inputs.run-vision-tests }}" - echo "test-pattern: ${{ inputs.test-pattern }}" echo "branch: ${{ github.ref_name }}" + echo "test-provider: ${{ inputs.test-provider }}" + echo "test-suite: ${{ inputs.test-suite }}" + echo "test-subdirs: ${{ inputs.test-subdirs }}" + echo "test-pattern: ${{ inputs.test-pattern }}" echo "::endgroup::" - name: Checkout repository @@ -56,15 +56,15 @@ jobs: python-version: "3.12" # Use single Python version for recording client-version: "latest" provider: ${{ inputs.test-provider || 'ollama' }} - run-vision-tests: ${{ inputs.run-vision-tests }} + test-suite: ${{ inputs.test-suite }} inference-mode: 'record' - name: Run and record tests uses: ./.github/actions/run-and-record-tests with: - test-pattern: ${{ inputs.test-pattern }} - test-subdirs: ${{ inputs.test-subdirs }} stack-config: 'server:ci-tests' # recording must be done with server since more tests are run provider: ${{ inputs.test-provider || 'ollama' }} inference-mode: 'record' - run-vision-tests: ${{ inputs.run-vision-tests }} + test-suite: ${{ inputs.test-suite }} + test-subdirs: ${{ inputs.test-subdirs }} + test-pattern: ${{ inputs.test-pattern }} diff --git a/scripts/github/schedule-record-workflow.sh b/scripts/github/schedule-record-workflow.sh index e381b60b6..09e055611 100755 --- a/scripts/github/schedule-record-workflow.sh +++ b/scripts/github/schedule-record-workflow.sh @@ -15,7 +15,7 @@ set -euo pipefail BRANCH="" TEST_SUBDIRS="" TEST_PROVIDER="ollama" -RUN_VISION_TESTS=false +TEST_SUITE="base" TEST_PATTERN="" # Help function @@ -27,9 +27,9 @@ Trigger the integration test recording workflow remotely. This way you do not ne OPTIONS: -b, --branch BRANCH Branch to run the workflow on (defaults to current branch) - -s, --test-subdirs DIRS Comma-separated list of test subdirectories to run (REQUIRED) -p, --test-provider PROVIDER Test provider to use: vllm or ollama (default: ollama) - -v, --run-vision-tests Include vision tests in the recording + -t, --test-suite SUITE Test suite to use: base, responses, vision, etc. (default: base) + -s, --test-subdirs DIRS Comma-separated list of test subdirectories to run (overrides suite) -k, --test-pattern PATTERN Regex pattern to pass to pytest -k -h, --help Show this help message @@ -38,7 +38,7 @@ EXAMPLES: $0 --test-subdirs "agents" # Record tests for specific branch with vision tests - $0 -b my-feature-branch --test-subdirs "inference" --run-vision-tests + $0 -b my-feature-branch --test-suite vision # Record multiple test subdirectories with specific provider $0 --test-subdirs "agents,inference" --test-provider vllm @@ -71,9 +71,9 @@ while [[ $# -gt 0 ]]; do TEST_PROVIDER="$2" shift 2 ;; - -v|--run-vision-tests) - RUN_VISION_TESTS=true - shift + -t|--test-suite) + TEST_SUITE="$2" + shift 2 ;; -k|--test-pattern) TEST_PATTERN="$2" @@ -92,11 +92,11 @@ while [[ $# -gt 0 ]]; do done # Validate required parameters -if [[ -z "$TEST_SUBDIRS" ]]; then - echo "Error: --test-subdirs is required" - echo "Please specify which test subdirectories to run, e.g.:" +if [[ -z "$TEST_SUBDIRS" && -z "$TEST_SUITE" ]]; then + echo "Error: --test-subdirs or --test-suite is required" + echo "Please specify which test subdirectories to run or test suite to use, e.g.:" echo " $0 --test-subdirs \"agents,inference\"" - echo " $0 --test-subdirs \"inference\" --run-vision-tests" + echo " $0 --test-suite vision" echo "" exit 1 fi @@ -239,17 +239,19 @@ echo "Triggering integration test recording workflow..." echo "Branch: $BRANCH" echo "Test provider: $TEST_PROVIDER" echo "Test subdirs: $TEST_SUBDIRS" -echo "Run vision tests: $RUN_VISION_TESTS" +echo "Test suite: $TEST_SUITE" echo "Test pattern: ${TEST_PATTERN:-"(none)"}" echo "" # Prepare inputs for gh workflow run -INPUTS="-f test-subdirs='$TEST_SUBDIRS'" +if [[ -n "$TEST_SUBDIRS" ]]; then + INPUTS="-f test-subdirs='$TEST_SUBDIRS'" +fi if [[ -n "$TEST_PROVIDER" ]]; then INPUTS="$INPUTS -f test-provider='$TEST_PROVIDER'" fi -if [[ "$RUN_VISION_TESTS" == "true" ]]; then - INPUTS="$INPUTS -f run-vision-tests=true" +if [[ -n "$TEST_SUITE" ]]; then + INPUTS="$INPUTS -f test-suite='$TEST_SUITE'" fi if [[ -n "$TEST_PATTERN" ]]; then INPUTS="$INPUTS -f test-pattern='$TEST_PATTERN'" diff --git a/scripts/integration-tests.sh b/scripts/integration-tests.sh index 104ba5cf3..ab7e37579 100755 --- a/scripts/integration-tests.sh +++ b/scripts/integration-tests.sh @@ -16,7 +16,7 @@ STACK_CONFIG="" PROVIDER="" TEST_SUBDIRS="" TEST_PATTERN="" -RUN_VISION_TESTS="false" +TEST_SUITE="base" INFERENCE_MODE="replay" EXTRA_PARAMS="" @@ -28,12 +28,16 @@ Usage: $0 [OPTIONS] Options: --stack-config STRING Stack configuration to use (required) --provider STRING Provider to use (ollama, vllm, etc.) (required) - --test-subdirs STRING Comma-separated list of test subdirectories to run (default: 'inference') - --run-vision-tests Run vision tests instead of regular tests + --test-suite STRING Comma-separated list of test suites to run (default: 'base') --inference-mode STRING Inference mode: record or replay (default: replay) + --test-subdirs STRING Comma-separated list of test subdirectories to run (overrides suite) --test-pattern STRING Regex pattern to pass to pytest -k --help Show this help message +Suites are defined in tests/integration/suites.py. They are used to narrow the collection of tests and provide default model options. + +You can also specify subdirectories (of tests/integration) to select tests from, which will override the suite. + Examples: # Basic inference tests with ollama $0 --stack-config server:ci-tests --provider ollama @@ -42,7 +46,7 @@ Examples: $0 --stack-config server:ci-tests --provider vllm --test-subdirs 'inference,agents' # Vision tests with ollama - $0 --stack-config server:ci-tests --provider ollama --run-vision-tests + $0 --stack-config server:ci-tests --provider ollama --test-suite vision # Record mode for updating test recordings $0 --stack-config server:ci-tests --provider ollama --inference-mode record @@ -64,9 +68,9 @@ while [[ $# -gt 0 ]]; do TEST_SUBDIRS="$2" shift 2 ;; - --run-vision-tests) - RUN_VISION_TESTS="true" - shift + --test-suite) + TEST_SUITE="$2" + shift 2 ;; --inference-mode) INFERENCE_MODE="$2" @@ -92,22 +96,25 @@ done # Validate required parameters if [[ -z "$STACK_CONFIG" ]]; then echo "Error: --stack-config is required" - usage exit 1 fi if [[ -z "$PROVIDER" ]]; then echo "Error: --provider is required" - usage + exit 1 +fi + +if [[ -z "$TEST_SUITE" && -z "$TEST_SUBDIRS" ]]; then + echo "Error: --test-suite or --test-subdirs is required" exit 1 fi echo "=== Llama Stack Integration Test Runner ===" echo "Stack Config: $STACK_CONFIG" echo "Provider: $PROVIDER" -echo "Test Subdirs: $TEST_SUBDIRS" -echo "Vision Tests: $RUN_VISION_TESTS" echo "Inference Mode: $INFERENCE_MODE" +echo "Test Suite: $TEST_SUITE" +echo "Test Subdirs: $TEST_SUBDIRS" echo "Test Pattern: $TEST_PATTERN" echo "" @@ -194,84 +201,46 @@ if [[ -n "$TEST_PATTERN" ]]; then PYTEST_PATTERN="${PYTEST_PATTERN} and $TEST_PATTERN" fi -# Run vision tests if specified -if [[ "$RUN_VISION_TESTS" == "true" ]]; then - echo "Running vision tests..." - set +e - pytest -s -v tests/integration/inference/test_vision_inference.py \ - --stack-config="$STACK_CONFIG" \ - -k "$PYTEST_PATTERN" \ - --vision-model=ollama/llama3.2-vision:11b \ - --embedding-model=sentence-transformers/all-MiniLM-L6-v2 \ - --color=yes $EXTRA_PARAMS \ - --capture=tee-sys - exit_code=$? - set -e - - if [ $exit_code -eq 0 ]; then - echo "✅ Vision tests completed successfully" - elif [ $exit_code -eq 5 ]; then - echo "⚠️ No vision tests collected (pattern matched no tests)" - else - echo "❌ Vision tests failed" - exit 1 - fi - exit 0 -fi - -# Run regular tests -if [[ -z "$TEST_SUBDIRS" ]]; then - TEST_SUBDIRS=$(find tests/integration -maxdepth 1 -mindepth 1 -type d | - sed 's|tests/integration/||' | - grep -Ev "^(__pycache__|fixtures|test_cases|recordings|non_ci|post_training)$" | - sort) -fi echo "Test subdirs to run: $TEST_SUBDIRS" -# Collect all test files for the specified test types -TEST_FILES="" -for test_subdir in $(echo "$TEST_SUBDIRS" | tr ',' '\n'); do - # Skip certain test types for vllm provider - if [[ "$PROVIDER" == "vllm" ]]; then - if [[ "$test_subdir" == "safety" ]] || [[ "$test_subdir" == "post_training" ]] || [[ "$test_subdir" == "tool_runtime" ]]; then - echo "Skipping $test_subdir for vllm provider" - continue +if [[ -n "$TEST_SUBDIRS" ]]; then + # Collect all test files for the specified test types + TEST_FILES="" + for test_subdir in $(echo "$TEST_SUBDIRS" | tr ',' '\n'); do + if [[ -d "tests/integration/$test_subdir" ]]; then + # Find all Python test files in this directory + test_files=$(find tests/integration/$test_subdir -name "test_*.py" -o -name "*_test.py") + if [[ -n "$test_files" ]]; then + TEST_FILES="$TEST_FILES $test_files" + echo "Added test files from $test_subdir: $(echo $test_files | wc -w) files" + fi + else + echo "Warning: Directory tests/integration/$test_subdir does not exist" fi + done + + if [[ -z "$TEST_FILES" ]]; then + echo "No test files found for the specified test types" + exit 1 fi - if [[ "$STACK_CONFIG" != *"server:"* ]] && [[ "$test_subdir" == "batches" ]]; then - echo "Skipping $test_subdir for library client until types are supported" - continue - fi + echo "" + echo "=== Running all collected tests in a single pytest command ===" + echo "Total test files: $(echo $TEST_FILES | wc -w)" - if [[ -d "tests/integration/$test_subdir" ]]; then - # Find all Python test files in this directory - test_files=$(find tests/integration/$test_subdir -name "test_*.py" -o -name "*_test.py") - if [[ -n "$test_files" ]]; then - TEST_FILES="$TEST_FILES $test_files" - echo "Added test files from $test_subdir: $(echo $test_files | wc -w) files" - fi - else - echo "Warning: Directory tests/integration/$test_subdir does not exist" - fi -done - -if [[ -z "$TEST_FILES" ]]; then - echo "No test files found for the specified test types" - exit 1 + PYTEST_TARGET="$TEST_FILES" + EXTRA_PARAMS="$EXTRA_PARAMS --text-model=$TEXT_MODEL --embedding-model=sentence-transformers/all-MiniLM-L6-v2" +else + PYTEST_TARGET="tests/integration/" + EXTRA_PARAMS="$EXTRA_PARAMS --suite=$TEST_SUITE" fi -echo "" -echo "=== Running all collected tests in a single pytest command ===" -echo "Total test files: $(echo $TEST_FILES | wc -w)" - set +e -pytest -s -v $TEST_FILES \ +pytest -s -v $PYTEST_TARGET \ --stack-config="$STACK_CONFIG" \ -k "$PYTEST_PATTERN" \ - --text-model="$TEXT_MODEL" \ - --embedding-model=sentence-transformers/all-MiniLM-L6-v2 \ - --color=yes $EXTRA_PARAMS \ + $EXTRA_PARAMS \ + --color=yes \ --capture=tee-sys exit_code=$? set -e @@ -294,7 +263,13 @@ df -h # stop server if [[ "$STACK_CONFIG" == *"server:"* ]]; then echo "Stopping Llama Stack Server..." - kill $(lsof -i :8321 | awk 'NR>1 {print $2}') + pids=$(lsof -i :8321 | awk 'NR>1 {print $2}') + if [[ -n "$pids" ]]; then + echo "Killing Llama Stack Server processes: $pids" + kill -9 $pids + else + echo "No Llama Stack Server processes found ?!" + fi echo "Llama Stack Server stopped" fi diff --git a/tests/README.md b/tests/README.md index 81f025f86..c00829d3e 100644 --- a/tests/README.md +++ b/tests/README.md @@ -77,7 +77,7 @@ You must be careful when re-recording. CI workflows assume a specific setup for ./scripts/github/schedule-record-workflow.sh --test-subdirs "agents,inference" # Record with vision tests enabled -./scripts/github/schedule-record-workflow.sh --test-subdirs "inference" --run-vision-tests +./scripts/github/schedule-record-workflow.sh --test-suite vision # Record with specific provider ./scripts/github/schedule-record-workflow.sh --test-subdirs "agents" --test-provider vllm diff --git a/tests/integration/README.md b/tests/integration/README.md index d177cbebf..b05beeb98 100644 --- a/tests/integration/README.md +++ b/tests/integration/README.md @@ -42,6 +42,27 @@ Model parameters can be influenced by the following options: Each of these are comma-separated lists and can be used to generate multiple parameter combinations. Note that tests will be skipped if no model is specified. +### Suites (fast selection + sane defaults) + +- `--suite`: comma-separated list of named suites that both narrow which tests are collected and prefill common model options (unless you pass them explicitly). +- Available suites: + - `responses`: collects tests under `tests/integration/responses`; this is a separate suite because it needs a strong tool-calling model. + - `vision`: collects only `tests/integration/inference/test_vision_inference.py`; defaults `--vision-model=ollama/llama3.2-vision:11b`, `--embedding-model=sentence-transformers/all-MiniLM-L6-v2`. +- Explicit flags always win. For example, `--suite=responses --text-model=` overrides the suite’s text model. + +Examples: + +```bash +# Fast responses run with defaults +pytest -s -v tests/integration --stack-config=server:starter --suite=responses + +# Fast single-file vision run with defaults +pytest -s -v tests/integration --stack-config=server:starter --suite=vision + +# Combine suites and override a default +pytest -s -v tests/integration --stack-config=server:starter --suite=responses,vision --embedding-model=text-embedding-3-small +``` + ## Examples ### Testing against a Server diff --git a/tests/integration/conftest.py b/tests/integration/conftest.py index fd9a54d04..96260fdb7 100644 --- a/tests/integration/conftest.py +++ b/tests/integration/conftest.py @@ -6,15 +6,17 @@ import inspect import itertools import os -import platform import textwrap import time +from pathlib import Path import pytest from dotenv import load_dotenv from llama_stack.log import get_logger +from .suites import SUITE_DEFINITIONS + logger = get_logger(__name__, category="tests") @@ -61,9 +63,22 @@ def pytest_configure(config): key, value = env_var.split("=", 1) os.environ[key] = value - if platform.system() == "Darwin": # Darwin is the system name for macOS - os.environ["DISABLE_CODE_SANDBOX"] = "1" - logger.info("Setting DISABLE_CODE_SANDBOX=1 for macOS") + suites_raw = config.getoption("--suite") + suites: list[str] = [] + if suites_raw: + suites = [p.strip() for p in str(suites_raw).split(",") if p.strip()] + unknown = [p for p in suites if p not in SUITE_DEFINITIONS] + if unknown: + raise pytest.UsageError( + f"Unknown suite(s): {', '.join(unknown)}. Available: {', '.join(sorted(SUITE_DEFINITIONS.keys()))}" + ) + for suite in suites: + suite_def = SUITE_DEFINITIONS.get(suite, {}) + defaults: dict = suite_def.get("defaults", {}) + for dest, value in defaults.items(): + current = getattr(config.option, dest, None) + if not current: + setattr(config.option, dest, value) def pytest_addoption(parser): @@ -105,16 +120,21 @@ def pytest_addoption(parser): default=384, help="Output dimensionality of the embedding model to use for testing. Default: 384", ) - parser.addoption( - "--record-responses", - action="store_true", - help="Record new API responses instead of using cached ones.", - ) parser.addoption( "--report", help="Path where the test report should be written, e.g. --report=/path/to/report.md", ) + available_suites = ", ".join(sorted(SUITE_DEFINITIONS.keys())) + suite_help = ( + "Comma-separated integration test suites to narrow collection and prefill defaults. " + "Available: " + f"{available_suites}. " + "Explicit CLI flags (e.g., --text-model) override suite defaults. " + "Examples: --suite=responses or --suite=responses,vision." + ) + parser.addoption("--suite", help=suite_help) + MODEL_SHORT_IDS = { "meta-llama/Llama-3.2-3B-Instruct": "3B", @@ -197,3 +217,40 @@ def pytest_generate_tests(metafunc): pytest_plugins = ["tests.integration.fixtures.common"] + + +def pytest_ignore_collect(path: str, config: pytest.Config) -> bool: + """Skip collecting paths outside the selected suite roots for speed.""" + suites_raw = config.getoption("--suite") + if not suites_raw: + return False + + names = [p.strip() for p in str(suites_raw).split(",") if p.strip()] + roots: list[str] = [] + for name in names: + suite_def = SUITE_DEFINITIONS.get(name) + if suite_def: + roots.extend(suite_def.get("roots", [])) + if not roots: + return False + + p = Path(str(path)).resolve() + + # Only constrain within tests/integration to avoid ignoring unrelated tests + integration_root = (Path(str(config.rootpath)) / "tests" / "integration").resolve() + if not p.is_relative_to(integration_root): + return False + + for r in roots: + rp = (Path(str(config.rootpath)) / r).resolve() + if rp.is_file(): + # Allow the exact file and any ancestor directories so pytest can walk into it. + if p == rp: + return False + if p.is_dir() and rp.is_relative_to(p): + return False + else: + # Allow anything inside an allowed directory + if p.is_relative_to(rp): + return False + return True diff --git a/tests/integration/non_ci/responses/__init__.py b/tests/integration/responses/__init__.py similarity index 100% rename from tests/integration/non_ci/responses/__init__.py rename to tests/integration/responses/__init__.py diff --git a/tests/integration/non_ci/responses/fixtures/__init__.py b/tests/integration/responses/fixtures/__init__.py similarity index 100% rename from tests/integration/non_ci/responses/fixtures/__init__.py rename to tests/integration/responses/fixtures/__init__.py diff --git a/tests/integration/non_ci/responses/fixtures/fixtures.py b/tests/integration/responses/fixtures/fixtures.py similarity index 100% rename from tests/integration/non_ci/responses/fixtures/fixtures.py rename to tests/integration/responses/fixtures/fixtures.py diff --git a/tests/integration/non_ci/responses/fixtures/images/vision_test_1.jpg b/tests/integration/responses/fixtures/images/vision_test_1.jpg similarity index 100% rename from tests/integration/non_ci/responses/fixtures/images/vision_test_1.jpg rename to tests/integration/responses/fixtures/images/vision_test_1.jpg diff --git a/tests/integration/non_ci/responses/fixtures/images/vision_test_2.jpg b/tests/integration/responses/fixtures/images/vision_test_2.jpg similarity index 100% rename from tests/integration/non_ci/responses/fixtures/images/vision_test_2.jpg rename to tests/integration/responses/fixtures/images/vision_test_2.jpg diff --git a/tests/integration/non_ci/responses/fixtures/images/vision_test_3.jpg b/tests/integration/responses/fixtures/images/vision_test_3.jpg similarity index 100% rename from tests/integration/non_ci/responses/fixtures/images/vision_test_3.jpg rename to tests/integration/responses/fixtures/images/vision_test_3.jpg diff --git a/tests/integration/non_ci/responses/fixtures/pdfs/llama_stack_and_models.pdf b/tests/integration/responses/fixtures/pdfs/llama_stack_and_models.pdf similarity index 100% rename from tests/integration/non_ci/responses/fixtures/pdfs/llama_stack_and_models.pdf rename to tests/integration/responses/fixtures/pdfs/llama_stack_and_models.pdf diff --git a/tests/integration/non_ci/responses/fixtures/test_cases.py b/tests/integration/responses/fixtures/test_cases.py similarity index 100% rename from tests/integration/non_ci/responses/fixtures/test_cases.py rename to tests/integration/responses/fixtures/test_cases.py diff --git a/tests/integration/non_ci/responses/helpers.py b/tests/integration/responses/helpers.py similarity index 100% rename from tests/integration/non_ci/responses/helpers.py rename to tests/integration/responses/helpers.py diff --git a/tests/integration/non_ci/responses/streaming_assertions.py b/tests/integration/responses/streaming_assertions.py similarity index 100% rename from tests/integration/non_ci/responses/streaming_assertions.py rename to tests/integration/responses/streaming_assertions.py diff --git a/tests/integration/non_ci/responses/test_basic_responses.py b/tests/integration/responses/test_basic_responses.py similarity index 100% rename from tests/integration/non_ci/responses/test_basic_responses.py rename to tests/integration/responses/test_basic_responses.py diff --git a/tests/integration/non_ci/responses/test_file_search.py b/tests/integration/responses/test_file_search.py similarity index 100% rename from tests/integration/non_ci/responses/test_file_search.py rename to tests/integration/responses/test_file_search.py diff --git a/tests/integration/non_ci/responses/test_tool_responses.py b/tests/integration/responses/test_tool_responses.py similarity index 100% rename from tests/integration/non_ci/responses/test_tool_responses.py rename to tests/integration/responses/test_tool_responses.py diff --git a/tests/integration/suites.py b/tests/integration/suites.py new file mode 100644 index 000000000..602855055 --- /dev/null +++ b/tests/integration/suites.py @@ -0,0 +1,53 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +# Central definition of integration test suites. You can use these suites by passing --suite=name to pytest. +# For example: +# +# ```bash +# pytest tests/integration/ --suite=vision +# ``` +# +# Each suite can: +# - restrict collection to specific roots (dirs or files) +# - provide default CLI option values (e.g. text_model, embedding_model, etc.) + +from pathlib import Path + +this_dir = Path(__file__).parent +default_roots = [ + str(p) + for p in this_dir.glob("*") + if p.is_dir() + and p.name not in ("__pycache__", "fixtures", "test_cases", "recordings", "responses", "post_training") +] + +SUITE_DEFINITIONS: dict[str, dict] = { + "base": { + "description": "Base suite that includes most tests but runs them with a text Ollama model", + "roots": default_roots, + "defaults": { + "text_model": "ollama/llama3.2:3b-instruct-fp16", + "embedding_model": "sentence-transformers/all-MiniLM-L6-v2", + }, + }, + "responses": { + "description": "Suite that includes only the OpenAI Responses tests; needs a strong tool-calling model", + "roots": ["tests/integration/responses"], + "defaults": { + "text_model": "openai/gpt-4o", + "embedding_model": "sentence-transformers/all-MiniLM-L6-v2", + }, + }, + "vision": { + "description": "Suite that includes only the vision tests", + "roots": ["tests/integration/inference/test_vision_inference.py"], + "defaults": { + "vision_model": "ollama/llama3.2-vision:11b", + "embedding_model": "sentence-transformers/all-MiniLM-L6-v2", + }, + }, +}