diff --git a/src/llama_stack/core/routers/vector_io.py b/src/llama_stack/core/routers/vector_io.py index 5256dda44..64039c870 100644 --- a/src/llama_stack/core/routers/vector_io.py +++ b/src/llama_stack/core/routers/vector_io.py @@ -5,6 +5,7 @@ # the root directory of this source tree. import asyncio +import re import uuid from typing import Annotated, Any @@ -44,6 +45,17 @@ from llama_stack_api import ( logger = get_logger(name=__name__, category="core::routers") +def validate_collection_name(collection_name: str) -> None: + if not collection_name: + raise ValueError("collection_name cannot be empty") + + if not re.match(r"^[a-zA-Z0-9_-]+$", collection_name): + raise ValueError( + f"collection_name '{collection_name}' contains invalid characters. " + "Only alphanumeric characters, hyphens, and underscores are allowed." + ) + + class VectorIORouter(VectorIO): """Routes to an provider based on the vector db identifier""" @@ -160,13 +172,25 @@ class VectorIORouter(VectorIO): else: provider_id = list(self.routing_table.impls_by_provider_id.keys())[0] + # Extract and validate collection_name if provided + collection_name = extra.get("collection_name") + if collection_name: + validate_collection_name(collection_name) + provider_vector_store_id = collection_name + logger.debug(f"Using custom collection name: {collection_name}") + else: + # Fall back to auto-generated UUID for backward compatibility + provider_vector_store_id = f"vs_{uuid.uuid4()}" + + # Always generate a unique vector_store_id for internal routing vector_store_id = f"vs_{uuid.uuid4()}" + registered_vector_store = await self.routing_table.register_vector_store( vector_store_id=vector_store_id, embedding_model=embedding_model, embedding_dimension=embedding_dimension, provider_id=provider_id, - provider_vector_store_id=vector_store_id, + provider_vector_store_id=provider_vector_store_id, vector_store_name=params.name, ) provider = await self.routing_table.get_provider_impl(registered_vector_store.identifier) @@ -174,8 +198,14 @@ class VectorIORouter(VectorIO): # Update model_extra with registered values so provider uses the already-registered vector_store if params.model_extra is None: params.model_extra = {} + params.model_extra["vector_store_id"] = vector_store_id # Pass canonical UUID to Provider params.model_extra["provider_vector_store_id"] = registered_vector_store.provider_resource_id params.model_extra["provider_id"] = registered_vector_store.provider_id + + # Add collection_name to metadata so users can see what was used + if params.metadata is None: + params.metadata = {} + params.metadata["provider_vector_store_id"] = provider_vector_store_id if embedding_model is not None: params.model_extra["embedding_model"] = embedding_model if embedding_dimension is not None: diff --git a/src/llama_stack/providers/inline/vector_io/faiss/faiss.py b/src/llama_stack/providers/inline/vector_io/faiss/faiss.py index 91a17058b..2b3665fc4 100644 --- a/src/llama_stack/providers/inline/vector_io/faiss/faiss.py +++ b/src/llama_stack/providers/inline/vector_io/faiss/faiss.py @@ -201,7 +201,11 @@ class FaissVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorStoresProtoco vector_store = VectorStore.model_validate_json(vector_store_data) index = VectorStoreWithIndex( vector_store, - await FaissIndex.create(vector_store.embedding_dimension, self.kvstore, vector_store.identifier), + await FaissIndex.create( + vector_store.embedding_dimension, + self.kvstore, + vector_store.provider_resource_id or vector_store.identifier, + ), self.inference_api, ) self.cache[vector_store.identifier] = index @@ -239,7 +243,11 @@ class FaissVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorStoresProtoco # Store in cache self.cache[vector_store.identifier] = VectorStoreWithIndex( vector_store=vector_store, - index=await FaissIndex.create(vector_store.embedding_dimension, self.kvstore, vector_store.identifier), + index=await FaissIndex.create( + vector_store.embedding_dimension, + self.kvstore, + vector_store.provider_resource_id or vector_store.identifier, + ), inference_api=self.inference_api, ) @@ -272,7 +280,11 @@ class FaissVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorStoresProtoco vector_store = VectorStore.model_validate_json(vector_store_data) index = VectorStoreWithIndex( vector_store=vector_store, - index=await FaissIndex.create(vector_store.embedding_dimension, self.kvstore, vector_store.identifier), + index=await FaissIndex.create( + vector_store.embedding_dimension, + self.kvstore, + vector_store.provider_resource_id or vector_store.identifier, + ), inference_api=self.inference_api, ) self.cache[vector_store_id] = index diff --git a/src/llama_stack/providers/inline/vector_io/sqlite_vec/sqlite_vec.py b/src/llama_stack/providers/inline/vector_io/sqlite_vec/sqlite_vec.py index a384a33dc..73451936f 100644 --- a/src/llama_stack/providers/inline/vector_io/sqlite_vec/sqlite_vec.py +++ b/src/llama_stack/providers/inline/vector_io/sqlite_vec/sqlite_vec.py @@ -401,7 +401,9 @@ class SQLiteVecVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorStoresPro for db_json in stored_vector_stores: vector_store = VectorStore.model_validate_json(db_json) index = await SQLiteVecIndex.create( - vector_store.embedding_dimension, self.config.db_path, vector_store.identifier + vector_store.embedding_dimension, + self.config.db_path, + vector_store.provider_resource_id or vector_store.identifier, ) self.cache[vector_store.identifier] = VectorStoreWithIndex(vector_store, index, self.inference_api) @@ -425,7 +427,9 @@ class SQLiteVecVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorStoresPro # Create and cache the index index = await SQLiteVecIndex.create( - vector_store.embedding_dimension, self.config.db_path, vector_store.identifier + vector_store.embedding_dimension, + self.config.db_path, + vector_store.provider_resource_id or vector_store.identifier, ) self.cache[vector_store.identifier] = VectorStoreWithIndex(vector_store, index, self.inference_api) @@ -448,7 +452,7 @@ class SQLiteVecVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorStoresPro index=SQLiteVecIndex( dimension=vector_store.embedding_dimension, db_path=self.config.db_path, - bank_id=vector_store.identifier, + bank_id=vector_store.provider_resource_id or vector_store.identifier, kvstore=self.kvstore, ), inference_api=self.inference_api, diff --git a/src/llama_stack/providers/utils/memory/openai_vector_store_mixin.py b/src/llama_stack/providers/utils/memory/openai_vector_store_mixin.py index bbfd60e25..3848ec01c 100644 --- a/src/llama_stack/providers/utils/memory/openai_vector_store_mixin.py +++ b/src/llama_stack/providers/utils/memory/openai_vector_store_mixin.py @@ -360,7 +360,11 @@ class OpenAIVectorStoreMixin(ABC): extra_body = params.model_extra or {} metadata = params.metadata or {} - provider_vector_store_id = extra_body.get("provider_vector_store_id") + # Get the canonical UUID from router (or generate if called directly without router) + vector_store_id = extra_body.get("vector_store_id") or generate_object_id("vector_store", lambda: f"vs_{uuid.uuid4()}") + + # Get the physical storage name (custom collection name or fallback to UUID) + provider_vector_store_id = extra_body.get("provider_vector_store_id") or vector_store_id # Use embedding info from metadata if available, otherwise from extra_body if metadata.get("embedding_model"): @@ -381,8 +385,6 @@ class OpenAIVectorStoreMixin(ABC): # use provider_id set by router; fallback to provider's own ID when used directly via --stack-config provider_id = extra_body.get("provider_id") or getattr(self, "__provider_id__", None) - # Derive the canonical vector_store_id (allow override, else generate) - vector_store_id = provider_vector_store_id or generate_object_id("vector_store", lambda: f"vs_{uuid.uuid4()}") if embedding_model is None: raise ValueError("embedding_model is required") @@ -396,11 +398,11 @@ class OpenAIVectorStoreMixin(ABC): # call to the provider to create any index, etc. vector_store = VectorStore( - identifier=vector_store_id, + identifier=vector_store_id, # Canonical UUID for routing embedding_dimension=embedding_dimension, embedding_model=embedding_model, provider_id=provider_id, - provider_resource_id=vector_store_id, + provider_resource_id=provider_vector_store_id, # Physical storage name (custom or UUID) vector_store_name=params.name, ) await self.register_vector_store(vector_store) diff --git a/tests/integration/vector_io/test_openai_vector_stores.py b/tests/integration/vector_io/test_openai_vector_stores.py index 102f3f00c..7125b246c 100644 --- a/tests/integration/vector_io/test_openai_vector_stores.py +++ b/tests/integration/vector_io/test_openai_vector_stores.py @@ -1698,3 +1698,83 @@ def test_openai_vector_store_file_contents_with_extra_query( assert with_flags_embedding is not None, "Embeddings should be included when include_embeddings=True" assert len(with_flags_embedding) > 0, "Embedding should be a non-empty list" assert without_flags_embedding is None, "Embeddings should not be included when include_embeddings=False" + + +@vector_provider_wrapper +def test_openai_vector_store_custom_collection_name( + compat_client_with_empty_stores, client_with_models, embedding_model_id, embedding_dimension, vector_io_provider_id +): + """Test creating a vector store with a custom collection name.""" + skip_if_provider_doesnt_support_openai_vector_stores(client_with_models) + client = compat_client_with_empty_stores + + # Create vector store with custom collection name + vector_store = client.vector_stores.create( + name="Test Custom Collection", + extra_body={ + "embedding_model": embedding_model_id, + "provider_id": vector_io_provider_id, + "collection_name": "my_custom_collection", + }, + ) + + assert vector_store is not None + assert vector_store.id.startswith("vs_") + assert "provider_vector_store_id" in vector_store.metadata + assert vector_store.metadata["provider_vector_store_id"] == "my_custom_collection" + + +@vector_provider_wrapper +def test_openai_vector_store_collection_name_validation( + compat_client_with_empty_stores, client_with_models, embedding_model_id, embedding_dimension, vector_io_provider_id +): + """Test that invalid collection names are rejected.""" + skip_if_provider_doesnt_support_openai_vector_stores(client_with_models) + client = compat_client_with_empty_stores + + # Test invalid collection names + invalid_names = ["with spaces", "with/slashes", "with@special", ""] + + for invalid_name in invalid_names: + with pytest.raises((BadRequestError, ValueError)): + client.vector_stores.create( + name="Test Invalid", + extra_body={ + "embedding_model": embedding_model_id, + "provider_id": vector_io_provider_id, + "collection_name": invalid_name, + }, + ) + + +@vector_provider_wrapper +def test_openai_vector_store_collection_name_with_data( + compat_client_with_empty_stores, client_with_models, sample_chunks, embedding_model_id, embedding_dimension, vector_io_provider_id +): + """Test that custom collection names work with data insertion and search.""" + skip_if_provider_doesnt_support_openai_vector_stores(client_with_models) + compat_client = compat_client_with_empty_stores + llama_client = client_with_models + + # Create vector store with custom collection name + vector_store = compat_client.vector_stores.create( + name="Data Test Collection", + extra_body={ + "embedding_model": embedding_model_id, + "provider_id": vector_io_provider_id, + "collection_name": "test_data_collection", + }, + ) + + # Insert and search data + llama_client.vector_io.insert(vector_store_id=vector_store.id, chunks=sample_chunks[:2]) + + search_response = compat_client.vector_stores.search( + vector_store_id=vector_store.id, + query="What is Python?", + max_num_results=2, + ) + + assert search_response is not None + assert len(search_response.data) > 0 + assert search_response.data[0].attributes["document_id"] == "doc1"