update resolver to only pass vector_stores section of run config

Signed-off-by: Francisco Javier Arceo <farceo@redhat.com>

Using Router only from VectorDBs

Signed-off-by: Francisco Javier Arceo <farceo@redhat.com>

removing model_api from vector store providers

Signed-off-by: Francisco Javier Arceo <farceo@redhat.com>

fix test

Signed-off-by: Francisco Javier Arceo <farceo@redhat.com>

updating integration tests

Signed-off-by: Francisco Javier Arceo <farceo@redhat.com>

special handling for replay mode for available providers

Signed-off-by: Francisco Javier Arceo <farceo@redhat.com>
This commit is contained in:
Francisco Javier Arceo 2025-10-16 10:59:01 -04:00
parent 24a1430c8b
commit accc4c437e
46 changed files with 397 additions and 702 deletions

View file

@ -4,27 +4,14 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from llama_stack.core.datatypes import StackRunConfig
from llama_stack.providers.datatypes import Api, ProviderSpec
from .config import ChromaVectorIOConfig
async def get_adapter_impl(
config: ChromaVectorIOConfig, deps: dict[Api, ProviderSpec], run_config: StackRunConfig | None = None
):
async def get_adapter_impl(config: ChromaVectorIOConfig, deps: dict[Api, ProviderSpec]):
from .chroma import ChromaVectorIOAdapter
vector_stores_config = None
if run_config and run_config.vector_stores:
vector_stores_config = run_config.vector_stores
impl = ChromaVectorIOAdapter(
config,
deps[Api.inference],
deps[Api.models],
deps.get(Api.files),
vector_stores_config,
)
impl = ChromaVectorIOAdapter(config, deps[Api.inference], deps.get(Api.files))
await impl.initialize()
return impl

View file

@ -13,25 +13,15 @@ from numpy.typing import NDArray
from llama_stack.apis.files import Files
from llama_stack.apis.inference import Inference, InterleavedContent
from llama_stack.apis.models import Models
from llama_stack.apis.vector_dbs import VectorDB
from llama_stack.apis.vector_io import (
Chunk,
QueryChunksResponse,
VectorIO,
)
from llama_stack.core.datatypes import VectorStoresConfig
from llama_stack.apis.vector_io import Chunk, QueryChunksResponse, VectorIO
from llama_stack.log import get_logger
from llama_stack.providers.datatypes import VectorDBsProtocolPrivate
from llama_stack.providers.inline.vector_io.chroma import ChromaVectorIOConfig as InlineChromaVectorIOConfig
from llama_stack.providers.utils.kvstore import kvstore_impl
from llama_stack.providers.utils.kvstore.api import KVStore
from llama_stack.providers.utils.memory.openai_vector_store_mixin import OpenAIVectorStoreMixin
from llama_stack.providers.utils.memory.vector_store import (
ChunkForDeletion,
EmbeddingIndex,
VectorDBWithIndex,
)
from llama_stack.providers.utils.memory.vector_store import ChunkForDeletion, EmbeddingIndex, VectorDBWithIndex
from .config import ChromaVectorIOConfig as RemoteChromaVectorIOConfig
@ -70,19 +60,13 @@ class ChromaIndex(EmbeddingIndex):
ids = [f"{c.metadata.get('document_id', '')}:{c.chunk_id}" for c in chunks]
await maybe_await(
self.collection.add(
documents=[chunk.model_dump_json() for chunk in chunks],
embeddings=embeddings,
ids=ids,
)
self.collection.add(documents=[chunk.model_dump_json() for chunk in chunks], embeddings=embeddings, ids=ids)
)
async def query_vector(self, embedding: NDArray, k: int, score_threshold: float) -> QueryChunksResponse:
results = await maybe_await(
self.collection.query(
query_embeddings=[embedding.tolist()],
n_results=k,
include=["documents", "distances"],
query_embeddings=[embedding.tolist()], n_results=k, include=["documents", "distances"]
)
)
distances = results["distances"][0]
@ -110,12 +94,7 @@ class ChromaIndex(EmbeddingIndex):
async def delete(self):
await maybe_await(self.client.delete_collection(self.collection.name))
async def query_keyword(
self,
query_string: str,
k: int,
score_threshold: float,
) -> QueryChunksResponse:
async def query_keyword(self, query_string: str, k: int, score_threshold: float) -> QueryChunksResponse:
raise NotImplementedError("Keyword search is not supported in Chroma")
async def delete_chunks(self, chunks_for_deletion: list[ChunkForDeletion]) -> None:
@ -140,16 +119,12 @@ class ChromaVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtocolP
self,
config: RemoteChromaVectorIOConfig | InlineChromaVectorIOConfig,
inference_api: Inference,
models_apis: Models,
files_api: Files | None,
vector_stores_config: VectorStoresConfig | None = None,
) -> None:
super().__init__(files_api=files_api, kvstore=None)
log.info(f"Initializing ChromaVectorIOAdapter with url: {config}")
self.config = config
self.inference_api = inference_api
self.models_api = models_apis
self.vector_stores_config = vector_stores_config
self.client = None
self.cache = {}
self.vector_db_store = None
@ -176,14 +151,10 @@ class ChromaVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtocolP
# Clean up mixin resources (file batch tasks)
await super().shutdown()
async def register_vector_db(
self,
vector_db: VectorDB,
) -> None:
async def register_vector_db(self, vector_db: VectorDB) -> None:
collection = await maybe_await(
self.client.get_or_create_collection(
name=vector_db.identifier,
metadata={"vector_db": vector_db.model_dump_json()},
name=vector_db.identifier, metadata={"vector_db": vector_db.model_dump_json()}
)
)
self.cache[vector_db.identifier] = VectorDBWithIndex(
@ -198,12 +169,7 @@ class ChromaVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtocolP
await self.cache[vector_db_id].index.delete()
del self.cache[vector_db_id]
async def insert_chunks(
self,
vector_db_id: str,
chunks: list[Chunk],
ttl_seconds: int | None = None,
) -> None:
async def insert_chunks(self, vector_db_id: str, chunks: list[Chunk], ttl_seconds: int | None = None) -> None:
index = await self._get_and_cache_vector_db_index(vector_db_id)
if index is None:
raise ValueError(f"Vector DB {vector_db_id} not found in Chroma")
@ -211,10 +177,7 @@ class ChromaVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtocolP
await index.insert_chunks(chunks)
async def query_chunks(
self,
vector_db_id: str,
query: InterleavedContent,
params: dict[str, Any] | None = None,
self, vector_db_id: str, query: InterleavedContent, params: dict[str, Any] | None = None
) -> QueryChunksResponse:
index = await self._get_and_cache_vector_db_index(vector_db_id)

View file

@ -22,7 +22,6 @@ class ChromaVectorIOConfig(BaseModel):
return {
"url": url,
"kvstore": SqliteKVStoreConfig.sample_run_config(
__distro_dir__=__distro_dir__,
db_name="chroma_remote_registry.db",
__distro_dir__=__distro_dir__, db_name="chroma_remote_registry.db"
),
}

View file

@ -4,28 +4,15 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from llama_stack.core.datatypes import StackRunConfig
from llama_stack.providers.datatypes import Api, ProviderSpec
from .config import MilvusVectorIOConfig
async def get_adapter_impl(
config: MilvusVectorIOConfig, deps: dict[Api, ProviderSpec], run_config: StackRunConfig | None = None
):
async def get_adapter_impl(config: MilvusVectorIOConfig, deps: dict[Api, ProviderSpec]):
from .milvus import MilvusVectorIOAdapter
vector_stores_config = None
if run_config and run_config.vector_stores:
vector_stores_config = run_config.vector_stores
assert isinstance(config, MilvusVectorIOConfig), f"Unexpected config type: {type(config)}"
impl = MilvusVectorIOAdapter(
config,
deps[Api.inference],
deps[Api.models],
deps.get(Api.files),
vector_stores_config,
)
impl = MilvusVectorIOAdapter(config, deps[Api.inference], deps.get(Api.files))
await impl.initialize()
return impl

View file

@ -29,7 +29,6 @@ class MilvusVectorIOConfig(BaseModel):
"uri": "${env.MILVUS_ENDPOINT}",
"token": "${env.MILVUS_TOKEN}",
"kvstore": SqliteKVStoreConfig.sample_run_config(
__distro_dir__=__distro_dir__,
db_name="milvus_remote_registry.db",
__distro_dir__=__distro_dir__, db_name="milvus_remote_registry.db"
),
}

View file

@ -14,14 +14,8 @@ from pymilvus import AnnSearchRequest, DataType, Function, FunctionType, MilvusC
from llama_stack.apis.common.errors import VectorStoreNotFoundError
from llama_stack.apis.files import Files
from llama_stack.apis.inference import Inference, InterleavedContent
from llama_stack.apis.models import Models
from llama_stack.apis.vector_dbs import VectorDB
from llama_stack.apis.vector_io import (
Chunk,
QueryChunksResponse,
VectorIO,
)
from llama_stack.core.datatypes import VectorStoresConfig
from llama_stack.apis.vector_io import Chunk, QueryChunksResponse, VectorIO
from llama_stack.log import get_logger
from llama_stack.providers.datatypes import VectorDBsProtocolPrivate
from llama_stack.providers.inline.vector_io.milvus import MilvusVectorIOConfig as InlineMilvusVectorIOConfig
@ -75,46 +69,23 @@ class MilvusIndex(EmbeddingIndex):
logger.info(f"Creating new collection {self.collection_name} with nullable sparse field")
# Create schema for vector search
schema = self.client.create_schema()
schema.add_field(
field_name="chunk_id",
datatype=DataType.VARCHAR,
is_primary=True,
max_length=100,
)
schema.add_field(field_name="chunk_id", datatype=DataType.VARCHAR, is_primary=True, max_length=100)
schema.add_field(
field_name="content",
datatype=DataType.VARCHAR,
max_length=65535,
enable_analyzer=True, # Enable text analysis for BM25
)
schema.add_field(
field_name="vector",
datatype=DataType.FLOAT_VECTOR,
dim=len(embeddings[0]),
)
schema.add_field(
field_name="chunk_content",
datatype=DataType.JSON,
)
schema.add_field(field_name="vector", datatype=DataType.FLOAT_VECTOR, dim=len(embeddings[0]))
schema.add_field(field_name="chunk_content", datatype=DataType.JSON)
# Add sparse vector field for BM25 (required by the function)
schema.add_field(
field_name="sparse",
datatype=DataType.SPARSE_FLOAT_VECTOR,
)
schema.add_field(field_name="sparse", datatype=DataType.SPARSE_FLOAT_VECTOR)
# Create indexes
index_params = self.client.prepare_index_params()
index_params.add_index(
field_name="vector",
index_type="FLAT",
metric_type="COSINE",
)
index_params.add_index(field_name="vector", index_type="FLAT", metric_type="COSINE")
# Add index for sparse field (required by BM25 function)
index_params.add_index(
field_name="sparse",
index_type="SPARSE_INVERTED_INDEX",
metric_type="BM25",
)
index_params.add_index(field_name="sparse", index_type="SPARSE_INVERTED_INDEX", metric_type="BM25")
# Add BM25 function for full-text search
bm25_function = Function(
@ -145,11 +116,7 @@ class MilvusIndex(EmbeddingIndex):
}
)
try:
await asyncio.to_thread(
self.client.insert,
self.collection_name,
data=data,
)
await asyncio.to_thread(self.client.insert, self.collection_name, data=data)
except Exception as e:
logger.error(f"Error inserting chunks into Milvus collection {self.collection_name}: {e}")
raise e
@ -168,12 +135,7 @@ class MilvusIndex(EmbeddingIndex):
scores = [res["distance"] for res in search_res[0]]
return QueryChunksResponse(chunks=chunks, scores=scores)
async def query_keyword(
self,
query_string: str,
k: int,
score_threshold: float,
) -> QueryChunksResponse:
async def query_keyword(self, query_string: str, k: int, score_threshold: float) -> QueryChunksResponse:
"""
Perform BM25-based keyword search using Milvus's built-in full-text search.
"""
@ -211,12 +173,7 @@ class MilvusIndex(EmbeddingIndex):
# Fallback to simple text search
return await self._fallback_keyword_search(query_string, k, score_threshold)
async def _fallback_keyword_search(
self,
query_string: str,
k: int,
score_threshold: float,
) -> QueryChunksResponse:
async def _fallback_keyword_search(self, query_string: str, k: int, score_threshold: float) -> QueryChunksResponse:
"""
Fallback to simple text search when BM25 search is not available.
"""
@ -309,17 +266,13 @@ class MilvusVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtocolP
self,
config: RemoteMilvusVectorIOConfig | InlineMilvusVectorIOConfig,
inference_api: Inference,
models_api: Models | None,
files_api: Files | None,
vector_stores_config: VectorStoresConfig | None = None,
) -> None:
super().__init__(files_api=files_api, kvstore=None)
self.config = config
self.cache = {}
self.client = None
self.inference_api = inference_api
self.models_api = models_api
self.vector_stores_config = vector_stores_config
self.vector_db_store = None
self.metadata_collection_name = "openai_vector_stores_metadata"
@ -358,10 +311,7 @@ class MilvusVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtocolP
# Clean up mixin resources (file batch tasks)
await super().shutdown()
async def register_vector_db(
self,
vector_db: VectorDB,
) -> None:
async def register_vector_db(self, vector_db: VectorDB) -> None:
if isinstance(self.config, RemoteMilvusVectorIOConfig):
consistency_level = self.config.consistency_level
else:
@ -398,12 +348,7 @@ class MilvusVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtocolP
await self.cache[vector_db_id].index.delete()
del self.cache[vector_db_id]
async def insert_chunks(
self,
vector_db_id: str,
chunks: list[Chunk],
ttl_seconds: int | None = None,
) -> None:
async def insert_chunks(self, vector_db_id: str, chunks: list[Chunk], ttl_seconds: int | None = None) -> None:
index = await self._get_and_cache_vector_db_index(vector_db_id)
if not index:
raise VectorStoreNotFoundError(vector_db_id)
@ -411,10 +356,7 @@ class MilvusVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtocolP
await index.insert_chunks(chunks)
async def query_chunks(
self,
vector_db_id: str,
query: InterleavedContent,
params: dict[str, Any] | None = None,
self, vector_db_id: str, query: InterleavedContent, params: dict[str, Any] | None = None
) -> QueryChunksResponse:
index = await self._get_and_cache_vector_db_index(vector_db_id)
if not index:

View file

@ -4,26 +4,14 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from llama_stack.core.datatypes import StackRunConfig
from llama_stack.providers.datatypes import Api, ProviderSpec
from .config import PGVectorVectorIOConfig
async def get_adapter_impl(
config: PGVectorVectorIOConfig, deps: dict[Api, ProviderSpec], run_config: StackRunConfig | None = None
):
async def get_adapter_impl(config: PGVectorVectorIOConfig, deps: dict[Api, ProviderSpec]):
from .pgvector import PGVectorVectorIOAdapter
vector_stores_config = None
if run_config and run_config.vector_stores:
vector_stores_config = run_config.vector_stores
impl = PGVectorVectorIOAdapter(
config,
deps[Api.inference],
deps[Api.models],
deps.get(Api.files),
vector_stores_config,
)
impl = PGVectorVectorIOAdapter(config, deps[Api.inference], deps.get(Api.files))
await impl.initialize()
return impl

View file

@ -8,10 +8,7 @@ from typing import Any
from pydantic import BaseModel, Field
from llama_stack.providers.utils.kvstore.config import (
KVStoreConfig,
SqliteKVStoreConfig,
)
from llama_stack.providers.utils.kvstore.config import KVStoreConfig, SqliteKVStoreConfig
from llama_stack.schema_utils import json_schema_type
@ -42,7 +39,6 @@ class PGVectorVectorIOConfig(BaseModel):
"user": user,
"password": password,
"kvstore": SqliteKVStoreConfig.sample_run_config(
__distro_dir__=__distro_dir__,
db_name="pgvector_registry.db",
__distro_dir__=__distro_dir__, db_name="pgvector_registry.db"
),
}

View file

@ -16,27 +16,15 @@ from pydantic import BaseModel, TypeAdapter
from llama_stack.apis.common.errors import VectorStoreNotFoundError
from llama_stack.apis.files import Files
from llama_stack.apis.inference import Inference, InterleavedContent
from llama_stack.apis.models import Models
from llama_stack.apis.vector_dbs import VectorDB
from llama_stack.apis.vector_io import (
Chunk,
QueryChunksResponse,
VectorIO,
)
from llama_stack.core.datatypes import VectorStoresConfig
from llama_stack.apis.vector_io import Chunk, QueryChunksResponse, VectorIO
from llama_stack.log import get_logger
from llama_stack.providers.datatypes import VectorDBsProtocolPrivate
from llama_stack.providers.utils.inference.prompt_adapter import (
interleaved_content_as_str,
)
from llama_stack.providers.utils.inference.prompt_adapter import interleaved_content_as_str
from llama_stack.providers.utils.kvstore import kvstore_impl
from llama_stack.providers.utils.kvstore.api import KVStore
from llama_stack.providers.utils.memory.openai_vector_store_mixin import OpenAIVectorStoreMixin
from llama_stack.providers.utils.memory.vector_store import (
ChunkForDeletion,
EmbeddingIndex,
VectorDBWithIndex,
)
from llama_stack.providers.utils.memory.vector_store import ChunkForDeletion, EmbeddingIndex, VectorDBWithIndex
from llama_stack.providers.utils.vector_io.vector_utils import WeightedInMemoryAggregator, sanitize_collection_name
from .config import PGVectorVectorIOConfig
@ -206,12 +194,7 @@ class PGVectorIndex(EmbeddingIndex):
return QueryChunksResponse(chunks=chunks, scores=scores)
async def query_keyword(
self,
query_string: str,
k: int,
score_threshold: float,
) -> QueryChunksResponse:
async def query_keyword(self, query_string: str, k: int, score_threshold: float) -> QueryChunksResponse:
"""
Performs keyword-based search using PostgreSQL's full-text search with ts_rank scoring.
@ -318,7 +301,7 @@ class PGVectorIndex(EmbeddingIndex):
"""Remove a chunk from the PostgreSQL table."""
chunk_ids = [c.chunk_id for c in chunks_for_deletion]
with self.conn.cursor(cursor_factory=psycopg2.extras.DictCursor) as cur:
cur.execute(f"DELETE FROM {self.table_name} WHERE id = ANY(%s)", (chunk_ids,))
cur.execute(f"DELETE FROM {self.table_name} WHERE id = ANY(%s)", (chunk_ids))
def get_pgvector_search_function(self) -> str:
return self.PGVECTOR_DISTANCE_METRIC_TO_SEARCH_FUNCTION[self.distance_metric]
@ -342,18 +325,11 @@ class PGVectorIndex(EmbeddingIndex):
class PGVectorVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtocolPrivate):
def __init__(
self,
config: PGVectorVectorIOConfig,
inference_api: Inference,
models_api: Models,
files_api: Files | None = None,
vector_stores_config: VectorStoresConfig | None = None,
self, config: PGVectorVectorIOConfig, inference_api: Inference, files_api: Files | None = None
) -> None:
super().__init__(files_api=files_api, kvstore=None)
self.config = config
self.inference_api = inference_api
self.models_api = models_api
self.vector_stores_config = vector_stores_config
self.conn = None
self.cache = {}
self.vector_db_store = None
@ -410,11 +386,7 @@ class PGVectorVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtoco
vector_db=vector_db, dimension=vector_db.embedding_dimension, conn=self.conn, kvstore=self.kvstore
)
await pgvector_index.initialize()
index = VectorDBWithIndex(
vector_db,
index=pgvector_index,
inference_api=self.inference_api,
)
index = VectorDBWithIndex(vector_db, index=pgvector_index, inference_api=self.inference_api)
self.cache[vector_db.identifier] = index
async def unregister_vector_db(self, vector_db_id: str) -> None:
@ -427,20 +399,12 @@ class PGVectorVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtoco
assert self.kvstore is not None
await self.kvstore.delete(key=f"{VECTOR_DBS_PREFIX}{vector_db_id}")
async def insert_chunks(
self,
vector_db_id: str,
chunks: list[Chunk],
ttl_seconds: int | None = None,
) -> None:
async def insert_chunks(self, vector_db_id: str, chunks: list[Chunk], ttl_seconds: int | None = None) -> None:
index = await self._get_and_cache_vector_db_index(vector_db_id)
await index.insert_chunks(chunks)
async def query_chunks(
self,
vector_db_id: str,
query: InterleavedContent,
params: dict[str, Any] | None = None,
self, vector_db_id: str, query: InterleavedContent, params: dict[str, Any] | None = None
) -> QueryChunksResponse:
index = await self._get_and_cache_vector_db_index(vector_db_id)
return await index.query_chunks(query, params)

View file

@ -4,27 +4,14 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from llama_stack.core.datatypes import StackRunConfig
from llama_stack.providers.datatypes import Api, ProviderSpec
from .config import QdrantVectorIOConfig
async def get_adapter_impl(
config: QdrantVectorIOConfig, deps: dict[Api, ProviderSpec], run_config: StackRunConfig | None = None
):
async def get_adapter_impl(config: QdrantVectorIOConfig, deps: dict[Api, ProviderSpec]):
from .qdrant import QdrantVectorIOAdapter
vector_stores_config = None
if run_config and run_config.vector_stores:
vector_stores_config = run_config.vector_stores
impl = QdrantVectorIOAdapter(
config,
deps[Api.inference],
deps[Api.models],
deps.get(Api.files),
vector_stores_config,
)
impl = QdrantVectorIOAdapter(config, deps[Api.inference], deps.get(Api.files))
await impl.initialize()
return impl

View file

@ -8,10 +8,7 @@ from typing import Any
from pydantic import BaseModel
from llama_stack.providers.utils.kvstore.config import (
KVStoreConfig,
SqliteKVStoreConfig,
)
from llama_stack.providers.utils.kvstore.config import KVStoreConfig, SqliteKVStoreConfig
from llama_stack.schema_utils import json_schema_type
@ -34,7 +31,6 @@ class QdrantVectorIOConfig(BaseModel):
return {
"api_key": "${env.QDRANT_API_KEY:=}",
"kvstore": SqliteKVStoreConfig.sample_run_config(
__distro_dir__=__distro_dir__,
db_name="qdrant_registry.db",
__distro_dir__=__distro_dir__, db_name="qdrant_registry.db"
),
}

View file

@ -16,7 +16,6 @@ from qdrant_client.models import PointStruct
from llama_stack.apis.common.errors import VectorStoreNotFoundError
from llama_stack.apis.files import Files
from llama_stack.apis.inference import Inference, InterleavedContent
from llama_stack.apis.models import Models
from llama_stack.apis.vector_dbs import VectorDB
from llama_stack.apis.vector_io import (
Chunk,
@ -25,17 +24,12 @@ from llama_stack.apis.vector_io import (
VectorStoreChunkingStrategy,
VectorStoreFileObject,
)
from llama_stack.core.datatypes import VectorStoresConfig
from llama_stack.log import get_logger
from llama_stack.providers.datatypes import VectorDBsProtocolPrivate
from llama_stack.providers.inline.vector_io.qdrant import QdrantVectorIOConfig as InlineQdrantVectorIOConfig
from llama_stack.providers.utils.kvstore import kvstore_impl
from llama_stack.providers.utils.memory.openai_vector_store_mixin import OpenAIVectorStoreMixin
from llama_stack.providers.utils.memory.vector_store import (
ChunkForDeletion,
EmbeddingIndex,
VectorDBWithIndex,
)
from llama_stack.providers.utils.memory.vector_store import ChunkForDeletion, EmbeddingIndex, VectorDBWithIndex
from .config import QdrantVectorIOConfig as RemoteQdrantVectorIOConfig
@ -100,8 +94,7 @@ class QdrantIndex(EmbeddingIndex):
chunk_ids = [convert_id(c.chunk_id) for c in chunks_for_deletion]
try:
await self.client.delete(
collection_name=self.collection_name,
points_selector=models.PointIdsList(points=chunk_ids),
collection_name=self.collection_name, points_selector=models.PointIdsList(points=chunk_ids)
)
except Exception as e:
log.error(f"Error deleting chunks from Qdrant collection {self.collection_name}: {e}")
@ -134,12 +127,7 @@ class QdrantIndex(EmbeddingIndex):
return QueryChunksResponse(chunks=chunks, scores=scores)
async def query_keyword(
self,
query_string: str,
k: int,
score_threshold: float,
) -> QueryChunksResponse:
async def query_keyword(self, query_string: str, k: int, score_threshold: float) -> QueryChunksResponse:
raise NotImplementedError("Keyword search is not supported in Qdrant")
async def query_hybrid(
@ -162,17 +150,13 @@ class QdrantVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtocolP
self,
config: RemoteQdrantVectorIOConfig | InlineQdrantVectorIOConfig,
inference_api: Inference,
models_api: Models,
files_api: Files | None = None,
vector_stores_config: VectorStoresConfig | None = None,
) -> None:
super().__init__(files_api=files_api, kvstore=None)
self.config = config
self.client: AsyncQdrantClient = None
self.cache = {}
self.inference_api = inference_api
self.models_api = models_api
self.vector_stores_config = vector_stores_config
self.vector_db_store = None
self._qdrant_lock = asyncio.Lock()
@ -187,11 +171,7 @@ class QdrantVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtocolP
for vector_db_data in stored_vector_dbs:
vector_db = VectorDB.model_validate_json(vector_db_data)
index = VectorDBWithIndex(
vector_db,
QdrantIndex(self.client, vector_db.identifier),
self.inference_api,
)
index = VectorDBWithIndex(vector_db, QdrantIndex(self.client, vector_db.identifier), self.inference_api)
self.cache[vector_db.identifier] = index
self.openai_vector_stores = await self._load_openai_vector_stores()
@ -200,18 +180,13 @@ class QdrantVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtocolP
# Clean up mixin resources (file batch tasks)
await super().shutdown()
async def register_vector_db(
self,
vector_db: VectorDB,
) -> None:
async def register_vector_db(self, vector_db: VectorDB) -> None:
assert self.kvstore is not None
key = f"{VECTOR_DBS_PREFIX}{vector_db.identifier}"
await self.kvstore.set(key=key, value=vector_db.model_dump_json())
index = VectorDBWithIndex(
vector_db=vector_db,
index=QdrantIndex(self.client, vector_db.identifier),
inference_api=self.inference_api,
vector_db=vector_db, index=QdrantIndex(self.client, vector_db.identifier), inference_api=self.inference_api
)
self.cache[vector_db.identifier] = index
@ -243,12 +218,7 @@ class QdrantVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtocolP
self.cache[vector_db_id] = index
return index
async def insert_chunks(
self,
vector_db_id: str,
chunks: list[Chunk],
ttl_seconds: int | None = None,
) -> None:
async def insert_chunks(self, vector_db_id: str, chunks: list[Chunk], ttl_seconds: int | None = None) -> None:
index = await self._get_and_cache_vector_db_index(vector_db_id)
if not index:
raise VectorStoreNotFoundError(vector_db_id)
@ -256,10 +226,7 @@ class QdrantVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtocolP
await index.insert_chunks(chunks)
async def query_chunks(
self,
vector_db_id: str,
query: InterleavedContent,
params: dict[str, Any] | None = None,
self, vector_db_id: str, query: InterleavedContent, params: dict[str, Any] | None = None
) -> QueryChunksResponse:
index = await self._get_and_cache_vector_db_index(vector_db_id)
if not index:

View file

@ -4,27 +4,14 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from llama_stack.core.datatypes import StackRunConfig
from llama_stack.providers.datatypes import Api, ProviderSpec
from .config import WeaviateVectorIOConfig
async def get_adapter_impl(
config: WeaviateVectorIOConfig, deps: dict[Api, ProviderSpec], run_config: StackRunConfig | None = None
):
async def get_adapter_impl(config: WeaviateVectorIOConfig, deps: dict[Api, ProviderSpec]):
from .weaviate import WeaviateVectorIOAdapter
vector_stores_config = None
if run_config and run_config.vector_stores:
vector_stores_config = run_config.vector_stores
impl = WeaviateVectorIOAdapter(
config,
deps[Api.inference],
deps[Api.models],
deps.get(Api.files),
vector_stores_config,
)
impl = WeaviateVectorIOAdapter(config, deps[Api.inference], deps.get(Api.files))
await impl.initialize()
return impl

View file

@ -8,10 +8,7 @@ from typing import Any
from pydantic import BaseModel, Field
from llama_stack.providers.utils.kvstore.config import (
KVStoreConfig,
SqliteKVStoreConfig,
)
from llama_stack.providers.utils.kvstore.config import KVStoreConfig, SqliteKVStoreConfig
from llama_stack.schema_utils import json_schema_type
@ -22,16 +19,11 @@ class WeaviateVectorIOConfig(BaseModel):
kvstore: KVStoreConfig | None = Field(description="Config for KV store backend (SQLite only for now)", default=None)
@classmethod
def sample_run_config(
cls,
__distro_dir__: str,
**kwargs: Any,
) -> dict[str, Any]:
def sample_run_config(cls, __distro_dir__: str, **kwargs: Any) -> dict[str, Any]:
return {
"weaviate_api_key": None,
"weaviate_cluster_url": "${env.WEAVIATE_CLUSTER_URL:=localhost:8080}",
"kvstore": SqliteKVStoreConfig.sample_run_config(
__distro_dir__=__distro_dir__,
db_name="weaviate_registry.db",
__distro_dir__=__distro_dir__, db_name="weaviate_registry.db"
),
}

View file

@ -16,18 +16,14 @@ from llama_stack.apis.common.content_types import InterleavedContent
from llama_stack.apis.common.errors import VectorStoreNotFoundError
from llama_stack.apis.files import Files
from llama_stack.apis.inference import Inference
from llama_stack.apis.models import Models
from llama_stack.apis.vector_dbs import VectorDB
from llama_stack.apis.vector_io import Chunk, QueryChunksResponse, VectorIO
from llama_stack.core.datatypes import VectorStoresConfig
from llama_stack.core.request_headers import NeedsRequestProviderData
from llama_stack.log import get_logger
from llama_stack.providers.datatypes import VectorDBsProtocolPrivate
from llama_stack.providers.utils.kvstore import kvstore_impl
from llama_stack.providers.utils.kvstore.api import KVStore
from llama_stack.providers.utils.memory.openai_vector_store_mixin import (
OpenAIVectorStoreMixin,
)
from llama_stack.providers.utils.memory.openai_vector_store_mixin import OpenAIVectorStoreMixin
from llama_stack.providers.utils.memory.vector_store import (
RERANKER_TYPE_RRF,
ChunkForDeletion,
@ -49,12 +45,7 @@ OPENAI_VECTOR_STORES_FILES_CONTENTS_PREFIX = f"openai_vector_stores_files_conten
class WeaviateIndex(EmbeddingIndex):
def __init__(
self,
client: weaviate.WeaviateClient,
collection_name: str,
kvstore: KVStore | None = None,
):
def __init__(self, client: weaviate.WeaviateClient, collection_name: str, kvstore: KVStore | None = None):
self.client = client
self.collection_name = sanitize_collection_name(collection_name, weaviate_format=True)
self.kvstore = kvstore
@ -109,9 +100,7 @@ class WeaviateIndex(EmbeddingIndex):
try:
results = collection.query.near_vector(
near_vector=embedding.tolist(),
limit=k,
return_metadata=wvc.query.MetadataQuery(distance=True),
near_vector=embedding.tolist(), limit=k, return_metadata=wvc.query.MetadataQuery(distance=True)
)
except Exception as e:
log.error(f"Weaviate client vector search failed: {e}")
@ -154,12 +143,7 @@ class WeaviateIndex(EmbeddingIndex):
collection = self.client.collections.get(sanitized_collection_name)
collection.data.delete_many(where=Filter.by_property("id").contains_any(chunk_ids))
async def query_keyword(
self,
query_string: str,
k: int,
score_threshold: float,
) -> QueryChunksResponse:
async def query_keyword(self, query_string: str, k: int, score_threshold: float) -> QueryChunksResponse:
"""
Performs BM25-based keyword search using Weaviate's built-in full-text search.
Args:
@ -176,9 +160,7 @@ class WeaviateIndex(EmbeddingIndex):
# Perform BM25 keyword search on chunk_content field
try:
results = collection.query.bm25(
query=query_string,
limit=k,
return_metadata=wvc.query.MetadataQuery(score=True),
query=query_string, limit=k, return_metadata=wvc.query.MetadataQuery(score=True)
)
except Exception as e:
log.error(f"Weaviate client keyword search failed: {e}")
@ -275,25 +257,11 @@ class WeaviateIndex(EmbeddingIndex):
return QueryChunksResponse(chunks=chunks, scores=scores)
class WeaviateVectorIOAdapter(
OpenAIVectorStoreMixin,
VectorIO,
NeedsRequestProviderData,
VectorDBsProtocolPrivate,
):
def __init__(
self,
config: WeaviateVectorIOConfig,
inference_api: Inference,
models_api: Models,
files_api: Files | None,
vector_stores_config: VectorStoresConfig | None = None,
) -> None:
class WeaviateVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, NeedsRequestProviderData, VectorDBsProtocolPrivate):
def __init__(self, config: WeaviateVectorIOConfig, inference_api: Inference, files_api: Files | None) -> None:
super().__init__(files_api=files_api, kvstore=None)
self.config = config
self.inference_api = inference_api
self.models_api = models_api
self.vector_stores_config = vector_stores_config
self.client_cache = {}
self.cache = {}
self.vector_db_store = None
@ -304,10 +272,7 @@ class WeaviateVectorIOAdapter(
log.info("Using Weaviate locally in container")
host, port = self.config.weaviate_cluster_url.split(":")
key = "local_test"
client = weaviate.connect_to_local(
host=host,
port=port,
)
client = weaviate.connect_to_local(host=host, port=port)
else:
log.info("Using Weaviate remote cluster with URL")
key = f"{self.config.weaviate_cluster_url}::{self.config.weaviate_api_key}"
@ -337,15 +302,9 @@ class WeaviateVectorIOAdapter(
for raw in stored:
vector_db = VectorDB.model_validate_json(raw)
client = self._get_client()
idx = WeaviateIndex(
client=client,
collection_name=vector_db.identifier,
kvstore=self.kvstore,
)
idx = WeaviateIndex(client=client, collection_name=vector_db.identifier, kvstore=self.kvstore)
self.cache[vector_db.identifier] = VectorDBWithIndex(
vector_db=vector_db,
index=idx,
inference_api=self.inference_api,
vector_db=vector_db, index=idx, inference_api=self.inference_api
)
# Load OpenAI vector stores metadata into cache
@ -357,10 +316,7 @@ class WeaviateVectorIOAdapter(
# Clean up mixin resources (file batch tasks)
await super().shutdown()
async def register_vector_db(
self,
vector_db: VectorDB,
) -> None:
async def register_vector_db(self, vector_db: VectorDB) -> None:
client = self._get_client()
sanitized_collection_name = sanitize_collection_name(vector_db.identifier, weaviate_format=True)
# Create collection if it doesn't exist
@ -369,17 +325,12 @@ class WeaviateVectorIOAdapter(
name=sanitized_collection_name,
vectorizer_config=wvc.config.Configure.Vectorizer.none(),
properties=[
wvc.config.Property(
name="chunk_content",
data_type=wvc.config.DataType.TEXT,
),
wvc.config.Property(name="chunk_content", data_type=wvc.config.DataType.TEXT),
],
)
self.cache[vector_db.identifier] = VectorDBWithIndex(
vector_db,
WeaviateIndex(client=client, collection_name=sanitized_collection_name),
self.inference_api,
vector_db, WeaviateIndex(client=client, collection_name=sanitized_collection_name), self.inference_api
)
async def unregister_vector_db(self, vector_db_id: str) -> None:
@ -415,12 +366,7 @@ class WeaviateVectorIOAdapter(
self.cache[vector_db_id] = index
return index
async def insert_chunks(
self,
vector_db_id: str,
chunks: list[Chunk],
ttl_seconds: int | None = None,
) -> None:
async def insert_chunks(self, vector_db_id: str, chunks: list[Chunk], ttl_seconds: int | None = None) -> None:
index = await self._get_and_cache_vector_db_index(vector_db_id)
if not index:
raise VectorStoreNotFoundError(vector_db_id)
@ -428,10 +374,7 @@ class WeaviateVectorIOAdapter(
await index.insert_chunks(chunks)
async def query_chunks(
self,
vector_db_id: str,
query: InterleavedContent,
params: dict[str, Any] | None = None,
self, vector_db_id: str, query: InterleavedContent, params: dict[str, Any] | None = None
) -> QueryChunksResponse:
index = await self._get_and_cache_vector_db_index(vector_db_id)
if not index: