mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-10-23 00:27:26 +00:00
chore: Updating how default embedding model is set in stack (#3818)
# What does this PR do? Refactor setting default vector store provider and embedding model to use an optional `vector_stores` config in the `StackRunConfig` and clean up code to do so (had to add back in some pieces of VectorDB). Also added remote Qdrant and Weaviate to starter distro (based on other PR where inference providers were added for UX). New config is simply (default for Starter distro): ```yaml vector_stores: default_provider_id: faiss default_embedding_model: provider_id: sentence-transformers model_id: nomic-ai/nomic-embed-text-v1.5 ``` ## Test Plan CI and Unit tests. --------- Signed-off-by: Francisco Javier Arceo <farceo@redhat.com> Co-authored-by: Ashwin Bharambe <ashwin.bharambe@gmail.com>
This commit is contained in:
parent
2c43285e22
commit
48581bf651
48 changed files with 973 additions and 818 deletions
|
@ -12,11 +12,6 @@ from .config import ChromaVectorIOConfig
|
|||
async def get_adapter_impl(config: ChromaVectorIOConfig, deps: dict[Api, ProviderSpec]):
|
||||
from .chroma import ChromaVectorIOAdapter
|
||||
|
||||
impl = ChromaVectorIOAdapter(
|
||||
config,
|
||||
deps[Api.inference],
|
||||
deps[Api.models],
|
||||
deps.get(Api.files),
|
||||
)
|
||||
impl = ChromaVectorIOAdapter(config, deps[Api.inference], deps.get(Api.files))
|
||||
await impl.initialize()
|
||||
return impl
|
||||
|
|
|
@ -12,24 +12,16 @@ import chromadb
|
|||
from numpy.typing import NDArray
|
||||
|
||||
from llama_stack.apis.files import Files
|
||||
from llama_stack.apis.inference import InterleavedContent
|
||||
from llama_stack.apis.inference import Inference, InterleavedContent
|
||||
from llama_stack.apis.vector_dbs import VectorDB
|
||||
from llama_stack.apis.vector_io import (
|
||||
Chunk,
|
||||
QueryChunksResponse,
|
||||
VectorIO,
|
||||
)
|
||||
from llama_stack.apis.vector_io import Chunk, QueryChunksResponse, VectorIO
|
||||
from llama_stack.log import get_logger
|
||||
from llama_stack.providers.datatypes import Api, VectorDBsProtocolPrivate
|
||||
from llama_stack.providers.datatypes import VectorDBsProtocolPrivate
|
||||
from llama_stack.providers.inline.vector_io.chroma import ChromaVectorIOConfig as InlineChromaVectorIOConfig
|
||||
from llama_stack.providers.utils.kvstore import kvstore_impl
|
||||
from llama_stack.providers.utils.kvstore.api import KVStore
|
||||
from llama_stack.providers.utils.memory.openai_vector_store_mixin import OpenAIVectorStoreMixin
|
||||
from llama_stack.providers.utils.memory.vector_store import (
|
||||
ChunkForDeletion,
|
||||
EmbeddingIndex,
|
||||
VectorDBWithIndex,
|
||||
)
|
||||
from llama_stack.providers.utils.memory.vector_store import ChunkForDeletion, EmbeddingIndex, VectorDBWithIndex
|
||||
|
||||
from .config import ChromaVectorIOConfig as RemoteChromaVectorIOConfig
|
||||
|
||||
|
@ -68,19 +60,13 @@ class ChromaIndex(EmbeddingIndex):
|
|||
|
||||
ids = [f"{c.metadata.get('document_id', '')}:{c.chunk_id}" for c in chunks]
|
||||
await maybe_await(
|
||||
self.collection.add(
|
||||
documents=[chunk.model_dump_json() for chunk in chunks],
|
||||
embeddings=embeddings,
|
||||
ids=ids,
|
||||
)
|
||||
self.collection.add(documents=[chunk.model_dump_json() for chunk in chunks], embeddings=embeddings, ids=ids)
|
||||
)
|
||||
|
||||
async def query_vector(self, embedding: NDArray, k: int, score_threshold: float) -> QueryChunksResponse:
|
||||
results = await maybe_await(
|
||||
self.collection.query(
|
||||
query_embeddings=[embedding.tolist()],
|
||||
n_results=k,
|
||||
include=["documents", "distances"],
|
||||
query_embeddings=[embedding.tolist()], n_results=k, include=["documents", "distances"]
|
||||
)
|
||||
)
|
||||
distances = results["distances"][0]
|
||||
|
@ -108,12 +94,7 @@ class ChromaIndex(EmbeddingIndex):
|
|||
async def delete(self):
|
||||
await maybe_await(self.client.delete_collection(self.collection.name))
|
||||
|
||||
async def query_keyword(
|
||||
self,
|
||||
query_string: str,
|
||||
k: int,
|
||||
score_threshold: float,
|
||||
) -> QueryChunksResponse:
|
||||
async def query_keyword(self, query_string: str, k: int, score_threshold: float) -> QueryChunksResponse:
|
||||
raise NotImplementedError("Keyword search is not supported in Chroma")
|
||||
|
||||
async def delete_chunks(self, chunks_for_deletion: list[ChunkForDeletion]) -> None:
|
||||
|
@ -137,15 +118,13 @@ class ChromaVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtocolP
|
|||
def __init__(
|
||||
self,
|
||||
config: RemoteChromaVectorIOConfig | InlineChromaVectorIOConfig,
|
||||
inference_api: Api.inference,
|
||||
models_apis: Api.models,
|
||||
inference_api: Inference,
|
||||
files_api: Files | None,
|
||||
) -> None:
|
||||
super().__init__(files_api=files_api, kvstore=None)
|
||||
log.info(f"Initializing ChromaVectorIOAdapter with url: {config}")
|
||||
self.config = config
|
||||
self.inference_api = inference_api
|
||||
self.models_api = models_apis
|
||||
self.client = None
|
||||
self.cache = {}
|
||||
self.vector_db_store = None
|
||||
|
@ -172,14 +151,10 @@ class ChromaVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtocolP
|
|||
# Clean up mixin resources (file batch tasks)
|
||||
await super().shutdown()
|
||||
|
||||
async def register_vector_db(
|
||||
self,
|
||||
vector_db: VectorDB,
|
||||
) -> None:
|
||||
async def register_vector_db(self, vector_db: VectorDB) -> None:
|
||||
collection = await maybe_await(
|
||||
self.client.get_or_create_collection(
|
||||
name=vector_db.identifier,
|
||||
metadata={"vector_db": vector_db.model_dump_json()},
|
||||
name=vector_db.identifier, metadata={"vector_db": vector_db.model_dump_json()}
|
||||
)
|
||||
)
|
||||
self.cache[vector_db.identifier] = VectorDBWithIndex(
|
||||
|
@ -194,12 +169,7 @@ class ChromaVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtocolP
|
|||
await self.cache[vector_db_id].index.delete()
|
||||
del self.cache[vector_db_id]
|
||||
|
||||
async def insert_chunks(
|
||||
self,
|
||||
vector_db_id: str,
|
||||
chunks: list[Chunk],
|
||||
ttl_seconds: int | None = None,
|
||||
) -> None:
|
||||
async def insert_chunks(self, vector_db_id: str, chunks: list[Chunk], ttl_seconds: int | None = None) -> None:
|
||||
index = await self._get_and_cache_vector_db_index(vector_db_id)
|
||||
if index is None:
|
||||
raise ValueError(f"Vector DB {vector_db_id} not found in Chroma")
|
||||
|
@ -207,10 +177,7 @@ class ChromaVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtocolP
|
|||
await index.insert_chunks(chunks)
|
||||
|
||||
async def query_chunks(
|
||||
self,
|
||||
vector_db_id: str,
|
||||
query: InterleavedContent,
|
||||
params: dict[str, Any] | None = None,
|
||||
self, vector_db_id: str, query: InterleavedContent, params: dict[str, Any] | None = None
|
||||
) -> QueryChunksResponse:
|
||||
index = await self._get_and_cache_vector_db_index(vector_db_id)
|
||||
|
||||
|
|
|
@ -13,12 +13,6 @@ async def get_adapter_impl(config: MilvusVectorIOConfig, deps: dict[Api, Provide
|
|||
from .milvus import MilvusVectorIOAdapter
|
||||
|
||||
assert isinstance(config, MilvusVectorIOConfig), f"Unexpected config type: {type(config)}"
|
||||
|
||||
impl = MilvusVectorIOAdapter(
|
||||
config,
|
||||
deps[Api.inference],
|
||||
deps[Api.models],
|
||||
deps.get(Api.files),
|
||||
)
|
||||
impl = MilvusVectorIOAdapter(config, deps[Api.inference], deps.get(Api.files))
|
||||
await impl.initialize()
|
||||
return impl
|
||||
|
|
|
@ -14,13 +14,8 @@ from pymilvus import AnnSearchRequest, DataType, Function, FunctionType, MilvusC
|
|||
from llama_stack.apis.common.errors import VectorStoreNotFoundError
|
||||
from llama_stack.apis.files import Files
|
||||
from llama_stack.apis.inference import Inference, InterleavedContent
|
||||
from llama_stack.apis.models import Models
|
||||
from llama_stack.apis.vector_dbs import VectorDB
|
||||
from llama_stack.apis.vector_io import (
|
||||
Chunk,
|
||||
QueryChunksResponse,
|
||||
VectorIO,
|
||||
)
|
||||
from llama_stack.apis.vector_io import Chunk, QueryChunksResponse, VectorIO
|
||||
from llama_stack.log import get_logger
|
||||
from llama_stack.providers.datatypes import VectorDBsProtocolPrivate
|
||||
from llama_stack.providers.inline.vector_io.milvus import MilvusVectorIOConfig as InlineMilvusVectorIOConfig
|
||||
|
@ -74,46 +69,23 @@ class MilvusIndex(EmbeddingIndex):
|
|||
logger.info(f"Creating new collection {self.collection_name} with nullable sparse field")
|
||||
# Create schema for vector search
|
||||
schema = self.client.create_schema()
|
||||
schema.add_field(
|
||||
field_name="chunk_id",
|
||||
datatype=DataType.VARCHAR,
|
||||
is_primary=True,
|
||||
max_length=100,
|
||||
)
|
||||
schema.add_field(field_name="chunk_id", datatype=DataType.VARCHAR, is_primary=True, max_length=100)
|
||||
schema.add_field(
|
||||
field_name="content",
|
||||
datatype=DataType.VARCHAR,
|
||||
max_length=65535,
|
||||
enable_analyzer=True, # Enable text analysis for BM25
|
||||
)
|
||||
schema.add_field(
|
||||
field_name="vector",
|
||||
datatype=DataType.FLOAT_VECTOR,
|
||||
dim=len(embeddings[0]),
|
||||
)
|
||||
schema.add_field(
|
||||
field_name="chunk_content",
|
||||
datatype=DataType.JSON,
|
||||
)
|
||||
schema.add_field(field_name="vector", datatype=DataType.FLOAT_VECTOR, dim=len(embeddings[0]))
|
||||
schema.add_field(field_name="chunk_content", datatype=DataType.JSON)
|
||||
# Add sparse vector field for BM25 (required by the function)
|
||||
schema.add_field(
|
||||
field_name="sparse",
|
||||
datatype=DataType.SPARSE_FLOAT_VECTOR,
|
||||
)
|
||||
schema.add_field(field_name="sparse", datatype=DataType.SPARSE_FLOAT_VECTOR)
|
||||
|
||||
# Create indexes
|
||||
index_params = self.client.prepare_index_params()
|
||||
index_params.add_index(
|
||||
field_name="vector",
|
||||
index_type="FLAT",
|
||||
metric_type="COSINE",
|
||||
)
|
||||
index_params.add_index(field_name="vector", index_type="FLAT", metric_type="COSINE")
|
||||
# Add index for sparse field (required by BM25 function)
|
||||
index_params.add_index(
|
||||
field_name="sparse",
|
||||
index_type="SPARSE_INVERTED_INDEX",
|
||||
metric_type="BM25",
|
||||
)
|
||||
index_params.add_index(field_name="sparse", index_type="SPARSE_INVERTED_INDEX", metric_type="BM25")
|
||||
|
||||
# Add BM25 function for full-text search
|
||||
bm25_function = Function(
|
||||
|
@ -144,11 +116,7 @@ class MilvusIndex(EmbeddingIndex):
|
|||
}
|
||||
)
|
||||
try:
|
||||
await asyncio.to_thread(
|
||||
self.client.insert,
|
||||
self.collection_name,
|
||||
data=data,
|
||||
)
|
||||
await asyncio.to_thread(self.client.insert, self.collection_name, data=data)
|
||||
except Exception as e:
|
||||
logger.error(f"Error inserting chunks into Milvus collection {self.collection_name}: {e}")
|
||||
raise e
|
||||
|
@ -167,12 +135,7 @@ class MilvusIndex(EmbeddingIndex):
|
|||
scores = [res["distance"] for res in search_res[0]]
|
||||
return QueryChunksResponse(chunks=chunks, scores=scores)
|
||||
|
||||
async def query_keyword(
|
||||
self,
|
||||
query_string: str,
|
||||
k: int,
|
||||
score_threshold: float,
|
||||
) -> QueryChunksResponse:
|
||||
async def query_keyword(self, query_string: str, k: int, score_threshold: float) -> QueryChunksResponse:
|
||||
"""
|
||||
Perform BM25-based keyword search using Milvus's built-in full-text search.
|
||||
"""
|
||||
|
@ -210,12 +173,7 @@ class MilvusIndex(EmbeddingIndex):
|
|||
# Fallback to simple text search
|
||||
return await self._fallback_keyword_search(query_string, k, score_threshold)
|
||||
|
||||
async def _fallback_keyword_search(
|
||||
self,
|
||||
query_string: str,
|
||||
k: int,
|
||||
score_threshold: float,
|
||||
) -> QueryChunksResponse:
|
||||
async def _fallback_keyword_search(self, query_string: str, k: int, score_threshold: float) -> QueryChunksResponse:
|
||||
"""
|
||||
Fallback to simple text search when BM25 search is not available.
|
||||
"""
|
||||
|
@ -308,7 +266,6 @@ class MilvusVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtocolP
|
|||
self,
|
||||
config: RemoteMilvusVectorIOConfig | InlineMilvusVectorIOConfig,
|
||||
inference_api: Inference,
|
||||
models_api: Models,
|
||||
files_api: Files | None,
|
||||
) -> None:
|
||||
super().__init__(files_api=files_api, kvstore=None)
|
||||
|
@ -316,7 +273,6 @@ class MilvusVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtocolP
|
|||
self.cache = {}
|
||||
self.client = None
|
||||
self.inference_api = inference_api
|
||||
self.models_api = models_api
|
||||
self.vector_db_store = None
|
||||
self.metadata_collection_name = "openai_vector_stores_metadata"
|
||||
|
||||
|
@ -355,10 +311,7 @@ class MilvusVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtocolP
|
|||
# Clean up mixin resources (file batch tasks)
|
||||
await super().shutdown()
|
||||
|
||||
async def register_vector_db(
|
||||
self,
|
||||
vector_db: VectorDB,
|
||||
) -> None:
|
||||
async def register_vector_db(self, vector_db: VectorDB) -> None:
|
||||
if isinstance(self.config, RemoteMilvusVectorIOConfig):
|
||||
consistency_level = self.config.consistency_level
|
||||
else:
|
||||
|
@ -395,12 +348,7 @@ class MilvusVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtocolP
|
|||
await self.cache[vector_db_id].index.delete()
|
||||
del self.cache[vector_db_id]
|
||||
|
||||
async def insert_chunks(
|
||||
self,
|
||||
vector_db_id: str,
|
||||
chunks: list[Chunk],
|
||||
ttl_seconds: int | None = None,
|
||||
) -> None:
|
||||
async def insert_chunks(self, vector_db_id: str, chunks: list[Chunk], ttl_seconds: int | None = None) -> None:
|
||||
index = await self._get_and_cache_vector_db_index(vector_db_id)
|
||||
if not index:
|
||||
raise VectorStoreNotFoundError(vector_db_id)
|
||||
|
@ -408,10 +356,7 @@ class MilvusVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtocolP
|
|||
await index.insert_chunks(chunks)
|
||||
|
||||
async def query_chunks(
|
||||
self,
|
||||
vector_db_id: str,
|
||||
query: InterleavedContent,
|
||||
params: dict[str, Any] | None = None,
|
||||
self, vector_db_id: str, query: InterleavedContent, params: dict[str, Any] | None = None
|
||||
) -> QueryChunksResponse:
|
||||
index = await self._get_and_cache_vector_db_index(vector_db_id)
|
||||
if not index:
|
||||
|
|
|
@ -12,6 +12,6 @@ from .config import PGVectorVectorIOConfig
|
|||
async def get_adapter_impl(config: PGVectorVectorIOConfig, deps: dict[Api, ProviderSpec]):
|
||||
from .pgvector import PGVectorVectorIOAdapter
|
||||
|
||||
impl = PGVectorVectorIOAdapter(config, deps[Api.inference], deps[Api.models], deps.get(Api.files, None))
|
||||
impl = PGVectorVectorIOAdapter(config, deps[Api.inference], deps.get(Api.files))
|
||||
await impl.initialize()
|
||||
return impl
|
||||
|
|
|
@ -16,26 +16,15 @@ from pydantic import BaseModel, TypeAdapter
|
|||
from llama_stack.apis.common.errors import VectorStoreNotFoundError
|
||||
from llama_stack.apis.files import Files
|
||||
from llama_stack.apis.inference import Inference, InterleavedContent
|
||||
from llama_stack.apis.models import Models
|
||||
from llama_stack.apis.vector_dbs import VectorDB
|
||||
from llama_stack.apis.vector_io import (
|
||||
Chunk,
|
||||
QueryChunksResponse,
|
||||
VectorIO,
|
||||
)
|
||||
from llama_stack.apis.vector_io import Chunk, QueryChunksResponse, VectorIO
|
||||
from llama_stack.log import get_logger
|
||||
from llama_stack.providers.datatypes import VectorDBsProtocolPrivate
|
||||
from llama_stack.providers.utils.inference.prompt_adapter import (
|
||||
interleaved_content_as_str,
|
||||
)
|
||||
from llama_stack.providers.utils.inference.prompt_adapter import interleaved_content_as_str
|
||||
from llama_stack.providers.utils.kvstore import kvstore_impl
|
||||
from llama_stack.providers.utils.kvstore.api import KVStore
|
||||
from llama_stack.providers.utils.memory.openai_vector_store_mixin import OpenAIVectorStoreMixin
|
||||
from llama_stack.providers.utils.memory.vector_store import (
|
||||
ChunkForDeletion,
|
||||
EmbeddingIndex,
|
||||
VectorDBWithIndex,
|
||||
)
|
||||
from llama_stack.providers.utils.memory.vector_store import ChunkForDeletion, EmbeddingIndex, VectorDBWithIndex
|
||||
from llama_stack.providers.utils.vector_io.vector_utils import WeightedInMemoryAggregator, sanitize_collection_name
|
||||
|
||||
from .config import PGVectorVectorIOConfig
|
||||
|
@ -205,12 +194,7 @@ class PGVectorIndex(EmbeddingIndex):
|
|||
|
||||
return QueryChunksResponse(chunks=chunks, scores=scores)
|
||||
|
||||
async def query_keyword(
|
||||
self,
|
||||
query_string: str,
|
||||
k: int,
|
||||
score_threshold: float,
|
||||
) -> QueryChunksResponse:
|
||||
async def query_keyword(self, query_string: str, k: int, score_threshold: float) -> QueryChunksResponse:
|
||||
"""
|
||||
Performs keyword-based search using PostgreSQL's full-text search with ts_rank scoring.
|
||||
|
||||
|
@ -317,7 +301,7 @@ class PGVectorIndex(EmbeddingIndex):
|
|||
"""Remove a chunk from the PostgreSQL table."""
|
||||
chunk_ids = [c.chunk_id for c in chunks_for_deletion]
|
||||
with self.conn.cursor(cursor_factory=psycopg2.extras.DictCursor) as cur:
|
||||
cur.execute(f"DELETE FROM {self.table_name} WHERE id = ANY(%s)", (chunk_ids,))
|
||||
cur.execute(f"DELETE FROM {self.table_name} WHERE id = ANY(%s)", (chunk_ids))
|
||||
|
||||
def get_pgvector_search_function(self) -> str:
|
||||
return self.PGVECTOR_DISTANCE_METRIC_TO_SEARCH_FUNCTION[self.distance_metric]
|
||||
|
@ -341,16 +325,11 @@ class PGVectorIndex(EmbeddingIndex):
|
|||
|
||||
class PGVectorVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtocolPrivate):
|
||||
def __init__(
|
||||
self,
|
||||
config: PGVectorVectorIOConfig,
|
||||
inference_api: Inference,
|
||||
models_api: Models,
|
||||
files_api: Files | None = None,
|
||||
self, config: PGVectorVectorIOConfig, inference_api: Inference, files_api: Files | None = None
|
||||
) -> None:
|
||||
super().__init__(files_api=files_api, kvstore=None)
|
||||
self.config = config
|
||||
self.inference_api = inference_api
|
||||
self.models_api = models_api
|
||||
self.conn = None
|
||||
self.cache = {}
|
||||
self.vector_db_store = None
|
||||
|
@ -407,11 +386,7 @@ class PGVectorVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtoco
|
|||
vector_db=vector_db, dimension=vector_db.embedding_dimension, conn=self.conn, kvstore=self.kvstore
|
||||
)
|
||||
await pgvector_index.initialize()
|
||||
index = VectorDBWithIndex(
|
||||
vector_db,
|
||||
index=pgvector_index,
|
||||
inference_api=self.inference_api,
|
||||
)
|
||||
index = VectorDBWithIndex(vector_db, index=pgvector_index, inference_api=self.inference_api)
|
||||
self.cache[vector_db.identifier] = index
|
||||
|
||||
async def unregister_vector_db(self, vector_db_id: str) -> None:
|
||||
|
@ -424,20 +399,12 @@ class PGVectorVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtoco
|
|||
assert self.kvstore is not None
|
||||
await self.kvstore.delete(key=f"{VECTOR_DBS_PREFIX}{vector_db_id}")
|
||||
|
||||
async def insert_chunks(
|
||||
self,
|
||||
vector_db_id: str,
|
||||
chunks: list[Chunk],
|
||||
ttl_seconds: int | None = None,
|
||||
) -> None:
|
||||
async def insert_chunks(self, vector_db_id: str, chunks: list[Chunk], ttl_seconds: int | None = None) -> None:
|
||||
index = await self._get_and_cache_vector_db_index(vector_db_id)
|
||||
await index.insert_chunks(chunks)
|
||||
|
||||
async def query_chunks(
|
||||
self,
|
||||
vector_db_id: str,
|
||||
query: InterleavedContent,
|
||||
params: dict[str, Any] | None = None,
|
||||
self, vector_db_id: str, query: InterleavedContent, params: dict[str, Any] | None = None
|
||||
) -> QueryChunksResponse:
|
||||
index = await self._get_and_cache_vector_db_index(vector_db_id)
|
||||
return await index.query_chunks(query, params)
|
||||
|
|
|
@ -12,11 +12,6 @@ from .config import QdrantVectorIOConfig
|
|||
async def get_adapter_impl(config: QdrantVectorIOConfig, deps: dict[Api, ProviderSpec]):
|
||||
from .qdrant import QdrantVectorIOAdapter
|
||||
|
||||
impl = QdrantVectorIOAdapter(
|
||||
config,
|
||||
deps[Api.inference],
|
||||
deps[Api.models],
|
||||
deps.get(Api.files),
|
||||
)
|
||||
impl = QdrantVectorIOAdapter(config, deps[Api.inference], deps.get(Api.files))
|
||||
await impl.initialize()
|
||||
return impl
|
||||
|
|
|
@ -16,7 +16,6 @@ from qdrant_client.models import PointStruct
|
|||
from llama_stack.apis.common.errors import VectorStoreNotFoundError
|
||||
from llama_stack.apis.files import Files
|
||||
from llama_stack.apis.inference import Inference, InterleavedContent
|
||||
from llama_stack.apis.models import Models
|
||||
from llama_stack.apis.vector_dbs import VectorDB
|
||||
from llama_stack.apis.vector_io import (
|
||||
Chunk,
|
||||
|
@ -30,11 +29,7 @@ from llama_stack.providers.datatypes import VectorDBsProtocolPrivate
|
|||
from llama_stack.providers.inline.vector_io.qdrant import QdrantVectorIOConfig as InlineQdrantVectorIOConfig
|
||||
from llama_stack.providers.utils.kvstore import kvstore_impl
|
||||
from llama_stack.providers.utils.memory.openai_vector_store_mixin import OpenAIVectorStoreMixin
|
||||
from llama_stack.providers.utils.memory.vector_store import (
|
||||
ChunkForDeletion,
|
||||
EmbeddingIndex,
|
||||
VectorDBWithIndex,
|
||||
)
|
||||
from llama_stack.providers.utils.memory.vector_store import ChunkForDeletion, EmbeddingIndex, VectorDBWithIndex
|
||||
|
||||
from .config import QdrantVectorIOConfig as RemoteQdrantVectorIOConfig
|
||||
|
||||
|
@ -99,8 +94,7 @@ class QdrantIndex(EmbeddingIndex):
|
|||
chunk_ids = [convert_id(c.chunk_id) for c in chunks_for_deletion]
|
||||
try:
|
||||
await self.client.delete(
|
||||
collection_name=self.collection_name,
|
||||
points_selector=models.PointIdsList(points=chunk_ids),
|
||||
collection_name=self.collection_name, points_selector=models.PointIdsList(points=chunk_ids)
|
||||
)
|
||||
except Exception as e:
|
||||
log.error(f"Error deleting chunks from Qdrant collection {self.collection_name}: {e}")
|
||||
|
@ -133,12 +127,7 @@ class QdrantIndex(EmbeddingIndex):
|
|||
|
||||
return QueryChunksResponse(chunks=chunks, scores=scores)
|
||||
|
||||
async def query_keyword(
|
||||
self,
|
||||
query_string: str,
|
||||
k: int,
|
||||
score_threshold: float,
|
||||
) -> QueryChunksResponse:
|
||||
async def query_keyword(self, query_string: str, k: int, score_threshold: float) -> QueryChunksResponse:
|
||||
raise NotImplementedError("Keyword search is not supported in Qdrant")
|
||||
|
||||
async def query_hybrid(
|
||||
|
@ -161,7 +150,6 @@ class QdrantVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtocolP
|
|||
self,
|
||||
config: RemoteQdrantVectorIOConfig | InlineQdrantVectorIOConfig,
|
||||
inference_api: Inference,
|
||||
models_api: Models,
|
||||
files_api: Files | None = None,
|
||||
) -> None:
|
||||
super().__init__(files_api=files_api, kvstore=None)
|
||||
|
@ -169,7 +157,6 @@ class QdrantVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtocolP
|
|||
self.client: AsyncQdrantClient = None
|
||||
self.cache = {}
|
||||
self.inference_api = inference_api
|
||||
self.models_api = models_api
|
||||
self.vector_db_store = None
|
||||
self._qdrant_lock = asyncio.Lock()
|
||||
|
||||
|
@ -184,11 +171,7 @@ class QdrantVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtocolP
|
|||
|
||||
for vector_db_data in stored_vector_dbs:
|
||||
vector_db = VectorDB.model_validate_json(vector_db_data)
|
||||
index = VectorDBWithIndex(
|
||||
vector_db,
|
||||
QdrantIndex(self.client, vector_db.identifier),
|
||||
self.inference_api,
|
||||
)
|
||||
index = VectorDBWithIndex(vector_db, QdrantIndex(self.client, vector_db.identifier), self.inference_api)
|
||||
self.cache[vector_db.identifier] = index
|
||||
self.openai_vector_stores = await self._load_openai_vector_stores()
|
||||
|
||||
|
@ -197,18 +180,13 @@ class QdrantVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtocolP
|
|||
# Clean up mixin resources (file batch tasks)
|
||||
await super().shutdown()
|
||||
|
||||
async def register_vector_db(
|
||||
self,
|
||||
vector_db: VectorDB,
|
||||
) -> None:
|
||||
async def register_vector_db(self, vector_db: VectorDB) -> None:
|
||||
assert self.kvstore is not None
|
||||
key = f"{VECTOR_DBS_PREFIX}{vector_db.identifier}"
|
||||
await self.kvstore.set(key=key, value=vector_db.model_dump_json())
|
||||
|
||||
index = VectorDBWithIndex(
|
||||
vector_db=vector_db,
|
||||
index=QdrantIndex(self.client, vector_db.identifier),
|
||||
inference_api=self.inference_api,
|
||||
vector_db=vector_db, index=QdrantIndex(self.client, vector_db.identifier), inference_api=self.inference_api
|
||||
)
|
||||
|
||||
self.cache[vector_db.identifier] = index
|
||||
|
@ -240,12 +218,7 @@ class QdrantVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtocolP
|
|||
self.cache[vector_db_id] = index
|
||||
return index
|
||||
|
||||
async def insert_chunks(
|
||||
self,
|
||||
vector_db_id: str,
|
||||
chunks: list[Chunk],
|
||||
ttl_seconds: int | None = None,
|
||||
) -> None:
|
||||
async def insert_chunks(self, vector_db_id: str, chunks: list[Chunk], ttl_seconds: int | None = None) -> None:
|
||||
index = await self._get_and_cache_vector_db_index(vector_db_id)
|
||||
if not index:
|
||||
raise VectorStoreNotFoundError(vector_db_id)
|
||||
|
@ -253,10 +226,7 @@ class QdrantVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtocolP
|
|||
await index.insert_chunks(chunks)
|
||||
|
||||
async def query_chunks(
|
||||
self,
|
||||
vector_db_id: str,
|
||||
query: InterleavedContent,
|
||||
params: dict[str, Any] | None = None,
|
||||
self, vector_db_id: str, query: InterleavedContent, params: dict[str, Any] | None = None
|
||||
) -> QueryChunksResponse:
|
||||
index = await self._get_and_cache_vector_db_index(vector_db_id)
|
||||
if not index:
|
||||
|
|
|
@ -12,11 +12,6 @@ from .config import WeaviateVectorIOConfig
|
|||
async def get_adapter_impl(config: WeaviateVectorIOConfig, deps: dict[Api, ProviderSpec]):
|
||||
from .weaviate import WeaviateVectorIOAdapter
|
||||
|
||||
impl = WeaviateVectorIOAdapter(
|
||||
config,
|
||||
deps[Api.inference],
|
||||
deps[Api.models],
|
||||
deps.get(Api.files),
|
||||
)
|
||||
impl = WeaviateVectorIOAdapter(config, deps[Api.inference], deps.get(Api.files))
|
||||
await impl.initialize()
|
||||
return impl
|
||||
|
|
|
@ -21,11 +21,7 @@ class WeaviateVectorIOConfig(BaseModel):
|
|||
)
|
||||
|
||||
@classmethod
|
||||
def sample_run_config(
|
||||
cls,
|
||||
__distro_dir__: str,
|
||||
**kwargs: Any,
|
||||
) -> dict[str, Any]:
|
||||
def sample_run_config(cls, __distro_dir__: str, **kwargs: Any) -> dict[str, Any]:
|
||||
return {
|
||||
"weaviate_api_key": None,
|
||||
"weaviate_cluster_url": "${env.WEAVIATE_CLUSTER_URL:=localhost:8080}",
|
||||
|
|
|
@ -16,7 +16,6 @@ from llama_stack.apis.common.content_types import InterleavedContent
|
|||
from llama_stack.apis.common.errors import VectorStoreNotFoundError
|
||||
from llama_stack.apis.files import Files
|
||||
from llama_stack.apis.inference import Inference
|
||||
from llama_stack.apis.models import Models
|
||||
from llama_stack.apis.vector_dbs import VectorDB
|
||||
from llama_stack.apis.vector_io import Chunk, QueryChunksResponse, VectorIO
|
||||
from llama_stack.core.request_headers import NeedsRequestProviderData
|
||||
|
@ -24,9 +23,7 @@ from llama_stack.log import get_logger
|
|||
from llama_stack.providers.datatypes import VectorDBsProtocolPrivate
|
||||
from llama_stack.providers.utils.kvstore import kvstore_impl
|
||||
from llama_stack.providers.utils.kvstore.api import KVStore
|
||||
from llama_stack.providers.utils.memory.openai_vector_store_mixin import (
|
||||
OpenAIVectorStoreMixin,
|
||||
)
|
||||
from llama_stack.providers.utils.memory.openai_vector_store_mixin import OpenAIVectorStoreMixin
|
||||
from llama_stack.providers.utils.memory.vector_store import (
|
||||
RERANKER_TYPE_RRF,
|
||||
ChunkForDeletion,
|
||||
|
@ -48,12 +45,7 @@ OPENAI_VECTOR_STORES_FILES_CONTENTS_PREFIX = f"openai_vector_stores_files_conten
|
|||
|
||||
|
||||
class WeaviateIndex(EmbeddingIndex):
|
||||
def __init__(
|
||||
self,
|
||||
client: weaviate.WeaviateClient,
|
||||
collection_name: str,
|
||||
kvstore: KVStore | None = None,
|
||||
):
|
||||
def __init__(self, client: weaviate.WeaviateClient, collection_name: str, kvstore: KVStore | None = None):
|
||||
self.client = client
|
||||
self.collection_name = sanitize_collection_name(collection_name, weaviate_format=True)
|
||||
self.kvstore = kvstore
|
||||
|
@ -108,9 +100,7 @@ class WeaviateIndex(EmbeddingIndex):
|
|||
|
||||
try:
|
||||
results = collection.query.near_vector(
|
||||
near_vector=embedding.tolist(),
|
||||
limit=k,
|
||||
return_metadata=wvc.query.MetadataQuery(distance=True),
|
||||
near_vector=embedding.tolist(), limit=k, return_metadata=wvc.query.MetadataQuery(distance=True)
|
||||
)
|
||||
except Exception as e:
|
||||
log.error(f"Weaviate client vector search failed: {e}")
|
||||
|
@ -153,12 +143,7 @@ class WeaviateIndex(EmbeddingIndex):
|
|||
collection = self.client.collections.get(sanitized_collection_name)
|
||||
collection.data.delete_many(where=Filter.by_property("id").contains_any(chunk_ids))
|
||||
|
||||
async def query_keyword(
|
||||
self,
|
||||
query_string: str,
|
||||
k: int,
|
||||
score_threshold: float,
|
||||
) -> QueryChunksResponse:
|
||||
async def query_keyword(self, query_string: str, k: int, score_threshold: float) -> QueryChunksResponse:
|
||||
"""
|
||||
Performs BM25-based keyword search using Weaviate's built-in full-text search.
|
||||
Args:
|
||||
|
@ -175,9 +160,7 @@ class WeaviateIndex(EmbeddingIndex):
|
|||
# Perform BM25 keyword search on chunk_content field
|
||||
try:
|
||||
results = collection.query.bm25(
|
||||
query=query_string,
|
||||
limit=k,
|
||||
return_metadata=wvc.query.MetadataQuery(score=True),
|
||||
query=query_string, limit=k, return_metadata=wvc.query.MetadataQuery(score=True)
|
||||
)
|
||||
except Exception as e:
|
||||
log.error(f"Weaviate client keyword search failed: {e}")
|
||||
|
@ -274,23 +257,11 @@ class WeaviateIndex(EmbeddingIndex):
|
|||
return QueryChunksResponse(chunks=chunks, scores=scores)
|
||||
|
||||
|
||||
class WeaviateVectorIOAdapter(
|
||||
OpenAIVectorStoreMixin,
|
||||
VectorIO,
|
||||
NeedsRequestProviderData,
|
||||
VectorDBsProtocolPrivate,
|
||||
):
|
||||
def __init__(
|
||||
self,
|
||||
config: WeaviateVectorIOConfig,
|
||||
inference_api: Inference,
|
||||
models_api: Models,
|
||||
files_api: Files | None,
|
||||
) -> None:
|
||||
class WeaviateVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, NeedsRequestProviderData, VectorDBsProtocolPrivate):
|
||||
def __init__(self, config: WeaviateVectorIOConfig, inference_api: Inference, files_api: Files | None) -> None:
|
||||
super().__init__(files_api=files_api, kvstore=None)
|
||||
self.config = config
|
||||
self.inference_api = inference_api
|
||||
self.models_api = models_api
|
||||
self.client_cache = {}
|
||||
self.cache = {}
|
||||
self.vector_db_store = None
|
||||
|
@ -301,10 +272,7 @@ class WeaviateVectorIOAdapter(
|
|||
log.info("Using Weaviate locally in container")
|
||||
host, port = self.config.weaviate_cluster_url.split(":")
|
||||
key = "local_test"
|
||||
client = weaviate.connect_to_local(
|
||||
host=host,
|
||||
port=port,
|
||||
)
|
||||
client = weaviate.connect_to_local(host=host, port=port)
|
||||
else:
|
||||
log.info("Using Weaviate remote cluster with URL")
|
||||
key = f"{self.config.weaviate_cluster_url}::{self.config.weaviate_api_key}"
|
||||
|
@ -334,15 +302,9 @@ class WeaviateVectorIOAdapter(
|
|||
for raw in stored:
|
||||
vector_db = VectorDB.model_validate_json(raw)
|
||||
client = self._get_client()
|
||||
idx = WeaviateIndex(
|
||||
client=client,
|
||||
collection_name=vector_db.identifier,
|
||||
kvstore=self.kvstore,
|
||||
)
|
||||
idx = WeaviateIndex(client=client, collection_name=vector_db.identifier, kvstore=self.kvstore)
|
||||
self.cache[vector_db.identifier] = VectorDBWithIndex(
|
||||
vector_db=vector_db,
|
||||
index=idx,
|
||||
inference_api=self.inference_api,
|
||||
vector_db=vector_db, index=idx, inference_api=self.inference_api
|
||||
)
|
||||
|
||||
# Load OpenAI vector stores metadata into cache
|
||||
|
@ -354,10 +316,7 @@ class WeaviateVectorIOAdapter(
|
|||
# Clean up mixin resources (file batch tasks)
|
||||
await super().shutdown()
|
||||
|
||||
async def register_vector_db(
|
||||
self,
|
||||
vector_db: VectorDB,
|
||||
) -> None:
|
||||
async def register_vector_db(self, vector_db: VectorDB) -> None:
|
||||
client = self._get_client()
|
||||
sanitized_collection_name = sanitize_collection_name(vector_db.identifier, weaviate_format=True)
|
||||
# Create collection if it doesn't exist
|
||||
|
@ -366,17 +325,12 @@ class WeaviateVectorIOAdapter(
|
|||
name=sanitized_collection_name,
|
||||
vectorizer_config=wvc.config.Configure.Vectorizer.none(),
|
||||
properties=[
|
||||
wvc.config.Property(
|
||||
name="chunk_content",
|
||||
data_type=wvc.config.DataType.TEXT,
|
||||
),
|
||||
wvc.config.Property(name="chunk_content", data_type=wvc.config.DataType.TEXT),
|
||||
],
|
||||
)
|
||||
|
||||
self.cache[vector_db.identifier] = VectorDBWithIndex(
|
||||
vector_db,
|
||||
WeaviateIndex(client=client, collection_name=sanitized_collection_name),
|
||||
self.inference_api,
|
||||
vector_db, WeaviateIndex(client=client, collection_name=sanitized_collection_name), self.inference_api
|
||||
)
|
||||
|
||||
async def unregister_vector_db(self, vector_db_id: str) -> None:
|
||||
|
@ -412,12 +366,7 @@ class WeaviateVectorIOAdapter(
|
|||
self.cache[vector_db_id] = index
|
||||
return index
|
||||
|
||||
async def insert_chunks(
|
||||
self,
|
||||
vector_db_id: str,
|
||||
chunks: list[Chunk],
|
||||
ttl_seconds: int | None = None,
|
||||
) -> None:
|
||||
async def insert_chunks(self, vector_db_id: str, chunks: list[Chunk], ttl_seconds: int | None = None) -> None:
|
||||
index = await self._get_and_cache_vector_db_index(vector_db_id)
|
||||
if not index:
|
||||
raise VectorStoreNotFoundError(vector_db_id)
|
||||
|
@ -425,10 +374,7 @@ class WeaviateVectorIOAdapter(
|
|||
await index.insert_chunks(chunks)
|
||||
|
||||
async def query_chunks(
|
||||
self,
|
||||
vector_db_id: str,
|
||||
query: InterleavedContent,
|
||||
params: dict[str, Any] | None = None,
|
||||
self, vector_db_id: str, query: InterleavedContent, params: dict[str, Any] | None = None
|
||||
) -> QueryChunksResponse:
|
||||
index = await self._get_and_cache_vector_db_index(vector_db_id)
|
||||
if not index:
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue