mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-12-05 02:17:31 +00:00
Merge branch 'main' into chroma
This commit is contained in:
commit
f856e53323
1881 changed files with 886579 additions and 84028 deletions
|
|
@ -22,16 +22,22 @@ from llama_stack.apis.vector_io import (
|
|||
)
|
||||
from llama_stack.log import get_logger
|
||||
from llama_stack.providers.datatypes import Api, VectorDBsProtocolPrivate
|
||||
from llama_stack.providers.inline.vector_io.chroma import ChromaVectorIOConfig as InlineChromaVectorIOConfig
|
||||
from llama_stack.providers.inline.vector_io.chroma import (
|
||||
ChromaVectorIOConfig as InlineChromaVectorIOConfig,
|
||||
)
|
||||
from llama_stack.providers.utils.kvstore import kvstore_impl
|
||||
from llama_stack.providers.utils.kvstore.api import KVStore
|
||||
from llama_stack.providers.utils.memory.openai_vector_store_mixin import OpenAIVectorStoreMixin
|
||||
from llama_stack.providers.utils.memory.openai_vector_store_mixin import (
|
||||
OpenAIVectorStoreMixin,
|
||||
)
|
||||
from llama_stack.providers.utils.memory.vector_store import (
|
||||
ChunkForDeletion,
|
||||
EmbeddingIndex,
|
||||
VectorDBWithIndex,
|
||||
)
|
||||
from llama_stack.providers.utils.vector_io.vector_utils import WeightedInMemoryAggregator
|
||||
from llama_stack.providers.utils.vector_io.vector_utils import (
|
||||
WeightedInMemoryAggregator,
|
||||
)
|
||||
|
||||
from .config import ChromaVectorIOConfig as RemoteChromaVectorIOConfig
|
||||
|
||||
|
|
@ -223,14 +229,13 @@ class ChromaVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtocolP
|
|||
inference_api: Api.inference,
|
||||
files_api: Files | None,
|
||||
) -> None:
|
||||
super().__init__(files_api=files_api, kvstore=None)
|
||||
log.info(f"Initializing ChromaVectorIOAdapter with url: {config}")
|
||||
self.config = config
|
||||
self.inference_api = inference_api
|
||||
self.client = None
|
||||
self.cache = {}
|
||||
self.kvstore: KVStore | None = None
|
||||
self.vector_db_store = None
|
||||
self.files_api = files_api
|
||||
|
||||
async def initialize(self) -> None:
|
||||
self.kvstore = await kvstore_impl(self.config.kvstore)
|
||||
|
|
@ -251,7 +256,8 @@ class ChromaVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtocolP
|
|||
self.openai_vector_stores = await self._load_openai_vector_stores()
|
||||
|
||||
async def shutdown(self) -> None:
|
||||
pass
|
||||
# Clean up mixin resources (file batch tasks)
|
||||
await super().shutdown()
|
||||
|
||||
async def register_vector_db(
|
||||
self,
|
||||
|
|
|
|||
|
|
@ -309,14 +309,12 @@ class MilvusVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtocolP
|
|||
inference_api: Inference,
|
||||
files_api: Files | None,
|
||||
) -> None:
|
||||
super().__init__(files_api=files_api, kvstore=None)
|
||||
self.config = config
|
||||
self.cache = {}
|
||||
self.client = None
|
||||
self.inference_api = inference_api
|
||||
self.files_api = files_api
|
||||
self.kvstore: KVStore | None = None
|
||||
self.vector_db_store = None
|
||||
self.openai_vector_stores: dict[str, dict[str, Any]] = {}
|
||||
self.metadata_collection_name = "openai_vector_stores_metadata"
|
||||
|
||||
async def initialize(self) -> None:
|
||||
|
|
@ -351,6 +349,8 @@ class MilvusVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtocolP
|
|||
|
||||
async def shutdown(self) -> None:
|
||||
self.client.close()
|
||||
# Clean up mixin resources (file batch tasks)
|
||||
await super().shutdown()
|
||||
|
||||
async def register_vector_db(
|
||||
self,
|
||||
|
|
|
|||
|
|
@ -345,14 +345,12 @@ class PGVectorVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtoco
|
|||
inference_api: Api.inference,
|
||||
files_api: Files | None = None,
|
||||
) -> None:
|
||||
super().__init__(files_api=files_api, kvstore=None)
|
||||
self.config = config
|
||||
self.inference_api = inference_api
|
||||
self.conn = None
|
||||
self.cache = {}
|
||||
self.files_api = files_api
|
||||
self.kvstore: KVStore | None = None
|
||||
self.vector_db_store = None
|
||||
self.openai_vector_stores: dict[str, dict[str, Any]] = {}
|
||||
self.metadata_collection_name = "openai_vector_stores_metadata"
|
||||
|
||||
async def initialize(self) -> None:
|
||||
|
|
@ -392,6 +390,8 @@ class PGVectorVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtoco
|
|||
if self.conn is not None:
|
||||
self.conn.close()
|
||||
log.info("Connection to PGVector database server closed")
|
||||
# Clean up mixin resources (file batch tasks)
|
||||
await super().shutdown()
|
||||
|
||||
async def register_vector_db(self, vector_db: VectorDB) -> None:
|
||||
# Persist vector DB metadata in the KV store
|
||||
|
|
|
|||
|
|
@ -27,7 +27,7 @@ from llama_stack.apis.vector_io import (
|
|||
from llama_stack.log import get_logger
|
||||
from llama_stack.providers.datatypes import Api, VectorDBsProtocolPrivate
|
||||
from llama_stack.providers.inline.vector_io.qdrant import QdrantVectorIOConfig as InlineQdrantVectorIOConfig
|
||||
from llama_stack.providers.utils.kvstore import KVStore, kvstore_impl
|
||||
from llama_stack.providers.utils.kvstore import kvstore_impl
|
||||
from llama_stack.providers.utils.memory.openai_vector_store_mixin import OpenAIVectorStoreMixin
|
||||
from llama_stack.providers.utils.memory.vector_store import (
|
||||
ChunkForDeletion,
|
||||
|
|
@ -162,14 +162,12 @@ class QdrantVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtocolP
|
|||
inference_api: Api.inference,
|
||||
files_api: Files | None = None,
|
||||
) -> None:
|
||||
super().__init__(files_api=files_api, kvstore=None)
|
||||
self.config = config
|
||||
self.client: AsyncQdrantClient = None
|
||||
self.cache = {}
|
||||
self.inference_api = inference_api
|
||||
self.files_api = files_api
|
||||
self.vector_db_store = None
|
||||
self.kvstore: KVStore | None = None
|
||||
self.openai_vector_stores: dict[str, dict[str, Any]] = {}
|
||||
self._qdrant_lock = asyncio.Lock()
|
||||
|
||||
async def initialize(self) -> None:
|
||||
|
|
@ -193,6 +191,8 @@ class QdrantVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtocolP
|
|||
|
||||
async def shutdown(self) -> None:
|
||||
await self.client.close()
|
||||
# Clean up mixin resources (file batch tasks)
|
||||
await super().shutdown()
|
||||
|
||||
async def register_vector_db(
|
||||
self,
|
||||
|
|
|
|||
|
|
@ -10,7 +10,7 @@ import weaviate
|
|||
import weaviate.classes as wvc
|
||||
from numpy.typing import NDArray
|
||||
from weaviate.classes.init import Auth
|
||||
from weaviate.classes.query import Filter
|
||||
from weaviate.classes.query import Filter, HybridFusion
|
||||
|
||||
from llama_stack.apis.common.content_types import InterleavedContent
|
||||
from llama_stack.apis.common.errors import VectorStoreNotFoundError
|
||||
|
|
@ -26,6 +26,7 @@ from llama_stack.providers.utils.memory.openai_vector_store_mixin import (
|
|||
OpenAIVectorStoreMixin,
|
||||
)
|
||||
from llama_stack.providers.utils.memory.vector_store import (
|
||||
RERANKER_TYPE_RRF,
|
||||
ChunkForDeletion,
|
||||
EmbeddingIndex,
|
||||
VectorDBWithIndex,
|
||||
|
|
@ -47,7 +48,7 @@ OPENAI_VECTOR_STORES_FILES_CONTENTS_PREFIX = f"openai_vector_stores_files_conten
|
|||
class WeaviateIndex(EmbeddingIndex):
|
||||
def __init__(
|
||||
self,
|
||||
client: weaviate.Client,
|
||||
client: weaviate.WeaviateClient,
|
||||
collection_name: str,
|
||||
kvstore: KVStore | None = None,
|
||||
):
|
||||
|
|
@ -64,14 +65,14 @@ class WeaviateIndex(EmbeddingIndex):
|
|||
)
|
||||
|
||||
data_objects = []
|
||||
for i, chunk in enumerate(chunks):
|
||||
for chunk, embedding in zip(chunks, embeddings, strict=False):
|
||||
data_objects.append(
|
||||
wvc.data.DataObject(
|
||||
properties={
|
||||
"chunk_id": chunk.chunk_id,
|
||||
"chunk_content": chunk.model_dump_json(),
|
||||
},
|
||||
vector=embeddings[i].tolist(),
|
||||
vector=embedding.tolist(),
|
||||
)
|
||||
)
|
||||
|
||||
|
|
@ -88,14 +89,30 @@ class WeaviateIndex(EmbeddingIndex):
|
|||
collection.data.delete_many(where=Filter.by_property("chunk_id").contains_any(chunk_ids))
|
||||
|
||||
async def query_vector(self, embedding: NDArray, k: int, score_threshold: float) -> QueryChunksResponse:
|
||||
"""
|
||||
Performs vector search using Weaviate's built-in vector search.
|
||||
Args:
|
||||
embedding: The query embedding vector
|
||||
k: Limit of number of results to return
|
||||
score_threshold: Minimum similarity score threshold
|
||||
Returns:
|
||||
QueryChunksResponse with chunks and scores.
|
||||
"""
|
||||
log.debug(
|
||||
f"WEAVIATE VECTOR SEARCH CALLED: embedding_shape={embedding.shape}, k={k}, threshold={score_threshold}"
|
||||
)
|
||||
sanitized_collection_name = sanitize_collection_name(self.collection_name, weaviate_format=True)
|
||||
collection = self.client.collections.get(sanitized_collection_name)
|
||||
|
||||
results = collection.query.near_vector(
|
||||
near_vector=embedding.tolist(),
|
||||
limit=k,
|
||||
return_metadata=wvc.query.MetadataQuery(distance=True),
|
||||
)
|
||||
try:
|
||||
results = collection.query.near_vector(
|
||||
near_vector=embedding.tolist(),
|
||||
limit=k,
|
||||
return_metadata=wvc.query.MetadataQuery(distance=True),
|
||||
)
|
||||
except Exception as e:
|
||||
log.error(f"Weaviate client vector search failed: {e}")
|
||||
raise
|
||||
|
||||
chunks = []
|
||||
scores = []
|
||||
|
|
@ -108,13 +125,17 @@ class WeaviateIndex(EmbeddingIndex):
|
|||
log.exception(f"Failed to parse document: {chunk_json}")
|
||||
continue
|
||||
|
||||
score = 1.0 / doc.metadata.distance if doc.metadata.distance != 0 else float("inf")
|
||||
if doc.metadata.distance is None:
|
||||
continue
|
||||
# Convert cosine distance ∈ [0,2] -> normalized cosine similarity ∈ [0,1]
|
||||
score = 1.0 - (float(doc.metadata.distance) / 2.0)
|
||||
if score < score_threshold:
|
||||
continue
|
||||
|
||||
chunks.append(chunk)
|
||||
scores.append(score)
|
||||
|
||||
log.debug(f"WEAVIATE VECTOR SEARCH RESULTS: Found {len(chunks)} chunks with scores {scores}")
|
||||
return QueryChunksResponse(chunks=chunks, scores=scores)
|
||||
|
||||
async def delete(self, chunk_ids: list[str] | None = None) -> None:
|
||||
|
|
@ -136,7 +157,50 @@ class WeaviateIndex(EmbeddingIndex):
|
|||
k: int,
|
||||
score_threshold: float,
|
||||
) -> QueryChunksResponse:
|
||||
raise NotImplementedError("Keyword search is not supported in Weaviate")
|
||||
"""
|
||||
Performs BM25-based keyword search using Weaviate's built-in full-text search.
|
||||
Args:
|
||||
query_string: The text query for keyword search
|
||||
k: Limit of number of results to return
|
||||
score_threshold: Minimum similarity score threshold
|
||||
Returns:
|
||||
QueryChunksResponse with chunks and scores
|
||||
"""
|
||||
log.debug(f"WEAVIATE KEYWORD SEARCH CALLED: query='{query_string}', k={k}, threshold={score_threshold}")
|
||||
sanitized_collection_name = sanitize_collection_name(self.collection_name, weaviate_format=True)
|
||||
collection = self.client.collections.get(sanitized_collection_name)
|
||||
|
||||
# Perform BM25 keyword search on chunk_content field
|
||||
try:
|
||||
results = collection.query.bm25(
|
||||
query=query_string,
|
||||
limit=k,
|
||||
return_metadata=wvc.query.MetadataQuery(score=True),
|
||||
)
|
||||
except Exception as e:
|
||||
log.error(f"Weaviate client keyword search failed: {e}")
|
||||
raise
|
||||
|
||||
chunks = []
|
||||
scores = []
|
||||
for doc in results.objects:
|
||||
chunk_json = doc.properties["chunk_content"]
|
||||
try:
|
||||
chunk_dict = json.loads(chunk_json)
|
||||
chunk = Chunk(**chunk_dict)
|
||||
except Exception:
|
||||
log.exception(f"Failed to parse document: {chunk_json}")
|
||||
continue
|
||||
|
||||
score = doc.metadata.score if doc.metadata.score is not None else 0.0
|
||||
if score < score_threshold:
|
||||
continue
|
||||
|
||||
chunks.append(chunk)
|
||||
scores.append(score)
|
||||
|
||||
log.debug(f"WEAVIATE KEYWORD SEARCH RESULTS: Found {len(chunks)} chunks with scores {scores}.")
|
||||
return QueryChunksResponse(chunks=chunks, scores=scores)
|
||||
|
||||
async def query_hybrid(
|
||||
self,
|
||||
|
|
@ -147,7 +211,65 @@ class WeaviateIndex(EmbeddingIndex):
|
|||
reranker_type: str,
|
||||
reranker_params: dict[str, Any] | None = None,
|
||||
) -> QueryChunksResponse:
|
||||
raise NotImplementedError("Hybrid search is not supported in Weaviate")
|
||||
"""
|
||||
Hybrid search combining vector similarity and keyword search using Weaviate's native hybrid search.
|
||||
Args:
|
||||
embedding: The query embedding vector
|
||||
query_string: The text query for keyword search
|
||||
k: Limit of number of results to return
|
||||
score_threshold: Minimum similarity score threshold
|
||||
reranker_type: Type of reranker to use ("rrf" or "normalized")
|
||||
reranker_params: Parameters for the reranker
|
||||
Returns:
|
||||
QueryChunksResponse with combined results
|
||||
"""
|
||||
log.debug(
|
||||
f"WEAVIATE HYBRID SEARCH CALLED: query='{query_string}', embedding_shape={embedding.shape}, k={k}, threshold={score_threshold}, reranker={reranker_type}"
|
||||
)
|
||||
sanitized_collection_name = sanitize_collection_name(self.collection_name, weaviate_format=True)
|
||||
collection = self.client.collections.get(sanitized_collection_name)
|
||||
|
||||
# Ranked (RRF) reranker fusion type
|
||||
if reranker_type == RERANKER_TYPE_RRF:
|
||||
rerank = HybridFusion.RANKED
|
||||
# Relative score (Normalized) reranker fusion type
|
||||
else:
|
||||
rerank = HybridFusion.RELATIVE_SCORE
|
||||
|
||||
# Perform hybrid search using Weaviate's native hybrid search
|
||||
try:
|
||||
results = collection.query.hybrid(
|
||||
query=query_string,
|
||||
alpha=0.5, # Range <0, 1>, where 0.5 will equally favor vector and keyword search
|
||||
vector=embedding.tolist(),
|
||||
limit=k,
|
||||
fusion_type=rerank,
|
||||
return_metadata=wvc.query.MetadataQuery(score=True),
|
||||
)
|
||||
except Exception as e:
|
||||
log.error(f"Weaviate client hybrid search failed: {e}")
|
||||
raise
|
||||
|
||||
chunks = []
|
||||
scores = []
|
||||
for doc in results.objects:
|
||||
chunk_json = doc.properties["chunk_content"]
|
||||
try:
|
||||
chunk_dict = json.loads(chunk_json)
|
||||
chunk = Chunk(**chunk_dict)
|
||||
except Exception:
|
||||
log.exception(f"Failed to parse document: {chunk_json}")
|
||||
continue
|
||||
|
||||
score = doc.metadata.score if doc.metadata.score is not None else 0.0
|
||||
if score < score_threshold:
|
||||
continue
|
||||
|
||||
chunks.append(chunk)
|
||||
scores.append(score)
|
||||
|
||||
log.debug(f"WEAVIATE HYBRID SEARCH RESULTS: Found {len(chunks)} chunks with scores {scores}")
|
||||
return QueryChunksResponse(chunks=chunks, scores=scores)
|
||||
|
||||
|
||||
class WeaviateVectorIOAdapter(
|
||||
|
|
@ -162,19 +284,17 @@ class WeaviateVectorIOAdapter(
|
|||
inference_api: Api.inference,
|
||||
files_api: Files | None,
|
||||
) -> None:
|
||||
super().__init__(files_api=files_api, kvstore=None)
|
||||
self.config = config
|
||||
self.inference_api = inference_api
|
||||
self.client_cache = {}
|
||||
self.cache = {}
|
||||
self.files_api = files_api
|
||||
self.kvstore: KVStore | None = None
|
||||
self.vector_db_store = None
|
||||
self.openai_vector_stores: dict[str, dict[str, Any]] = {}
|
||||
self.metadata_collection_name = "openai_vector_stores_metadata"
|
||||
|
||||
def _get_client(self) -> weaviate.Client:
|
||||
def _get_client(self) -> weaviate.WeaviateClient:
|
||||
if "localhost" in self.config.weaviate_cluster_url:
|
||||
log.info("using Weaviate locally in container")
|
||||
log.info("Using Weaviate locally in container")
|
||||
host, port = self.config.weaviate_cluster_url.split(":")
|
||||
key = "local_test"
|
||||
client = weaviate.connect_to_local(
|
||||
|
|
@ -227,6 +347,8 @@ class WeaviateVectorIOAdapter(
|
|||
async def shutdown(self) -> None:
|
||||
for client in self.client_cache.values():
|
||||
client.close()
|
||||
# Clean up mixin resources (file batch tasks)
|
||||
await super().shutdown()
|
||||
|
||||
async def register_vector_db(
|
||||
self,
|
||||
|
|
@ -247,7 +369,7 @@ class WeaviateVectorIOAdapter(
|
|||
],
|
||||
)
|
||||
|
||||
self.cache[sanitized_collection_name] = VectorDBWithIndex(
|
||||
self.cache[vector_db.identifier] = VectorDBWithIndex(
|
||||
vector_db,
|
||||
WeaviateIndex(client=client, collection_name=sanitized_collection_name),
|
||||
self.inference_api,
|
||||
|
|
@ -256,32 +378,34 @@ class WeaviateVectorIOAdapter(
|
|||
async def unregister_vector_db(self, vector_db_id: str) -> None:
|
||||
client = self._get_client()
|
||||
sanitized_collection_name = sanitize_collection_name(vector_db_id, weaviate_format=True)
|
||||
if sanitized_collection_name not in self.cache or client.collections.exists(sanitized_collection_name) is False:
|
||||
log.warning(f"Vector DB {sanitized_collection_name} not found")
|
||||
if vector_db_id not in self.cache or client.collections.exists(sanitized_collection_name) is False:
|
||||
return
|
||||
client.collections.delete(sanitized_collection_name)
|
||||
await self.cache[sanitized_collection_name].index.delete()
|
||||
del self.cache[sanitized_collection_name]
|
||||
await self.cache[vector_db_id].index.delete()
|
||||
del self.cache[vector_db_id]
|
||||
|
||||
async def _get_and_cache_vector_db_index(self, vector_db_id: str) -> VectorDBWithIndex | None:
|
||||
sanitized_collection_name = sanitize_collection_name(vector_db_id, weaviate_format=True)
|
||||
if sanitized_collection_name in self.cache:
|
||||
return self.cache[sanitized_collection_name]
|
||||
if vector_db_id in self.cache:
|
||||
return self.cache[vector_db_id]
|
||||
|
||||
vector_db = await self.vector_db_store.get_vector_db(sanitized_collection_name)
|
||||
if self.vector_db_store is None:
|
||||
raise VectorStoreNotFoundError(vector_db_id)
|
||||
|
||||
vector_db = await self.vector_db_store.get_vector_db(vector_db_id)
|
||||
if not vector_db:
|
||||
raise VectorStoreNotFoundError(vector_db_id)
|
||||
|
||||
client = self._get_client()
|
||||
if not client.collections.exists(vector_db.identifier):
|
||||
sanitized_collection_name = sanitize_collection_name(vector_db.identifier, weaviate_format=True)
|
||||
if not client.collections.exists(sanitized_collection_name):
|
||||
raise ValueError(f"Collection with name `{sanitized_collection_name}` not found")
|
||||
|
||||
index = VectorDBWithIndex(
|
||||
vector_db=vector_db,
|
||||
index=WeaviateIndex(client=client, collection_name=sanitized_collection_name),
|
||||
index=WeaviateIndex(client=client, collection_name=vector_db.identifier),
|
||||
inference_api=self.inference_api,
|
||||
)
|
||||
self.cache[sanitized_collection_name] = index
|
||||
self.cache[vector_db_id] = index
|
||||
return index
|
||||
|
||||
async def insert_chunks(
|
||||
|
|
@ -290,8 +414,7 @@ class WeaviateVectorIOAdapter(
|
|||
chunks: list[Chunk],
|
||||
ttl_seconds: int | None = None,
|
||||
) -> None:
|
||||
sanitized_collection_name = sanitize_collection_name(vector_db_id, weaviate_format=True)
|
||||
index = await self._get_and_cache_vector_db_index(sanitized_collection_name)
|
||||
index = await self._get_and_cache_vector_db_index(vector_db_id)
|
||||
if not index:
|
||||
raise VectorStoreNotFoundError(vector_db_id)
|
||||
|
||||
|
|
@ -303,17 +426,15 @@ class WeaviateVectorIOAdapter(
|
|||
query: InterleavedContent,
|
||||
params: dict[str, Any] | None = None,
|
||||
) -> QueryChunksResponse:
|
||||
sanitized_collection_name = sanitize_collection_name(vector_db_id, weaviate_format=True)
|
||||
index = await self._get_and_cache_vector_db_index(sanitized_collection_name)
|
||||
index = await self._get_and_cache_vector_db_index(vector_db_id)
|
||||
if not index:
|
||||
raise VectorStoreNotFoundError(vector_db_id)
|
||||
|
||||
return await index.query_chunks(query, params)
|
||||
|
||||
async def delete_chunks(self, store_id: str, chunks_for_deletion: list[ChunkForDeletion]) -> None:
|
||||
sanitized_collection_name = sanitize_collection_name(store_id, weaviate_format=True)
|
||||
index = await self._get_and_cache_vector_db_index(sanitized_collection_name)
|
||||
index = await self._get_and_cache_vector_db_index(store_id)
|
||||
if not index:
|
||||
raise ValueError(f"Vector DB {sanitized_collection_name} not found")
|
||||
raise ValueError(f"Vector DB {store_id} not found")
|
||||
|
||||
await index.index.delete_chunks(chunks_for_deletion)
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue