mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-12-03 09:53:45 +00:00
feat(vector_io): add custom collection names support for vector stores (#4135)
This commit is contained in:
parent
91f1b352b4
commit
6e6ddd3c69
5 changed files with 140 additions and 12 deletions
|
|
@ -5,6 +5,7 @@
|
|||
# the root directory of this source tree.
|
||||
|
||||
import asyncio
|
||||
import re
|
||||
import uuid
|
||||
from typing import Annotated, Any
|
||||
|
||||
|
|
@ -44,6 +45,17 @@ from llama_stack_api import (
|
|||
logger = get_logger(name=__name__, category="core::routers")
|
||||
|
||||
|
||||
def validate_collection_name(collection_name: str) -> None:
|
||||
if not collection_name:
|
||||
raise ValueError("collection_name cannot be empty")
|
||||
|
||||
if not re.match(r"^[a-zA-Z0-9_-]+$", collection_name):
|
||||
raise ValueError(
|
||||
f"collection_name '{collection_name}' contains invalid characters. "
|
||||
"Only alphanumeric characters, hyphens, and underscores are allowed."
|
||||
)
|
||||
|
||||
|
||||
class VectorIORouter(VectorIO):
|
||||
"""Routes to an provider based on the vector db identifier"""
|
||||
|
||||
|
|
@ -160,13 +172,25 @@ class VectorIORouter(VectorIO):
|
|||
else:
|
||||
provider_id = list(self.routing_table.impls_by_provider_id.keys())[0]
|
||||
|
||||
# Extract and validate collection_name if provided
|
||||
collection_name = extra.get("collection_name")
|
||||
if collection_name:
|
||||
validate_collection_name(collection_name)
|
||||
provider_vector_store_id = collection_name
|
||||
logger.debug(f"Using custom collection name: {collection_name}")
|
||||
else:
|
||||
# Fall back to auto-generated UUID for backward compatibility
|
||||
provider_vector_store_id = f"vs_{uuid.uuid4()}"
|
||||
|
||||
# Always generate a unique vector_store_id for internal routing
|
||||
vector_store_id = f"vs_{uuid.uuid4()}"
|
||||
|
||||
registered_vector_store = await self.routing_table.register_vector_store(
|
||||
vector_store_id=vector_store_id,
|
||||
embedding_model=embedding_model,
|
||||
embedding_dimension=embedding_dimension,
|
||||
provider_id=provider_id,
|
||||
provider_vector_store_id=vector_store_id,
|
||||
provider_vector_store_id=provider_vector_store_id,
|
||||
vector_store_name=params.name,
|
||||
)
|
||||
provider = await self.routing_table.get_provider_impl(registered_vector_store.identifier)
|
||||
|
|
@ -174,8 +198,14 @@ class VectorIORouter(VectorIO):
|
|||
# Update model_extra with registered values so provider uses the already-registered vector_store
|
||||
if params.model_extra is None:
|
||||
params.model_extra = {}
|
||||
params.model_extra["vector_store_id"] = vector_store_id # Pass canonical UUID to Provider
|
||||
params.model_extra["provider_vector_store_id"] = registered_vector_store.provider_resource_id
|
||||
params.model_extra["provider_id"] = registered_vector_store.provider_id
|
||||
|
||||
# Add collection_name to metadata so users can see what was used
|
||||
if params.metadata is None:
|
||||
params.metadata = {}
|
||||
params.metadata["provider_vector_store_id"] = provider_vector_store_id
|
||||
if embedding_model is not None:
|
||||
params.model_extra["embedding_model"] = embedding_model
|
||||
if embedding_dimension is not None:
|
||||
|
|
|
|||
|
|
@ -201,7 +201,11 @@ class FaissVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorStoresProtoco
|
|||
vector_store = VectorStore.model_validate_json(vector_store_data)
|
||||
index = VectorStoreWithIndex(
|
||||
vector_store,
|
||||
await FaissIndex.create(vector_store.embedding_dimension, self.kvstore, vector_store.identifier),
|
||||
await FaissIndex.create(
|
||||
vector_store.embedding_dimension,
|
||||
self.kvstore,
|
||||
vector_store.provider_resource_id or vector_store.identifier,
|
||||
),
|
||||
self.inference_api,
|
||||
)
|
||||
self.cache[vector_store.identifier] = index
|
||||
|
|
@ -239,7 +243,11 @@ class FaissVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorStoresProtoco
|
|||
# Store in cache
|
||||
self.cache[vector_store.identifier] = VectorStoreWithIndex(
|
||||
vector_store=vector_store,
|
||||
index=await FaissIndex.create(vector_store.embedding_dimension, self.kvstore, vector_store.identifier),
|
||||
index=await FaissIndex.create(
|
||||
vector_store.embedding_dimension,
|
||||
self.kvstore,
|
||||
vector_store.provider_resource_id or vector_store.identifier,
|
||||
),
|
||||
inference_api=self.inference_api,
|
||||
)
|
||||
|
||||
|
|
@ -272,7 +280,11 @@ class FaissVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorStoresProtoco
|
|||
vector_store = VectorStore.model_validate_json(vector_store_data)
|
||||
index = VectorStoreWithIndex(
|
||||
vector_store=vector_store,
|
||||
index=await FaissIndex.create(vector_store.embedding_dimension, self.kvstore, vector_store.identifier),
|
||||
index=await FaissIndex.create(
|
||||
vector_store.embedding_dimension,
|
||||
self.kvstore,
|
||||
vector_store.provider_resource_id or vector_store.identifier,
|
||||
),
|
||||
inference_api=self.inference_api,
|
||||
)
|
||||
self.cache[vector_store_id] = index
|
||||
|
|
|
|||
|
|
@ -401,7 +401,9 @@ class SQLiteVecVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorStoresPro
|
|||
for db_json in stored_vector_stores:
|
||||
vector_store = VectorStore.model_validate_json(db_json)
|
||||
index = await SQLiteVecIndex.create(
|
||||
vector_store.embedding_dimension, self.config.db_path, vector_store.identifier
|
||||
vector_store.embedding_dimension,
|
||||
self.config.db_path,
|
||||
vector_store.provider_resource_id or vector_store.identifier,
|
||||
)
|
||||
self.cache[vector_store.identifier] = VectorStoreWithIndex(vector_store, index, self.inference_api)
|
||||
|
||||
|
|
@ -425,7 +427,9 @@ class SQLiteVecVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorStoresPro
|
|||
|
||||
# Create and cache the index
|
||||
index = await SQLiteVecIndex.create(
|
||||
vector_store.embedding_dimension, self.config.db_path, vector_store.identifier
|
||||
vector_store.embedding_dimension,
|
||||
self.config.db_path,
|
||||
vector_store.provider_resource_id or vector_store.identifier,
|
||||
)
|
||||
self.cache[vector_store.identifier] = VectorStoreWithIndex(vector_store, index, self.inference_api)
|
||||
|
||||
|
|
@ -448,7 +452,7 @@ class SQLiteVecVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorStoresPro
|
|||
index=SQLiteVecIndex(
|
||||
dimension=vector_store.embedding_dimension,
|
||||
db_path=self.config.db_path,
|
||||
bank_id=vector_store.identifier,
|
||||
bank_id=vector_store.provider_resource_id or vector_store.identifier,
|
||||
kvstore=self.kvstore,
|
||||
),
|
||||
inference_api=self.inference_api,
|
||||
|
|
|
|||
|
|
@ -360,7 +360,11 @@ class OpenAIVectorStoreMixin(ABC):
|
|||
extra_body = params.model_extra or {}
|
||||
metadata = params.metadata or {}
|
||||
|
||||
provider_vector_store_id = extra_body.get("provider_vector_store_id")
|
||||
# Get the canonical UUID from router (or generate if called directly without router)
|
||||
vector_store_id = extra_body.get("vector_store_id") or generate_object_id("vector_store", lambda: f"vs_{uuid.uuid4()}")
|
||||
|
||||
# Get the physical storage name (custom collection name or fallback to UUID)
|
||||
provider_vector_store_id = extra_body.get("provider_vector_store_id") or vector_store_id
|
||||
|
||||
# Use embedding info from metadata if available, otherwise from extra_body
|
||||
if metadata.get("embedding_model"):
|
||||
|
|
@ -381,8 +385,6 @@ class OpenAIVectorStoreMixin(ABC):
|
|||
|
||||
# use provider_id set by router; fallback to provider's own ID when used directly via --stack-config
|
||||
provider_id = extra_body.get("provider_id") or getattr(self, "__provider_id__", None)
|
||||
# Derive the canonical vector_store_id (allow override, else generate)
|
||||
vector_store_id = provider_vector_store_id or generate_object_id("vector_store", lambda: f"vs_{uuid.uuid4()}")
|
||||
|
||||
if embedding_model is None:
|
||||
raise ValueError("embedding_model is required")
|
||||
|
|
@ -396,11 +398,11 @@ class OpenAIVectorStoreMixin(ABC):
|
|||
|
||||
# call to the provider to create any index, etc.
|
||||
vector_store = VectorStore(
|
||||
identifier=vector_store_id,
|
||||
identifier=vector_store_id, # Canonical UUID for routing
|
||||
embedding_dimension=embedding_dimension,
|
||||
embedding_model=embedding_model,
|
||||
provider_id=provider_id,
|
||||
provider_resource_id=vector_store_id,
|
||||
provider_resource_id=provider_vector_store_id, # Physical storage name (custom or UUID)
|
||||
vector_store_name=params.name,
|
||||
)
|
||||
await self.register_vector_store(vector_store)
|
||||
|
|
|
|||
|
|
@ -1698,3 +1698,83 @@ def test_openai_vector_store_file_contents_with_extra_query(
|
|||
assert with_flags_embedding is not None, "Embeddings should be included when include_embeddings=True"
|
||||
assert len(with_flags_embedding) > 0, "Embedding should be a non-empty list"
|
||||
assert without_flags_embedding is None, "Embeddings should not be included when include_embeddings=False"
|
||||
|
||||
|
||||
@vector_provider_wrapper
|
||||
def test_openai_vector_store_custom_collection_name(
|
||||
compat_client_with_empty_stores, client_with_models, embedding_model_id, embedding_dimension, vector_io_provider_id
|
||||
):
|
||||
"""Test creating a vector store with a custom collection name."""
|
||||
skip_if_provider_doesnt_support_openai_vector_stores(client_with_models)
|
||||
client = compat_client_with_empty_stores
|
||||
|
||||
# Create vector store with custom collection name
|
||||
vector_store = client.vector_stores.create(
|
||||
name="Test Custom Collection",
|
||||
extra_body={
|
||||
"embedding_model": embedding_model_id,
|
||||
"provider_id": vector_io_provider_id,
|
||||
"collection_name": "my_custom_collection",
|
||||
},
|
||||
)
|
||||
|
||||
assert vector_store is not None
|
||||
assert vector_store.id.startswith("vs_")
|
||||
assert "provider_vector_store_id" in vector_store.metadata
|
||||
assert vector_store.metadata["provider_vector_store_id"] == "my_custom_collection"
|
||||
|
||||
|
||||
@vector_provider_wrapper
|
||||
def test_openai_vector_store_collection_name_validation(
|
||||
compat_client_with_empty_stores, client_with_models, embedding_model_id, embedding_dimension, vector_io_provider_id
|
||||
):
|
||||
"""Test that invalid collection names are rejected."""
|
||||
skip_if_provider_doesnt_support_openai_vector_stores(client_with_models)
|
||||
client = compat_client_with_empty_stores
|
||||
|
||||
# Test invalid collection names
|
||||
invalid_names = ["with spaces", "with/slashes", "with@special", ""]
|
||||
|
||||
for invalid_name in invalid_names:
|
||||
with pytest.raises((BadRequestError, ValueError)):
|
||||
client.vector_stores.create(
|
||||
name="Test Invalid",
|
||||
extra_body={
|
||||
"embedding_model": embedding_model_id,
|
||||
"provider_id": vector_io_provider_id,
|
||||
"collection_name": invalid_name,
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
@vector_provider_wrapper
|
||||
def test_openai_vector_store_collection_name_with_data(
|
||||
compat_client_with_empty_stores, client_with_models, sample_chunks, embedding_model_id, embedding_dimension, vector_io_provider_id
|
||||
):
|
||||
"""Test that custom collection names work with data insertion and search."""
|
||||
skip_if_provider_doesnt_support_openai_vector_stores(client_with_models)
|
||||
compat_client = compat_client_with_empty_stores
|
||||
llama_client = client_with_models
|
||||
|
||||
# Create vector store with custom collection name
|
||||
vector_store = compat_client.vector_stores.create(
|
||||
name="Data Test Collection",
|
||||
extra_body={
|
||||
"embedding_model": embedding_model_id,
|
||||
"provider_id": vector_io_provider_id,
|
||||
"collection_name": "test_data_collection",
|
||||
},
|
||||
)
|
||||
|
||||
# Insert and search data
|
||||
llama_client.vector_io.insert(vector_store_id=vector_store.id, chunks=sample_chunks[:2])
|
||||
|
||||
search_response = compat_client.vector_stores.search(
|
||||
vector_store_id=vector_store.id,
|
||||
query="What is Python?",
|
||||
max_num_results=2,
|
||||
)
|
||||
|
||||
assert search_response is not None
|
||||
assert len(search_response.data) > 0
|
||||
assert search_response.data[0].attributes["document_id"] == "doc1"
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue