feat: rebase and implement file API methods

Signed-off-by: Varsha Prasad Narsing <varshaprasad96@gmail.com>
This commit is contained in:
Varsha Prasad Narsing 2025-06-25 16:59:29 -07:00
parent 918e68548f
commit dfafa5bbae
15 changed files with 212 additions and 214 deletions

View file

@ -22,7 +22,7 @@ jobs:
runs-on: ubuntu-latest runs-on: ubuntu-latest
strategy: strategy:
matrix: matrix:
vector-io-provider: ["inline::faiss", "inline::sqlite-vec", "inline::milvus", "remote::chromadb", "remote::pgvector"] vector-io-provider: ["inline::faiss", "inline::sqlite-vec", "inline::milvus", "remote::chromadb", "remote::pgvector", "remote::qdrant"]
python-version: ["3.12", "3.13"] python-version: ["3.12", "3.13"]
fail-fast: false # we want to run all tests regardless of failure fail-fast: false # we want to run all tests regardless of failure
@ -76,6 +76,29 @@ jobs:
PGPASSWORD=llamastack psql -h localhost -U llamastack -d llamastack \ PGPASSWORD=llamastack psql -h localhost -U llamastack -d llamastack \
-c "CREATE EXTENSION IF NOT EXISTS vector;" -c "CREATE EXTENSION IF NOT EXISTS vector;"
- name: Setup Qdrant
if: matrix.vector-io-provider == 'remote::qdrant'
run: |
docker run --rm -d --pull always \
--name qdrant \
-p 6333:6333 \
qdrant/qdrant
- name: Wait for Qdrant to be ready
if: matrix.vector-io-provider == 'remote::qdrant'
run: |
echo "Waiting for Qdrant to be ready..."
for i in {1..30}; do
if curl -s http://localhost:6333/collections | grep -q '"status":"ok"'; then
echo "Qdrant is ready!"
exit 0
fi
sleep 2
done
echo "Qdrant failed to start"
docker logs qdrant
exit 1
- name: Wait for ChromaDB to be ready - name: Wait for ChromaDB to be ready
if: matrix.vector-io-provider == 'remote::chromadb' if: matrix.vector-io-provider == 'remote::chromadb'
run: | run: |
@ -111,6 +134,8 @@ jobs:
PGVECTOR_DB: ${{ matrix.vector-io-provider == 'remote::pgvector' && 'llamastack' || '' }} PGVECTOR_DB: ${{ matrix.vector-io-provider == 'remote::pgvector' && 'llamastack' || '' }}
PGVECTOR_USER: ${{ matrix.vector-io-provider == 'remote::pgvector' && 'llamastack' || '' }} PGVECTOR_USER: ${{ matrix.vector-io-provider == 'remote::pgvector' && 'llamastack' || '' }}
PGVECTOR_PASSWORD: ${{ matrix.vector-io-provider == 'remote::pgvector' && 'llamastack' || '' }} PGVECTOR_PASSWORD: ${{ matrix.vector-io-provider == 'remote::pgvector' && 'llamastack' || '' }}
ENABLE_QDRANT: ${{ matrix.vector-io-provider == 'remote::qdrant' && 'true' || '' }}
QDRANT_URL: ${{ matrix.vector-io-provider == 'remote::qdrant' && 'http://localhost:6333' || '' }}
run: | run: |
uv run pytest -sv --stack-config="inference=inline::sentence-transformers,vector_io=${{ matrix.vector-io-provider }}" \ uv run pytest -sv --stack-config="inference=inline::sentence-transformers,vector_io=${{ matrix.vector-io-provider }}" \
tests/integration/vector_io \ tests/integration/vector_io \
@ -132,6 +157,11 @@ jobs:
run: | run: |
docker logs chromadb > chromadb.log docker logs chromadb > chromadb.log
- name: Write Qdrant logs to file
if: ${{ always() && matrix.vector-io-provider == 'remote::qdrant' }}
run: |
docker logs qdrant > qdrant.log
- name: Upload all logs to artifacts - name: Upload all logs to artifacts
if: ${{ always() }} if: ${{ always() }}
uses: actions/upload-artifact@ea165f8d65b6e75b540449e92b4886f43607fa02 # v4.6.2 uses: actions/upload-artifact@ea165f8d65b6e75b540449e92b4886f43607fa02 # v4.6.2

View file

@ -51,11 +51,15 @@ See the [Qdrant documentation](https://qdrant.tech/documentation/) for more deta
| Field | Type | Required | Default | Description | | Field | Type | Required | Default | Description |
|-------|------|----------|---------|-------------| |-------|------|----------|---------|-------------|
| `path` | `<class 'str'>` | No | PydanticUndefined | | | `path` | `<class 'str'>` | No | PydanticUndefined | |
| `kvstore` | `utils.kvstore.config.RedisKVStoreConfig \| utils.kvstore.config.SqliteKVStoreConfig \| utils.kvstore.config.PostgresKVStoreConfig \| utils.kvstore.config.MongoDBKVStoreConfig` | No | sqlite | |
## Sample Configuration ## Sample Configuration
```yaml ```yaml
path: ${env.QDRANT_PATH:=~/.llama/~/.llama/dummy}/qdrant.db path: ${env.QDRANT_PATH:=~/.llama/~/.llama/dummy}/qdrant.db
kvstore:
type: sqlite
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/dummy}/qdrant_registry.db
``` ```

View file

@ -20,11 +20,14 @@ Please refer to the inline provider documentation.
| `prefix` | `str \| None` | No | | | | `prefix` | `str \| None` | No | | |
| `timeout` | `int \| None` | No | | | | `timeout` | `int \| None` | No | | |
| `host` | `str \| None` | No | | | | `host` | `str \| None` | No | | |
| `kvstore` | `utils.kvstore.config.RedisKVStoreConfig \| utils.kvstore.config.SqliteKVStoreConfig \| utils.kvstore.config.PostgresKVStoreConfig \| utils.kvstore.config.MongoDBKVStoreConfig` | No | sqlite | |
## Sample Configuration ## Sample Configuration
```yaml ```yaml
api_key: ${env.QDRANT_API_KEY} kvstore:
type: sqlite
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/dummy}/qdrant_registry.db
``` ```

View file

@ -5,20 +5,28 @@
# the root directory of this source tree. # the root directory of this source tree.
from typing import Any, Literal from typing import Any
from pydantic import BaseModel from pydantic import BaseModel
from llama_stack.providers.utils.kvstore.config import (
KVStoreConfig,
SqliteKVStoreConfig,
)
from llama_stack.schema_utils import json_schema_type from llama_stack.schema_utils import json_schema_type
@json_schema_type @json_schema_type
class QdrantVectorIOConfig(BaseModel): class QdrantVectorIOConfig(BaseModel):
path: str path: str
distance_metric: Literal["COSINE", "DOT", "EUCLID", "MANHATTAN"] = "COSINE" kvstore: KVStoreConfig
@classmethod @classmethod
def sample_run_config(cls, __distro_dir__: str) -> dict[str, Any]: def sample_run_config(cls, __distro_dir__: str) -> dict[str, Any]:
return { return {
"path": "${env.QDRANT_PATH:=~/.llama/" + __distro_dir__ + "}/" + "qdrant.db", "path": "${env.QDRANT_PATH:=~/.llama/" + __distro_dir__ + "}/" + "qdrant.db",
"kvstore": SqliteKVStoreConfig.sample_run_config(
__distro_dir__=__distro_dir__,
db_name="qdrant_registry.db",
),
} }

View file

@ -192,7 +192,9 @@ class SQLiteVecIndex(EmbeddingIndex):
await asyncio.to_thread(_drop_tables) await asyncio.to_thread(_drop_tables)
async def add_chunks(self, chunks: list[Chunk], embeddings: NDArray, batch_size: int = 500): async def add_chunks(
self, chunks: list[Chunk], embeddings: NDArray, metadata: dict[str, Any] | None = None, batch_size: int = 500
):
""" """
Add new chunks along with their embeddings using batch inserts. Add new chunks along with their embeddings using batch inserts.
For each chunk, we insert its JSON into the metadata table and then insert its For each chunk, we insert its JSON into the metadata table and then insert its

View file

@ -459,6 +459,7 @@ See [Weaviate's documentation](https://weaviate.io/developers/weaviate) for more
module="llama_stack.providers.inline.vector_io.qdrant", module="llama_stack.providers.inline.vector_io.qdrant",
config_class="llama_stack.providers.inline.vector_io.qdrant.QdrantVectorIOConfig", config_class="llama_stack.providers.inline.vector_io.qdrant.QdrantVectorIOConfig",
api_dependencies=[Api.inference], api_dependencies=[Api.inference],
optional_api_dependencies=[Api.files],
description=r""" description=r"""
[Qdrant](https://qdrant.tech/documentation/) is an inline and remote vector database provider for Llama Stack. It [Qdrant](https://qdrant.tech/documentation/) is an inline and remote vector database provider for Llama Stack. It
allows you to store and query vectors directly in memory. allows you to store and query vectors directly in memory.

View file

@ -12,6 +12,7 @@ from .config import QdrantVectorIOConfig
async def get_adapter_impl(config: QdrantVectorIOConfig, deps: dict[Api, ProviderSpec]): async def get_adapter_impl(config: QdrantVectorIOConfig, deps: dict[Api, ProviderSpec]):
from .qdrant import QdrantVectorIOAdapter from .qdrant import QdrantVectorIOAdapter
impl = QdrantVectorIOAdapter(config, deps[Api.inference]) files_api = deps.get(Api.files)
impl = QdrantVectorIOAdapter(config, deps[Api.inference], files_api)
await impl.initialize() await impl.initialize()
return impl return impl

View file

@ -4,10 +4,14 @@
# This source code is licensed under the terms described in the LICENSE file in # This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree. # the root directory of this source tree.
from typing import Any, Literal from typing import Any
from pydantic import BaseModel from pydantic import BaseModel
from llama_stack.providers.utils.kvstore.config import (
KVStoreConfig,
SqliteKVStoreConfig,
)
from llama_stack.schema_utils import json_schema_type from llama_stack.schema_utils import json_schema_type
@ -23,10 +27,13 @@ class QdrantVectorIOConfig(BaseModel):
prefix: str | None = None prefix: str | None = None
timeout: int | None = None timeout: int | None = None
host: str | None = None host: str | None = None
distance_metric: Literal["COSINE", "DOT", "EUCLID", "MANHATTAN"] = "COSINE" kvstore: KVStoreConfig
@classmethod @classmethod
def sample_run_config(cls, **kwargs: Any) -> dict[str, Any]: def sample_run_config(cls, __distro_dir__: str, **kwargs: Any) -> dict[str, Any]:
return { return {
"api_key": "${env.QDRANT_API_KEY}", "kvstore": SqliteKVStoreConfig.sample_run_config(
__distro_dir__=__distro_dir__,
db_name="qdrant_registry.db",
),
} }

View file

@ -18,20 +18,11 @@ from llama_stack.apis.vector_dbs import VectorDB
from llama_stack.apis.vector_io import ( from llama_stack.apis.vector_io import (
Chunk, Chunk,
QueryChunksResponse, QueryChunksResponse,
SearchRankingOptions,
VectorIO, VectorIO,
VectorStoreChunkingStrategy,
VectorStoreDeleteResponse,
VectorStoreFileContentsResponse,
VectorStoreFileObject,
VectorStoreFileStatus,
VectorStoreListFilesResponse,
VectorStoreListResponse,
VectorStoreObject,
VectorStoreSearchResponsePage,
) )
from llama_stack.providers.datatypes import Api, VectorDBsProtocolPrivate from llama_stack.providers.datatypes import Api, VectorDBsProtocolPrivate
from llama_stack.providers.inline.vector_io.qdrant import QdrantVectorIOConfig as InlineQdrantVectorIOConfig from llama_stack.providers.inline.vector_io.qdrant import QdrantVectorIOConfig as InlineQdrantVectorIOConfig
from llama_stack.providers.utils.kvstore import KVStore, kvstore_impl
from llama_stack.providers.utils.memory.openai_vector_store_mixin import OpenAIVectorStoreMixin from llama_stack.providers.utils.memory.openai_vector_store_mixin import OpenAIVectorStoreMixin
from llama_stack.providers.utils.memory.vector_store import ( from llama_stack.providers.utils.memory.vector_store import (
EmbeddingIndex, EmbeddingIndex,
@ -42,7 +33,10 @@ from .config import QdrantVectorIOConfig as RemoteQdrantVectorIOConfig
log = logging.getLogger(__name__) log = logging.getLogger(__name__)
CHUNK_ID_KEY = "_chunk_id" CHUNK_ID_KEY = "_chunk_id"
OPENAI_VECTOR_STORES_METADATA_COLLECTION = "openai_vector_stores_metadata"
# KV store prefixes for vector databases
VERSION = "v3"
VECTOR_DBS_PREFIX = f"vector_dbs:qdrant:{VERSION}::"
def convert_id(_id: str) -> str: def convert_id(_id: str) -> str:
@ -57,10 +51,14 @@ def convert_id(_id: str) -> str:
class QdrantIndex(EmbeddingIndex): class QdrantIndex(EmbeddingIndex):
def __init__(self, client: AsyncQdrantClient, collection_name: str, distance_metric: str = "COSINE"): def __init__(self, client: AsyncQdrantClient, collection_name: str):
self.client = client self.client = client
self.collection_name = collection_name self.collection_name = collection_name
self.distance_metric = distance_metric
async def initialize(self) -> None:
# Qdrant collections are created on-demand in add_chunks
# If the collection does not exist, it will be created in add_chunks.
pass
async def add_chunks(self, chunks: list[Chunk], embeddings: NDArray): async def add_chunks(self, chunks: list[Chunk], embeddings: NDArray):
assert len(chunks) == len(embeddings), ( assert len(chunks) == len(embeddings), (
@ -68,12 +66,9 @@ class QdrantIndex(EmbeddingIndex):
) )
if not await self.client.collection_exists(self.collection_name): if not await self.client.collection_exists(self.collection_name):
# Get distance metric, defaulting to COSINE
distance = getattr(models.Distance, self.distance_metric, models.Distance.COSINE)
await self.client.create_collection( await self.client.create_collection(
self.collection_name, self.collection_name,
vectors_config=models.VectorParams(size=len(embeddings[0]), distance=distance), vectors_config=models.VectorParams(size=len(embeddings[0]), distance=models.Distance.COSINE),
) )
points = [] points = []
@ -152,87 +147,55 @@ class QdrantVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtocolP
self.inference_api = inference_api self.inference_api = inference_api
self.files_api = files_api self.files_api = files_api
self.vector_db_store = None self.vector_db_store = None
self.kvstore: KVStore | None = None
self.openai_vector_stores: dict[str, dict[str, Any]] = {} self.openai_vector_stores: dict[str, dict[str, Any]] = {}
async def initialize(self) -> None: async def initialize(self) -> None:
self.client = AsyncQdrantClient(**self.config.model_dump(exclude_none=True)) # Close existing client if it exists
# Load existing OpenAI vector stores using the mixin method # Qdrant doesn't allow multiple clients to access the same storage path simultaneously
# This prevents "Storage folder is already accessed by another instance" errors during re-initialization
if self.client is not None:
await self.client.close()
self.client = None
# Create client config excluding kvstore (which is used for metadata storage, not Qdrant client connection)
client_config = self.config.model_dump(exclude_none=True, exclude={"kvstore"})
self.client = AsyncQdrantClient(**client_config)
self.kvstore = await kvstore_impl(self.config.kvstore)
# Load existing vector DBs from kvstore
start_key = VECTOR_DBS_PREFIX
end_key = f"{VECTOR_DBS_PREFIX}\xff"
stored_vector_dbs = await self.kvstore.values_in_range(start_key, end_key)
for vector_db_data in stored_vector_dbs:
vector_db = VectorDB.model_validate_json(vector_db_data)
index = VectorDBWithIndex(
vector_db,
QdrantIndex(self.client, vector_db.identifier),
self.inference_api,
)
self.cache[vector_db.identifier] = index
# Load OpenAI vector stores as before
self.openai_vector_stores = await self._load_openai_vector_stores() self.openai_vector_stores = await self._load_openai_vector_stores()
async def shutdown(self) -> None: async def shutdown(self) -> None:
await self.client.close() await self.client.close()
# OpenAI Vector Store Mixin abstract method implementations
async def _save_openai_vector_store(self, store_id: str, store_info: dict[str, Any]) -> None:
"""Save vector store metadata to Qdrant collection metadata."""
# Store metadata in a special collection for vector store metadata
metadata_collection = OPENAI_VECTOR_STORES_METADATA_COLLECTION
# Create metadata collection if it doesn't exist
if not await self.client.collection_exists(metadata_collection):
# Get distance metric from config, defaulting to COSINE for backward compatibility
distance_metric = getattr(self.config, "distance_metric", "COSINE")
distance = getattr(models.Distance, distance_metric, models.Distance.COSINE)
await self.client.create_collection(
collection_name=metadata_collection,
vectors_config=models.VectorParams(size=1, distance=distance),
)
# Store metadata as a point with dummy vector
await self.client.upsert(
collection_name=metadata_collection,
points=[
models.PointStruct(
id=convert_id(store_id),
vector=[0.0], # Dummy vector
payload={"metadata": store_info},
)
],
)
async def _load_openai_vector_stores(self) -> dict[str, dict[str, Any]]:
"""Load all vector store metadata from Qdrant."""
metadata_collection = OPENAI_VECTOR_STORES_METADATA_COLLECTION
if not await self.client.collection_exists(metadata_collection):
return {}
# Get all points from metadata collection
points = await self.client.scroll(
collection_name=metadata_collection,
limit=1000, # Reasonable limit for metadata
with_payload=True,
)
stores = {}
for point in points[0]: # points[0] contains the actual points
if point.payload and "metadata" in point.payload:
store_info = point.payload["metadata"]
stores[store_info["id"]] = store_info
return stores
async def _update_openai_vector_store(self, store_id: str, store_info: dict[str, Any]) -> None:
"""Update vector store metadata in Qdrant."""
await self._save_openai_vector_store(store_id, store_info)
async def _delete_openai_vector_store_from_storage(self, store_id: str) -> None:
"""Delete vector store metadata from Qdrant."""
metadata_collection = OPENAI_VECTOR_STORES_METADATA_COLLECTION
if await self.client.collection_exists(metadata_collection):
await self.client.delete(
collection_name=metadata_collection, points_selector=models.PointIdsList(points=[convert_id(store_id)])
)
async def register_vector_db( async def register_vector_db(
self, self,
vector_db: VectorDB, vector_db: VectorDB,
) -> None: ) -> None:
# Save to kvstore
assert self.kvstore is not None
key = f"{VECTOR_DBS_PREFIX}{vector_db.identifier}"
await self.kvstore.set(key=key, value=vector_db.model_dump_json())
# Store in cache
index = VectorDBWithIndex( index = VectorDBWithIndex(
vector_db=vector_db, vector_db=vector_db,
index=QdrantIndex(self.client, vector_db.identifier, self.config.distance_metric), index=QdrantIndex(self.client, vector_db.identifier),
inference_api=self.inference_api, inference_api=self.inference_api,
) )
@ -243,19 +206,24 @@ class QdrantVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtocolP
await self.cache[vector_db_id].index.delete() await self.cache[vector_db_id].index.delete()
del self.cache[vector_db_id] del self.cache[vector_db_id]
# Remove from kvstore
assert self.kvstore is not None
await self.kvstore.delete(f"{VECTOR_DBS_PREFIX}{vector_db_id}")
async def _get_and_cache_vector_db_index(self, vector_db_id: str) -> VectorDBWithIndex | None: async def _get_and_cache_vector_db_index(self, vector_db_id: str) -> VectorDBWithIndex | None:
if vector_db_id in self.cache: if vector_db_id in self.cache:
return self.cache[vector_db_id] return self.cache[vector_db_id]
if self.vector_db_store is None:
raise ValueError(f"Vector DB {vector_db_id} not found")
vector_db = await self.vector_db_store.get_vector_db(vector_db_id) vector_db = await self.vector_db_store.get_vector_db(vector_db_id)
if not vector_db: if not vector_db:
raise ValueError(f"Vector DB {vector_db_id} not found") raise ValueError(f"Vector DB {vector_db_id} not found")
index = VectorDBWithIndex( index = VectorDBWithIndex(
vector_db=vector_db, vector_db=vector_db,
index=QdrantIndex( index=QdrantIndex(client=self.client, collection_name=vector_db.identifier),
client=self.client, collection_name=vector_db.identifier, distance_metric=self.config.distance_metric
),
inference_api=self.inference_api, inference_api=self.inference_api,
) )
self.cache[vector_db_id] = index self.cache[vector_db_id] = index
@ -270,7 +238,6 @@ class QdrantVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtocolP
index = await self._get_and_cache_vector_db_index(vector_db_id) index = await self._get_and_cache_vector_db_index(vector_db_id)
if not index: if not index:
raise ValueError(f"Vector DB {vector_db_id} not found") raise ValueError(f"Vector DB {vector_db_id} not found")
await index.insert_chunks(chunks) await index.insert_chunks(chunks)
async def query_chunks( async def query_chunks(
@ -284,107 +251,3 @@ class QdrantVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtocolP
raise ValueError(f"Vector DB {vector_db_id} not found") raise ValueError(f"Vector DB {vector_db_id} not found")
return await index.query_chunks(query, params) return await index.query_chunks(query, params)
async def openai_create_vector_store(
self,
name: str,
file_ids: list[str] | None = None,
expires_after: dict[str, Any] | None = None,
chunking_strategy: dict[str, Any] | None = None,
metadata: dict[str, Any] | None = None,
embedding_model: str | None = None,
embedding_dimension: int | None = 384,
provider_id: str | None = None,
) -> VectorStoreObject:
raise NotImplementedError("OpenAI Vector Stores API is not supported in Qdrant")
async def openai_list_vector_stores(
self,
limit: int | None = 20,
order: str | None = "desc",
after: str | None = None,
before: str | None = None,
) -> VectorStoreListResponse:
raise NotImplementedError("OpenAI Vector Stores API is not supported in Qdrant")
async def openai_retrieve_vector_store(
self,
vector_store_id: str,
) -> VectorStoreObject:
raise NotImplementedError("OpenAI Vector Stores API is not supported in Qdrant")
async def openai_update_vector_store(
self,
vector_store_id: str,
name: str | None = None,
expires_after: dict[str, Any] | None = None,
metadata: dict[str, Any] | None = None,
) -> VectorStoreObject:
raise NotImplementedError("OpenAI Vector Stores API is not supported in Qdrant")
async def openai_delete_vector_store(
self,
vector_store_id: str,
) -> VectorStoreDeleteResponse:
raise NotImplementedError("OpenAI Vector Stores API is not supported in Qdrant")
async def openai_search_vector_store(
self,
vector_store_id: str,
query: str | list[str],
filters: dict[str, Any] | None = None,
max_num_results: int | None = 10,
ranking_options: SearchRankingOptions | None = None,
rewrite_query: bool | None = False,
search_mode: str | None = "vector",
) -> VectorStoreSearchResponsePage:
raise NotImplementedError("OpenAI Vector Stores API is not supported in Qdrant")
async def openai_attach_file_to_vector_store(
self,
vector_store_id: str,
file_id: str,
attributes: dict[str, Any] | None = None,
chunking_strategy: VectorStoreChunkingStrategy | None = None,
) -> VectorStoreFileObject:
raise NotImplementedError("OpenAI Vector Stores API is not supported in Qdrant")
async def openai_list_files_in_vector_store(
self,
vector_store_id: str,
limit: int | None = 20,
order: str | None = "desc",
after: str | None = None,
before: str | None = None,
filter: VectorStoreFileStatus | None = None,
) -> VectorStoreListFilesResponse:
raise NotImplementedError("OpenAI Vector Stores API is not supported in Qdrant")
async def openai_retrieve_vector_store_file(
self,
vector_store_id: str,
file_id: str,
) -> VectorStoreFileObject:
raise NotImplementedError("OpenAI Vector Stores API is not supported in Qdrant")
async def openai_retrieve_vector_store_file_contents(
self,
vector_store_id: str,
file_id: str,
) -> VectorStoreFileContentsResponse:
raise NotImplementedError("OpenAI Vector Stores API is not supported in Qdrant")
async def openai_update_vector_store_file(
self,
vector_store_id: str,
file_id: str,
attributes: dict[str, Any] | None = None,
) -> VectorStoreFileObject:
raise NotImplementedError("OpenAI Vector Stores API is not supported in Qdrant")
async def openai_delete_vector_store_file(
self,
vector_store_id: str,
file_id: str,
) -> VectorStoreFileObject:
raise NotImplementedError("OpenAI Vector Stores API is not supported in Qdrant")

View file

@ -4,7 +4,6 @@
# This source code is licensed under the terms described in the LICENSE file in # This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree. # the root directory of this source tree.
import asyncio
import json import json
import logging import logging
import mimetypes import mimetypes
@ -259,8 +258,9 @@ class OpenAIVectorStoreMixin(ABC):
# Now that our vector store is created, attach any files that were provided # Now that our vector store is created, attach any files that were provided
file_ids = file_ids or [] file_ids = file_ids or []
tasks = [self.openai_attach_file_to_vector_store(vector_db_id, file_id) for file_id in file_ids] # Process files sequentially to avoid concurrency issues with some vector store providers like qdrant.
await asyncio.gather(*tasks) for file_id in file_ids:
await self.openai_attach_file_to_vector_store(vector_db_id, file_id)
# Get the updated store info and return it # Get the updated store info and return it
store_info = self.openai_vector_stores[vector_db_id] store_info = self.openai_vector_stores[vector_db_id]

View file

@ -22,7 +22,14 @@ logger = logging.getLogger(__name__)
def skip_if_provider_doesnt_support_openai_vector_stores(client_with_models): def skip_if_provider_doesnt_support_openai_vector_stores(client_with_models):
vector_io_providers = [p for p in client_with_models.providers.list() if p.api == "vector_io"] vector_io_providers = [p for p in client_with_models.providers.list() if p.api == "vector_io"]
for p in vector_io_providers: for p in vector_io_providers:
if p.provider_type in ["inline::faiss", "inline::sqlite-vec", "inline::milvus", "inline::qdrant", "inline::chromadb"]: if p.provider_type in [
"inline::faiss",
"inline::sqlite-vec",
"inline::milvus",
"inline::qdrant",
"inline::chromadb",
"remote::qdrant",
]:
return return
pytest.skip("OpenAI vector stores are not supported by any provider") pytest.skip("OpenAI vector stores are not supported by any provider")
@ -31,7 +38,14 @@ def skip_if_provider_doesnt_support_openai_vector_stores(client_with_models):
def skip_if_provider_doesnt_support_openai_vector_store_files_api(client_with_models): def skip_if_provider_doesnt_support_openai_vector_store_files_api(client_with_models):
vector_io_providers = [p for p in client_with_models.providers.list() if p.api == "vector_io"] vector_io_providers = [p for p in client_with_models.providers.list() if p.api == "vector_io"]
for p in vector_io_providers: for p in vector_io_providers:
if p.provider_type in ["inline::faiss", "inline::sqlite-vec", "inline::milvus", "inline::qdrant", "remote::pgvector"]: if p.provider_type in [
"inline::faiss",
"inline::sqlite-vec",
"inline::milvus",
"inline::qdrant",
"remote::pgvector",
"remote::qdrant",
]:
return return
pytest.skip("OpenAI vector stores are not supported by any provider") pytest.skip("OpenAI vector stores are not supported by any provider")

View file

@ -125,6 +125,8 @@ def test_insert_chunks(client_with_empty_registry, embedding_model_id, embedding
def test_insert_chunks_with_precomputed_embeddings(client_with_empty_registry, embedding_model_id, embedding_dimension): def test_insert_chunks_with_precomputed_embeddings(client_with_empty_registry, embedding_model_id, embedding_dimension):
vector_io_provider_params_dict = { vector_io_provider_params_dict = {
"inline::milvus": {"score_threshold": -1.0}, "inline::milvus": {"score_threshold": -1.0},
"remote::qdrant": {"score_threshold": -1.0},
"inline::qdrant": {"score_threshold": -1.0},
} }
vector_db_id = "test_precomputed_embeddings_db" vector_db_id = "test_precomputed_embeddings_db"
client_with_empty_registry.vector_dbs.register( client_with_empty_registry.vector_dbs.register(
@ -168,6 +170,8 @@ def test_query_returns_valid_object_when_identical_to_embedding_in_vdb(
): ):
vector_io_provider_params_dict = { vector_io_provider_params_dict = {
"inline::milvus": {"score_threshold": 0.0}, "inline::milvus": {"score_threshold": 0.0},
"remote::qdrant": {"score_threshold": 0.0},
"inline::qdrant": {"score_threshold": 0.0},
} }
vector_db_id = "test_precomputed_embeddings_db" vector_db_id = "test_precomputed_embeddings_db"
client_with_empty_registry.vector_dbs.register( client_with_empty_registry.vector_dbs.register(

View file

@ -16,10 +16,12 @@ from llama_stack.providers.inline.vector_io.chroma.config import ChromaVectorIOC
from llama_stack.providers.inline.vector_io.faiss.config import FaissVectorIOConfig from llama_stack.providers.inline.vector_io.faiss.config import FaissVectorIOConfig
from llama_stack.providers.inline.vector_io.faiss.faiss import FaissIndex, FaissVectorIOAdapter from llama_stack.providers.inline.vector_io.faiss.faiss import FaissIndex, FaissVectorIOAdapter
from llama_stack.providers.inline.vector_io.milvus.config import MilvusVectorIOConfig, SqliteKVStoreConfig from llama_stack.providers.inline.vector_io.milvus.config import MilvusVectorIOConfig, SqliteKVStoreConfig
from llama_stack.providers.inline.vector_io.qdrant import QdrantVectorIOConfig
from llama_stack.providers.inline.vector_io.sqlite_vec import SQLiteVectorIOConfig from llama_stack.providers.inline.vector_io.sqlite_vec import SQLiteVectorIOConfig
from llama_stack.providers.inline.vector_io.sqlite_vec.sqlite_vec import SQLiteVecIndex, SQLiteVecVectorIOAdapter from llama_stack.providers.inline.vector_io.sqlite_vec.sqlite_vec import SQLiteVecIndex, SQLiteVecVectorIOAdapter
from llama_stack.providers.remote.vector_io.chroma.chroma import ChromaIndex, ChromaVectorIOAdapter from llama_stack.providers.remote.vector_io.chroma.chroma import ChromaIndex, ChromaVectorIOAdapter
from llama_stack.providers.remote.vector_io.milvus.milvus import MilvusIndex, MilvusVectorIOAdapter from llama_stack.providers.remote.vector_io.milvus.milvus import MilvusIndex, MilvusVectorIOAdapter
from llama_stack.providers.remote.vector_io.qdrant.qdrant import QdrantVectorIOAdapter
EMBEDDING_DIMENSION = 384 EMBEDDING_DIMENSION = 384
COLLECTION_PREFIX = "test_collection" COLLECTION_PREFIX = "test_collection"
@ -94,7 +96,7 @@ def sample_embeddings_with_metadata(sample_chunks_with_metadata):
return np.array([np.random.rand(EMBEDDING_DIMENSION).astype(np.float32) for _ in sample_chunks_with_metadata]) return np.array([np.random.rand(EMBEDDING_DIMENSION).astype(np.float32) for _ in sample_chunks_with_metadata])
@pytest.fixture(params=["milvus", "sqlite_vec", "faiss"]) @pytest.fixture(params=["milvus", "sqlite_vec", "faiss", "chroma", "qdrant"])
def vector_provider(request): def vector_provider(request):
return request.param return request.param
@ -133,7 +135,7 @@ async def sqlite_vec_vec_index(embedding_dimension, tmp_path_factory):
await index.initialize() await index.initialize()
index.db_path = db_path index.db_path = db_path
yield index yield index
index.delete() await index.delete()
@pytest.fixture @pytest.fixture
@ -276,14 +278,66 @@ async def chroma_vec_adapter(chroma_vec_db_path, mock_inference_api, embedding_d
await adapter.shutdown() await adapter.shutdown()
@pytest.fixture
def qdrant_vec_db_path(tmp_path_factory):
import uuid
db_path = str(tmp_path_factory.getbasetemp() / f"test_qdrant_{uuid.uuid4()}.db")
return db_path
@pytest.fixture
async def qdrant_vec_adapter(qdrant_vec_db_path, mock_inference_api, embedding_dimension):
import uuid
config = QdrantVectorIOConfig(
path=qdrant_vec_db_path,
kvstore=SqliteKVStoreConfig(),
)
adapter = QdrantVectorIOAdapter(
config=config,
inference_api=mock_inference_api,
files_api=None,
)
collection_id = f"qdrant_test_collection_{uuid.uuid4()}"
await adapter.initialize()
await adapter.register_vector_db(
VectorDB(
identifier=collection_id,
provider_id="test_provider",
embedding_model="test_model",
embedding_dimension=embedding_dimension,
)
)
adapter.test_collection_id = collection_id
yield adapter
await adapter.shutdown()
@pytest.fixture
async def qdrant_vec_index(qdrant_vec_db_path, embedding_dimension):
import uuid
from qdrant_client import AsyncQdrantClient
from llama_stack.providers.remote.vector_io.qdrant.qdrant import QdrantIndex
client = AsyncQdrantClient(path=qdrant_vec_db_path)
collection_name = f"qdrant_test_collection_{uuid.uuid4()}"
index = QdrantIndex(client, collection_name)
yield index
await index.delete()
@pytest.fixture @pytest.fixture
def vector_io_adapter(vector_provider, request): def vector_io_adapter(vector_provider, request):
"""Returns the appropriate vector IO adapter based on the provider parameter.""" """Returns the appropriate vector IO adapter based on the provider parameter."""
vector_provider_dict = { vector_provider_dict = {
"milvus": "milvus_vec_adapter", "milvus": "milvus_vec_adapter",
"faiss": "faiss_vec_adapter", "faiss": "faiss_vec_adapter",
"sqlite_vec": "sqlite_vec_adapter", "qdrant": "qdrant_vec_adapter",
"chroma": "chroma_vec_adapter", "chroma": "chroma_vec_adapter",
"sqlite_vec": "sqlite_vec_adapter",
} }
return request.getfixturevalue(vector_provider_dict[vector_provider]) return request.getfixturevalue(vector_provider_dict[vector_provider])

View file

@ -23,6 +23,7 @@ from llama_stack.providers.inline.vector_io.qdrant.config import (
from llama_stack.providers.remote.vector_io.qdrant.qdrant import ( from llama_stack.providers.remote.vector_io.qdrant.qdrant import (
QdrantVectorIOAdapter, QdrantVectorIOAdapter,
) )
from llama_stack.providers.utils.kvstore.config import SqliteKVStoreConfig
# This test is a unit test for the QdrantVectorIOAdapter class. This should only contain # This test is a unit test for the QdrantVectorIOAdapter class. This should only contain
# tests which are specific to this class. More general (API-level) tests should be placed in # tests which are specific to this class. More general (API-level) tests should be placed in
@ -36,7 +37,9 @@ from llama_stack.providers.remote.vector_io.qdrant.qdrant import (
@pytest.fixture @pytest.fixture
def qdrant_config(tmp_path) -> InlineQdrantVectorIOConfig: def qdrant_config(tmp_path) -> InlineQdrantVectorIOConfig:
return InlineQdrantVectorIOConfig(path=os.path.join(tmp_path, "qdrant.db")) kvstore_config = SqliteKVStoreConfig(db_name=os.path.join(tmp_path, "test_kvstore.db"))
return InlineQdrantVectorIOConfig(path=os.path.join(tmp_path, "qdrant.db"), kvstore=kvstore_config)
@pytest.fixture(scope="session") @pytest.fixture(scope="session")
@ -50,6 +53,10 @@ def mock_vector_db(vector_db_id) -> MagicMock:
mock_vector_db.embedding_model = "embedding_model" mock_vector_db.embedding_model = "embedding_model"
mock_vector_db.identifier = vector_db_id mock_vector_db.identifier = vector_db_id
mock_vector_db.embedding_dimension = 384 mock_vector_db.embedding_dimension = 384
# Mock model_dump_json to return a proper JSON string for kvstore persistence
mock_vector_db.model_dump_json.return_value = (
'{"identifier": "' + vector_db_id + '", "embedding_model": "embedding_model", "embedding_dimension": 384}'
)
return mock_vector_db return mock_vector_db

View file

@ -30,12 +30,12 @@ async def test_initialize_index(vector_index):
async def test_add_chunks_query_vector(vector_index, sample_chunks, sample_embeddings): async def test_add_chunks_query_vector(vector_index, sample_chunks, sample_embeddings):
vector_index.delete() await vector_index.delete()
vector_index.initialize() await vector_index.initialize()
await vector_index.add_chunks(sample_chunks, sample_embeddings) await vector_index.add_chunks(sample_chunks, sample_embeddings)
resp = await vector_index.query_vector(sample_embeddings[0], k=1, score_threshold=-1) resp = await vector_index.query_vector(sample_embeddings[0], k=1, score_threshold=-1)
assert resp.chunks[0].content == sample_chunks[0].content assert resp.chunks[0].content == sample_chunks[0].content
vector_index.delete() await vector_index.delete()
async def test_chunk_id_conflict(vector_index, sample_chunks, embedding_dimension): async def test_chunk_id_conflict(vector_index, sample_chunks, embedding_dimension):