feat(vector_io): add custom collection names support for vector stores (#4135)

This commit is contained in:
r-bit-rry 2025-11-20 19:28:33 +02:00
parent 91f1b352b4
commit 6e6ddd3c69
5 changed files with 140 additions and 12 deletions

View file

@ -5,6 +5,7 @@
# the root directory of this source tree. # the root directory of this source tree.
import asyncio import asyncio
import re
import uuid import uuid
from typing import Annotated, Any from typing import Annotated, Any
@ -44,6 +45,17 @@ from llama_stack_api import (
logger = get_logger(name=__name__, category="core::routers") logger = get_logger(name=__name__, category="core::routers")
def validate_collection_name(collection_name: str) -> None:
if not collection_name:
raise ValueError("collection_name cannot be empty")
if not re.match(r"^[a-zA-Z0-9_-]+$", collection_name):
raise ValueError(
f"collection_name '{collection_name}' contains invalid characters. "
"Only alphanumeric characters, hyphens, and underscores are allowed."
)
class VectorIORouter(VectorIO): class VectorIORouter(VectorIO):
"""Routes to an provider based on the vector db identifier""" """Routes to an provider based on the vector db identifier"""
@ -160,13 +172,25 @@ class VectorIORouter(VectorIO):
else: else:
provider_id = list(self.routing_table.impls_by_provider_id.keys())[0] provider_id = list(self.routing_table.impls_by_provider_id.keys())[0]
# Extract and validate collection_name if provided
collection_name = extra.get("collection_name")
if collection_name:
validate_collection_name(collection_name)
provider_vector_store_id = collection_name
logger.debug(f"Using custom collection name: {collection_name}")
else:
# Fall back to auto-generated UUID for backward compatibility
provider_vector_store_id = f"vs_{uuid.uuid4()}"
# Always generate a unique vector_store_id for internal routing
vector_store_id = f"vs_{uuid.uuid4()}" vector_store_id = f"vs_{uuid.uuid4()}"
registered_vector_store = await self.routing_table.register_vector_store( registered_vector_store = await self.routing_table.register_vector_store(
vector_store_id=vector_store_id, vector_store_id=vector_store_id,
embedding_model=embedding_model, embedding_model=embedding_model,
embedding_dimension=embedding_dimension, embedding_dimension=embedding_dimension,
provider_id=provider_id, provider_id=provider_id,
provider_vector_store_id=vector_store_id, provider_vector_store_id=provider_vector_store_id,
vector_store_name=params.name, vector_store_name=params.name,
) )
provider = await self.routing_table.get_provider_impl(registered_vector_store.identifier) provider = await self.routing_table.get_provider_impl(registered_vector_store.identifier)
@ -174,8 +198,14 @@ class VectorIORouter(VectorIO):
# Update model_extra with registered values so provider uses the already-registered vector_store # Update model_extra with registered values so provider uses the already-registered vector_store
if params.model_extra is None: if params.model_extra is None:
params.model_extra = {} params.model_extra = {}
params.model_extra["vector_store_id"] = vector_store_id # Pass canonical UUID to Provider
params.model_extra["provider_vector_store_id"] = registered_vector_store.provider_resource_id params.model_extra["provider_vector_store_id"] = registered_vector_store.provider_resource_id
params.model_extra["provider_id"] = registered_vector_store.provider_id params.model_extra["provider_id"] = registered_vector_store.provider_id
# Add collection_name to metadata so users can see what was used
if params.metadata is None:
params.metadata = {}
params.metadata["provider_vector_store_id"] = provider_vector_store_id
if embedding_model is not None: if embedding_model is not None:
params.model_extra["embedding_model"] = embedding_model params.model_extra["embedding_model"] = embedding_model
if embedding_dimension is not None: if embedding_dimension is not None:

View file

@ -201,7 +201,11 @@ class FaissVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorStoresProtoco
vector_store = VectorStore.model_validate_json(vector_store_data) vector_store = VectorStore.model_validate_json(vector_store_data)
index = VectorStoreWithIndex( index = VectorStoreWithIndex(
vector_store, vector_store,
await FaissIndex.create(vector_store.embedding_dimension, self.kvstore, vector_store.identifier), await FaissIndex.create(
vector_store.embedding_dimension,
self.kvstore,
vector_store.provider_resource_id or vector_store.identifier,
),
self.inference_api, self.inference_api,
) )
self.cache[vector_store.identifier] = index self.cache[vector_store.identifier] = index
@ -239,7 +243,11 @@ class FaissVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorStoresProtoco
# Store in cache # Store in cache
self.cache[vector_store.identifier] = VectorStoreWithIndex( self.cache[vector_store.identifier] = VectorStoreWithIndex(
vector_store=vector_store, vector_store=vector_store,
index=await FaissIndex.create(vector_store.embedding_dimension, self.kvstore, vector_store.identifier), index=await FaissIndex.create(
vector_store.embedding_dimension,
self.kvstore,
vector_store.provider_resource_id or vector_store.identifier,
),
inference_api=self.inference_api, inference_api=self.inference_api,
) )
@ -272,7 +280,11 @@ class FaissVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorStoresProtoco
vector_store = VectorStore.model_validate_json(vector_store_data) vector_store = VectorStore.model_validate_json(vector_store_data)
index = VectorStoreWithIndex( index = VectorStoreWithIndex(
vector_store=vector_store, vector_store=vector_store,
index=await FaissIndex.create(vector_store.embedding_dimension, self.kvstore, vector_store.identifier), index=await FaissIndex.create(
vector_store.embedding_dimension,
self.kvstore,
vector_store.provider_resource_id or vector_store.identifier,
),
inference_api=self.inference_api, inference_api=self.inference_api,
) )
self.cache[vector_store_id] = index self.cache[vector_store_id] = index

View file

@ -401,7 +401,9 @@ class SQLiteVecVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorStoresPro
for db_json in stored_vector_stores: for db_json in stored_vector_stores:
vector_store = VectorStore.model_validate_json(db_json) vector_store = VectorStore.model_validate_json(db_json)
index = await SQLiteVecIndex.create( index = await SQLiteVecIndex.create(
vector_store.embedding_dimension, self.config.db_path, vector_store.identifier vector_store.embedding_dimension,
self.config.db_path,
vector_store.provider_resource_id or vector_store.identifier,
) )
self.cache[vector_store.identifier] = VectorStoreWithIndex(vector_store, index, self.inference_api) self.cache[vector_store.identifier] = VectorStoreWithIndex(vector_store, index, self.inference_api)
@ -425,7 +427,9 @@ class SQLiteVecVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorStoresPro
# Create and cache the index # Create and cache the index
index = await SQLiteVecIndex.create( index = await SQLiteVecIndex.create(
vector_store.embedding_dimension, self.config.db_path, vector_store.identifier vector_store.embedding_dimension,
self.config.db_path,
vector_store.provider_resource_id or vector_store.identifier,
) )
self.cache[vector_store.identifier] = VectorStoreWithIndex(vector_store, index, self.inference_api) self.cache[vector_store.identifier] = VectorStoreWithIndex(vector_store, index, self.inference_api)
@ -448,7 +452,7 @@ class SQLiteVecVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorStoresPro
index=SQLiteVecIndex( index=SQLiteVecIndex(
dimension=vector_store.embedding_dimension, dimension=vector_store.embedding_dimension,
db_path=self.config.db_path, db_path=self.config.db_path,
bank_id=vector_store.identifier, bank_id=vector_store.provider_resource_id or vector_store.identifier,
kvstore=self.kvstore, kvstore=self.kvstore,
), ),
inference_api=self.inference_api, inference_api=self.inference_api,

View file

@ -360,7 +360,11 @@ class OpenAIVectorStoreMixin(ABC):
extra_body = params.model_extra or {} extra_body = params.model_extra or {}
metadata = params.metadata or {} metadata = params.metadata or {}
provider_vector_store_id = extra_body.get("provider_vector_store_id") # Get the canonical UUID from router (or generate if called directly without router)
vector_store_id = extra_body.get("vector_store_id") or generate_object_id("vector_store", lambda: f"vs_{uuid.uuid4()}")
# Get the physical storage name (custom collection name or fallback to UUID)
provider_vector_store_id = extra_body.get("provider_vector_store_id") or vector_store_id
# Use embedding info from metadata if available, otherwise from extra_body # Use embedding info from metadata if available, otherwise from extra_body
if metadata.get("embedding_model"): if metadata.get("embedding_model"):
@ -381,8 +385,6 @@ class OpenAIVectorStoreMixin(ABC):
# use provider_id set by router; fallback to provider's own ID when used directly via --stack-config # use provider_id set by router; fallback to provider's own ID when used directly via --stack-config
provider_id = extra_body.get("provider_id") or getattr(self, "__provider_id__", None) provider_id = extra_body.get("provider_id") or getattr(self, "__provider_id__", None)
# Derive the canonical vector_store_id (allow override, else generate)
vector_store_id = provider_vector_store_id or generate_object_id("vector_store", lambda: f"vs_{uuid.uuid4()}")
if embedding_model is None: if embedding_model is None:
raise ValueError("embedding_model is required") raise ValueError("embedding_model is required")
@ -396,11 +398,11 @@ class OpenAIVectorStoreMixin(ABC):
# call to the provider to create any index, etc. # call to the provider to create any index, etc.
vector_store = VectorStore( vector_store = VectorStore(
identifier=vector_store_id, identifier=vector_store_id, # Canonical UUID for routing
embedding_dimension=embedding_dimension, embedding_dimension=embedding_dimension,
embedding_model=embedding_model, embedding_model=embedding_model,
provider_id=provider_id, provider_id=provider_id,
provider_resource_id=vector_store_id, provider_resource_id=provider_vector_store_id, # Physical storage name (custom or UUID)
vector_store_name=params.name, vector_store_name=params.name,
) )
await self.register_vector_store(vector_store) await self.register_vector_store(vector_store)

View file

@ -1698,3 +1698,83 @@ def test_openai_vector_store_file_contents_with_extra_query(
assert with_flags_embedding is not None, "Embeddings should be included when include_embeddings=True" assert with_flags_embedding is not None, "Embeddings should be included when include_embeddings=True"
assert len(with_flags_embedding) > 0, "Embedding should be a non-empty list" assert len(with_flags_embedding) > 0, "Embedding should be a non-empty list"
assert without_flags_embedding is None, "Embeddings should not be included when include_embeddings=False" assert without_flags_embedding is None, "Embeddings should not be included when include_embeddings=False"
@vector_provider_wrapper
def test_openai_vector_store_custom_collection_name(
compat_client_with_empty_stores, client_with_models, embedding_model_id, embedding_dimension, vector_io_provider_id
):
"""Test creating a vector store with a custom collection name."""
skip_if_provider_doesnt_support_openai_vector_stores(client_with_models)
client = compat_client_with_empty_stores
# Create vector store with custom collection name
vector_store = client.vector_stores.create(
name="Test Custom Collection",
extra_body={
"embedding_model": embedding_model_id,
"provider_id": vector_io_provider_id,
"collection_name": "my_custom_collection",
},
)
assert vector_store is not None
assert vector_store.id.startswith("vs_")
assert "provider_vector_store_id" in vector_store.metadata
assert vector_store.metadata["provider_vector_store_id"] == "my_custom_collection"
@vector_provider_wrapper
def test_openai_vector_store_collection_name_validation(
compat_client_with_empty_stores, client_with_models, embedding_model_id, embedding_dimension, vector_io_provider_id
):
"""Test that invalid collection names are rejected."""
skip_if_provider_doesnt_support_openai_vector_stores(client_with_models)
client = compat_client_with_empty_stores
# Test invalid collection names
invalid_names = ["with spaces", "with/slashes", "with@special", ""]
for invalid_name in invalid_names:
with pytest.raises((BadRequestError, ValueError)):
client.vector_stores.create(
name="Test Invalid",
extra_body={
"embedding_model": embedding_model_id,
"provider_id": vector_io_provider_id,
"collection_name": invalid_name,
},
)
@vector_provider_wrapper
def test_openai_vector_store_collection_name_with_data(
compat_client_with_empty_stores, client_with_models, sample_chunks, embedding_model_id, embedding_dimension, vector_io_provider_id
):
"""Test that custom collection names work with data insertion and search."""
skip_if_provider_doesnt_support_openai_vector_stores(client_with_models)
compat_client = compat_client_with_empty_stores
llama_client = client_with_models
# Create vector store with custom collection name
vector_store = compat_client.vector_stores.create(
name="Data Test Collection",
extra_body={
"embedding_model": embedding_model_id,
"provider_id": vector_io_provider_id,
"collection_name": "test_data_collection",
},
)
# Insert and search data
llama_client.vector_io.insert(vector_store_id=vector_store.id, chunks=sample_chunks[:2])
search_response = compat_client.vector_stores.search(
vector_store_id=vector_store.id,
query="What is Python?",
max_num_results=2,
)
assert search_response is not None
assert len(search_response.data) > 0
assert search_response.data[0].attributes["document_id"] == "doc1"