mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-10-03 19:57:35 +00:00
feat: implement keyword and hybrid search for Weaviate provider (#3264)
# What does this PR do? <!-- Provide a short summary of what this PR does and why. Link to relevant issues if applicable. --> - This PR implements keyword and hybrid search for Weaviate DB based on its inbuilt functions. - Added fixtures to conftest.py for Weaviate. - Enabled integration tests for remote Weaviate on all 3 search modes. <!-- If resolving an issue, uncomment and update the line below --> <!-- Closes #[issue-number] --> Closes #3010 ## Test Plan <!-- Describe the tests you ran to verify your changes with result summaries. *Provide clear instructions so the plan can be easily re-executed.* --> Unit tests and integration tests should pass on this PR.
This commit is contained in:
parent
52c8df2322
commit
bcdbb53be3
6 changed files with 242 additions and 48 deletions
|
@ -500,7 +500,7 @@ See [PGVector's documentation](https://github.com/pgvector/pgvector) for more de
|
||||||
api=Api.vector_io,
|
api=Api.vector_io,
|
||||||
adapter_type="weaviate",
|
adapter_type="weaviate",
|
||||||
provider_type="remote::weaviate",
|
provider_type="remote::weaviate",
|
||||||
pip_packages=["weaviate-client"],
|
pip_packages=["weaviate-client>=4.16.5"],
|
||||||
module="llama_stack.providers.remote.vector_io.weaviate",
|
module="llama_stack.providers.remote.vector_io.weaviate",
|
||||||
config_class="llama_stack.providers.remote.vector_io.weaviate.WeaviateVectorIOConfig",
|
config_class="llama_stack.providers.remote.vector_io.weaviate.WeaviateVectorIOConfig",
|
||||||
provider_data_validator="llama_stack.providers.remote.vector_io.weaviate.WeaviateRequestProviderData",
|
provider_data_validator="llama_stack.providers.remote.vector_io.weaviate.WeaviateRequestProviderData",
|
||||||
|
|
|
@ -10,7 +10,7 @@ import weaviate
|
||||||
import weaviate.classes as wvc
|
import weaviate.classes as wvc
|
||||||
from numpy.typing import NDArray
|
from numpy.typing import NDArray
|
||||||
from weaviate.classes.init import Auth
|
from weaviate.classes.init import Auth
|
||||||
from weaviate.classes.query import Filter
|
from weaviate.classes.query import Filter, HybridFusion
|
||||||
|
|
||||||
from llama_stack.apis.common.content_types import InterleavedContent
|
from llama_stack.apis.common.content_types import InterleavedContent
|
||||||
from llama_stack.apis.common.errors import VectorStoreNotFoundError
|
from llama_stack.apis.common.errors import VectorStoreNotFoundError
|
||||||
|
@ -26,6 +26,7 @@ from llama_stack.providers.utils.memory.openai_vector_store_mixin import (
|
||||||
OpenAIVectorStoreMixin,
|
OpenAIVectorStoreMixin,
|
||||||
)
|
)
|
||||||
from llama_stack.providers.utils.memory.vector_store import (
|
from llama_stack.providers.utils.memory.vector_store import (
|
||||||
|
RERANKER_TYPE_RRF,
|
||||||
ChunkForDeletion,
|
ChunkForDeletion,
|
||||||
EmbeddingIndex,
|
EmbeddingIndex,
|
||||||
VectorDBWithIndex,
|
VectorDBWithIndex,
|
||||||
|
@ -47,7 +48,7 @@ OPENAI_VECTOR_STORES_FILES_CONTENTS_PREFIX = f"openai_vector_stores_files_conten
|
||||||
class WeaviateIndex(EmbeddingIndex):
|
class WeaviateIndex(EmbeddingIndex):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
client: weaviate.Client,
|
client: weaviate.WeaviateClient,
|
||||||
collection_name: str,
|
collection_name: str,
|
||||||
kvstore: KVStore | None = None,
|
kvstore: KVStore | None = None,
|
||||||
):
|
):
|
||||||
|
@ -64,14 +65,14 @@ class WeaviateIndex(EmbeddingIndex):
|
||||||
)
|
)
|
||||||
|
|
||||||
data_objects = []
|
data_objects = []
|
||||||
for i, chunk in enumerate(chunks):
|
for chunk, embedding in zip(chunks, embeddings, strict=False):
|
||||||
data_objects.append(
|
data_objects.append(
|
||||||
wvc.data.DataObject(
|
wvc.data.DataObject(
|
||||||
properties={
|
properties={
|
||||||
"chunk_id": chunk.chunk_id,
|
"chunk_id": chunk.chunk_id,
|
||||||
"chunk_content": chunk.model_dump_json(),
|
"chunk_content": chunk.model_dump_json(),
|
||||||
},
|
},
|
||||||
vector=embeddings[i].tolist(),
|
vector=embedding.tolist(),
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -88,14 +89,30 @@ class WeaviateIndex(EmbeddingIndex):
|
||||||
collection.data.delete_many(where=Filter.by_property("chunk_id").contains_any(chunk_ids))
|
collection.data.delete_many(where=Filter.by_property("chunk_id").contains_any(chunk_ids))
|
||||||
|
|
||||||
async def query_vector(self, embedding: NDArray, k: int, score_threshold: float) -> QueryChunksResponse:
|
async def query_vector(self, embedding: NDArray, k: int, score_threshold: float) -> QueryChunksResponse:
|
||||||
|
"""
|
||||||
|
Performs vector search using Weaviate's built-in vector search.
|
||||||
|
Args:
|
||||||
|
embedding: The query embedding vector
|
||||||
|
k: Limit of number of results to return
|
||||||
|
score_threshold: Minimum similarity score threshold
|
||||||
|
Returns:
|
||||||
|
QueryChunksResponse with chunks and scores.
|
||||||
|
"""
|
||||||
|
log.debug(
|
||||||
|
f"WEAVIATE VECTOR SEARCH CALLED: embedding_shape={embedding.shape}, k={k}, threshold={score_threshold}"
|
||||||
|
)
|
||||||
sanitized_collection_name = sanitize_collection_name(self.collection_name, weaviate_format=True)
|
sanitized_collection_name = sanitize_collection_name(self.collection_name, weaviate_format=True)
|
||||||
collection = self.client.collections.get(sanitized_collection_name)
|
collection = self.client.collections.get(sanitized_collection_name)
|
||||||
|
|
||||||
|
try:
|
||||||
results = collection.query.near_vector(
|
results = collection.query.near_vector(
|
||||||
near_vector=embedding.tolist(),
|
near_vector=embedding.tolist(),
|
||||||
limit=k,
|
limit=k,
|
||||||
return_metadata=wvc.query.MetadataQuery(distance=True),
|
return_metadata=wvc.query.MetadataQuery(distance=True),
|
||||||
)
|
)
|
||||||
|
except Exception as e:
|
||||||
|
log.error(f"Weaviate client vector search failed: {e}")
|
||||||
|
raise
|
||||||
|
|
||||||
chunks = []
|
chunks = []
|
||||||
scores = []
|
scores = []
|
||||||
|
@ -108,13 +125,17 @@ class WeaviateIndex(EmbeddingIndex):
|
||||||
log.exception(f"Failed to parse document: {chunk_json}")
|
log.exception(f"Failed to parse document: {chunk_json}")
|
||||||
continue
|
continue
|
||||||
|
|
||||||
score = 1.0 / doc.metadata.distance if doc.metadata.distance != 0 else float("inf")
|
if doc.metadata.distance is None:
|
||||||
|
continue
|
||||||
|
# Convert cosine distance ∈ [0,2] -> normalized cosine similarity ∈ [0,1]
|
||||||
|
score = 1.0 - (float(doc.metadata.distance) / 2.0)
|
||||||
if score < score_threshold:
|
if score < score_threshold:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
chunks.append(chunk)
|
chunks.append(chunk)
|
||||||
scores.append(score)
|
scores.append(score)
|
||||||
|
|
||||||
|
log.debug(f"WEAVIATE VECTOR SEARCH RESULTS: Found {len(chunks)} chunks with scores {scores}")
|
||||||
return QueryChunksResponse(chunks=chunks, scores=scores)
|
return QueryChunksResponse(chunks=chunks, scores=scores)
|
||||||
|
|
||||||
async def delete(self, chunk_ids: list[str] | None = None) -> None:
|
async def delete(self, chunk_ids: list[str] | None = None) -> None:
|
||||||
|
@ -136,7 +157,50 @@ class WeaviateIndex(EmbeddingIndex):
|
||||||
k: int,
|
k: int,
|
||||||
score_threshold: float,
|
score_threshold: float,
|
||||||
) -> QueryChunksResponse:
|
) -> QueryChunksResponse:
|
||||||
raise NotImplementedError("Keyword search is not supported in Weaviate")
|
"""
|
||||||
|
Performs BM25-based keyword search using Weaviate's built-in full-text search.
|
||||||
|
Args:
|
||||||
|
query_string: The text query for keyword search
|
||||||
|
k: Limit of number of results to return
|
||||||
|
score_threshold: Minimum similarity score threshold
|
||||||
|
Returns:
|
||||||
|
QueryChunksResponse with chunks and scores
|
||||||
|
"""
|
||||||
|
log.debug(f"WEAVIATE KEYWORD SEARCH CALLED: query='{query_string}', k={k}, threshold={score_threshold}")
|
||||||
|
sanitized_collection_name = sanitize_collection_name(self.collection_name, weaviate_format=True)
|
||||||
|
collection = self.client.collections.get(sanitized_collection_name)
|
||||||
|
|
||||||
|
# Perform BM25 keyword search on chunk_content field
|
||||||
|
try:
|
||||||
|
results = collection.query.bm25(
|
||||||
|
query=query_string,
|
||||||
|
limit=k,
|
||||||
|
return_metadata=wvc.query.MetadataQuery(score=True),
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
log.error(f"Weaviate client keyword search failed: {e}")
|
||||||
|
raise
|
||||||
|
|
||||||
|
chunks = []
|
||||||
|
scores = []
|
||||||
|
for doc in results.objects:
|
||||||
|
chunk_json = doc.properties["chunk_content"]
|
||||||
|
try:
|
||||||
|
chunk_dict = json.loads(chunk_json)
|
||||||
|
chunk = Chunk(**chunk_dict)
|
||||||
|
except Exception:
|
||||||
|
log.exception(f"Failed to parse document: {chunk_json}")
|
||||||
|
continue
|
||||||
|
|
||||||
|
score = doc.metadata.score if doc.metadata.score is not None else 0.0
|
||||||
|
if score < score_threshold:
|
||||||
|
continue
|
||||||
|
|
||||||
|
chunks.append(chunk)
|
||||||
|
scores.append(score)
|
||||||
|
|
||||||
|
log.debug(f"WEAVIATE KEYWORD SEARCH RESULTS: Found {len(chunks)} chunks with scores {scores}.")
|
||||||
|
return QueryChunksResponse(chunks=chunks, scores=scores)
|
||||||
|
|
||||||
async def query_hybrid(
|
async def query_hybrid(
|
||||||
self,
|
self,
|
||||||
|
@ -147,7 +211,65 @@ class WeaviateIndex(EmbeddingIndex):
|
||||||
reranker_type: str,
|
reranker_type: str,
|
||||||
reranker_params: dict[str, Any] | None = None,
|
reranker_params: dict[str, Any] | None = None,
|
||||||
) -> QueryChunksResponse:
|
) -> QueryChunksResponse:
|
||||||
raise NotImplementedError("Hybrid search is not supported in Weaviate")
|
"""
|
||||||
|
Hybrid search combining vector similarity and keyword search using Weaviate's native hybrid search.
|
||||||
|
Args:
|
||||||
|
embedding: The query embedding vector
|
||||||
|
query_string: The text query for keyword search
|
||||||
|
k: Limit of number of results to return
|
||||||
|
score_threshold: Minimum similarity score threshold
|
||||||
|
reranker_type: Type of reranker to use ("rrf" or "normalized")
|
||||||
|
reranker_params: Parameters for the reranker
|
||||||
|
Returns:
|
||||||
|
QueryChunksResponse with combined results
|
||||||
|
"""
|
||||||
|
log.debug(
|
||||||
|
f"WEAVIATE HYBRID SEARCH CALLED: query='{query_string}', embedding_shape={embedding.shape}, k={k}, threshold={score_threshold}, reranker={reranker_type}"
|
||||||
|
)
|
||||||
|
sanitized_collection_name = sanitize_collection_name(self.collection_name, weaviate_format=True)
|
||||||
|
collection = self.client.collections.get(sanitized_collection_name)
|
||||||
|
|
||||||
|
# Ranked (RRF) reranker fusion type
|
||||||
|
if reranker_type == RERANKER_TYPE_RRF:
|
||||||
|
rerank = HybridFusion.RANKED
|
||||||
|
# Relative score (Normalized) reranker fusion type
|
||||||
|
else:
|
||||||
|
rerank = HybridFusion.RELATIVE_SCORE
|
||||||
|
|
||||||
|
# Perform hybrid search using Weaviate's native hybrid search
|
||||||
|
try:
|
||||||
|
results = collection.query.hybrid(
|
||||||
|
query=query_string,
|
||||||
|
alpha=0.5, # Range <0, 1>, where 0.5 will equally favor vector and keyword search
|
||||||
|
vector=embedding.tolist(),
|
||||||
|
limit=k,
|
||||||
|
fusion_type=rerank,
|
||||||
|
return_metadata=wvc.query.MetadataQuery(score=True),
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
log.error(f"Weaviate client hybrid search failed: {e}")
|
||||||
|
raise
|
||||||
|
|
||||||
|
chunks = []
|
||||||
|
scores = []
|
||||||
|
for doc in results.objects:
|
||||||
|
chunk_json = doc.properties["chunk_content"]
|
||||||
|
try:
|
||||||
|
chunk_dict = json.loads(chunk_json)
|
||||||
|
chunk = Chunk(**chunk_dict)
|
||||||
|
except Exception:
|
||||||
|
log.exception(f"Failed to parse document: {chunk_json}")
|
||||||
|
continue
|
||||||
|
|
||||||
|
score = doc.metadata.score if doc.metadata.score is not None else 0.0
|
||||||
|
if score < score_threshold:
|
||||||
|
continue
|
||||||
|
|
||||||
|
chunks.append(chunk)
|
||||||
|
scores.append(score)
|
||||||
|
|
||||||
|
log.debug(f"WEAVIATE HYBRID SEARCH RESULTS: Found {len(chunks)} chunks with scores {scores}")
|
||||||
|
return QueryChunksResponse(chunks=chunks, scores=scores)
|
||||||
|
|
||||||
|
|
||||||
class WeaviateVectorIOAdapter(
|
class WeaviateVectorIOAdapter(
|
||||||
|
@ -172,9 +294,9 @@ class WeaviateVectorIOAdapter(
|
||||||
self.openai_vector_stores: dict[str, dict[str, Any]] = {}
|
self.openai_vector_stores: dict[str, dict[str, Any]] = {}
|
||||||
self.metadata_collection_name = "openai_vector_stores_metadata"
|
self.metadata_collection_name = "openai_vector_stores_metadata"
|
||||||
|
|
||||||
def _get_client(self) -> weaviate.Client:
|
def _get_client(self) -> weaviate.WeaviateClient:
|
||||||
if "localhost" in self.config.weaviate_cluster_url:
|
if "localhost" in self.config.weaviate_cluster_url:
|
||||||
log.info("using Weaviate locally in container")
|
log.info("Using Weaviate locally in container")
|
||||||
host, port = self.config.weaviate_cluster_url.split(":")
|
host, port = self.config.weaviate_cluster_url.split(":")
|
||||||
key = "local_test"
|
key = "local_test"
|
||||||
client = weaviate.connect_to_local(
|
client = weaviate.connect_to_local(
|
||||||
|
@ -247,7 +369,7 @@ class WeaviateVectorIOAdapter(
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
self.cache[sanitized_collection_name] = VectorDBWithIndex(
|
self.cache[vector_db.identifier] = VectorDBWithIndex(
|
||||||
vector_db,
|
vector_db,
|
||||||
WeaviateIndex(client=client, collection_name=sanitized_collection_name),
|
WeaviateIndex(client=client, collection_name=sanitized_collection_name),
|
||||||
self.inference_api,
|
self.inference_api,
|
||||||
|
@ -256,32 +378,34 @@ class WeaviateVectorIOAdapter(
|
||||||
async def unregister_vector_db(self, vector_db_id: str) -> None:
|
async def unregister_vector_db(self, vector_db_id: str) -> None:
|
||||||
client = self._get_client()
|
client = self._get_client()
|
||||||
sanitized_collection_name = sanitize_collection_name(vector_db_id, weaviate_format=True)
|
sanitized_collection_name = sanitize_collection_name(vector_db_id, weaviate_format=True)
|
||||||
if sanitized_collection_name not in self.cache or client.collections.exists(sanitized_collection_name) is False:
|
if vector_db_id not in self.cache or client.collections.exists(sanitized_collection_name) is False:
|
||||||
log.warning(f"Vector DB {sanitized_collection_name} not found")
|
|
||||||
return
|
return
|
||||||
client.collections.delete(sanitized_collection_name)
|
client.collections.delete(sanitized_collection_name)
|
||||||
await self.cache[sanitized_collection_name].index.delete()
|
await self.cache[vector_db_id].index.delete()
|
||||||
del self.cache[sanitized_collection_name]
|
del self.cache[vector_db_id]
|
||||||
|
|
||||||
async def _get_and_cache_vector_db_index(self, vector_db_id: str) -> VectorDBWithIndex | None:
|
async def _get_and_cache_vector_db_index(self, vector_db_id: str) -> VectorDBWithIndex | None:
|
||||||
sanitized_collection_name = sanitize_collection_name(vector_db_id, weaviate_format=True)
|
if vector_db_id in self.cache:
|
||||||
if sanitized_collection_name in self.cache:
|
return self.cache[vector_db_id]
|
||||||
return self.cache[sanitized_collection_name]
|
|
||||||
|
|
||||||
vector_db = await self.vector_db_store.get_vector_db(sanitized_collection_name)
|
if self.vector_db_store is None:
|
||||||
|
raise VectorStoreNotFoundError(vector_db_id)
|
||||||
|
|
||||||
|
vector_db = await self.vector_db_store.get_vector_db(vector_db_id)
|
||||||
if not vector_db:
|
if not vector_db:
|
||||||
raise VectorStoreNotFoundError(vector_db_id)
|
raise VectorStoreNotFoundError(vector_db_id)
|
||||||
|
|
||||||
client = self._get_client()
|
client = self._get_client()
|
||||||
if not client.collections.exists(vector_db.identifier):
|
sanitized_collection_name = sanitize_collection_name(vector_db.identifier, weaviate_format=True)
|
||||||
|
if not client.collections.exists(sanitized_collection_name):
|
||||||
raise ValueError(f"Collection with name `{sanitized_collection_name}` not found")
|
raise ValueError(f"Collection with name `{sanitized_collection_name}` not found")
|
||||||
|
|
||||||
index = VectorDBWithIndex(
|
index = VectorDBWithIndex(
|
||||||
vector_db=vector_db,
|
vector_db=vector_db,
|
||||||
index=WeaviateIndex(client=client, collection_name=sanitized_collection_name),
|
index=WeaviateIndex(client=client, collection_name=vector_db.identifier),
|
||||||
inference_api=self.inference_api,
|
inference_api=self.inference_api,
|
||||||
)
|
)
|
||||||
self.cache[sanitized_collection_name] = index
|
self.cache[vector_db_id] = index
|
||||||
return index
|
return index
|
||||||
|
|
||||||
async def insert_chunks(
|
async def insert_chunks(
|
||||||
|
@ -290,8 +414,7 @@ class WeaviateVectorIOAdapter(
|
||||||
chunks: list[Chunk],
|
chunks: list[Chunk],
|
||||||
ttl_seconds: int | None = None,
|
ttl_seconds: int | None = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
sanitized_collection_name = sanitize_collection_name(vector_db_id, weaviate_format=True)
|
index = await self._get_and_cache_vector_db_index(vector_db_id)
|
||||||
index = await self._get_and_cache_vector_db_index(sanitized_collection_name)
|
|
||||||
if not index:
|
if not index:
|
||||||
raise VectorStoreNotFoundError(vector_db_id)
|
raise VectorStoreNotFoundError(vector_db_id)
|
||||||
|
|
||||||
|
@ -303,17 +426,15 @@ class WeaviateVectorIOAdapter(
|
||||||
query: InterleavedContent,
|
query: InterleavedContent,
|
||||||
params: dict[str, Any] | None = None,
|
params: dict[str, Any] | None = None,
|
||||||
) -> QueryChunksResponse:
|
) -> QueryChunksResponse:
|
||||||
sanitized_collection_name = sanitize_collection_name(vector_db_id, weaviate_format=True)
|
index = await self._get_and_cache_vector_db_index(vector_db_id)
|
||||||
index = await self._get_and_cache_vector_db_index(sanitized_collection_name)
|
|
||||||
if not index:
|
if not index:
|
||||||
raise VectorStoreNotFoundError(vector_db_id)
|
raise VectorStoreNotFoundError(vector_db_id)
|
||||||
|
|
||||||
return await index.query_chunks(query, params)
|
return await index.query_chunks(query, params)
|
||||||
|
|
||||||
async def delete_chunks(self, store_id: str, chunks_for_deletion: list[ChunkForDeletion]) -> None:
|
async def delete_chunks(self, store_id: str, chunks_for_deletion: list[ChunkForDeletion]) -> None:
|
||||||
sanitized_collection_name = sanitize_collection_name(store_id, weaviate_format=True)
|
index = await self._get_and_cache_vector_db_index(store_id)
|
||||||
index = await self._get_and_cache_vector_db_index(sanitized_collection_name)
|
|
||||||
if not index:
|
if not index:
|
||||||
raise ValueError(f"Vector DB {sanitized_collection_name} not found")
|
raise ValueError(f"Vector DB {store_id} not found")
|
||||||
|
|
||||||
await index.index.delete_chunks(chunks_for_deletion)
|
await index.index.delete_chunks(chunks_for_deletion)
|
||||||
|
|
|
@ -50,6 +50,7 @@ class ChunkForDeletion(BaseModel):
|
||||||
# Constants for reranker types
|
# Constants for reranker types
|
||||||
RERANKER_TYPE_RRF = "rrf"
|
RERANKER_TYPE_RRF = "rrf"
|
||||||
RERANKER_TYPE_WEIGHTED = "weighted"
|
RERANKER_TYPE_WEIGHTED = "weighted"
|
||||||
|
RERANKER_TYPE_NORMALIZED = "normalized"
|
||||||
|
|
||||||
|
|
||||||
def parse_pdf(data: bytes) -> str:
|
def parse_pdf(data: bytes) -> str:
|
||||||
|
@ -325,6 +326,8 @@ class VectorDBWithIndex:
|
||||||
weights = ranker.get("params", {}).get("weights", [0.5, 0.5])
|
weights = ranker.get("params", {}).get("weights", [0.5, 0.5])
|
||||||
reranker_type = RERANKER_TYPE_WEIGHTED
|
reranker_type = RERANKER_TYPE_WEIGHTED
|
||||||
reranker_params = {"alpha": weights[0] if len(weights) > 0 else 0.5}
|
reranker_params = {"alpha": weights[0] if len(weights) > 0 else 0.5}
|
||||||
|
elif strategy == "normalized":
|
||||||
|
reranker_type = RERANKER_TYPE_NORMALIZED
|
||||||
else:
|
else:
|
||||||
reranker_type = RERANKER_TYPE_RRF
|
reranker_type = RERANKER_TYPE_RRF
|
||||||
k_value = ranker.get("params", {}).get("k", 60.0)
|
k_value = ranker.get("params", {}).get("k", 60.0)
|
||||||
|
|
|
@ -22,16 +22,16 @@ def skip_if_provider_doesnt_support_openai_vector_stores(client_with_models):
|
||||||
vector_io_providers = [p for p in client_with_models.providers.list() if p.api == "vector_io"]
|
vector_io_providers = [p for p in client_with_models.providers.list() if p.api == "vector_io"]
|
||||||
for p in vector_io_providers:
|
for p in vector_io_providers:
|
||||||
if p.provider_type in [
|
if p.provider_type in [
|
||||||
"inline::faiss",
|
|
||||||
"inline::sqlite-vec",
|
|
||||||
"inline::milvus",
|
|
||||||
"inline::chromadb",
|
"inline::chromadb",
|
||||||
"remote::pgvector",
|
"inline::faiss",
|
||||||
"remote::chromadb",
|
"inline::milvus",
|
||||||
"remote::qdrant",
|
|
||||||
"inline::qdrant",
|
"inline::qdrant",
|
||||||
"remote::weaviate",
|
"inline::sqlite-vec",
|
||||||
|
"remote::chromadb",
|
||||||
"remote::milvus",
|
"remote::milvus",
|
||||||
|
"remote::pgvector",
|
||||||
|
"remote::qdrant",
|
||||||
|
"remote::weaviate",
|
||||||
]:
|
]:
|
||||||
return
|
return
|
||||||
|
|
||||||
|
@ -47,23 +47,25 @@ def skip_if_provider_doesnt_support_openai_vector_stores_search(client_with_mode
|
||||||
"inline::milvus",
|
"inline::milvus",
|
||||||
"inline::chromadb",
|
"inline::chromadb",
|
||||||
"inline::qdrant",
|
"inline::qdrant",
|
||||||
"remote::pgvector",
|
|
||||||
"remote::chromadb",
|
"remote::chromadb",
|
||||||
"remote::weaviate",
|
|
||||||
"remote::qdrant",
|
|
||||||
"remote::milvus",
|
"remote::milvus",
|
||||||
|
"remote::pgvector",
|
||||||
|
"remote::qdrant",
|
||||||
|
"remote::weaviate",
|
||||||
],
|
],
|
||||||
"keyword": [
|
"keyword": [
|
||||||
|
"inline::milvus",
|
||||||
"inline::sqlite-vec",
|
"inline::sqlite-vec",
|
||||||
"remote::milvus",
|
"remote::milvus",
|
||||||
"inline::milvus",
|
|
||||||
"remote::pgvector",
|
"remote::pgvector",
|
||||||
|
"remote::weaviate",
|
||||||
],
|
],
|
||||||
"hybrid": [
|
"hybrid": [
|
||||||
"inline::sqlite-vec",
|
|
||||||
"inline::milvus",
|
"inline::milvus",
|
||||||
|
"inline::sqlite-vec",
|
||||||
"remote::milvus",
|
"remote::milvus",
|
||||||
"remote::pgvector",
|
"remote::pgvector",
|
||||||
|
"remote::weaviate",
|
||||||
],
|
],
|
||||||
}
|
}
|
||||||
supported_providers = search_mode_support.get(search_mode, [])
|
supported_providers = search_mode_support.get(search_mode, [])
|
||||||
|
|
|
@ -138,8 +138,8 @@ def test_insert_chunks(client_with_empty_registry, embedding_model_id, embedding
|
||||||
def test_insert_chunks_with_precomputed_embeddings(client_with_empty_registry, embedding_model_id, embedding_dimension):
|
def test_insert_chunks_with_precomputed_embeddings(client_with_empty_registry, embedding_model_id, embedding_dimension):
|
||||||
vector_io_provider_params_dict = {
|
vector_io_provider_params_dict = {
|
||||||
"inline::milvus": {"score_threshold": -1.0},
|
"inline::milvus": {"score_threshold": -1.0},
|
||||||
"remote::qdrant": {"score_threshold": -1.0},
|
|
||||||
"inline::qdrant": {"score_threshold": -1.0},
|
"inline::qdrant": {"score_threshold": -1.0},
|
||||||
|
"remote::qdrant": {"score_threshold": -1.0},
|
||||||
}
|
}
|
||||||
vector_db_name = "test_precomputed_embeddings_db"
|
vector_db_name = "test_precomputed_embeddings_db"
|
||||||
register_response = client_with_empty_registry.vector_dbs.register(
|
register_response = client_with_empty_registry.vector_dbs.register(
|
||||||
|
|
|
@ -26,13 +26,15 @@ from llama_stack.providers.remote.vector_io.milvus.milvus import MilvusIndex, Mi
|
||||||
from llama_stack.providers.remote.vector_io.pgvector.config import PGVectorVectorIOConfig
|
from llama_stack.providers.remote.vector_io.pgvector.config import PGVectorVectorIOConfig
|
||||||
from llama_stack.providers.remote.vector_io.pgvector.pgvector import PGVectorIndex, PGVectorVectorIOAdapter
|
from llama_stack.providers.remote.vector_io.pgvector.pgvector import PGVectorIndex, PGVectorVectorIOAdapter
|
||||||
from llama_stack.providers.remote.vector_io.qdrant.qdrant import QdrantVectorIOAdapter
|
from llama_stack.providers.remote.vector_io.qdrant.qdrant import QdrantVectorIOAdapter
|
||||||
|
from llama_stack.providers.remote.vector_io.weaviate.config import WeaviateVectorIOConfig
|
||||||
|
from llama_stack.providers.remote.vector_io.weaviate.weaviate import WeaviateIndex, WeaviateVectorIOAdapter
|
||||||
|
|
||||||
EMBEDDING_DIMENSION = 384
|
EMBEDDING_DIMENSION = 384
|
||||||
COLLECTION_PREFIX = "test_collection"
|
COLLECTION_PREFIX = "test_collection"
|
||||||
MILVUS_ALIAS = "test_milvus"
|
MILVUS_ALIAS = "test_milvus"
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(params=["milvus", "sqlite_vec", "faiss", "chroma", "pgvector"])
|
@pytest.fixture(params=["milvus", "sqlite_vec", "faiss", "chroma", "pgvector", "weaviate"])
|
||||||
def vector_provider(request):
|
def vector_provider(request):
|
||||||
return request.param
|
return request.param
|
||||||
|
|
||||||
|
@ -448,6 +450,71 @@ async def pgvector_vec_adapter(mock_inference_api, embedding_dimension):
|
||||||
await adapter.shutdown()
|
await adapter.shutdown()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope="session")
|
||||||
|
def weaviate_vec_db_path(tmp_path_factory):
|
||||||
|
db_path = str(tmp_path_factory.getbasetemp() / "test_weaviate.db")
|
||||||
|
return db_path
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
async def weaviate_vec_index(weaviate_vec_db_path):
|
||||||
|
import pytest_socket
|
||||||
|
import weaviate
|
||||||
|
|
||||||
|
pytest_socket.enable_socket()
|
||||||
|
client = weaviate.connect_to_embedded(
|
||||||
|
hostname="localhost",
|
||||||
|
port=8080,
|
||||||
|
grpc_port=50051,
|
||||||
|
persistence_data_path=weaviate_vec_db_path,
|
||||||
|
)
|
||||||
|
index = WeaviateIndex(client=client, collection_name="Testcollection")
|
||||||
|
await index.initialize()
|
||||||
|
yield index
|
||||||
|
await index.delete()
|
||||||
|
client.close()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
async def weaviate_vec_adapter(weaviate_vec_db_path, mock_inference_api, embedding_dimension):
|
||||||
|
import pytest_socket
|
||||||
|
import weaviate
|
||||||
|
|
||||||
|
pytest_socket.enable_socket()
|
||||||
|
|
||||||
|
client = weaviate.connect_to_embedded(
|
||||||
|
hostname="localhost",
|
||||||
|
port=8080,
|
||||||
|
grpc_port=50051,
|
||||||
|
persistence_data_path=weaviate_vec_db_path,
|
||||||
|
)
|
||||||
|
|
||||||
|
config = WeaviateVectorIOConfig(
|
||||||
|
weaviate_cluster_url="localhost:8080",
|
||||||
|
weaviate_api_key=None,
|
||||||
|
kvstore=SqliteKVStoreConfig(),
|
||||||
|
)
|
||||||
|
adapter = WeaviateVectorIOAdapter(
|
||||||
|
config=config,
|
||||||
|
inference_api=mock_inference_api,
|
||||||
|
files_api=None,
|
||||||
|
)
|
||||||
|
collection_id = f"weaviate_test_collection_{random.randint(1, 1_000_000)}"
|
||||||
|
await adapter.initialize()
|
||||||
|
await adapter.register_vector_db(
|
||||||
|
VectorDB(
|
||||||
|
identifier=collection_id,
|
||||||
|
provider_id="test_provider",
|
||||||
|
embedding_model="test_model",
|
||||||
|
embedding_dimension=embedding_dimension,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
adapter.test_collection_id = collection_id
|
||||||
|
yield adapter
|
||||||
|
await adapter.shutdown()
|
||||||
|
client.close()
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def vector_io_adapter(vector_provider, request):
|
def vector_io_adapter(vector_provider, request):
|
||||||
vector_provider_dict = {
|
vector_provider_dict = {
|
||||||
|
@ -457,6 +524,7 @@ def vector_io_adapter(vector_provider, request):
|
||||||
"chroma": "chroma_vec_adapter",
|
"chroma": "chroma_vec_adapter",
|
||||||
"qdrant": "qdrant_vec_adapter",
|
"qdrant": "qdrant_vec_adapter",
|
||||||
"pgvector": "pgvector_vec_adapter",
|
"pgvector": "pgvector_vec_adapter",
|
||||||
|
"weaviate": "weaviate_vec_adapter",
|
||||||
}
|
}
|
||||||
return request.getfixturevalue(vector_provider_dict[vector_provider])
|
return request.getfixturevalue(vector_provider_dict[vector_provider])
|
||||||
|
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue