mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-12-12 04:00:42 +00:00
4: finished rename I think
This commit is contained in:
parent
3d7b463a80
commit
44f104baae
15 changed files with 273 additions and 272 deletions
|
|
@ -4,4 +4,4 @@
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
from .vector_dbs import *
|
from .vector_stores import *
|
||||||
|
|
@ -17,7 +17,7 @@ from llama_stack.apis.models import Model
|
||||||
from llama_stack.apis.scoring_functions import ScoringFn
|
from llama_stack.apis.scoring_functions import ScoringFn
|
||||||
from llama_stack.apis.shields import Shield
|
from llama_stack.apis.shields import Shield
|
||||||
from llama_stack.apis.tools import ToolGroup
|
from llama_stack.apis.tools import ToolGroup
|
||||||
from llama_stack.apis.vector_dbs import VectorStore
|
from llama_stack.apis.vector_stores import VectorStore
|
||||||
from llama_stack.schema_utils import json_schema_type
|
from llama_stack.schema_utils import json_schema_type
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -17,7 +17,7 @@ from pydantic import TypeAdapter
|
||||||
|
|
||||||
from llama_stack.apis.common.errors import VectorStoreNotFoundError
|
from llama_stack.apis.common.errors import VectorStoreNotFoundError
|
||||||
from llama_stack.apis.files import Files, OpenAIFileObject
|
from llama_stack.apis.files import Files, OpenAIFileObject
|
||||||
from llama_stack.apis.vector_dbs import VectorDB
|
from llama_stack.apis.vector_stores import VectorStore
|
||||||
from llama_stack.apis.vector_io import (
|
from llama_stack.apis.vector_io import (
|
||||||
Chunk,
|
Chunk,
|
||||||
OpenAICreateVectorStoreFileBatchRequestWithExtraBody,
|
OpenAICreateVectorStoreFileBatchRequestWithExtraBody,
|
||||||
|
|
@ -63,7 +63,7 @@ MAX_CONCURRENT_FILES_PER_BATCH = 3 # Maximum concurrent file processing within
|
||||||
FILE_BATCH_CHUNK_SIZE = 10 # Process files in chunks of this size
|
FILE_BATCH_CHUNK_SIZE = 10 # Process files in chunks of this size
|
||||||
|
|
||||||
VERSION = "v3"
|
VERSION = "v3"
|
||||||
VECTOR_DBS_PREFIX = f"vector_dbs:{VERSION}::"
|
VECTOR_DBS_PREFIX = f"vector_stores:{VERSION}::"
|
||||||
OPENAI_VECTOR_STORES_PREFIX = f"openai_vector_stores:{VERSION}::"
|
OPENAI_VECTOR_STORES_PREFIX = f"openai_vector_stores:{VERSION}::"
|
||||||
OPENAI_VECTOR_STORES_FILES_PREFIX = f"openai_vector_stores_files:{VERSION}::"
|
OPENAI_VECTOR_STORES_FILES_PREFIX = f"openai_vector_stores_files:{VERSION}::"
|
||||||
OPENAI_VECTOR_STORES_FILES_CONTENTS_PREFIX = f"openai_vector_stores_files_contents:{VERSION}::"
|
OPENAI_VECTOR_STORES_FILES_CONTENTS_PREFIX = f"openai_vector_stores_files_contents:{VERSION}::"
|
||||||
|
|
@ -321,19 +321,19 @@ class OpenAIVectorStoreMixin(ABC):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
async def register_vector_db(self, vector_db: VectorDB) -> None:
|
async def register_vector_store(self, vector_store: VectorStore) -> None:
|
||||||
"""Register a vector database (provider-specific implementation)."""
|
"""Register a vector database (provider-specific implementation)."""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
async def unregister_vector_db(self, vector_db_id: str) -> None:
|
async def unregister_vector_store(self, vector_store_id: str) -> None:
|
||||||
"""Unregister a vector database (provider-specific implementation)."""
|
"""Unregister a vector database (provider-specific implementation)."""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
async def insert_chunks(
|
async def insert_chunks(
|
||||||
self,
|
self,
|
||||||
vector_db_id: str,
|
vector_store_id: str,
|
||||||
chunks: list[Chunk],
|
chunks: list[Chunk],
|
||||||
ttl_seconds: int | None = None,
|
ttl_seconds: int | None = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
|
|
@ -342,7 +342,7 @@ class OpenAIVectorStoreMixin(ABC):
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
async def query_chunks(
|
async def query_chunks(
|
||||||
self, vector_db_id: str, query: Any, params: dict[str, Any] | None = None
|
self, vector_store_id: str, query: Any, params: dict[str, Any] | None = None
|
||||||
) -> QueryChunksResponse:
|
) -> QueryChunksResponse:
|
||||||
"""Query chunks from a vector database (provider-specific implementation)."""
|
"""Query chunks from a vector database (provider-specific implementation)."""
|
||||||
pass
|
pass
|
||||||
|
|
@ -358,7 +358,7 @@ class OpenAIVectorStoreMixin(ABC):
|
||||||
extra_body = params.model_extra or {}
|
extra_body = params.model_extra or {}
|
||||||
metadata = params.metadata or {}
|
metadata = params.metadata or {}
|
||||||
|
|
||||||
provider_vector_db_id = extra_body.get("provider_vector_db_id")
|
provider_vector_store_id = extra_body.get("provider_vector_store_id")
|
||||||
|
|
||||||
# Use embedding info from metadata if available, otherwise from extra_body
|
# Use embedding info from metadata if available, otherwise from extra_body
|
||||||
if metadata.get("embedding_model"):
|
if metadata.get("embedding_model"):
|
||||||
|
|
@ -389,8 +389,8 @@ class OpenAIVectorStoreMixin(ABC):
|
||||||
|
|
||||||
# use provider_id set by router; fallback to provider's own ID when used directly via --stack-config
|
# use provider_id set by router; fallback to provider's own ID when used directly via --stack-config
|
||||||
provider_id = extra_body.get("provider_id") or getattr(self, "__provider_id__", None)
|
provider_id = extra_body.get("provider_id") or getattr(self, "__provider_id__", None)
|
||||||
# Derive the canonical vector_db_id (allow override, else generate)
|
# Derive the canonical vector_store_id (allow override, else generate)
|
||||||
vector_db_id = provider_vector_db_id or generate_object_id("vector_store", lambda: f"vs_{uuid.uuid4()}")
|
vector_store_id = provider_vector_store_id or generate_object_id("vector_store", lambda: f"vs_{uuid.uuid4()}")
|
||||||
|
|
||||||
if embedding_model is None:
|
if embedding_model is None:
|
||||||
raise ValueError("embedding_model is required")
|
raise ValueError("embedding_model is required")
|
||||||
|
|
@ -398,19 +398,20 @@ class OpenAIVectorStoreMixin(ABC):
|
||||||
if embedding_dimension is None:
|
if embedding_dimension is None:
|
||||||
raise ValueError("Embedding dimension is required")
|
raise ValueError("Embedding dimension is required")
|
||||||
|
|
||||||
# Register the VectorDB backing this vector store
|
# Register the VectorStore backing this vector store
|
||||||
if provider_id is None:
|
if provider_id is None:
|
||||||
raise ValueError("Provider ID is required but was not provided")
|
raise ValueError("Provider ID is required but was not provided")
|
||||||
|
|
||||||
vector_db = VectorDB(
|
# call to the provider to create any index, etc.
|
||||||
identifier=vector_db_id,
|
vector_store = VectorStore(
|
||||||
|
identifier=vector_store_id,
|
||||||
embedding_dimension=embedding_dimension,
|
embedding_dimension=embedding_dimension,
|
||||||
embedding_model=embedding_model,
|
embedding_model=embedding_model,
|
||||||
provider_id=provider_id,
|
provider_id=provider_id,
|
||||||
provider_resource_id=vector_db_id,
|
provider_resource_id=vector_store_id,
|
||||||
vector_db_name=params.name,
|
vector_store_name=params.name,
|
||||||
)
|
)
|
||||||
await self.register_vector_db(vector_db)
|
await self.register_vector_store(vector_store)
|
||||||
|
|
||||||
# Create OpenAI vector store metadata
|
# Create OpenAI vector store metadata
|
||||||
status = "completed"
|
status = "completed"
|
||||||
|
|
@ -424,7 +425,7 @@ class OpenAIVectorStoreMixin(ABC):
|
||||||
total=0,
|
total=0,
|
||||||
)
|
)
|
||||||
store_info: dict[str, Any] = {
|
store_info: dict[str, Any] = {
|
||||||
"id": vector_db_id,
|
"id": vector_store_id,
|
||||||
"object": "vector_store",
|
"object": "vector_store",
|
||||||
"created_at": created_at,
|
"created_at": created_at,
|
||||||
"name": params.name,
|
"name": params.name,
|
||||||
|
|
@ -441,23 +442,23 @@ class OpenAIVectorStoreMixin(ABC):
|
||||||
# Add provider information to metadata if provided
|
# Add provider information to metadata if provided
|
||||||
if provider_id:
|
if provider_id:
|
||||||
metadata["provider_id"] = provider_id
|
metadata["provider_id"] = provider_id
|
||||||
if provider_vector_db_id:
|
if provider_vector_store_id:
|
||||||
metadata["provider_vector_db_id"] = provider_vector_db_id
|
metadata["provider_vector_store_id"] = provider_vector_store_id
|
||||||
store_info["metadata"] = metadata
|
store_info["metadata"] = metadata
|
||||||
|
|
||||||
# Save to persistent storage (provider-specific)
|
# Save to persistent storage (provider-specific)
|
||||||
await self._save_openai_vector_store(vector_db_id, store_info)
|
await self._save_openai_vector_store(vector_store_id, store_info)
|
||||||
|
|
||||||
# Store in memory cache
|
# Store in memory cache
|
||||||
self.openai_vector_stores[vector_db_id] = store_info
|
self.openai_vector_stores[vector_store_id] = store_info
|
||||||
|
|
||||||
# Now that our vector store is created, attach any files that were provided
|
# Now that our vector store is created, attach any files that were provided
|
||||||
file_ids = params.file_ids or []
|
file_ids = params.file_ids or []
|
||||||
tasks = [self.openai_attach_file_to_vector_store(vector_db_id, file_id) for file_id in file_ids]
|
tasks = [self.openai_attach_file_to_vector_store(vector_store_id, file_id) for file_id in file_ids]
|
||||||
await asyncio.gather(*tasks)
|
await asyncio.gather(*tasks)
|
||||||
|
|
||||||
# Get the updated store info and return it
|
# Get the updated store info and return it
|
||||||
store_info = self.openai_vector_stores[vector_db_id]
|
store_info = self.openai_vector_stores[vector_store_id]
|
||||||
return VectorStoreObject.model_validate(store_info)
|
return VectorStoreObject.model_validate(store_info)
|
||||||
|
|
||||||
async def openai_list_vector_stores(
|
async def openai_list_vector_stores(
|
||||||
|
|
@ -567,7 +568,7 @@ class OpenAIVectorStoreMixin(ABC):
|
||||||
|
|
||||||
# Also delete the underlying vector DB
|
# Also delete the underlying vector DB
|
||||||
try:
|
try:
|
||||||
await self.unregister_vector_db(vector_store_id)
|
await self.unregister_vector_store(vector_store_id)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.warning(f"Failed to delete underlying vector DB {vector_store_id}: {e}")
|
logger.warning(f"Failed to delete underlying vector DB {vector_store_id}: {e}")
|
||||||
|
|
||||||
|
|
@ -618,7 +619,7 @@ class OpenAIVectorStoreMixin(ABC):
|
||||||
# TODO: Add support for ranking_options.ranker
|
# TODO: Add support for ranking_options.ranker
|
||||||
|
|
||||||
response = await self.query_chunks(
|
response = await self.query_chunks(
|
||||||
vector_db_id=vector_store_id,
|
vector_store_id=vector_store_id,
|
||||||
query=search_query,
|
query=search_query,
|
||||||
params=params,
|
params=params,
|
||||||
)
|
)
|
||||||
|
|
@ -812,7 +813,7 @@ class OpenAIVectorStoreMixin(ABC):
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
await self.insert_chunks(
|
await self.insert_chunks(
|
||||||
vector_db_id=vector_store_id,
|
vector_store_id=vector_store_id,
|
||||||
chunks=chunks,
|
chunks=chunks,
|
||||||
)
|
)
|
||||||
vector_store_file_object.status = "completed"
|
vector_store_file_object.status = "completed"
|
||||||
|
|
|
||||||
|
|
@ -23,7 +23,7 @@ from llama_stack.apis.common.content_types import (
|
||||||
)
|
)
|
||||||
from llama_stack.apis.inference import OpenAIEmbeddingsRequestWithExtraBody
|
from llama_stack.apis.inference import OpenAIEmbeddingsRequestWithExtraBody
|
||||||
from llama_stack.apis.tools import RAGDocument
|
from llama_stack.apis.tools import RAGDocument
|
||||||
from llama_stack.apis.vector_dbs import VectorDB
|
from llama_stack.apis.vector_stores import VectorStore
|
||||||
from llama_stack.apis.vector_io import Chunk, ChunkMetadata, QueryChunksResponse
|
from llama_stack.apis.vector_io import Chunk, ChunkMetadata, QueryChunksResponse
|
||||||
from llama_stack.log import get_logger
|
from llama_stack.log import get_logger
|
||||||
from llama_stack.models.llama.llama3.tokenizer import Tokenizer
|
from llama_stack.models.llama.llama3.tokenizer import Tokenizer
|
||||||
|
|
@ -187,7 +187,7 @@ def make_overlapped_chunks(
|
||||||
updated_timestamp=int(time.time()),
|
updated_timestamp=int(time.time()),
|
||||||
chunk_window=chunk_window,
|
chunk_window=chunk_window,
|
||||||
chunk_tokenizer=default_tokenizer,
|
chunk_tokenizer=default_tokenizer,
|
||||||
chunk_embedding_model=None, # This will be set in `VectorDBWithIndex.insert_chunks`
|
chunk_embedding_model=None, # This will be set in `VectorStoreWithIndex.insert_chunks`
|
||||||
content_token_count=len(toks),
|
content_token_count=len(toks),
|
||||||
metadata_token_count=len(metadata_tokens),
|
metadata_token_count=len(metadata_tokens),
|
||||||
)
|
)
|
||||||
|
|
@ -255,8 +255,8 @@ class EmbeddingIndex(ABC):
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class VectorDBWithIndex:
|
class VectorStoreWithIndex:
|
||||||
vector_db: VectorDB
|
vector_store: VectorStore
|
||||||
index: EmbeddingIndex
|
index: EmbeddingIndex
|
||||||
inference_api: Api.inference
|
inference_api: Api.inference
|
||||||
|
|
||||||
|
|
@ -269,14 +269,14 @@ class VectorDBWithIndex:
|
||||||
if c.embedding is None:
|
if c.embedding is None:
|
||||||
chunks_to_embed.append(c)
|
chunks_to_embed.append(c)
|
||||||
if c.chunk_metadata:
|
if c.chunk_metadata:
|
||||||
c.chunk_metadata.chunk_embedding_model = self.vector_db.embedding_model
|
c.chunk_metadata.chunk_embedding_model = self.vector_store.embedding_model
|
||||||
c.chunk_metadata.chunk_embedding_dimension = self.vector_db.embedding_dimension
|
c.chunk_metadata.chunk_embedding_dimension = self.vector_store.embedding_dimension
|
||||||
else:
|
else:
|
||||||
_validate_embedding(c.embedding, i, self.vector_db.embedding_dimension)
|
_validate_embedding(c.embedding, i, self.vector_store.embedding_dimension)
|
||||||
|
|
||||||
if chunks_to_embed:
|
if chunks_to_embed:
|
||||||
params = OpenAIEmbeddingsRequestWithExtraBody(
|
params = OpenAIEmbeddingsRequestWithExtraBody(
|
||||||
model=self.vector_db.embedding_model,
|
model=self.vector_store.embedding_model,
|
||||||
input=[c.content for c in chunks_to_embed],
|
input=[c.content for c in chunks_to_embed],
|
||||||
)
|
)
|
||||||
resp = await self.inference_api.openai_embeddings(params)
|
resp = await self.inference_api.openai_embeddings(params)
|
||||||
|
|
@ -319,7 +319,7 @@ class VectorDBWithIndex:
|
||||||
return await self.index.query_keyword(query_string, k, score_threshold)
|
return await self.index.query_keyword(query_string, k, score_threshold)
|
||||||
|
|
||||||
params = OpenAIEmbeddingsRequestWithExtraBody(
|
params = OpenAIEmbeddingsRequestWithExtraBody(
|
||||||
model=self.vector_db.embedding_model,
|
model=self.vector_store.embedding_model,
|
||||||
input=[query_string],
|
input=[query_string],
|
||||||
)
|
)
|
||||||
embeddings_response = await self.inference_api.openai_embeddings(params)
|
embeddings_response = await self.inference_api.openai_embeddings(params)
|
||||||
|
|
|
||||||
|
|
@ -367,7 +367,7 @@ def test_openai_vector_store_with_chunks(
|
||||||
|
|
||||||
# Insert chunks using the native LlamaStack API (since OpenAI API doesn't have direct chunk insertion)
|
# Insert chunks using the native LlamaStack API (since OpenAI API doesn't have direct chunk insertion)
|
||||||
llama_client.vector_io.insert(
|
llama_client.vector_io.insert(
|
||||||
vector_db_id=vector_store.id,
|
vector_store_id=vector_store.id,
|
||||||
chunks=sample_chunks,
|
chunks=sample_chunks,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
@ -434,7 +434,7 @@ def test_openai_vector_store_search_relevance(
|
||||||
|
|
||||||
# Insert chunks using native API
|
# Insert chunks using native API
|
||||||
llama_client.vector_io.insert(
|
llama_client.vector_io.insert(
|
||||||
vector_db_id=vector_store.id,
|
vector_store_id=vector_store.id,
|
||||||
chunks=sample_chunks,
|
chunks=sample_chunks,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
@ -484,7 +484,7 @@ def test_openai_vector_store_search_with_ranking_options(
|
||||||
|
|
||||||
# Insert chunks
|
# Insert chunks
|
||||||
llama_client.vector_io.insert(
|
llama_client.vector_io.insert(
|
||||||
vector_db_id=vector_store.id,
|
vector_store_id=vector_store.id,
|
||||||
chunks=sample_chunks,
|
chunks=sample_chunks,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
@ -544,7 +544,7 @@ def test_openai_vector_store_search_with_high_score_filter(
|
||||||
|
|
||||||
# Insert chunks
|
# Insert chunks
|
||||||
llama_client.vector_io.insert(
|
llama_client.vector_io.insert(
|
||||||
vector_db_id=vector_store.id,
|
vector_store_id=vector_store.id,
|
||||||
chunks=sample_chunks,
|
chunks=sample_chunks,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
@ -610,7 +610,7 @@ def test_openai_vector_store_search_with_max_num_results(
|
||||||
|
|
||||||
# Insert chunks
|
# Insert chunks
|
||||||
llama_client.vector_io.insert(
|
llama_client.vector_io.insert(
|
||||||
vector_db_id=vector_store.id,
|
vector_store_id=vector_store.id,
|
||||||
chunks=sample_chunks,
|
chunks=sample_chunks,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
@ -1175,7 +1175,7 @@ def test_openai_vector_store_search_modes(
|
||||||
)
|
)
|
||||||
|
|
||||||
client_with_models.vector_io.insert(
|
client_with_models.vector_io.insert(
|
||||||
vector_db_id=vector_store.id,
|
vector_store_id=vector_store.id,
|
||||||
chunks=sample_chunks,
|
chunks=sample_chunks,
|
||||||
)
|
)
|
||||||
query = "Python programming language"
|
query = "Python programming language"
|
||||||
|
|
|
||||||
|
|
@ -49,46 +49,46 @@ def client_with_empty_registry(client_with_models):
|
||||||
|
|
||||||
|
|
||||||
@vector_provider_wrapper
|
@vector_provider_wrapper
|
||||||
def test_vector_db_retrieve(client_with_empty_registry, embedding_model_id, embedding_dimension, vector_io_provider_id):
|
def test_vector_store_retrieve(client_with_empty_registry, embedding_model_id, embedding_dimension, vector_io_provider_id):
|
||||||
vector_db_name = "test_vector_db"
|
vector_store_name = "test_vector_store"
|
||||||
create_response = client_with_empty_registry.vector_stores.create(
|
create_response = client_with_empty_registry.vector_stores.create(
|
||||||
name=vector_db_name,
|
name=vector_store_name,
|
||||||
extra_body={
|
extra_body={
|
||||||
"provider_id": vector_io_provider_id,
|
"provider_id": vector_io_provider_id,
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
actual_vector_db_id = create_response.id
|
actual_vector_store_id = create_response.id
|
||||||
|
|
||||||
# Retrieve the vector store and validate its properties
|
# Retrieve the vector store and validate its properties
|
||||||
response = client_with_empty_registry.vector_stores.retrieve(vector_store_id=actual_vector_db_id)
|
response = client_with_empty_registry.vector_stores.retrieve(vector_store_id=actual_vector_store_id)
|
||||||
assert response is not None
|
assert response is not None
|
||||||
assert response.id == actual_vector_db_id
|
assert response.id == actual_vector_store_id
|
||||||
assert response.name == vector_db_name
|
assert response.name == vector_store_name
|
||||||
assert response.id.startswith("vs_")
|
assert response.id.startswith("vs_")
|
||||||
|
|
||||||
|
|
||||||
@vector_provider_wrapper
|
@vector_provider_wrapper
|
||||||
def test_vector_db_register(client_with_empty_registry, embedding_model_id, embedding_dimension, vector_io_provider_id):
|
def test_vector_store_register(client_with_empty_registry, embedding_model_id, embedding_dimension, vector_io_provider_id):
|
||||||
vector_db_name = "test_vector_db"
|
vector_store_name = "test_vector_store"
|
||||||
response = client_with_empty_registry.vector_stores.create(
|
response = client_with_empty_registry.vector_stores.create(
|
||||||
name=vector_db_name,
|
name=vector_store_name,
|
||||||
extra_body={
|
extra_body={
|
||||||
"provider_id": vector_io_provider_id,
|
"provider_id": vector_io_provider_id,
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
actual_vector_db_id = response.id
|
actual_vector_store_id = response.id
|
||||||
assert actual_vector_db_id.startswith("vs_")
|
assert actual_vector_store_id.startswith("vs_")
|
||||||
assert actual_vector_db_id != vector_db_name
|
assert actual_vector_store_id != vector_store_name
|
||||||
|
|
||||||
vector_stores = client_with_empty_registry.vector_stores.list()
|
vector_stores = client_with_empty_registry.vector_stores.list()
|
||||||
assert len(vector_stores.data) == 1
|
assert len(vector_stores.data) == 1
|
||||||
vector_store = vector_stores.data[0]
|
vector_store = vector_stores.data[0]
|
||||||
assert vector_store.id == actual_vector_db_id
|
assert vector_store.id == actual_vector_store_id
|
||||||
assert vector_store.name == vector_db_name
|
assert vector_store.name == vector_store_name
|
||||||
|
|
||||||
client_with_empty_registry.vector_stores.delete(vector_store_id=actual_vector_db_id)
|
client_with_empty_registry.vector_stores.delete(vector_store_id=actual_vector_store_id)
|
||||||
|
|
||||||
vector_stores = client_with_empty_registry.vector_stores.list()
|
vector_stores = client_with_empty_registry.vector_stores.list()
|
||||||
assert len(vector_stores.data) == 0
|
assert len(vector_stores.data) == 0
|
||||||
|
|
@ -108,23 +108,23 @@ def test_vector_db_register(client_with_empty_registry, embedding_model_id, embe
|
||||||
def test_insert_chunks(
|
def test_insert_chunks(
|
||||||
client_with_empty_registry, embedding_model_id, embedding_dimension, sample_chunks, test_case, vector_io_provider_id
|
client_with_empty_registry, embedding_model_id, embedding_dimension, sample_chunks, test_case, vector_io_provider_id
|
||||||
):
|
):
|
||||||
vector_db_name = "test_vector_db"
|
vector_store_name = "test_vector_store"
|
||||||
create_response = client_with_empty_registry.vector_stores.create(
|
create_response = client_with_empty_registry.vector_stores.create(
|
||||||
name=vector_db_name,
|
name=vector_store_name,
|
||||||
extra_body={
|
extra_body={
|
||||||
"provider_id": vector_io_provider_id,
|
"provider_id": vector_io_provider_id,
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
actual_vector_db_id = create_response.id
|
actual_vector_store_id = create_response.id
|
||||||
|
|
||||||
client_with_empty_registry.vector_io.insert(
|
client_with_empty_registry.vector_io.insert(
|
||||||
vector_db_id=actual_vector_db_id,
|
vector_store_id=actual_vector_store_id,
|
||||||
chunks=sample_chunks,
|
chunks=sample_chunks,
|
||||||
)
|
)
|
||||||
|
|
||||||
response = client_with_empty_registry.vector_io.query(
|
response = client_with_empty_registry.vector_io.query(
|
||||||
vector_db_id=actual_vector_db_id,
|
vector_store_id=actual_vector_store_id,
|
||||||
query="What is the capital of France?",
|
query="What is the capital of France?",
|
||||||
)
|
)
|
||||||
assert response is not None
|
assert response is not None
|
||||||
|
|
@ -133,7 +133,7 @@ def test_insert_chunks(
|
||||||
|
|
||||||
query, expected_doc_id = test_case
|
query, expected_doc_id = test_case
|
||||||
response = client_with_empty_registry.vector_io.query(
|
response = client_with_empty_registry.vector_io.query(
|
||||||
vector_db_id=actual_vector_db_id,
|
vector_store_id=actual_vector_store_id,
|
||||||
query=query,
|
query=query,
|
||||||
)
|
)
|
||||||
assert response is not None
|
assert response is not None
|
||||||
|
|
@ -151,15 +151,15 @@ def test_insert_chunks_with_precomputed_embeddings(
|
||||||
"inline::qdrant": {"score_threshold": -1.0},
|
"inline::qdrant": {"score_threshold": -1.0},
|
||||||
"remote::qdrant": {"score_threshold": -1.0},
|
"remote::qdrant": {"score_threshold": -1.0},
|
||||||
}
|
}
|
||||||
vector_db_name = "test_precomputed_embeddings_db"
|
vector_store_name = "test_precomputed_embeddings_db"
|
||||||
register_response = client_with_empty_registry.vector_stores.create(
|
register_response = client_with_empty_registry.vector_stores.create(
|
||||||
name=vector_db_name,
|
name=vector_store_name,
|
||||||
extra_body={
|
extra_body={
|
||||||
"provider_id": vector_io_provider_id,
|
"provider_id": vector_io_provider_id,
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
actual_vector_db_id = register_response.id
|
actual_vector_store_id = register_response.id
|
||||||
|
|
||||||
chunks_with_embeddings = [
|
chunks_with_embeddings = [
|
||||||
Chunk(
|
Chunk(
|
||||||
|
|
@ -170,13 +170,13 @@ def test_insert_chunks_with_precomputed_embeddings(
|
||||||
]
|
]
|
||||||
|
|
||||||
client_with_empty_registry.vector_io.insert(
|
client_with_empty_registry.vector_io.insert(
|
||||||
vector_db_id=actual_vector_db_id,
|
vector_store_id=actual_vector_store_id,
|
||||||
chunks=chunks_with_embeddings,
|
chunks=chunks_with_embeddings,
|
||||||
)
|
)
|
||||||
|
|
||||||
provider = [p.provider_id for p in client_with_empty_registry.providers.list() if p.api == "vector_io"][0]
|
provider = [p.provider_id for p in client_with_empty_registry.providers.list() if p.api == "vector_io"][0]
|
||||||
response = client_with_empty_registry.vector_io.query(
|
response = client_with_empty_registry.vector_io.query(
|
||||||
vector_db_id=actual_vector_db_id,
|
vector_store_id=actual_vector_store_id,
|
||||||
query="precomputed embedding test",
|
query="precomputed embedding test",
|
||||||
params=vector_io_provider_params_dict.get(provider, None),
|
params=vector_io_provider_params_dict.get(provider, None),
|
||||||
)
|
)
|
||||||
|
|
@ -200,16 +200,16 @@ def test_query_returns_valid_object_when_identical_to_embedding_in_vdb(
|
||||||
"remote::qdrant": {"score_threshold": 0.0},
|
"remote::qdrant": {"score_threshold": 0.0},
|
||||||
"inline::qdrant": {"score_threshold": 0.0},
|
"inline::qdrant": {"score_threshold": 0.0},
|
||||||
}
|
}
|
||||||
vector_db_name = "test_precomputed_embeddings_db"
|
vector_store_name = "test_precomputed_embeddings_db"
|
||||||
register_response = client_with_empty_registry.vector_stores.create(
|
register_response = client_with_empty_registry.vector_stores.create(
|
||||||
name=vector_db_name,
|
name=vector_store_name,
|
||||||
extra_body={
|
extra_body={
|
||||||
"embedding_model": embedding_model_id,
|
"embedding_model": embedding_model_id,
|
||||||
"provider_id": vector_io_provider_id,
|
"provider_id": vector_io_provider_id,
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
actual_vector_db_id = register_response.id
|
actual_vector_store_id = register_response.id
|
||||||
|
|
||||||
chunks_with_embeddings = [
|
chunks_with_embeddings = [
|
||||||
Chunk(
|
Chunk(
|
||||||
|
|
@ -220,13 +220,13 @@ def test_query_returns_valid_object_when_identical_to_embedding_in_vdb(
|
||||||
]
|
]
|
||||||
|
|
||||||
client_with_empty_registry.vector_io.insert(
|
client_with_empty_registry.vector_io.insert(
|
||||||
vector_db_id=actual_vector_db_id,
|
vector_store_id=actual_vector_store_id,
|
||||||
chunks=chunks_with_embeddings,
|
chunks=chunks_with_embeddings,
|
||||||
)
|
)
|
||||||
|
|
||||||
provider = [p.provider_id for p in client_with_empty_registry.providers.list() if p.api == "vector_io"][0]
|
provider = [p.provider_id for p in client_with_empty_registry.providers.list() if p.api == "vector_io"][0]
|
||||||
response = client_with_empty_registry.vector_io.query(
|
response = client_with_empty_registry.vector_io.query(
|
||||||
vector_db_id=actual_vector_db_id,
|
vector_store_id=actual_vector_store_id,
|
||||||
query="duplicate",
|
query="duplicate",
|
||||||
params=vector_io_provider_params_dict.get(provider, None),
|
params=vector_io_provider_params_dict.get(provider, None),
|
||||||
)
|
)
|
||||||
|
|
|
||||||
|
|
@ -21,7 +21,7 @@ async def test_single_provider_auto_selection():
|
||||||
Mock(identifier="all-MiniLM-L6-v2", model_type="embedding", metadata={"embedding_dimension": 384})
|
Mock(identifier="all-MiniLM-L6-v2", model_type="embedding", metadata={"embedding_dimension": 384})
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
mock_routing_table.register_vector_db = AsyncMock(
|
mock_routing_table.register_vector_store = AsyncMock(
|
||||||
return_value=Mock(identifier="vs_123", provider_id="inline::faiss", provider_resource_id="vs_123")
|
return_value=Mock(identifier="vs_123", provider_id="inline::faiss", provider_resource_id="vs_123")
|
||||||
)
|
)
|
||||||
mock_routing_table.get_provider_impl = AsyncMock(
|
mock_routing_table.get_provider_impl = AsyncMock(
|
||||||
|
|
|
||||||
|
|
@ -10,7 +10,7 @@ from unittest.mock import AsyncMock, MagicMock, patch
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from llama_stack.apis.vector_dbs import VectorDB
|
from llama_stack.apis.vector_stores import VectorStore
|
||||||
from llama_stack.apis.vector_io import Chunk, ChunkMetadata, QueryChunksResponse
|
from llama_stack.apis.vector_io import Chunk, ChunkMetadata, QueryChunksResponse
|
||||||
from llama_stack.core.storage.datatypes import KVStoreReference, SqliteKVStoreConfig
|
from llama_stack.core.storage.datatypes import KVStoreReference, SqliteKVStoreConfig
|
||||||
from llama_stack.providers.inline.vector_io.faiss.config import FaissVectorIOConfig
|
from llama_stack.providers.inline.vector_io.faiss.config import FaissVectorIOConfig
|
||||||
|
|
@ -31,7 +31,7 @@ def vector_provider(request):
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def vector_db_id() -> str:
|
def vector_store_id() -> str:
|
||||||
return f"test-vector-db-{random.randint(1, 100)}"
|
return f"test-vector-db-{random.randint(1, 100)}"
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -149,8 +149,8 @@ async def sqlite_vec_adapter(sqlite_vec_db_path, unique_kvstore_config, mock_inf
|
||||||
)
|
)
|
||||||
collection_id = f"sqlite_test_collection_{np.random.randint(1e6)}"
|
collection_id = f"sqlite_test_collection_{np.random.randint(1e6)}"
|
||||||
await adapter.initialize()
|
await adapter.initialize()
|
||||||
await adapter.register_vector_db(
|
await adapter.register_vector_store(
|
||||||
VectorDB(
|
VectorStore(
|
||||||
identifier=collection_id,
|
identifier=collection_id,
|
||||||
provider_id="test_provider",
|
provider_id="test_provider",
|
||||||
embedding_model="test_model",
|
embedding_model="test_model",
|
||||||
|
|
@ -186,8 +186,8 @@ async def faiss_vec_adapter(unique_kvstore_config, mock_inference_api, embedding
|
||||||
files_api=None,
|
files_api=None,
|
||||||
)
|
)
|
||||||
await adapter.initialize()
|
await adapter.initialize()
|
||||||
await adapter.register_vector_db(
|
await adapter.register_vector_store(
|
||||||
VectorDB(
|
VectorStore(
|
||||||
identifier=f"faiss_test_collection_{np.random.randint(1e6)}",
|
identifier=f"faiss_test_collection_{np.random.randint(1e6)}",
|
||||||
provider_id="test_provider",
|
provider_id="test_provider",
|
||||||
embedding_model="test_model",
|
embedding_model="test_model",
|
||||||
|
|
@ -215,7 +215,7 @@ def mock_psycopg2_connection():
|
||||||
async def pgvector_vec_index(embedding_dimension, mock_psycopg2_connection):
|
async def pgvector_vec_index(embedding_dimension, mock_psycopg2_connection):
|
||||||
connection, cursor = mock_psycopg2_connection
|
connection, cursor = mock_psycopg2_connection
|
||||||
|
|
||||||
vector_db = VectorDB(
|
vector_store = VectorStore(
|
||||||
identifier="test-vector-db",
|
identifier="test-vector-db",
|
||||||
embedding_model="test-model",
|
embedding_model="test-model",
|
||||||
embedding_dimension=embedding_dimension,
|
embedding_dimension=embedding_dimension,
|
||||||
|
|
@ -225,7 +225,7 @@ async def pgvector_vec_index(embedding_dimension, mock_psycopg2_connection):
|
||||||
|
|
||||||
with patch("llama_stack.providers.remote.vector_io.pgvector.pgvector.psycopg2"):
|
with patch("llama_stack.providers.remote.vector_io.pgvector.pgvector.psycopg2"):
|
||||||
with patch("llama_stack.providers.remote.vector_io.pgvector.pgvector.execute_values"):
|
with patch("llama_stack.providers.remote.vector_io.pgvector.pgvector.execute_values"):
|
||||||
index = PGVectorIndex(vector_db, embedding_dimension, connection, distance_metric="COSINE")
|
index = PGVectorIndex(vector_store, embedding_dimension, connection, distance_metric="COSINE")
|
||||||
index._test_chunks = []
|
index._test_chunks = []
|
||||||
original_add_chunks = index.add_chunks
|
original_add_chunks = index.add_chunks
|
||||||
|
|
||||||
|
|
@ -281,30 +281,30 @@ async def pgvector_vec_adapter(unique_kvstore_config, mock_inference_api, embedd
|
||||||
await adapter.initialize()
|
await adapter.initialize()
|
||||||
adapter.conn = mock_conn
|
adapter.conn = mock_conn
|
||||||
|
|
||||||
async def mock_insert_chunks(vector_db_id, chunks, ttl_seconds=None):
|
async def mock_insert_chunks(vector_store_id, chunks, ttl_seconds=None):
|
||||||
index = await adapter._get_and_cache_vector_db_index(vector_db_id)
|
index = await adapter._get_and_cache_vector_store_index(vector_store_id)
|
||||||
if not index:
|
if not index:
|
||||||
raise ValueError(f"Vector DB {vector_db_id} not found")
|
raise ValueError(f"Vector DB {vector_store_id} not found")
|
||||||
await index.insert_chunks(chunks)
|
await index.insert_chunks(chunks)
|
||||||
|
|
||||||
adapter.insert_chunks = mock_insert_chunks
|
adapter.insert_chunks = mock_insert_chunks
|
||||||
|
|
||||||
async def mock_query_chunks(vector_db_id, query, params=None):
|
async def mock_query_chunks(vector_store_id, query, params=None):
|
||||||
index = await adapter._get_and_cache_vector_db_index(vector_db_id)
|
index = await adapter._get_and_cache_vector_store_index(vector_store_id)
|
||||||
if not index:
|
if not index:
|
||||||
raise ValueError(f"Vector DB {vector_db_id} not found")
|
raise ValueError(f"Vector DB {vector_store_id} not found")
|
||||||
return await index.query_chunks(query, params)
|
return await index.query_chunks(query, params)
|
||||||
|
|
||||||
adapter.query_chunks = mock_query_chunks
|
adapter.query_chunks = mock_query_chunks
|
||||||
|
|
||||||
test_vector_db = VectorDB(
|
test_vector_store = VectorStore(
|
||||||
identifier=f"pgvector_test_collection_{random.randint(1, 1_000_000)}",
|
identifier=f"pgvector_test_collection_{random.randint(1, 1_000_000)}",
|
||||||
provider_id="test_provider",
|
provider_id="test_provider",
|
||||||
embedding_model="test_model",
|
embedding_model="test_model",
|
||||||
embedding_dimension=embedding_dimension,
|
embedding_dimension=embedding_dimension,
|
||||||
)
|
)
|
||||||
await adapter.register_vector_db(test_vector_db)
|
await adapter.register_vector_store(test_vector_store)
|
||||||
adapter.test_collection_id = test_vector_db.identifier
|
adapter.test_collection_id = test_vector_store.identifier
|
||||||
|
|
||||||
yield adapter
|
yield adapter
|
||||||
await adapter.shutdown()
|
await adapter.shutdown()
|
||||||
|
|
|
||||||
|
|
@ -11,7 +11,7 @@ import numpy as np
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from llama_stack.apis.files import Files
|
from llama_stack.apis.files import Files
|
||||||
from llama_stack.apis.vector_dbs import VectorDB
|
from llama_stack.apis.vector_stores import VectorStore
|
||||||
from llama_stack.apis.vector_io import Chunk, QueryChunksResponse
|
from llama_stack.apis.vector_io import Chunk, QueryChunksResponse
|
||||||
from llama_stack.providers.datatypes import HealthStatus
|
from llama_stack.providers.datatypes import HealthStatus
|
||||||
from llama_stack.providers.inline.vector_io.faiss.config import FaissVectorIOConfig
|
from llama_stack.providers.inline.vector_io.faiss.config import FaissVectorIOConfig
|
||||||
|
|
@ -43,8 +43,8 @@ def embedding_dimension():
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def vector_db_id():
|
def vector_store_id():
|
||||||
return "test_vector_db"
|
return "test_vector_store"
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
|
|
@ -61,12 +61,12 @@ def sample_embeddings(embedding_dimension):
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def mock_vector_db(vector_db_id, embedding_dimension) -> MagicMock:
|
def mock_vector_store(vector_store_id, embedding_dimension) -> MagicMock:
|
||||||
mock_vector_db = MagicMock(spec=VectorDB)
|
mock_vector_store = MagicMock(spec=VectorStore)
|
||||||
mock_vector_db.embedding_model = "mock_embedding_model"
|
mock_vector_store.embedding_model = "mock_embedding_model"
|
||||||
mock_vector_db.identifier = vector_db_id
|
mock_vector_store.identifier = vector_store_id
|
||||||
mock_vector_db.embedding_dimension = embedding_dimension
|
mock_vector_store.embedding_dimension = embedding_dimension
|
||||||
return mock_vector_db
|
return mock_vector_store
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
|
|
|
||||||
|
|
@ -12,7 +12,7 @@ import numpy as np
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from llama_stack.apis.common.errors import VectorStoreNotFoundError
|
from llama_stack.apis.common.errors import VectorStoreNotFoundError
|
||||||
from llama_stack.apis.vector_dbs import VectorDB
|
from llama_stack.apis.vector_stores import VectorStore
|
||||||
from llama_stack.apis.vector_io import (
|
from llama_stack.apis.vector_io import (
|
||||||
Chunk,
|
Chunk,
|
||||||
OpenAICreateVectorStoreFileBatchRequestWithExtraBody,
|
OpenAICreateVectorStoreFileBatchRequestWithExtraBody,
|
||||||
|
|
@ -71,7 +71,7 @@ async def test_chunk_id_conflict(vector_index, sample_chunks, embedding_dimensio
|
||||||
|
|
||||||
async def test_initialize_adapter_with_existing_kvstore(vector_io_adapter):
|
async def test_initialize_adapter_with_existing_kvstore(vector_io_adapter):
|
||||||
key = f"{VECTOR_DBS_PREFIX}db1"
|
key = f"{VECTOR_DBS_PREFIX}db1"
|
||||||
dummy = VectorDB(
|
dummy = VectorStore(
|
||||||
identifier="foo_db", provider_id="test_provider", embedding_model="test_model", embedding_dimension=128
|
identifier="foo_db", provider_id="test_provider", embedding_model="test_model", embedding_dimension=128
|
||||||
)
|
)
|
||||||
await vector_io_adapter.kvstore.set(key=key, value=json.dumps(dummy.model_dump()))
|
await vector_io_adapter.kvstore.set(key=key, value=json.dumps(dummy.model_dump()))
|
||||||
|
|
@ -81,10 +81,10 @@ async def test_initialize_adapter_with_existing_kvstore(vector_io_adapter):
|
||||||
|
|
||||||
async def test_persistence_across_adapter_restarts(vector_io_adapter):
|
async def test_persistence_across_adapter_restarts(vector_io_adapter):
|
||||||
await vector_io_adapter.initialize()
|
await vector_io_adapter.initialize()
|
||||||
dummy = VectorDB(
|
dummy = VectorStore(
|
||||||
identifier="foo_db", provider_id="test_provider", embedding_model="test_model", embedding_dimension=128
|
identifier="foo_db", provider_id="test_provider", embedding_model="test_model", embedding_dimension=128
|
||||||
)
|
)
|
||||||
await vector_io_adapter.register_vector_db(dummy)
|
await vector_io_adapter.register_vector_store(dummy)
|
||||||
await vector_io_adapter.shutdown()
|
await vector_io_adapter.shutdown()
|
||||||
|
|
||||||
await vector_io_adapter.initialize()
|
await vector_io_adapter.initialize()
|
||||||
|
|
@ -92,15 +92,15 @@ async def test_persistence_across_adapter_restarts(vector_io_adapter):
|
||||||
await vector_io_adapter.shutdown()
|
await vector_io_adapter.shutdown()
|
||||||
|
|
||||||
|
|
||||||
async def test_register_and_unregister_vector_db(vector_io_adapter):
|
async def test_register_and_unregister_vector_store(vector_io_adapter):
|
||||||
unique_id = f"foo_db_{np.random.randint(1e6)}"
|
unique_id = f"foo_db_{np.random.randint(1e6)}"
|
||||||
dummy = VectorDB(
|
dummy = VectorStore(
|
||||||
identifier=unique_id, provider_id="test_provider", embedding_model="test_model", embedding_dimension=128
|
identifier=unique_id, provider_id="test_provider", embedding_model="test_model", embedding_dimension=128
|
||||||
)
|
)
|
||||||
|
|
||||||
await vector_io_adapter.register_vector_db(dummy)
|
await vector_io_adapter.register_vector_store(dummy)
|
||||||
assert dummy.identifier in vector_io_adapter.cache
|
assert dummy.identifier in vector_io_adapter.cache
|
||||||
await vector_io_adapter.unregister_vector_db(dummy.identifier)
|
await vector_io_adapter.unregister_vector_store(dummy.identifier)
|
||||||
assert dummy.identifier not in vector_io_adapter.cache
|
assert dummy.identifier not in vector_io_adapter.cache
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -121,7 +121,7 @@ async def test_insert_chunks_calls_underlying_index(vector_io_adapter):
|
||||||
|
|
||||||
|
|
||||||
async def test_insert_chunks_missing_db_raises(vector_io_adapter):
|
async def test_insert_chunks_missing_db_raises(vector_io_adapter):
|
||||||
vector_io_adapter._get_and_cache_vector_db_index = AsyncMock(return_value=None)
|
vector_io_adapter._get_and_cache_vector_store_index = AsyncMock(return_value=None)
|
||||||
|
|
||||||
with pytest.raises(ValueError):
|
with pytest.raises(ValueError):
|
||||||
await vector_io_adapter.insert_chunks("db_not_exist", [])
|
await vector_io_adapter.insert_chunks("db_not_exist", [])
|
||||||
|
|
@ -170,7 +170,7 @@ async def test_query_chunks_calls_underlying_index_and_returns(vector_io_adapter
|
||||||
|
|
||||||
|
|
||||||
async def test_query_chunks_missing_db_raises(vector_io_adapter):
|
async def test_query_chunks_missing_db_raises(vector_io_adapter):
|
||||||
vector_io_adapter._get_and_cache_vector_db_index = AsyncMock(return_value=None)
|
vector_io_adapter._get_and_cache_vector_store_index = AsyncMock(return_value=None)
|
||||||
|
|
||||||
with pytest.raises(ValueError):
|
with pytest.raises(ValueError):
|
||||||
await vector_io_adapter.query_chunks("db_missing", "q", None)
|
await vector_io_adapter.query_chunks("db_missing", "q", None)
|
||||||
|
|
@ -182,7 +182,7 @@ async def test_save_openai_vector_store(vector_io_adapter):
|
||||||
"id": store_id,
|
"id": store_id,
|
||||||
"name": "Test Store",
|
"name": "Test Store",
|
||||||
"description": "A test OpenAI vector store",
|
"description": "A test OpenAI vector store",
|
||||||
"vector_db_id": "test_db",
|
"vector_store_id": "test_db",
|
||||||
"embedding_model": "test_model",
|
"embedding_model": "test_model",
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -198,7 +198,7 @@ async def test_update_openai_vector_store(vector_io_adapter):
|
||||||
"id": store_id,
|
"id": store_id,
|
||||||
"name": "Test Store",
|
"name": "Test Store",
|
||||||
"description": "A test OpenAI vector store",
|
"description": "A test OpenAI vector store",
|
||||||
"vector_db_id": "test_db",
|
"vector_store_id": "test_db",
|
||||||
"embedding_model": "test_model",
|
"embedding_model": "test_model",
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -214,7 +214,7 @@ async def test_delete_openai_vector_store(vector_io_adapter):
|
||||||
"id": store_id,
|
"id": store_id,
|
||||||
"name": "Test Store",
|
"name": "Test Store",
|
||||||
"description": "A test OpenAI vector store",
|
"description": "A test OpenAI vector store",
|
||||||
"vector_db_id": "test_db",
|
"vector_store_id": "test_db",
|
||||||
"embedding_model": "test_model",
|
"embedding_model": "test_model",
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -229,7 +229,7 @@ async def test_load_openai_vector_stores(vector_io_adapter):
|
||||||
"id": store_id,
|
"id": store_id,
|
||||||
"name": "Test Store",
|
"name": "Test Store",
|
||||||
"description": "A test OpenAI vector store",
|
"description": "A test OpenAI vector store",
|
||||||
"vector_db_id": "test_db",
|
"vector_store_id": "test_db",
|
||||||
"embedding_model": "test_model",
|
"embedding_model": "test_model",
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -998,8 +998,8 @@ async def test_max_concurrent_files_per_batch(vector_io_adapter):
|
||||||
async def test_embedding_config_from_metadata(vector_io_adapter):
|
async def test_embedding_config_from_metadata(vector_io_adapter):
|
||||||
"""Test that embedding configuration is correctly extracted from metadata."""
|
"""Test that embedding configuration is correctly extracted from metadata."""
|
||||||
|
|
||||||
# Mock register_vector_db to avoid actual registration
|
# Mock register_vector_store to avoid actual registration
|
||||||
vector_io_adapter.register_vector_db = AsyncMock()
|
vector_io_adapter.register_vector_store = AsyncMock()
|
||||||
# Set provider_id attribute for the adapter
|
# Set provider_id attribute for the adapter
|
||||||
vector_io_adapter.__provider_id__ = "test_provider"
|
vector_io_adapter.__provider_id__ = "test_provider"
|
||||||
|
|
||||||
|
|
@ -1015,9 +1015,9 @@ async def test_embedding_config_from_metadata(vector_io_adapter):
|
||||||
|
|
||||||
await vector_io_adapter.openai_create_vector_store(params)
|
await vector_io_adapter.openai_create_vector_store(params)
|
||||||
|
|
||||||
# Verify VectorDB was registered with correct embedding config from metadata
|
# Verify VectorStore was registered with correct embedding config from metadata
|
||||||
vector_io_adapter.register_vector_db.assert_called_once()
|
vector_io_adapter.register_vector_store.assert_called_once()
|
||||||
call_args = vector_io_adapter.register_vector_db.call_args[0][0]
|
call_args = vector_io_adapter.register_vector_store.call_args[0][0]
|
||||||
assert call_args.embedding_model == "test-embedding-model"
|
assert call_args.embedding_model == "test-embedding-model"
|
||||||
assert call_args.embedding_dimension == 512
|
assert call_args.embedding_dimension == 512
|
||||||
|
|
||||||
|
|
@ -1025,8 +1025,8 @@ async def test_embedding_config_from_metadata(vector_io_adapter):
|
||||||
async def test_embedding_config_from_extra_body(vector_io_adapter):
|
async def test_embedding_config_from_extra_body(vector_io_adapter):
|
||||||
"""Test that embedding configuration is correctly extracted from extra_body when metadata is empty."""
|
"""Test that embedding configuration is correctly extracted from extra_body when metadata is empty."""
|
||||||
|
|
||||||
# Mock register_vector_db to avoid actual registration
|
# Mock register_vector_store to avoid actual registration
|
||||||
vector_io_adapter.register_vector_db = AsyncMock()
|
vector_io_adapter.register_vector_store = AsyncMock()
|
||||||
# Set provider_id attribute for the adapter
|
# Set provider_id attribute for the adapter
|
||||||
vector_io_adapter.__provider_id__ = "test_provider"
|
vector_io_adapter.__provider_id__ = "test_provider"
|
||||||
|
|
||||||
|
|
@ -1042,9 +1042,9 @@ async def test_embedding_config_from_extra_body(vector_io_adapter):
|
||||||
|
|
||||||
await vector_io_adapter.openai_create_vector_store(params)
|
await vector_io_adapter.openai_create_vector_store(params)
|
||||||
|
|
||||||
# Verify VectorDB was registered with correct embedding config from extra_body
|
# Verify VectorStore was registered with correct embedding config from extra_body
|
||||||
vector_io_adapter.register_vector_db.assert_called_once()
|
vector_io_adapter.register_vector_store.assert_called_once()
|
||||||
call_args = vector_io_adapter.register_vector_db.call_args[0][0]
|
call_args = vector_io_adapter.register_vector_store.call_args[0][0]
|
||||||
assert call_args.embedding_model == "extra-body-model"
|
assert call_args.embedding_model == "extra-body-model"
|
||||||
assert call_args.embedding_dimension == 1024
|
assert call_args.embedding_dimension == 1024
|
||||||
|
|
||||||
|
|
@ -1052,8 +1052,8 @@ async def test_embedding_config_from_extra_body(vector_io_adapter):
|
||||||
async def test_embedding_config_consistency_check_passes(vector_io_adapter):
|
async def test_embedding_config_consistency_check_passes(vector_io_adapter):
|
||||||
"""Test that consistent embedding config in both metadata and extra_body passes validation."""
|
"""Test that consistent embedding config in both metadata and extra_body passes validation."""
|
||||||
|
|
||||||
# Mock register_vector_db to avoid actual registration
|
# Mock register_vector_store to avoid actual registration
|
||||||
vector_io_adapter.register_vector_db = AsyncMock()
|
vector_io_adapter.register_vector_store = AsyncMock()
|
||||||
# Set provider_id attribute for the adapter
|
# Set provider_id attribute for the adapter
|
||||||
vector_io_adapter.__provider_id__ = "test_provider"
|
vector_io_adapter.__provider_id__ = "test_provider"
|
||||||
|
|
||||||
|
|
@ -1073,8 +1073,8 @@ async def test_embedding_config_consistency_check_passes(vector_io_adapter):
|
||||||
await vector_io_adapter.openai_create_vector_store(params)
|
await vector_io_adapter.openai_create_vector_store(params)
|
||||||
|
|
||||||
# Should not raise any error and use metadata config
|
# Should not raise any error and use metadata config
|
||||||
vector_io_adapter.register_vector_db.assert_called_once()
|
vector_io_adapter.register_vector_store.assert_called_once()
|
||||||
call_args = vector_io_adapter.register_vector_db.call_args[0][0]
|
call_args = vector_io_adapter.register_vector_store.call_args[0][0]
|
||||||
assert call_args.embedding_model == "consistent-model"
|
assert call_args.embedding_model == "consistent-model"
|
||||||
assert call_args.embedding_dimension == 768
|
assert call_args.embedding_dimension == 768
|
||||||
|
|
||||||
|
|
@ -1082,8 +1082,8 @@ async def test_embedding_config_consistency_check_passes(vector_io_adapter):
|
||||||
async def test_embedding_config_inconsistency_errors(vector_io_adapter):
|
async def test_embedding_config_inconsistency_errors(vector_io_adapter):
|
||||||
"""Test that inconsistent embedding config between metadata and extra_body raises errors."""
|
"""Test that inconsistent embedding config between metadata and extra_body raises errors."""
|
||||||
|
|
||||||
# Mock register_vector_db to avoid actual registration
|
# Mock register_vector_store to avoid actual registration
|
||||||
vector_io_adapter.register_vector_db = AsyncMock()
|
vector_io_adapter.register_vector_store = AsyncMock()
|
||||||
# Set provider_id attribute for the adapter
|
# Set provider_id attribute for the adapter
|
||||||
vector_io_adapter.__provider_id__ = "test_provider"
|
vector_io_adapter.__provider_id__ = "test_provider"
|
||||||
|
|
||||||
|
|
@ -1104,7 +1104,7 @@ async def test_embedding_config_inconsistency_errors(vector_io_adapter):
|
||||||
await vector_io_adapter.openai_create_vector_store(params)
|
await vector_io_adapter.openai_create_vector_store(params)
|
||||||
|
|
||||||
# Reset mock for second test
|
# Reset mock for second test
|
||||||
vector_io_adapter.register_vector_db.reset_mock()
|
vector_io_adapter.register_vector_store.reset_mock()
|
||||||
|
|
||||||
# Test with inconsistent embedding dimension
|
# Test with inconsistent embedding dimension
|
||||||
params = OpenAICreateVectorStoreRequestWithExtraBody(
|
params = OpenAICreateVectorStoreRequestWithExtraBody(
|
||||||
|
|
@ -1126,8 +1126,8 @@ async def test_embedding_config_inconsistency_errors(vector_io_adapter):
|
||||||
async def test_embedding_config_defaults_when_missing(vector_io_adapter):
|
async def test_embedding_config_defaults_when_missing(vector_io_adapter):
|
||||||
"""Test that embedding dimension defaults to 768 when not provided."""
|
"""Test that embedding dimension defaults to 768 when not provided."""
|
||||||
|
|
||||||
# Mock register_vector_db to avoid actual registration
|
# Mock register_vector_store to avoid actual registration
|
||||||
vector_io_adapter.register_vector_db = AsyncMock()
|
vector_io_adapter.register_vector_store = AsyncMock()
|
||||||
# Set provider_id attribute for the adapter
|
# Set provider_id attribute for the adapter
|
||||||
vector_io_adapter.__provider_id__ = "test_provider"
|
vector_io_adapter.__provider_id__ = "test_provider"
|
||||||
|
|
||||||
|
|
@ -1143,8 +1143,8 @@ async def test_embedding_config_defaults_when_missing(vector_io_adapter):
|
||||||
await vector_io_adapter.openai_create_vector_store(params)
|
await vector_io_adapter.openai_create_vector_store(params)
|
||||||
|
|
||||||
# Should default to 768 dimensions
|
# Should default to 768 dimensions
|
||||||
vector_io_adapter.register_vector_db.assert_called_once()
|
vector_io_adapter.register_vector_store.assert_called_once()
|
||||||
call_args = vector_io_adapter.register_vector_db.call_args[0][0]
|
call_args = vector_io_adapter.register_vector_store.call_args[0][0]
|
||||||
assert call_args.embedding_model == "model-without-dimension"
|
assert call_args.embedding_model == "model-without-dimension"
|
||||||
assert call_args.embedding_dimension == 768
|
assert call_args.embedding_dimension == 768
|
||||||
|
|
||||||
|
|
@ -1152,8 +1152,8 @@ async def test_embedding_config_defaults_when_missing(vector_io_adapter):
|
||||||
async def test_embedding_config_required_model_missing(vector_io_adapter):
|
async def test_embedding_config_required_model_missing(vector_io_adapter):
|
||||||
"""Test that missing embedding model raises error."""
|
"""Test that missing embedding model raises error."""
|
||||||
|
|
||||||
# Mock register_vector_db to avoid actual registration
|
# Mock register_vector_store to avoid actual registration
|
||||||
vector_io_adapter.register_vector_db = AsyncMock()
|
vector_io_adapter.register_vector_store = AsyncMock()
|
||||||
# Set provider_id attribute for the adapter
|
# Set provider_id attribute for the adapter
|
||||||
vector_io_adapter.__provider_id__ = "test_provider"
|
vector_io_adapter.__provider_id__ = "test_provider"
|
||||||
# Mock the default model lookup to return None (no default model available)
|
# Mock the default model lookup to return None (no default model available)
|
||||||
|
|
|
||||||
|
|
@ -18,19 +18,19 @@ from llama_stack.providers.inline.tool_runtime.rag.memory import MemoryToolRunti
|
||||||
|
|
||||||
|
|
||||||
class TestRagQuery:
|
class TestRagQuery:
|
||||||
async def test_query_raises_on_empty_vector_db_ids(self):
|
async def test_query_raises_on_empty_vector_store_ids(self):
|
||||||
rag_tool = MemoryToolRuntimeImpl(
|
rag_tool = MemoryToolRuntimeImpl(
|
||||||
config=MagicMock(), vector_io_api=MagicMock(), inference_api=MagicMock(), files_api=MagicMock()
|
config=MagicMock(), vector_io_api=MagicMock(), inference_api=MagicMock(), files_api=MagicMock()
|
||||||
)
|
)
|
||||||
with pytest.raises(ValueError):
|
with pytest.raises(ValueError):
|
||||||
await rag_tool.query(content=MagicMock(), vector_db_ids=[])
|
await rag_tool.query(content=MagicMock(), vector_store_ids=[])
|
||||||
|
|
||||||
async def test_query_chunk_metadata_handling(self):
|
async def test_query_chunk_metadata_handling(self):
|
||||||
rag_tool = MemoryToolRuntimeImpl(
|
rag_tool = MemoryToolRuntimeImpl(
|
||||||
config=MagicMock(), vector_io_api=MagicMock(), inference_api=MagicMock(), files_api=MagicMock()
|
config=MagicMock(), vector_io_api=MagicMock(), inference_api=MagicMock(), files_api=MagicMock()
|
||||||
)
|
)
|
||||||
content = "test query content"
|
content = "test query content"
|
||||||
vector_db_ids = ["db1"]
|
vector_store_ids = ["db1"]
|
||||||
|
|
||||||
chunk_metadata = ChunkMetadata(
|
chunk_metadata = ChunkMetadata(
|
||||||
document_id="doc1",
|
document_id="doc1",
|
||||||
|
|
@ -55,7 +55,7 @@ class TestRagQuery:
|
||||||
query_response = QueryChunksResponse(chunks=[chunk], scores=[1.0])
|
query_response = QueryChunksResponse(chunks=[chunk], scores=[1.0])
|
||||||
|
|
||||||
rag_tool.vector_io_api.query_chunks = AsyncMock(return_value=query_response)
|
rag_tool.vector_io_api.query_chunks = AsyncMock(return_value=query_response)
|
||||||
result = await rag_tool.query(content=content, vector_db_ids=vector_db_ids)
|
result = await rag_tool.query(content=content, vector_store_ids=vector_store_ids)
|
||||||
|
|
||||||
assert result is not None
|
assert result is not None
|
||||||
expected_metadata_string = (
|
expected_metadata_string = (
|
||||||
|
|
@ -82,7 +82,7 @@ class TestRagQuery:
|
||||||
with pytest.raises(ValueError):
|
with pytest.raises(ValueError):
|
||||||
RAGQueryConfig(mode="wrong_mode")
|
RAGQueryConfig(mode="wrong_mode")
|
||||||
|
|
||||||
async def test_query_adds_vector_db_id_to_chunk_metadata(self):
|
async def test_query_adds_vector_store_id_to_chunk_metadata(self):
|
||||||
rag_tool = MemoryToolRuntimeImpl(
|
rag_tool = MemoryToolRuntimeImpl(
|
||||||
config=MagicMock(),
|
config=MagicMock(),
|
||||||
vector_io_api=MagicMock(),
|
vector_io_api=MagicMock(),
|
||||||
|
|
@ -90,7 +90,7 @@ class TestRagQuery:
|
||||||
files_api=MagicMock(),
|
files_api=MagicMock(),
|
||||||
)
|
)
|
||||||
|
|
||||||
vector_db_ids = ["db1", "db2"]
|
vector_store_ids = ["db1", "db2"]
|
||||||
|
|
||||||
# Fake chunks from each DB
|
# Fake chunks from each DB
|
||||||
chunk_metadata1 = ChunkMetadata(
|
chunk_metadata1 = ChunkMetadata(
|
||||||
|
|
@ -101,7 +101,7 @@ class TestRagQuery:
|
||||||
)
|
)
|
||||||
chunk1 = Chunk(
|
chunk1 = Chunk(
|
||||||
content="chunk from db1",
|
content="chunk from db1",
|
||||||
metadata={"vector_db_id": "db1", "document_id": "doc1"},
|
metadata={"vector_store_id": "db1", "document_id": "doc1"},
|
||||||
stored_chunk_id="c1",
|
stored_chunk_id="c1",
|
||||||
chunk_metadata=chunk_metadata1,
|
chunk_metadata=chunk_metadata1,
|
||||||
)
|
)
|
||||||
|
|
@ -114,7 +114,7 @@ class TestRagQuery:
|
||||||
)
|
)
|
||||||
chunk2 = Chunk(
|
chunk2 = Chunk(
|
||||||
content="chunk from db2",
|
content="chunk from db2",
|
||||||
metadata={"vector_db_id": "db2", "document_id": "doc2"},
|
metadata={"vector_store_id": "db2", "document_id": "doc2"},
|
||||||
stored_chunk_id="c2",
|
stored_chunk_id="c2",
|
||||||
chunk_metadata=chunk_metadata2,
|
chunk_metadata=chunk_metadata2,
|
||||||
)
|
)
|
||||||
|
|
@ -126,13 +126,13 @@ class TestRagQuery:
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
|
|
||||||
result = await rag_tool.query(content="test", vector_db_ids=vector_db_ids)
|
result = await rag_tool.query(content="test", vector_store_ids=vector_store_ids)
|
||||||
returned_chunks = result.metadata["chunks"]
|
returned_chunks = result.metadata["chunks"]
|
||||||
returned_scores = result.metadata["scores"]
|
returned_scores = result.metadata["scores"]
|
||||||
returned_doc_ids = result.metadata["document_ids"]
|
returned_doc_ids = result.metadata["document_ids"]
|
||||||
returned_vector_db_ids = result.metadata["vector_db_ids"]
|
returned_vector_store_ids = result.metadata["vector_store_ids"]
|
||||||
|
|
||||||
assert returned_chunks == ["chunk from db1", "chunk from db2"]
|
assert returned_chunks == ["chunk from db1", "chunk from db2"]
|
||||||
assert returned_scores == (0.9, 0.8)
|
assert returned_scores == (0.9, 0.8)
|
||||||
assert returned_doc_ids == ["doc1", "doc2"]
|
assert returned_doc_ids == ["doc1", "doc2"]
|
||||||
assert returned_vector_db_ids == ["db1", "db2"]
|
assert returned_vector_store_ids == ["db1", "db2"]
|
||||||
|
|
|
||||||
|
|
@ -21,7 +21,7 @@ from llama_stack.apis.tools import RAGDocument
|
||||||
from llama_stack.apis.vector_io import Chunk
|
from llama_stack.apis.vector_io import Chunk
|
||||||
from llama_stack.providers.utils.memory.vector_store import (
|
from llama_stack.providers.utils.memory.vector_store import (
|
||||||
URL,
|
URL,
|
||||||
VectorDBWithIndex,
|
VectorStoreWithIndex,
|
||||||
_validate_embedding,
|
_validate_embedding,
|
||||||
content_from_doc,
|
content_from_doc,
|
||||||
make_overlapped_chunks,
|
make_overlapped_chunks,
|
||||||
|
|
@ -206,15 +206,15 @@ class TestVectorStore:
|
||||||
assert str(excinfo.value.__cause__) == "Cannot convert to string"
|
assert str(excinfo.value.__cause__) == "Cannot convert to string"
|
||||||
|
|
||||||
|
|
||||||
class TestVectorDBWithIndex:
|
class TestVectorStoreWithIndex:
|
||||||
async def test_insert_chunks_without_embeddings(self):
|
async def test_insert_chunks_without_embeddings(self):
|
||||||
mock_vector_db = MagicMock()
|
mock_vector_store = MagicMock()
|
||||||
mock_vector_db.embedding_model = "test-model without embeddings"
|
mock_vector_store.embedding_model = "test-model without embeddings"
|
||||||
mock_index = AsyncMock()
|
mock_index = AsyncMock()
|
||||||
mock_inference_api = AsyncMock()
|
mock_inference_api = AsyncMock()
|
||||||
|
|
||||||
vector_db_with_index = VectorDBWithIndex(
|
vector_store_with_index = VectorStoreWithIndex(
|
||||||
vector_db=mock_vector_db, index=mock_index, inference_api=mock_inference_api
|
vector_store=mock_vector_store, index=mock_index, inference_api=mock_inference_api
|
||||||
)
|
)
|
||||||
|
|
||||||
chunks = [
|
chunks = [
|
||||||
|
|
@ -227,7 +227,7 @@ class TestVectorDBWithIndex:
|
||||||
OpenAIEmbeddingData(embedding=[0.4, 0.5, 0.6], index=1),
|
OpenAIEmbeddingData(embedding=[0.4, 0.5, 0.6], index=1),
|
||||||
]
|
]
|
||||||
|
|
||||||
await vector_db_with_index.insert_chunks(chunks)
|
await vector_store_with_index.insert_chunks(chunks)
|
||||||
|
|
||||||
# Verify openai_embeddings was called with correct params
|
# Verify openai_embeddings was called with correct params
|
||||||
mock_inference_api.openai_embeddings.assert_called_once()
|
mock_inference_api.openai_embeddings.assert_called_once()
|
||||||
|
|
@ -243,14 +243,14 @@ class TestVectorDBWithIndex:
|
||||||
assert np.array_equal(args[1], np.array([[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]], dtype=np.float32))
|
assert np.array_equal(args[1], np.array([[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]], dtype=np.float32))
|
||||||
|
|
||||||
async def test_insert_chunks_with_valid_embeddings(self):
|
async def test_insert_chunks_with_valid_embeddings(self):
|
||||||
mock_vector_db = MagicMock()
|
mock_vector_store = MagicMock()
|
||||||
mock_vector_db.embedding_model = "test-model with embeddings"
|
mock_vector_store.embedding_model = "test-model with embeddings"
|
||||||
mock_vector_db.embedding_dimension = 3
|
mock_vector_store.embedding_dimension = 3
|
||||||
mock_index = AsyncMock()
|
mock_index = AsyncMock()
|
||||||
mock_inference_api = AsyncMock()
|
mock_inference_api = AsyncMock()
|
||||||
|
|
||||||
vector_db_with_index = VectorDBWithIndex(
|
vector_store_with_index = VectorStoreWithIndex(
|
||||||
vector_db=mock_vector_db, index=mock_index, inference_api=mock_inference_api
|
vector_store=mock_vector_store, index=mock_index, inference_api=mock_inference_api
|
||||||
)
|
)
|
||||||
|
|
||||||
chunks = [
|
chunks = [
|
||||||
|
|
@ -258,7 +258,7 @@ class TestVectorDBWithIndex:
|
||||||
Chunk(content="Test 2", embedding=[0.4, 0.5, 0.6], metadata={}),
|
Chunk(content="Test 2", embedding=[0.4, 0.5, 0.6], metadata={}),
|
||||||
]
|
]
|
||||||
|
|
||||||
await vector_db_with_index.insert_chunks(chunks)
|
await vector_store_with_index.insert_chunks(chunks)
|
||||||
|
|
||||||
mock_inference_api.openai_embeddings.assert_not_called()
|
mock_inference_api.openai_embeddings.assert_not_called()
|
||||||
mock_index.add_chunks.assert_called_once()
|
mock_index.add_chunks.assert_called_once()
|
||||||
|
|
@ -267,14 +267,14 @@ class TestVectorDBWithIndex:
|
||||||
assert np.array_equal(args[1], np.array([[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]], dtype=np.float32))
|
assert np.array_equal(args[1], np.array([[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]], dtype=np.float32))
|
||||||
|
|
||||||
async def test_insert_chunks_with_invalid_embeddings(self):
|
async def test_insert_chunks_with_invalid_embeddings(self):
|
||||||
mock_vector_db = MagicMock()
|
mock_vector_store = MagicMock()
|
||||||
mock_vector_db.embedding_dimension = 3
|
mock_vector_store.embedding_dimension = 3
|
||||||
mock_vector_db.embedding_model = "test-model with invalid embeddings"
|
mock_vector_store.embedding_model = "test-model with invalid embeddings"
|
||||||
mock_index = AsyncMock()
|
mock_index = AsyncMock()
|
||||||
mock_inference_api = AsyncMock()
|
mock_inference_api = AsyncMock()
|
||||||
|
|
||||||
vector_db_with_index = VectorDBWithIndex(
|
vector_store_with_index = VectorStoreWithIndex(
|
||||||
vector_db=mock_vector_db, index=mock_index, inference_api=mock_inference_api
|
vector_store=mock_vector_store, index=mock_index, inference_api=mock_inference_api
|
||||||
)
|
)
|
||||||
|
|
||||||
# Verify Chunk raises ValueError for invalid embedding type
|
# Verify Chunk raises ValueError for invalid embedding type
|
||||||
|
|
@ -283,7 +283,7 @@ class TestVectorDBWithIndex:
|
||||||
|
|
||||||
# Verify Chunk raises ValueError for invalid embedding type in insert_chunks (i.e., Chunk errors before insert_chunks is called)
|
# Verify Chunk raises ValueError for invalid embedding type in insert_chunks (i.e., Chunk errors before insert_chunks is called)
|
||||||
with pytest.raises(ValueError, match="Input should be a valid list"):
|
with pytest.raises(ValueError, match="Input should be a valid list"):
|
||||||
await vector_db_with_index.insert_chunks(
|
await vector_store_with_index.insert_chunks(
|
||||||
[
|
[
|
||||||
Chunk(content="Test 1", embedding=None, metadata={}),
|
Chunk(content="Test 1", embedding=None, metadata={}),
|
||||||
Chunk(content="Test 2", embedding="invalid_type", metadata={}),
|
Chunk(content="Test 2", embedding="invalid_type", metadata={}),
|
||||||
|
|
@ -292,7 +292,7 @@ class TestVectorDBWithIndex:
|
||||||
|
|
||||||
# Verify Chunk raises ValueError for invalid embedding element type in insert_chunks (i.e., Chunk errors before insert_chunks is called)
|
# Verify Chunk raises ValueError for invalid embedding element type in insert_chunks (i.e., Chunk errors before insert_chunks is called)
|
||||||
with pytest.raises(ValueError, match=" Input should be a valid number, unable to parse string as a number "):
|
with pytest.raises(ValueError, match=" Input should be a valid number, unable to parse string as a number "):
|
||||||
await vector_db_with_index.insert_chunks(
|
await vector_store_with_index.insert_chunks(
|
||||||
Chunk(content="Test 1", embedding=[0.1, "string", 0.3], metadata={})
|
Chunk(content="Test 1", embedding=[0.1, "string", 0.3], metadata={})
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
@ -300,20 +300,20 @@ class TestVectorDBWithIndex:
|
||||||
Chunk(content="Test 1", embedding=[0.1, 0.2, 0.3, 0.4], metadata={}),
|
Chunk(content="Test 1", embedding=[0.1, 0.2, 0.3, 0.4], metadata={}),
|
||||||
]
|
]
|
||||||
with pytest.raises(ValueError, match="has dimension 4, expected 3"):
|
with pytest.raises(ValueError, match="has dimension 4, expected 3"):
|
||||||
await vector_db_with_index.insert_chunks(chunks_wrong_dim)
|
await vector_store_with_index.insert_chunks(chunks_wrong_dim)
|
||||||
|
|
||||||
mock_inference_api.openai_embeddings.assert_not_called()
|
mock_inference_api.openai_embeddings.assert_not_called()
|
||||||
mock_index.add_chunks.assert_not_called()
|
mock_index.add_chunks.assert_not_called()
|
||||||
|
|
||||||
async def test_insert_chunks_with_partially_precomputed_embeddings(self):
|
async def test_insert_chunks_with_partially_precomputed_embeddings(self):
|
||||||
mock_vector_db = MagicMock()
|
mock_vector_store = MagicMock()
|
||||||
mock_vector_db.embedding_model = "test-model with partial embeddings"
|
mock_vector_store.embedding_model = "test-model with partial embeddings"
|
||||||
mock_vector_db.embedding_dimension = 3
|
mock_vector_store.embedding_dimension = 3
|
||||||
mock_index = AsyncMock()
|
mock_index = AsyncMock()
|
||||||
mock_inference_api = AsyncMock()
|
mock_inference_api = AsyncMock()
|
||||||
|
|
||||||
vector_db_with_index = VectorDBWithIndex(
|
vector_store_with_index = VectorStoreWithIndex(
|
||||||
vector_db=mock_vector_db, index=mock_index, inference_api=mock_inference_api
|
vector_store=mock_vector_store, index=mock_index, inference_api=mock_inference_api
|
||||||
)
|
)
|
||||||
|
|
||||||
chunks = [
|
chunks = [
|
||||||
|
|
@ -327,7 +327,7 @@ class TestVectorDBWithIndex:
|
||||||
OpenAIEmbeddingData(embedding=[0.3, 0.3, 0.3], index=1),
|
OpenAIEmbeddingData(embedding=[0.3, 0.3, 0.3], index=1),
|
||||||
]
|
]
|
||||||
|
|
||||||
await vector_db_with_index.insert_chunks(chunks)
|
await vector_store_with_index.insert_chunks(chunks)
|
||||||
|
|
||||||
# Verify openai_embeddings was called with correct params
|
# Verify openai_embeddings was called with correct params
|
||||||
mock_inference_api.openai_embeddings.assert_called_once()
|
mock_inference_api.openai_embeddings.assert_called_once()
|
||||||
|
|
|
||||||
|
|
@ -8,8 +8,8 @@
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from llama_stack.apis.inference import Model
|
from llama_stack.apis.inference import Model
|
||||||
from llama_stack.apis.vector_dbs import VectorDB
|
from llama_stack.apis.vector_stores import VectorStore
|
||||||
from llama_stack.core.datatypes import VectorDBWithOwner
|
from llama_stack.core.datatypes import VectorStoreWithOwner
|
||||||
from llama_stack.core.storage.datatypes import KVStoreReference, SqliteKVStoreConfig
|
from llama_stack.core.storage.datatypes import KVStoreReference, SqliteKVStoreConfig
|
||||||
from llama_stack.core.store.registry import (
|
from llama_stack.core.store.registry import (
|
||||||
KEY_FORMAT,
|
KEY_FORMAT,
|
||||||
|
|
@ -20,12 +20,12 @@ from llama_stack.providers.utils.kvstore import kvstore_impl, register_kvstore_b
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def sample_vector_db():
|
def sample_vector_store():
|
||||||
return VectorDB(
|
return VectorStore(
|
||||||
identifier="test_vector_db",
|
identifier="test_vector_store",
|
||||||
embedding_model="nomic-embed-text-v1.5",
|
embedding_model="nomic-embed-text-v1.5",
|
||||||
embedding_dimension=768,
|
embedding_dimension=768,
|
||||||
provider_resource_id="test_vector_db",
|
provider_resource_id="test_vector_store",
|
||||||
provider_id="test-provider",
|
provider_id="test-provider",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
@ -45,17 +45,17 @@ async def test_registry_initialization(disk_dist_registry):
|
||||||
assert result is None
|
assert result is None
|
||||||
|
|
||||||
|
|
||||||
async def test_basic_registration(disk_dist_registry, sample_vector_db, sample_model):
|
async def test_basic_registration(disk_dist_registry, sample_vector_store, sample_model):
|
||||||
print(f"Registering {sample_vector_db}")
|
print(f"Registering {sample_vector_store}")
|
||||||
await disk_dist_registry.register(sample_vector_db)
|
await disk_dist_registry.register(sample_vector_store)
|
||||||
print(f"Registering {sample_model}")
|
print(f"Registering {sample_model}")
|
||||||
await disk_dist_registry.register(sample_model)
|
await disk_dist_registry.register(sample_model)
|
||||||
print("Getting vector_db")
|
print("Getting vector_store")
|
||||||
result_vector_db = await disk_dist_registry.get("vector_db", "test_vector_db")
|
result_vector_store = await disk_dist_registry.get("vector_store", "test_vector_store")
|
||||||
assert result_vector_db is not None
|
assert result_vector_store is not None
|
||||||
assert result_vector_db.identifier == sample_vector_db.identifier
|
assert result_vector_store.identifier == sample_vector_store.identifier
|
||||||
assert result_vector_db.embedding_model == sample_vector_db.embedding_model
|
assert result_vector_store.embedding_model == sample_vector_store.embedding_model
|
||||||
assert result_vector_db.provider_id == sample_vector_db.provider_id
|
assert result_vector_store.provider_id == sample_vector_store.provider_id
|
||||||
|
|
||||||
result_model = await disk_dist_registry.get("model", "test_model")
|
result_model = await disk_dist_registry.get("model", "test_model")
|
||||||
assert result_model is not None
|
assert result_model is not None
|
||||||
|
|
@ -63,11 +63,11 @@ async def test_basic_registration(disk_dist_registry, sample_vector_db, sample_m
|
||||||
assert result_model.provider_id == sample_model.provider_id
|
assert result_model.provider_id == sample_model.provider_id
|
||||||
|
|
||||||
|
|
||||||
async def test_cached_registry_initialization(sqlite_kvstore, sample_vector_db, sample_model):
|
async def test_cached_registry_initialization(sqlite_kvstore, sample_vector_store, sample_model):
|
||||||
# First populate the disk registry
|
# First populate the disk registry
|
||||||
disk_registry = DiskDistributionRegistry(sqlite_kvstore)
|
disk_registry = DiskDistributionRegistry(sqlite_kvstore)
|
||||||
await disk_registry.initialize()
|
await disk_registry.initialize()
|
||||||
await disk_registry.register(sample_vector_db)
|
await disk_registry.register(sample_vector_store)
|
||||||
await disk_registry.register(sample_model)
|
await disk_registry.register(sample_model)
|
||||||
|
|
||||||
# Test cached version loads from disk
|
# Test cached version loads from disk
|
||||||
|
|
@ -79,29 +79,29 @@ async def test_cached_registry_initialization(sqlite_kvstore, sample_vector_db,
|
||||||
)
|
)
|
||||||
await cached_registry.initialize()
|
await cached_registry.initialize()
|
||||||
|
|
||||||
result_vector_db = await cached_registry.get("vector_db", "test_vector_db")
|
result_vector_store = await cached_registry.get("vector_store", "test_vector_store")
|
||||||
assert result_vector_db is not None
|
assert result_vector_store is not None
|
||||||
assert result_vector_db.identifier == sample_vector_db.identifier
|
assert result_vector_store.identifier == sample_vector_store.identifier
|
||||||
assert result_vector_db.embedding_model == sample_vector_db.embedding_model
|
assert result_vector_store.embedding_model == sample_vector_store.embedding_model
|
||||||
assert result_vector_db.embedding_dimension == sample_vector_db.embedding_dimension
|
assert result_vector_store.embedding_dimension == sample_vector_store.embedding_dimension
|
||||||
assert result_vector_db.provider_id == sample_vector_db.provider_id
|
assert result_vector_store.provider_id == sample_vector_store.provider_id
|
||||||
|
|
||||||
|
|
||||||
async def test_cached_registry_updates(cached_disk_dist_registry):
|
async def test_cached_registry_updates(cached_disk_dist_registry):
|
||||||
new_vector_db = VectorDB(
|
new_vector_store = VectorStore(
|
||||||
identifier="test_vector_db_2",
|
identifier="test_vector_store_2",
|
||||||
embedding_model="nomic-embed-text-v1.5",
|
embedding_model="nomic-embed-text-v1.5",
|
||||||
embedding_dimension=768,
|
embedding_dimension=768,
|
||||||
provider_resource_id="test_vector_db_2",
|
provider_resource_id="test_vector_store_2",
|
||||||
provider_id="baz",
|
provider_id="baz",
|
||||||
)
|
)
|
||||||
await cached_disk_dist_registry.register(new_vector_db)
|
await cached_disk_dist_registry.register(new_vector_store)
|
||||||
|
|
||||||
# Verify in cache
|
# Verify in cache
|
||||||
result_vector_db = await cached_disk_dist_registry.get("vector_db", "test_vector_db_2")
|
result_vector_store = await cached_disk_dist_registry.get("vector_store", "test_vector_store_2")
|
||||||
assert result_vector_db is not None
|
assert result_vector_store is not None
|
||||||
assert result_vector_db.identifier == new_vector_db.identifier
|
assert result_vector_store.identifier == new_vector_store.identifier
|
||||||
assert result_vector_db.provider_id == new_vector_db.provider_id
|
assert result_vector_store.provider_id == new_vector_store.provider_id
|
||||||
|
|
||||||
# Verify persisted to disk
|
# Verify persisted to disk
|
||||||
db_path = cached_disk_dist_registry.kvstore.db_path
|
db_path = cached_disk_dist_registry.kvstore.db_path
|
||||||
|
|
@ -111,87 +111,87 @@ async def test_cached_registry_updates(cached_disk_dist_registry):
|
||||||
await kvstore_impl(KVStoreReference(backend=backend_name, namespace="registry"))
|
await kvstore_impl(KVStoreReference(backend=backend_name, namespace="registry"))
|
||||||
)
|
)
|
||||||
await new_registry.initialize()
|
await new_registry.initialize()
|
||||||
result_vector_db = await new_registry.get("vector_db", "test_vector_db_2")
|
result_vector_store = await new_registry.get("vector_store", "test_vector_store_2")
|
||||||
assert result_vector_db is not None
|
assert result_vector_store is not None
|
||||||
assert result_vector_db.identifier == new_vector_db.identifier
|
assert result_vector_store.identifier == new_vector_store.identifier
|
||||||
assert result_vector_db.provider_id == new_vector_db.provider_id
|
assert result_vector_store.provider_id == new_vector_store.provider_id
|
||||||
|
|
||||||
|
|
||||||
async def test_duplicate_provider_registration(cached_disk_dist_registry):
|
async def test_duplicate_provider_registration(cached_disk_dist_registry):
|
||||||
original_vector_db = VectorDB(
|
original_vector_store = VectorStore(
|
||||||
identifier="test_vector_db_2",
|
identifier="test_vector_store_2",
|
||||||
embedding_model="nomic-embed-text-v1.5",
|
embedding_model="nomic-embed-text-v1.5",
|
||||||
embedding_dimension=768,
|
embedding_dimension=768,
|
||||||
provider_resource_id="test_vector_db_2",
|
provider_resource_id="test_vector_store_2",
|
||||||
provider_id="baz",
|
provider_id="baz",
|
||||||
)
|
)
|
||||||
assert await cached_disk_dist_registry.register(original_vector_db)
|
assert await cached_disk_dist_registry.register(original_vector_store)
|
||||||
|
|
||||||
duplicate_vector_db = VectorDB(
|
duplicate_vector_store = VectorStore(
|
||||||
identifier="test_vector_db_2",
|
identifier="test_vector_store_2",
|
||||||
embedding_model="different-model",
|
embedding_model="different-model",
|
||||||
embedding_dimension=768,
|
embedding_dimension=768,
|
||||||
provider_resource_id="test_vector_db_2",
|
provider_resource_id="test_vector_store_2",
|
||||||
provider_id="baz", # Same provider_id
|
provider_id="baz", # Same provider_id
|
||||||
)
|
)
|
||||||
with pytest.raises(ValueError, match="Object of type 'vector_db' and identifier 'test_vector_db_2' already exists"):
|
with pytest.raises(ValueError, match="Object of type 'vector_store' and identifier 'test_vector_store_2' already exists"):
|
||||||
await cached_disk_dist_registry.register(duplicate_vector_db)
|
await cached_disk_dist_registry.register(duplicate_vector_store)
|
||||||
|
|
||||||
result = await cached_disk_dist_registry.get("vector_db", "test_vector_db_2")
|
result = await cached_disk_dist_registry.get("vector_store", "test_vector_store_2")
|
||||||
assert result is not None
|
assert result is not None
|
||||||
assert result.embedding_model == original_vector_db.embedding_model # Original values preserved
|
assert result.embedding_model == original_vector_store.embedding_model # Original values preserved
|
||||||
|
|
||||||
|
|
||||||
async def test_get_all_objects(cached_disk_dist_registry):
|
async def test_get_all_objects(cached_disk_dist_registry):
|
||||||
# Create multiple test banks
|
# Create multiple test banks
|
||||||
# Create multiple test banks
|
# Create multiple test banks
|
||||||
test_vector_dbs = [
|
test_vector_stores = [
|
||||||
VectorDB(
|
VectorStore(
|
||||||
identifier=f"test_vector_db_{i}",
|
identifier=f"test_vector_store_{i}",
|
||||||
embedding_model="nomic-embed-text-v1.5",
|
embedding_model="nomic-embed-text-v1.5",
|
||||||
embedding_dimension=768,
|
embedding_dimension=768,
|
||||||
provider_resource_id=f"test_vector_db_{i}",
|
provider_resource_id=f"test_vector_store_{i}",
|
||||||
provider_id=f"provider_{i}",
|
provider_id=f"provider_{i}",
|
||||||
)
|
)
|
||||||
for i in range(3)
|
for i in range(3)
|
||||||
]
|
]
|
||||||
|
|
||||||
# Register all vector_dbs
|
# Register all vector_stores
|
||||||
for vector_db in test_vector_dbs:
|
for vector_store in test_vector_stores:
|
||||||
await cached_disk_dist_registry.register(vector_db)
|
await cached_disk_dist_registry.register(vector_store)
|
||||||
|
|
||||||
# Test get_all retrieval
|
# Test get_all retrieval
|
||||||
all_results = await cached_disk_dist_registry.get_all()
|
all_results = await cached_disk_dist_registry.get_all()
|
||||||
assert len(all_results) == 3
|
assert len(all_results) == 3
|
||||||
|
|
||||||
# Verify each vector_db was stored correctly
|
# Verify each vector_store was stored correctly
|
||||||
for original_vector_db in test_vector_dbs:
|
for original_vector_store in test_vector_stores:
|
||||||
matching_vector_dbs = [v for v in all_results if v.identifier == original_vector_db.identifier]
|
matching_vector_stores = [v for v in all_results if v.identifier == original_vector_store.identifier]
|
||||||
assert len(matching_vector_dbs) == 1
|
assert len(matching_vector_stores) == 1
|
||||||
stored_vector_db = matching_vector_dbs[0]
|
stored_vector_store = matching_vector_stores[0]
|
||||||
assert stored_vector_db.embedding_model == original_vector_db.embedding_model
|
assert stored_vector_store.embedding_model == original_vector_store.embedding_model
|
||||||
assert stored_vector_db.provider_id == original_vector_db.provider_id
|
assert stored_vector_store.provider_id == original_vector_store.provider_id
|
||||||
assert stored_vector_db.embedding_dimension == original_vector_db.embedding_dimension
|
assert stored_vector_store.embedding_dimension == original_vector_store.embedding_dimension
|
||||||
|
|
||||||
|
|
||||||
async def test_parse_registry_values_error_handling(sqlite_kvstore):
|
async def test_parse_registry_values_error_handling(sqlite_kvstore):
|
||||||
valid_db = VectorDB(
|
valid_db = VectorStore(
|
||||||
identifier="valid_vector_db",
|
identifier="valid_vector_store",
|
||||||
embedding_model="nomic-embed-text-v1.5",
|
embedding_model="nomic-embed-text-v1.5",
|
||||||
embedding_dimension=768,
|
embedding_dimension=768,
|
||||||
provider_resource_id="valid_vector_db",
|
provider_resource_id="valid_vector_store",
|
||||||
provider_id="test-provider",
|
provider_id="test-provider",
|
||||||
)
|
)
|
||||||
|
|
||||||
await sqlite_kvstore.set(
|
await sqlite_kvstore.set(
|
||||||
KEY_FORMAT.format(type="vector_db", identifier="valid_vector_db"), valid_db.model_dump_json()
|
KEY_FORMAT.format(type="vector_store", identifier="valid_vector_store"), valid_db.model_dump_json()
|
||||||
)
|
)
|
||||||
|
|
||||||
await sqlite_kvstore.set(KEY_FORMAT.format(type="vector_db", identifier="corrupted_json"), "{not valid json")
|
await sqlite_kvstore.set(KEY_FORMAT.format(type="vector_store", identifier="corrupted_json"), "{not valid json")
|
||||||
|
|
||||||
await sqlite_kvstore.set(
|
await sqlite_kvstore.set(
|
||||||
KEY_FORMAT.format(type="vector_db", identifier="missing_fields"),
|
KEY_FORMAT.format(type="vector_store", identifier="missing_fields"),
|
||||||
'{"type": "vector_db", "identifier": "missing_fields"}',
|
'{"type": "vector_store", "identifier": "missing_fields"}',
|
||||||
)
|
)
|
||||||
|
|
||||||
test_registry = DiskDistributionRegistry(sqlite_kvstore)
|
test_registry = DiskDistributionRegistry(sqlite_kvstore)
|
||||||
|
|
@ -202,18 +202,18 @@ async def test_parse_registry_values_error_handling(sqlite_kvstore):
|
||||||
|
|
||||||
# Should have filtered out the invalid entries
|
# Should have filtered out the invalid entries
|
||||||
assert len(all_objects) == 1
|
assert len(all_objects) == 1
|
||||||
assert all_objects[0].identifier == "valid_vector_db"
|
assert all_objects[0].identifier == "valid_vector_store"
|
||||||
|
|
||||||
# Check that the get method also handles errors correctly
|
# Check that the get method also handles errors correctly
|
||||||
invalid_obj = await test_registry.get("vector_db", "corrupted_json")
|
invalid_obj = await test_registry.get("vector_store", "corrupted_json")
|
||||||
assert invalid_obj is None
|
assert invalid_obj is None
|
||||||
|
|
||||||
invalid_obj = await test_registry.get("vector_db", "missing_fields")
|
invalid_obj = await test_registry.get("vector_store", "missing_fields")
|
||||||
assert invalid_obj is None
|
assert invalid_obj is None
|
||||||
|
|
||||||
|
|
||||||
async def test_cached_registry_error_handling(sqlite_kvstore):
|
async def test_cached_registry_error_handling(sqlite_kvstore):
|
||||||
valid_db = VectorDB(
|
valid_db = VectorStore(
|
||||||
identifier="valid_cached_db",
|
identifier="valid_cached_db",
|
||||||
embedding_model="nomic-embed-text-v1.5",
|
embedding_model="nomic-embed-text-v1.5",
|
||||||
embedding_dimension=768,
|
embedding_dimension=768,
|
||||||
|
|
@ -222,12 +222,12 @@ async def test_cached_registry_error_handling(sqlite_kvstore):
|
||||||
)
|
)
|
||||||
|
|
||||||
await sqlite_kvstore.set(
|
await sqlite_kvstore.set(
|
||||||
KEY_FORMAT.format(type="vector_db", identifier="valid_cached_db"), valid_db.model_dump_json()
|
KEY_FORMAT.format(type="vector_store", identifier="valid_cached_db"), valid_db.model_dump_json()
|
||||||
)
|
)
|
||||||
|
|
||||||
await sqlite_kvstore.set(
|
await sqlite_kvstore.set(
|
||||||
KEY_FORMAT.format(type="vector_db", identifier="invalid_cached_db"),
|
KEY_FORMAT.format(type="vector_store", identifier="invalid_cached_db"),
|
||||||
'{"type": "vector_db", "identifier": "invalid_cached_db", "embedding_model": 12345}', # Should be string
|
'{"type": "vector_store", "identifier": "invalid_cached_db", "embedding_model": 12345}', # Should be string
|
||||||
)
|
)
|
||||||
|
|
||||||
cached_registry = CachedDiskDistributionRegistry(sqlite_kvstore)
|
cached_registry = CachedDiskDistributionRegistry(sqlite_kvstore)
|
||||||
|
|
@ -237,63 +237,63 @@ async def test_cached_registry_error_handling(sqlite_kvstore):
|
||||||
assert len(all_objects) == 1
|
assert len(all_objects) == 1
|
||||||
assert all_objects[0].identifier == "valid_cached_db"
|
assert all_objects[0].identifier == "valid_cached_db"
|
||||||
|
|
||||||
invalid_obj = await cached_registry.get("vector_db", "invalid_cached_db")
|
invalid_obj = await cached_registry.get("vector_store", "invalid_cached_db")
|
||||||
assert invalid_obj is None
|
assert invalid_obj is None
|
||||||
|
|
||||||
|
|
||||||
async def test_double_registration_identical_objects(disk_dist_registry):
|
async def test_double_registration_identical_objects(disk_dist_registry):
|
||||||
"""Test that registering identical objects succeeds (idempotent)."""
|
"""Test that registering identical objects succeeds (idempotent)."""
|
||||||
vector_db = VectorDBWithOwner(
|
vector_store = VectorStoreWithOwner(
|
||||||
identifier="test_vector_db",
|
identifier="test_vector_store",
|
||||||
embedding_model="all-MiniLM-L6-v2",
|
embedding_model="all-MiniLM-L6-v2",
|
||||||
embedding_dimension=384,
|
embedding_dimension=384,
|
||||||
provider_resource_id="test_vector_db",
|
provider_resource_id="test_vector_store",
|
||||||
provider_id="test-provider",
|
provider_id="test-provider",
|
||||||
)
|
)
|
||||||
|
|
||||||
# First registration should succeed
|
# First registration should succeed
|
||||||
result1 = await disk_dist_registry.register(vector_db)
|
result1 = await disk_dist_registry.register(vector_store)
|
||||||
assert result1 is True
|
assert result1 is True
|
||||||
|
|
||||||
# Second registration of identical object should also succeed (idempotent)
|
# Second registration of identical object should also succeed (idempotent)
|
||||||
result2 = await disk_dist_registry.register(vector_db)
|
result2 = await disk_dist_registry.register(vector_store)
|
||||||
assert result2 is True
|
assert result2 is True
|
||||||
|
|
||||||
# Verify object exists and is unchanged
|
# Verify object exists and is unchanged
|
||||||
retrieved = await disk_dist_registry.get("vector_db", "test_vector_db")
|
retrieved = await disk_dist_registry.get("vector_store", "test_vector_store")
|
||||||
assert retrieved is not None
|
assert retrieved is not None
|
||||||
assert retrieved.identifier == vector_db.identifier
|
assert retrieved.identifier == vector_store.identifier
|
||||||
assert retrieved.embedding_model == vector_db.embedding_model
|
assert retrieved.embedding_model == vector_store.embedding_model
|
||||||
|
|
||||||
|
|
||||||
async def test_double_registration_different_objects(disk_dist_registry):
|
async def test_double_registration_different_objects(disk_dist_registry):
|
||||||
"""Test that registering different objects with same identifier fails."""
|
"""Test that registering different objects with same identifier fails."""
|
||||||
vector_db1 = VectorDBWithOwner(
|
vector_store1 = VectorStoreWithOwner(
|
||||||
identifier="test_vector_db",
|
identifier="test_vector_store",
|
||||||
embedding_model="all-MiniLM-L6-v2",
|
embedding_model="all-MiniLM-L6-v2",
|
||||||
embedding_dimension=384,
|
embedding_dimension=384,
|
||||||
provider_resource_id="test_vector_db",
|
provider_resource_id="test_vector_store",
|
||||||
provider_id="test-provider",
|
provider_id="test-provider",
|
||||||
)
|
)
|
||||||
|
|
||||||
vector_db2 = VectorDBWithOwner(
|
vector_store2 = VectorStoreWithOwner(
|
||||||
identifier="test_vector_db", # Same identifier
|
identifier="test_vector_store", # Same identifier
|
||||||
embedding_model="different-model", # Different embedding model
|
embedding_model="different-model", # Different embedding model
|
||||||
embedding_dimension=384,
|
embedding_dimension=384,
|
||||||
provider_resource_id="test_vector_db",
|
provider_resource_id="test_vector_store",
|
||||||
provider_id="test-provider",
|
provider_id="test-provider",
|
||||||
)
|
)
|
||||||
|
|
||||||
# First registration should succeed
|
# First registration should succeed
|
||||||
result1 = await disk_dist_registry.register(vector_db1)
|
result1 = await disk_dist_registry.register(vector_store1)
|
||||||
assert result1 is True
|
assert result1 is True
|
||||||
|
|
||||||
# Second registration with different data should fail
|
# Second registration with different data should fail
|
||||||
with pytest.raises(ValueError, match="Object of type 'vector_db' and identifier 'test_vector_db' already exists"):
|
with pytest.raises(ValueError, match="Object of type 'vector_store' and identifier 'test_vector_store' already exists"):
|
||||||
await disk_dist_registry.register(vector_db2)
|
await disk_dist_registry.register(vector_store2)
|
||||||
|
|
||||||
# Verify original object is unchanged
|
# Verify original object is unchanged
|
||||||
retrieved = await disk_dist_registry.get("vector_db", "test_vector_db")
|
retrieved = await disk_dist_registry.get("vector_store", "test_vector_store")
|
||||||
assert retrieved is not None
|
assert retrieved is not None
|
||||||
assert retrieved.embedding_model == "all-MiniLM-L6-v2" # Original value
|
assert retrieved.embedding_model == "all-MiniLM-L6-v2" # Original value
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -41,7 +41,7 @@ class TestTranslateException:
|
||||||
self.identifier = identifier
|
self.identifier = identifier
|
||||||
self.owner = owner
|
self.owner = owner
|
||||||
|
|
||||||
resource = MockResource("vector_db", "test-db")
|
resource = MockResource("vector_store", "test-db")
|
||||||
|
|
||||||
exc = AccessDeniedError("create", resource, user)
|
exc = AccessDeniedError("create", resource, user)
|
||||||
result = translate_exception(exc)
|
result = translate_exception(exc)
|
||||||
|
|
@ -49,7 +49,7 @@ class TestTranslateException:
|
||||||
assert isinstance(result, HTTPException)
|
assert isinstance(result, HTTPException)
|
||||||
assert result.status_code == 403
|
assert result.status_code == 403
|
||||||
assert "test-user" in result.detail
|
assert "test-user" in result.detail
|
||||||
assert "vector_db::test-db" in result.detail
|
assert "vector_store::test-db" in result.detail
|
||||||
assert "create" in result.detail
|
assert "create" in result.detail
|
||||||
assert "roles=['user']" in result.detail
|
assert "roles=['user']" in result.detail
|
||||||
assert "teams=['dev']" in result.detail
|
assert "teams=['dev']" in result.detail
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue