llama-stack/llama_stack/providers/remote/vector_io/qdrant/qdrant.py
Bill Murdock 3856927ee8
fix: Update Qdrant support post-refactor (#1022)
# What does this PR do?

I tried running the Qdrant provider and found some bugs. See #1021 for
details. @terrytangyuan wrote there:

> Please feel free to submit your changes in a PR. I fixed similar
issues for pgvector provider. This might be an issue introduced from a
refactoring.

So I am submitting this PR.

Closes #1021

## Test Plan

Here are the highlights for what I did to test this:

References:
-
https://llama-stack.readthedocs.io/en/latest/getting_started/index.html
-
https://github.com/meta-llama/llama-stack-apps/blob/main/examples/agents/rag_with_vector_db.py
-
https://github.com/meta-llama/llama-stack/blob/main/docs/zero_to_hero_guide/README.md#build-configure-and-run-llama-stack

Install and run Qdrant server:

```
podman pull qdrant/qdrant
mkdir qdrant-data
podman run -p 6333:6333 -v $(pwd)/qdrant-data:/qdrant/storage qdrant/qdrant
```

Install and run Llama Stack from the venv-support PR (mainly because I
didn't want to install conda):

```
brew install cmake # Should just need this once

git clone https://github.com/meta-llama/llama-models.git
gh repo clone cdoern/llama-stack
cd llama-stack
gh pr checkout 1018 # This is the checkout that introduces venv support for build/run.  Otherwise you have to use conda.  Eventually this wil be part of main, hopefully.

uv sync --extra dev
uv pip install -e .
source .venv/bin/activate
uv pip install qdrant_client

LLAMA_STACK_DIR=$(pwd) LLAMA_MODELS_DIR=../llama-models llama stack build --template ollama --image-type venv
```
```
edit llama_stack/templates/ollama/run.yaml
```

in that editor under:
```
  vector_io:
```
add:
```
  - provider_id: qdrant
    provider_type: remote::qdrant
    config: {}
```

see
https://github.com/meta-llama/llama-stack/blob/main/llama_stack/providers/remote/vector_io/qdrant/config.py#L14
for config options (but I didn't need any)

```
LLAMA_STACK_DIR=$(pwd) LLAMA_MODELS_DIR=../llama-models llama stack run ollama --image-type venv \
   --port $LLAMA_STACK_PORT \
   --env INFERENCE_MODEL=$INFERENCE_MODEL \
   --env SAFETY_MODEL=$SAFETY_MODEL \
   --env OLLAMA_URL=$OLLAMA_URL
```

Then I tested it out in a notebook.  Key highlights included:

```
qdrant_provider = None
for provider in client.providers.list():
    if provider.api == "vector_io" and provider.provider_id == "qdrant":
        qdrant_provider = provider
qdrant_provider
assert qdrant_provider is not None, "QDrant is not a provider.  You need to edit the run yaml file you use in your `llama stack run` call"

vector_db_id = f"test-vector-db-{uuid.uuid4().hex}"
client.vector_dbs.register(
    vector_db_id=vector_db_id,
    embedding_model="all-MiniLM-L6-v2",
    embedding_dimension=384,
    provider_id=qdrant_provider.provider_id,
)
```

Other than that, I just followed what was in
https://llama-stack.readthedocs.io/en/latest/getting_started/index.html

It would be good to have automated tests for this in the future, but
that would be a big undertaking.

Signed-off-by: Bill Murdock <bmurdock@redhat.com>
2025-02-10 18:08:33 -05:00

163 lines
5.5 KiB
Python

# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import logging
import uuid
from typing import Any, Dict, List, Optional
from numpy.typing import NDArray
from qdrant_client import AsyncQdrantClient, models
from qdrant_client.models import PointStruct
from llama_stack.apis.inference import InterleavedContent
from llama_stack.apis.vector_dbs import VectorDB
from llama_stack.apis.vector_io import Chunk, QueryChunksResponse, VectorIO
from llama_stack.providers.datatypes import Api, VectorDBsProtocolPrivate
from llama_stack.providers.utils.memory.vector_store import (
EmbeddingIndex,
VectorDBWithIndex,
)
from .config import QdrantConfig
log = logging.getLogger(__name__)
CHUNK_ID_KEY = "_chunk_id"
def convert_id(_id: str) -> str:
"""
Converts any string into a UUID string based on a seed.
Qdrant accepts UUID strings and unsigned integers as point ID.
We use a seed to convert each string into a UUID string deterministically.
This allows us to overwrite the same point with the original ID.
"""
return str(uuid.uuid5(uuid.NAMESPACE_DNS, _id))
class QdrantIndex(EmbeddingIndex):
def __init__(self, client: AsyncQdrantClient, collection_name: str):
self.client = client
self.collection_name = collection_name
async def add_chunks(self, chunks: List[Chunk], embeddings: NDArray):
assert len(chunks) == len(embeddings), (
f"Chunk length {len(chunks)} does not match embedding length {len(embeddings)}"
)
if not await self.client.collection_exists(self.collection_name):
await self.client.create_collection(
self.collection_name,
vectors_config=models.VectorParams(size=len(embeddings[0]), distance=models.Distance.COSINE),
)
points = []
for i, (chunk, embedding) in enumerate(zip(chunks, embeddings)):
chunk_id = f"{chunk.metadata['document_id']}:chunk-{i}"
points.append(
PointStruct(
id=convert_id(chunk_id),
vector=embedding,
payload={"chunk_content": chunk.model_dump()} | {CHUNK_ID_KEY: chunk_id},
)
)
await self.client.upsert(collection_name=self.collection_name, points=points)
async def query(self, embedding: NDArray, k: int, score_threshold: float) -> QueryChunksResponse:
results = (
await self.client.query_points(
collection_name=self.collection_name,
query=embedding.tolist(),
limit=k,
with_payload=True,
score_threshold=score_threshold,
)
).points
chunks, scores = [], []
for point in results:
assert isinstance(point, models.ScoredPoint)
assert point.payload is not None
try:
chunk = Chunk(**point.payload["chunk_content"])
except Exception:
log.exception("Failed to parse chunk")
continue
chunks.append(chunk)
scores.append(point.score)
return QueryChunksResponse(chunks=chunks, scores=scores)
async def delete(self):
await self.client.delete_collection(collection_name=self.collection_name)
class QdrantVectorDBAdapter(VectorIO, VectorDBsProtocolPrivate):
def __init__(self, config: QdrantConfig, inference_api: Api.inference) -> None:
self.config = config
self.client = AsyncQdrantClient(**self.config.model_dump(exclude_none=True))
self.cache = {}
self.inference_api = inference_api
async def initialize(self) -> None:
pass
async def shutdown(self) -> None:
self.client.close()
async def register_vector_db(
self,
vector_db: VectorDB,
) -> None:
index = VectorDBWithIndex(
vector_db=vector_db,
index=QdrantIndex(self.client, vector_db.identifier),
inference_api=self.inference_api,
)
self.cache[vector_db.identifier] = index
async def _get_and_cache_vector_db_index(self, vector_db_id: str) -> Optional[VectorDBWithIndex]:
if vector_db_id in self.cache:
return self.cache[vector_db_id]
vector_db = await self.vector_db_store.get_vector_db(vector_db_id)
if not vector_db:
raise ValueError(f"Vector DB {vector_db_id} not found")
index = VectorDBWithIndex(
vector_db=vector_db,
index=QdrantIndex(client=self.client, collection_name=vector_db.identifier),
inference_api=self.inference_api,
)
self.cache[vector_db_id] = index
return index
async def insert_chunks(
self,
vector_db_id: str,
chunks: List[Chunk],
ttl_seconds: Optional[int] = None,
) -> None:
index = await self._get_and_cache_vector_db_index(vector_db_id)
if not index:
raise ValueError(f"Vector DB {vector_db_id} not found")
await index.insert_chunks(chunks)
async def query_chunks(
self,
vector_db_id: str,
query: InterleavedContent,
params: Optional[Dict[str, Any]] = None,
) -> QueryChunksResponse:
index = await self._get_and_cache_vector_db_index(vector_db_id)
if not index:
raise ValueError(f"Vector DB {vector_db_id} not found")
return await index.query_chunks(query, params)