This commit is contained in:
Varsha 2025-06-27 09:47:33 +01:00 committed by GitHub
commit e9609c00e6
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
17 changed files with 245 additions and 137 deletions

View file

@ -11468,6 +11468,32 @@
"ttl_seconds": { "ttl_seconds": {
"type": "integer", "type": "integer",
"description": "The time to live of the chunks." "description": "The time to live of the chunks."
},
"params": {
"type": "object",
"additionalProperties": {
"oneOf": [
{
"type": "null"
},
{
"type": "boolean"
},
{
"type": "number"
},
{
"type": "string"
},
{
"type": "array"
},
{
"type": "object"
}
]
},
"description": "Optional parameters for the insertion operation, such as distance_metric for vector databases."
} }
}, },
"additionalProperties": false, "additionalProperties": false,

View file

@ -8095,6 +8095,19 @@ components:
ttl_seconds: ttl_seconds:
type: integer type: integer
description: The time to live of the chunks. description: The time to live of the chunks.
params:
type: object
additionalProperties:
oneOf:
- type: 'null'
- type: boolean
- type: number
- type: string
- type: array
- type: object
description: >-
Optional parameters for the insertion operation, such as distance_metric
for vector databases.
additionalProperties: false additionalProperties: false
required: required:
- vector_db_id - vector_db_id

View file

@ -306,6 +306,7 @@ class VectorIO(Protocol):
vector_db_id: str, vector_db_id: str,
chunks: list[Chunk], chunks: list[Chunk],
ttl_seconds: int | None = None, ttl_seconds: int | None = None,
params: dict[str, Any] | None = None,
) -> None: ) -> None:
"""Insert chunks into a vector database. """Insert chunks into a vector database.
@ -315,6 +316,7 @@ class VectorIO(Protocol):
If `metadata` is provided, you configure how Llama Stack formats the chunk during generation. If `metadata` is provided, you configure how Llama Stack formats the chunk during generation.
If `embedding` is not provided, it will be computed later. If `embedding` is not provided, it will be computed later.
:param ttl_seconds: The time to live of the chunks. :param ttl_seconds: The time to live of the chunks.
:param params: Optional parameters for the insertion operation, such as distance_metric for vector databases.
""" """
... ...

View file

@ -97,11 +97,14 @@ class VectorIORouter(VectorIO):
vector_db_id: str, vector_db_id: str,
chunks: list[Chunk], chunks: list[Chunk],
ttl_seconds: int | None = None, ttl_seconds: int | None = None,
params: dict[str, Any] | None = None,
) -> None: ) -> None:
logger.debug( logger.debug(
f"VectorIORouter.insert_chunks: {vector_db_id}, {len(chunks)} chunks, ttl_seconds={ttl_seconds}, chunk_ids={[chunk.metadata['document_id'] for chunk in chunks[:3]]}{' and more...' if len(chunks) > 3 else ''}", f"VectorIORouter.insert_chunks: {vector_db_id}, {len(chunks)} chunks, ttl_seconds={ttl_seconds}, chunk_ids={[chunk.metadata['document_id'] for chunk in chunks[:3]]}{' and more...' if len(chunks) > 3 else ''}",
) )
return await self.routing_table.get_provider_impl(vector_db_id).insert_chunks(vector_db_id, chunks, ttl_seconds) return await self.routing_table.get_provider_impl(vector_db_id).insert_chunks(
vector_db_id, chunks, ttl_seconds, params
)
async def query_chunks( async def query_chunks(
self, self,

View file

@ -96,7 +96,7 @@ class FaissIndex(EmbeddingIndex):
await self.kvstore.delete(f"{FAISS_INDEX_PREFIX}{self.bank_id}") await self.kvstore.delete(f"{FAISS_INDEX_PREFIX}{self.bank_id}")
async def add_chunks(self, chunks: list[Chunk], embeddings: NDArray): async def add_chunks(self, chunks: list[Chunk], embeddings: NDArray, metadata: dict[str, Any] | None = None):
# Add dimension check # Add dimension check
embedding_dim = embeddings.shape[1] if len(embeddings.shape) > 1 else embeddings.shape[0] embedding_dim = embeddings.shape[1] if len(embeddings.shape) > 1 else embeddings.shape[0]
if embedding_dim != self.index.d: if embedding_dim != self.index.d:
@ -234,6 +234,7 @@ class FaissVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtocolPr
vector_db_id: str, vector_db_id: str,
chunks: list[Chunk], chunks: list[Chunk],
ttl_seconds: int | None = None, ttl_seconds: int | None = None,
params: dict[str, Any] | None = None,
) -> None: ) -> None:
index = self.cache.get(vector_db_id) index = self.cache.get(vector_db_id)
if index is None: if index is None:

View file

@ -4,14 +4,18 @@
# This source code is licensed under the terms described in the LICENSE file in # This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree. # the root directory of this source tree.
from llama_stack.providers.datatypes import Api, ProviderSpec from typing import Any
from llama_stack.providers.datatypes import Api
from .config import QdrantVectorIOConfig from .config import QdrantVectorIOConfig
async def get_adapter_impl(config: QdrantVectorIOConfig, deps: dict[Api, ProviderSpec]): async def get_provider_impl(config: QdrantVectorIOConfig, deps: dict[Api, Any]):
from llama_stack.providers.remote.vector_io.qdrant.qdrant import QdrantVectorIOAdapter from llama_stack.providers.remote.vector_io.qdrant.qdrant import QdrantVectorIOAdapter
impl = QdrantVectorIOAdapter(config, deps[Api.inference]) assert isinstance(config, QdrantVectorIOConfig), f"Unexpected config type: {type(config)}"
files_api = deps.get(Api.files)
impl = QdrantVectorIOAdapter(config, deps[Api.inference], files_api)
await impl.initialize() await impl.initialize()
return impl return impl

View file

@ -9,15 +9,24 @@ from typing import Any
from pydantic import BaseModel from pydantic import BaseModel
from llama_stack.providers.utils.kvstore.config import (
KVStoreConfig,
SqliteKVStoreConfig,
)
from llama_stack.schema_utils import json_schema_type from llama_stack.schema_utils import json_schema_type
@json_schema_type @json_schema_type
class QdrantVectorIOConfig(BaseModel): class QdrantVectorIOConfig(BaseModel):
path: str path: str
kvstore: KVStoreConfig
@classmethod @classmethod
def sample_run_config(cls, __distro_dir__: str) -> dict[str, Any]: def sample_run_config(cls, __distro_dir__: str) -> dict[str, Any]:
return { return {
"path": "${env.QDRANT_PATH:=~/.llama/" + __distro_dir__ + "}/" + "qdrant.db", "path": "${env.QDRANT_PATH:=~/.llama/" + __distro_dir__ + "}/" + "qdrant.db",
"kvstore": SqliteKVStoreConfig.sample_run_config(
__distro_dir__=__distro_dir__,
db_name="qdrant_store.db",
),
} }

View file

@ -178,7 +178,9 @@ class SQLiteVecIndex(EmbeddingIndex):
await asyncio.to_thread(_drop_tables) await asyncio.to_thread(_drop_tables)
async def add_chunks(self, chunks: list[Chunk], embeddings: NDArray, batch_size: int = 500): async def add_chunks(
self, chunks: list[Chunk], embeddings: NDArray, metadata: dict[str, Any] | None = None, batch_size: int = 500
):
""" """
Add new chunks along with their embeddings using batch inserts. Add new chunks along with their embeddings using batch inserts.
For each chunk, we insert its JSON into the metadata table and then insert its For each chunk, we insert its JSON into the metadata table and then insert its
@ -729,7 +731,13 @@ class SQLiteVecVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtoc
await asyncio.to_thread(_delete) await asyncio.to_thread(_delete)
async def insert_chunks(self, vector_db_id: str, chunks: list[Chunk], ttl_seconds: int | None = None) -> None: async def insert_chunks(
self,
vector_db_id: str,
chunks: list[Chunk],
ttl_seconds: int | None = None,
params: dict[str, Any] | None = None,
) -> None:
if vector_db_id not in self.cache: if vector_db_id not in self.cache:
raise ValueError(f"Vector DB {vector_db_id} not found. Found: {list(self.cache.keys())}") raise ValueError(f"Vector DB {vector_db_id} not found. Found: {list(self.cache.keys())}")
# The VectorDBWithIndex helper is expected to compute embeddings via the inference_api # The VectorDBWithIndex helper is expected to compute embeddings via the inference_api

View file

@ -55,7 +55,7 @@ class ChromaIndex(EmbeddingIndex):
self.client = client self.client = client
self.collection = collection self.collection = collection
async def add_chunks(self, chunks: list[Chunk], embeddings: NDArray): async def add_chunks(self, chunks: list[Chunk], embeddings: NDArray, metadata: dict[str, Any] | None = None):
assert len(chunks) == len(embeddings), ( assert len(chunks) == len(embeddings), (
f"Chunk length {len(chunks)} does not match embedding length {len(embeddings)}" f"Chunk length {len(chunks)} does not match embedding length {len(embeddings)}"
) )
@ -178,6 +178,7 @@ class ChromaVectorIOAdapter(VectorIO, VectorDBsProtocolPrivate):
vector_db_id: str, vector_db_id: str,
chunks: list[Chunk], chunks: list[Chunk],
ttl_seconds: int | None = None, ttl_seconds: int | None = None,
params: dict[str, Any] | None = None,
) -> None: ) -> None:
index = await self._get_and_cache_vector_db_index(vector_db_id) index = await self._get_and_cache_vector_db_index(vector_db_id)

View file

@ -53,7 +53,7 @@ class MilvusIndex(EmbeddingIndex):
if await asyncio.to_thread(self.client.has_collection, self.collection_name): if await asyncio.to_thread(self.client.has_collection, self.collection_name):
await asyncio.to_thread(self.client.drop_collection, collection_name=self.collection_name) await asyncio.to_thread(self.client.drop_collection, collection_name=self.collection_name)
async def add_chunks(self, chunks: list[Chunk], embeddings: NDArray): async def add_chunks(self, chunks: list[Chunk], embeddings: NDArray, metadata: dict[str, Any] | None = None):
assert len(chunks) == len(embeddings), ( assert len(chunks) == len(embeddings), (
f"Chunk length {len(chunks)} does not match embedding length {len(embeddings)}" f"Chunk length {len(chunks)} does not match embedding length {len(embeddings)}"
) )
@ -183,6 +183,7 @@ class MilvusVectorIOAdapter(VectorIO, VectorDBsProtocolPrivate):
vector_db_id: str, vector_db_id: str,
chunks: list[Chunk], chunks: list[Chunk],
ttl_seconds: int | None = None, ttl_seconds: int | None = None,
params: dict[str, Any] | None = None,
) -> None: ) -> None:
index = await self._get_and_cache_vector_db_index(vector_db_id) index = await self._get_and_cache_vector_db_index(vector_db_id)
if not index: if not index:

View file

@ -88,7 +88,7 @@ class PGVectorIndex(EmbeddingIndex):
""" """
) )
async def add_chunks(self, chunks: list[Chunk], embeddings: NDArray): async def add_chunks(self, chunks: list[Chunk], embeddings: NDArray, metadata: dict[str, Any] | None = None):
assert len(chunks) == len(embeddings), ( assert len(chunks) == len(embeddings), (
f"Chunk length {len(chunks)} does not match embedding length {len(embeddings)}" f"Chunk length {len(chunks)} does not match embedding length {len(embeddings)}"
) )
@ -215,6 +215,7 @@ class PGVectorVectorIOAdapter(VectorIO, VectorDBsProtocolPrivate):
vector_db_id: str, vector_db_id: str,
chunks: list[Chunk], chunks: list[Chunk],
ttl_seconds: int | None = None, ttl_seconds: int | None = None,
params: dict[str, Any] | None = None,
) -> None: ) -> None:
index = await self._get_and_cache_vector_db_index(vector_db_id) index = await self._get_and_cache_vector_db_index(vector_db_id)
await index.insert_chunks(chunks) await index.insert_chunks(chunks)

View file

@ -8,6 +8,10 @@ from typing import Any
from pydantic import BaseModel from pydantic import BaseModel
from llama_stack.providers.utils.kvstore.config import (
KVStoreConfig,
SqliteKVStoreConfig,
)
from llama_stack.schema_utils import json_schema_type from llama_stack.schema_utils import json_schema_type
@ -23,9 +27,14 @@ class QdrantVectorIOConfig(BaseModel):
prefix: str | None = None prefix: str | None = None
timeout: int | None = None timeout: int | None = None
host: str | None = None host: str | None = None
kvstore: KVStoreConfig
@classmethod @classmethod
def sample_run_config(cls, **kwargs: Any) -> dict[str, Any]: def sample_run_config(cls, __distro_dir__: str, **kwargs: Any) -> dict[str, Any]:
return { return {
"api_key": "${env.QDRANT_API_KEY}", "api_key": "${env.QDRANT_API_KEY}",
"kvstore": SqliteKVStoreConfig.sample_run_config(
__distro_dir__=__distro_dir__,
db_name="qdrant_store.db",
),
} }

View file

@ -4,6 +4,7 @@
# This source code is licensed under the terms described in the LICENSE file in # This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree. # the root directory of this source tree.
import json
import logging import logging
import uuid import uuid
from typing import Any from typing import Any
@ -12,25 +13,18 @@ from numpy.typing import NDArray
from qdrant_client import AsyncQdrantClient, models from qdrant_client import AsyncQdrantClient, models
from qdrant_client.models import PointStruct from qdrant_client.models import PointStruct
from llama_stack.apis.files import Files
from llama_stack.apis.inference import InterleavedContent from llama_stack.apis.inference import InterleavedContent
from llama_stack.apis.vector_dbs import VectorDB from llama_stack.apis.vector_dbs import VectorDB
from llama_stack.apis.vector_io import ( from llama_stack.apis.vector_io import (
Chunk, Chunk,
QueryChunksResponse, QueryChunksResponse,
SearchRankingOptions,
VectorIO, VectorIO,
VectorStoreChunkingStrategy,
VectorStoreDeleteResponse,
VectorStoreFileContentsResponse,
VectorStoreFileObject,
VectorStoreFileStatus,
VectorStoreListFilesResponse,
VectorStoreListResponse,
VectorStoreObject,
VectorStoreSearchResponsePage,
) )
from llama_stack.providers.datatypes import Api, VectorDBsProtocolPrivate from llama_stack.providers.datatypes import Api, VectorDBsProtocolPrivate
from llama_stack.providers.inline.vector_io.qdrant import QdrantVectorIOConfig as InlineQdrantVectorIOConfig from llama_stack.providers.inline.vector_io.qdrant import QdrantVectorIOConfig as InlineQdrantVectorIOConfig
from llama_stack.providers.utils.kvstore import KVStore, kvstore_impl
from llama_stack.providers.utils.memory.openai_vector_store_mixin import OpenAIVectorStoreMixin
from llama_stack.providers.utils.memory.vector_store import ( from llama_stack.providers.utils.memory.vector_store import (
EmbeddingIndex, EmbeddingIndex,
VectorDBWithIndex, VectorDBWithIndex,
@ -41,6 +35,13 @@ from .config import QdrantVectorIOConfig as RemoteQdrantVectorIOConfig
log = logging.getLogger(__name__) log = logging.getLogger(__name__)
CHUNK_ID_KEY = "_chunk_id" CHUNK_ID_KEY = "_chunk_id"
# KV store prefixes for OpenAI vector stores
OPENAI_VECTOR_STORES_PREFIX = "openai_vector_stores:"
OPENAI_VECTOR_STORES_FILES_PREFIX = "openai_vector_stores_files:"
OPENAI_VECTOR_STORES_FILES_CONTENTS_PREFIX = "openai_vector_stores_files_contents:"
VECTOR_DBS_PREFIX = "vector_dbs:"
def convert_id(_id: str) -> str: def convert_id(_id: str) -> str:
""" """
@ -57,17 +58,38 @@ class QdrantIndex(EmbeddingIndex):
def __init__(self, client: AsyncQdrantClient, collection_name: str): def __init__(self, client: AsyncQdrantClient, collection_name: str):
self.client = client self.client = client
self.collection_name = collection_name self.collection_name = collection_name
self._distance_metric = None # Will be set when collection is created
async def add_chunks(self, chunks: list[Chunk], embeddings: NDArray): async def add_chunks(self, chunks: list[Chunk], embeddings: NDArray, metadata: dict[str, Any] | None = None):
assert len(chunks) == len(embeddings), ( assert len(chunks) == len(embeddings), (
f"Chunk length {len(chunks)} does not match embedding length {len(embeddings)}" f"Chunk length {len(chunks)} does not match embedding length {len(embeddings)}"
) )
# Extract distance_metric from metadata if provided, default to COSINE
distance_metric = "COSINE" # Default
if metadata is not None and "distance_metric" in metadata:
distance_metric = metadata["distance_metric"]
if not await self.client.collection_exists(self.collection_name): if not await self.client.collection_exists(self.collection_name):
# Create collection with the specified distance metric
distance = getattr(models.Distance, distance_metric, models.Distance.COSINE)
self._distance_metric = distance_metric
await self.client.create_collection( await self.client.create_collection(
self.collection_name, self.collection_name,
vectors_config=models.VectorParams(size=len(embeddings[0]), distance=models.Distance.COSINE), vectors_config=models.VectorParams(size=len(embeddings[0]), distance=distance),
) )
else:
# Collection already exists, warn if different distance metric was requested
if self._distance_metric is None:
# For now, assume COSINE as default since we can't easily extract it from collection info
self._distance_metric = "COSINE"
if self._distance_metric != distance_metric:
log.warning(
f"Collection {self.collection_name} was created with distance metric '{self._distance_metric}', "
f"but '{distance_metric}' was requested. Using existing distance metric."
)
points = [] points = []
for _i, (chunk, embedding) in enumerate(zip(chunks, embeddings, strict=False)): for _i, (chunk, embedding) in enumerate(zip(chunks, embeddings, strict=False)):
@ -83,6 +105,7 @@ class QdrantIndex(EmbeddingIndex):
await self.client.upsert(collection_name=self.collection_name, points=points) await self.client.upsert(collection_name=self.collection_name, points=points)
async def query_vector(self, embedding: NDArray, k: int, score_threshold: float) -> QueryChunksResponse: async def query_vector(self, embedding: NDArray, k: int, score_threshold: float) -> QueryChunksResponse:
# Distance metric is set at collection creation and cannot be changed
results = ( results = (
await self.client.query_points( await self.client.query_points(
collection_name=self.collection_name, collection_name=self.collection_name,
@ -132,21 +155,116 @@ class QdrantIndex(EmbeddingIndex):
await self.client.delete_collection(collection_name=self.collection_name) await self.client.delete_collection(collection_name=self.collection_name)
class QdrantVectorIOAdapter(VectorIO, VectorDBsProtocolPrivate): class QdrantVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtocolPrivate):
def __init__( def __init__(
self, config: RemoteQdrantVectorIOConfig | InlineQdrantVectorIOConfig, inference_api: Api.inference self,
config: RemoteQdrantVectorIOConfig | InlineQdrantVectorIOConfig,
inference_api: Api.inference,
files_api: Files | None,
) -> None: ) -> None:
self.config = config self.config = config
self.client: AsyncQdrantClient = None self.client: AsyncQdrantClient = None
self.cache = {} self.cache = {}
self.inference_api = inference_api self.inference_api = inference_api
self.files_api = files_api
self.vector_db_store = None
self.kvstore: KVStore | None = None
self.openai_vector_stores: dict[str, dict[str, Any]] = {}
async def initialize(self) -> None: async def initialize(self) -> None:
self.client = AsyncQdrantClient(**self.config.model_dump(exclude_none=True)) self.client = AsyncQdrantClient(**self.config.model_dump(exclude_none=True))
self.kvstore = await kvstore_impl(self.config.kvstore)
# Load existing vector DBs from kvstore
start_key = VECTOR_DBS_PREFIX
end_key = f"{VECTOR_DBS_PREFIX}\xff"
stored_vector_dbs = await self.kvstore.values_in_range(start_key, end_key)
for vector_db_data in stored_vector_dbs:
vector_db = VectorDB.model_validate_json(vector_db_data)
index = VectorDBWithIndex(
vector_db,
QdrantIndex(self.client, vector_db.identifier),
self.inference_api,
)
self.cache[vector_db.identifier] = index
# Load OpenAI vector stores as before
self.openai_vector_stores = await self._load_openai_vector_stores()
async def shutdown(self) -> None: async def shutdown(self) -> None:
await self.client.close() await self.client.close()
# OpenAI Vector Store Mixin abstract method implementations
async def _save_openai_vector_store(self, store_id: str, store_info: dict[str, Any]) -> None:
"""Save vector store metadata to kvstore."""
assert self.kvstore is not None
key = f"{OPENAI_VECTOR_STORES_PREFIX}{store_id}"
await self.kvstore.set(key=key, value=json.dumps(store_info))
async def _load_openai_vector_stores(self) -> dict[str, dict[str, Any]]:
"""Load all vector store metadata from kvstore."""
assert self.kvstore is not None
start_key = OPENAI_VECTOR_STORES_PREFIX
end_key = f"{OPENAI_VECTOR_STORES_PREFIX}\xff"
stored_openai_stores = await self.kvstore.values_in_range(start_key, end_key)
stores = {}
for store_data in stored_openai_stores:
store_info = json.loads(store_data)
stores[store_info["id"]] = store_info
return stores
async def _update_openai_vector_store(self, store_id: str, store_info: dict[str, Any]) -> None:
"""Update vector store metadata in kvstore."""
assert self.kvstore is not None
key = f"{OPENAI_VECTOR_STORES_PREFIX}{store_id}"
await self.kvstore.set(key=key, value=json.dumps(store_info))
async def _delete_openai_vector_store_from_storage(self, store_id: str) -> None:
"""Delete vector store metadata from kvstore."""
assert self.kvstore is not None
key = f"{OPENAI_VECTOR_STORES_PREFIX}{store_id}"
await self.kvstore.delete(key)
async def _save_openai_vector_store_file(
self, store_id: str, file_id: str, file_info: dict[str, Any], file_contents: list[dict[str, Any]]
) -> None:
"""Save vector store file metadata to kvstore."""
assert self.kvstore is not None
key = f"{OPENAI_VECTOR_STORES_FILES_PREFIX}{store_id}:{file_id}"
await self.kvstore.set(key=key, value=json.dumps(file_info))
content_key = f"{OPENAI_VECTOR_STORES_FILES_CONTENTS_PREFIX}{store_id}:{file_id}"
await self.kvstore.set(key=content_key, value=json.dumps(file_contents))
async def _load_openai_vector_store_file(self, store_id: str, file_id: str) -> dict[str, Any]:
"""Load vector store file metadata from kvstore."""
assert self.kvstore is not None
key = f"{OPENAI_VECTOR_STORES_FILES_PREFIX}{store_id}:{file_id}"
stored_data = await self.kvstore.get(key)
return json.loads(stored_data) if stored_data else {}
async def _load_openai_vector_store_file_contents(self, store_id: str, file_id: str) -> list[dict[str, Any]]:
"""Load vector store file contents from kvstore."""
assert self.kvstore is not None
key = f"{OPENAI_VECTOR_STORES_FILES_CONTENTS_PREFIX}{store_id}:{file_id}"
stored_data = await self.kvstore.get(key)
return json.loads(stored_data) if stored_data else []
async def _update_openai_vector_store_file(self, store_id: str, file_id: str, file_info: dict[str, Any]) -> None:
"""Update vector store file metadata in kvstore."""
assert self.kvstore is not None
key = f"{OPENAI_VECTOR_STORES_FILES_PREFIX}{store_id}:{file_id}"
await self.kvstore.set(key=key, value=json.dumps(file_info))
async def _delete_openai_vector_store_file_from_storage(self, store_id: str, file_id: str) -> None:
"""Delete vector store file metadata from kvstore."""
assert self.kvstore is not None
key = f"{OPENAI_VECTOR_STORES_FILES_PREFIX}{store_id}:{file_id}"
await self.kvstore.delete(key)
content_key = f"{OPENAI_VECTOR_STORES_FILES_CONTENTS_PREFIX}{store_id}:{file_id}"
await self.kvstore.delete(content_key)
async def register_vector_db( async def register_vector_db(
self, self,
vector_db: VectorDB, vector_db: VectorDB,
@ -185,12 +303,18 @@ class QdrantVectorIOAdapter(VectorIO, VectorDBsProtocolPrivate):
vector_db_id: str, vector_db_id: str,
chunks: list[Chunk], chunks: list[Chunk],
ttl_seconds: int | None = None, ttl_seconds: int | None = None,
params: dict[str, Any] | None = None,
) -> None: ) -> None:
index = await self._get_and_cache_vector_db_index(vector_db_id) index = await self._get_and_cache_vector_db_index(vector_db_id)
if not index: if not index:
raise ValueError(f"Vector DB {vector_db_id} not found") raise ValueError(f"Vector DB {vector_db_id} not found")
await index.insert_chunks(chunks) # Extract distance_metric from params if provided
distance_metric = None
if params is not None:
distance_metric = params.get("distance_metric")
await index.insert_chunks(chunks, distance_metric=distance_metric)
async def query_chunks( async def query_chunks(
self, self,
@ -203,108 +327,3 @@ class QdrantVectorIOAdapter(VectorIO, VectorDBsProtocolPrivate):
raise ValueError(f"Vector DB {vector_db_id} not found") raise ValueError(f"Vector DB {vector_db_id} not found")
return await index.query_chunks(query, params) return await index.query_chunks(query, params)
async def openai_create_vector_store(
self,
name: str,
file_ids: list[str] | None = None,
expires_after: dict[str, Any] | None = None,
chunking_strategy: dict[str, Any] | None = None,
metadata: dict[str, Any] | None = None,
embedding_model: str | None = None,
embedding_dimension: int | None = 384,
provider_id: str | None = None,
provider_vector_db_id: str | None = None,
) -> VectorStoreObject:
raise NotImplementedError("OpenAI Vector Stores API is not supported in Qdrant")
async def openai_list_vector_stores(
self,
limit: int | None = 20,
order: str | None = "desc",
after: str | None = None,
before: str | None = None,
) -> VectorStoreListResponse:
raise NotImplementedError("OpenAI Vector Stores API is not supported in Qdrant")
async def openai_retrieve_vector_store(
self,
vector_store_id: str,
) -> VectorStoreObject:
raise NotImplementedError("OpenAI Vector Stores API is not supported in Qdrant")
async def openai_update_vector_store(
self,
vector_store_id: str,
name: str | None = None,
expires_after: dict[str, Any] | None = None,
metadata: dict[str, Any] | None = None,
) -> VectorStoreObject:
raise NotImplementedError("OpenAI Vector Stores API is not supported in Qdrant")
async def openai_delete_vector_store(
self,
vector_store_id: str,
) -> VectorStoreDeleteResponse:
raise NotImplementedError("OpenAI Vector Stores API is not supported in Qdrant")
async def openai_search_vector_store(
self,
vector_store_id: str,
query: str | list[str],
filters: dict[str, Any] | None = None,
max_num_results: int | None = 10,
ranking_options: SearchRankingOptions | None = None,
rewrite_query: bool | None = False,
search_mode: str | None = "vector",
) -> VectorStoreSearchResponsePage:
raise NotImplementedError("OpenAI Vector Stores API is not supported in Qdrant")
async def openai_attach_file_to_vector_store(
self,
vector_store_id: str,
file_id: str,
attributes: dict[str, Any] | None = None,
chunking_strategy: VectorStoreChunkingStrategy | None = None,
) -> VectorStoreFileObject:
raise NotImplementedError("OpenAI Vector Stores API is not supported in Qdrant")
async def openai_list_files_in_vector_store(
self,
vector_store_id: str,
limit: int | None = 20,
order: str | None = "desc",
after: str | None = None,
before: str | None = None,
filter: VectorStoreFileStatus | None = None,
) -> VectorStoreListFilesResponse:
raise NotImplementedError("OpenAI Vector Stores API is not supported in Qdrant")
async def openai_retrieve_vector_store_file(
self,
vector_store_id: str,
file_id: str,
) -> VectorStoreFileObject:
raise NotImplementedError("OpenAI Vector Stores API is not supported in Qdrant")
async def openai_retrieve_vector_store_file_contents(
self,
vector_store_id: str,
file_id: str,
) -> VectorStoreFileContentsResponse:
raise NotImplementedError("OpenAI Vector Stores API is not supported in Qdrant")
async def openai_update_vector_store_file(
self,
vector_store_id: str,
file_id: str,
attributes: dict[str, Any] | None = None,
) -> VectorStoreFileObject:
raise NotImplementedError("OpenAI Vector Stores API is not supported in Qdrant")
async def openai_delete_vector_store_file(
self,
vector_store_id: str,
file_id: str,
) -> VectorStoreFileObject:
raise NotImplementedError("OpenAI Vector Stores API is not supported in Qdrant")

View file

@ -33,7 +33,7 @@ class WeaviateIndex(EmbeddingIndex):
self.client = client self.client = client
self.collection_name = collection_name self.collection_name = collection_name
async def add_chunks(self, chunks: list[Chunk], embeddings: NDArray): async def add_chunks(self, chunks: list[Chunk], embeddings: NDArray, metadata: dict[str, Any] | None = None):
assert len(chunks) == len(embeddings), ( assert len(chunks) == len(embeddings), (
f"Chunk length {len(chunks)} does not match embedding length {len(embeddings)}" f"Chunk length {len(chunks)} does not match embedding length {len(embeddings)}"
) )
@ -188,6 +188,7 @@ class WeaviateVectorIOAdapter(
vector_db_id: str, vector_db_id: str,
chunks: list[Chunk], chunks: list[Chunk],
ttl_seconds: int | None = None, ttl_seconds: int | None = None,
params: dict[str, Any] | None = None,
) -> None: ) -> None:
index = await self._get_and_cache_vector_db_index(vector_db_id) index = await self._get_and_cache_vector_db_index(vector_db_id)
if not index: if not index:

View file

@ -214,7 +214,7 @@ def _validate_embedding(embedding: NDArray, index: int, expected_dimension: int)
class EmbeddingIndex(ABC): class EmbeddingIndex(ABC):
@abstractmethod @abstractmethod
async def add_chunks(self, chunks: list[Chunk], embeddings: NDArray): async def add_chunks(self, chunks: list[Chunk], embeddings: NDArray, metadata: dict[str, Any] | None = None):
raise NotImplementedError() raise NotImplementedError()
@abstractmethod @abstractmethod
@ -251,6 +251,7 @@ class VectorDBWithIndex:
async def insert_chunks( async def insert_chunks(
self, self,
chunks: list[Chunk], chunks: list[Chunk],
distance_metric: str | None = None,
) -> None: ) -> None:
chunks_to_embed = [] chunks_to_embed = []
for i, c in enumerate(chunks): for i, c in enumerate(chunks):
@ -271,7 +272,13 @@ class VectorDBWithIndex:
c.embedding = embedding c.embedding = embedding
embeddings = np.array([c.embedding for c in chunks], dtype=np.float32) embeddings = np.array([c.embedding for c in chunks], dtype=np.float32)
await self.index.add_chunks(chunks, embeddings)
# Create metadata dict with distance_metric if provided
metadata = None
if distance_metric is not None:
metadata = {"distance_metric": distance_metric}
await self.index.add_chunks(chunks, embeddings, metadata=metadata)
async def query_chunks( async def query_chunks(
self, self,

View file

@ -22,7 +22,7 @@ logger = logging.getLogger(__name__)
def skip_if_provider_doesnt_support_openai_vector_stores(client_with_models): def skip_if_provider_doesnt_support_openai_vector_stores(client_with_models):
vector_io_providers = [p for p in client_with_models.providers.list() if p.api == "vector_io"] vector_io_providers = [p for p in client_with_models.providers.list() if p.api == "vector_io"]
for p in vector_io_providers: for p in vector_io_providers:
if p.provider_type in ["inline::faiss", "inline::sqlite-vec"]: if p.provider_type in ["inline::faiss", "inline::sqlite-vec", "inline::qdrant"]:
return return
pytest.skip("OpenAI vector stores are not supported by any provider") pytest.skip("OpenAI vector stores are not supported by any provider")

View file

@ -24,6 +24,7 @@ from llama_stack.providers.inline.vector_io.qdrant.config import (
from llama_stack.providers.remote.vector_io.qdrant.qdrant import ( from llama_stack.providers.remote.vector_io.qdrant.qdrant import (
QdrantVectorIOAdapter, QdrantVectorIOAdapter,
) )
from llama_stack.providers.utils.kvstore.config import SqliteKVStoreConfig
# This test is a unit test for the QdrantVectorIOAdapter class. This should only contain # This test is a unit test for the QdrantVectorIOAdapter class. This should only contain
# tests which are specific to this class. More general (API-level) tests should be placed in # tests which are specific to this class. More general (API-level) tests should be placed in
@ -37,7 +38,9 @@ from llama_stack.providers.remote.vector_io.qdrant.qdrant import (
@pytest.fixture @pytest.fixture
def qdrant_config(tmp_path) -> InlineQdrantVectorIOConfig: def qdrant_config(tmp_path) -> InlineQdrantVectorIOConfig:
return InlineQdrantVectorIOConfig(path=os.path.join(tmp_path, "qdrant.db")) kvstore_config = SqliteKVStoreConfig(db_name=os.path.join(tmp_path, "test_kvstore.db"))
return InlineQdrantVectorIOConfig(path=os.path.join(tmp_path, "qdrant.db"), kvstore=kvstore_config)
@pytest.fixture(scope="session") @pytest.fixture(scope="session")
@ -70,7 +73,7 @@ def mock_api_service(sample_embeddings):
@pytest_asyncio.fixture @pytest_asyncio.fixture
async def qdrant_adapter(qdrant_config, mock_vector_db_store, mock_api_service, loop) -> QdrantVectorIOAdapter: async def qdrant_adapter(qdrant_config, mock_vector_db_store, mock_api_service, loop) -> QdrantVectorIOAdapter:
adapter = QdrantVectorIOAdapter(config=qdrant_config, inference_api=mock_api_service) adapter = QdrantVectorIOAdapter(config=qdrant_config, inference_api=mock_api_service, files_api=None)
adapter.vector_db_store = mock_vector_db_store adapter.vector_db_store = mock_vector_db_store
await adapter.initialize() await adapter.initialize()
yield adapter yield adapter