feat: Qdrant Vector index support (#221)

This PR adds support for Qdrant - https://qdrant.tech/ to be used as a vector memory.

I've unit-tested the methods to confirm that they work as intended.

To run Qdrant

```
docker run -p 6333:6333 qdrant/qdrant
```
This commit is contained in:
Anush 2024-10-23 01:20:19 +05:30 committed by GitHub
parent 668a495aba
commit 4c3d33e6f4
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
11 changed files with 242 additions and 7 deletions

1
.gitignore vendored
View file

@ -13,5 +13,6 @@ xcuserdata/
Package.resolved
*.pte
*.ipynb_checkpoints*
.venv/
.idea
_build

View file

@ -38,7 +38,9 @@ class ChromaIndex(EmbeddingIndex):
ids=[f"{c.document_id}:chunk-{i}" for i, c in enumerate(chunks)],
)
async def query(self, embedding: NDArray, k: int) -> QueryDocumentsResponse:
async def query(
self, embedding: NDArray, k: int, score_threshold: float
) -> QueryDocumentsResponse:
results = await self.collection.query(
query_embeddings=[embedding.tolist()],
n_results=k,

View file

@ -91,7 +91,9 @@ class PGVectorIndex(EmbeddingIndex):
)
execute_values(self.cursor, query, values, template="(%s, %s, %s::vector)")
async def query(self, embedding: NDArray, k: int) -> QueryDocumentsResponse:
async def query(
self, embedding: NDArray, k: int, score_threshold: float
) -> QueryDocumentsResponse:
self.cursor.execute(
f"""
SELECT document, embedding <-> %s::vector AS distance

View file

@ -0,0 +1,15 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from .config import QdrantConfig
async def get_adapter_impl(config: QdrantConfig, _deps):
from .qdrant import QdrantVectorMemoryAdapter
impl = QdrantVectorMemoryAdapter(config)
await impl.initialize()
return impl

View file

@ -0,0 +1,25 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Optional
from llama_models.schema_utils import json_schema_type
from pydantic import BaseModel
@json_schema_type
class QdrantConfig(BaseModel):
location: Optional[str] = None
url: Optional[str] = None
port: Optional[int] = 6333
grpc_port: int = 6334
prefer_grpc: bool = False
https: Optional[bool] = None
api_key: Optional[str] = None
prefix: Optional[str] = None
timeout: Optional[int] = None
host: Optional[str] = None
path: Optional[str] = None

View file

@ -0,0 +1,170 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import traceback
import uuid
from typing import Any, Dict, List
from numpy.typing import NDArray
from qdrant_client import AsyncQdrantClient, models
from qdrant_client.models import PointStruct
from llama_stack.providers.datatypes import MemoryBanksProtocolPrivate
from llama_stack.apis.memory import * # noqa: F403
from llama_stack.providers.adapters.memory.qdrant.config import QdrantConfig
from llama_stack.providers.utils.memory.vector_store import (
BankWithIndex,
EmbeddingIndex,
)
CHUNK_ID_KEY = "_chunk_id"
def convert_id(_id: str) -> str:
"""
Converts any string into a UUID string based on a seed.
Qdrant accepts UUID strings and unsigned integers as point ID.
We use a seed to convert each string into a UUID string deterministically.
This allows us to overwrite the same point with the original ID.
"""
return str(uuid.uuid5(uuid.NAMESPACE_DNS, _id))
class QdrantIndex(EmbeddingIndex):
def __init__(self, client: AsyncQdrantClient, collection_name: str):
self.client = client
self.collection_name = collection_name
async def add_chunks(self, chunks: List[Chunk], embeddings: NDArray):
assert len(chunks) == len(
embeddings
), f"Chunk length {len(chunks)} does not match embedding length {len(embeddings)}"
if not await self.client.collection_exists(self.collection_name):
await self.client.create_collection(
self.collection_name,
vectors_config=models.VectorParams(
size=len(embeddings[0]), distance=models.Distance.COSINE
),
)
points = []
for i, (chunk, embedding) in enumerate(zip(chunks, embeddings)):
chunk_id = f"{chunk.document_id}:chunk-{i}"
points.append(
PointStruct(
id=convert_id(chunk_id),
vector=embedding,
payload={"chunk_content": chunk.model_dump()}
| {CHUNK_ID_KEY: chunk_id},
)
)
await self.client.upsert(collection_name=self.collection_name, points=points)
async def query(
self, embedding: NDArray, k: int, score_threshold: float
) -> QueryDocumentsResponse:
results = (
await self.client.query_points(
collection_name=self.collection_name,
query=embedding.tolist(),
limit=k,
with_payload=True,
score_threshold=score_threshold,
)
).points
chunks, scores = [], []
for point in results:
assert isinstance(point, models.ScoredPoint)
assert point.payload is not None
try:
chunk = Chunk(**point.payload["chunk_content"])
except Exception:
traceback.print_exc()
continue
chunks.append(chunk)
scores.append(point.score)
return QueryDocumentsResponse(chunks=chunks, scores=scores)
class QdrantVectorMemoryAdapter(Memory, MemoryBanksProtocolPrivate):
def __init__(self, config: QdrantConfig) -> None:
self.config = config
self.client = AsyncQdrantClient(**self.config.model_dump(exclude_none=True))
self.cache = {}
async def initialize(self) -> None:
pass
async def shutdown(self) -> None:
self.client.close()
async def register_memory_bank(
self,
memory_bank: MemoryBankDef,
) -> None:
assert (
memory_bank.type == MemoryBankType.vector.value
), f"Only vector banks are supported {memory_bank.type}"
index = BankWithIndex(
bank=memory_bank,
index=QdrantIndex(self.client, memory_bank.identifier),
)
self.cache[memory_bank.identifier] = index
async def list_memory_banks(self) -> List[MemoryBankDef]:
# Qdrant doesn't have collection level metadata to store the bank properties
# So we only return from the cache value
return [i.bank for i in self.cache.values()]
async def _get_and_cache_bank_index(self, bank_id: str) -> Optional[BankWithIndex]:
if bank_id in self.cache:
return self.cache[bank_id]
bank = await self.memory_bank_store.get_memory_bank(bank_id)
if not bank:
raise ValueError(f"Bank {bank_id} not found")
index = BankWithIndex(
bank=bank,
index=QdrantIndex(client=self.client, collection_name=bank_id),
)
self.cache[bank_id] = index
return index
async def insert_documents(
self,
bank_id: str,
documents: List[MemoryBankDocument],
ttl_seconds: Optional[int] = None,
) -> None:
index = await self._get_and_cache_bank_index(bank_id)
if not index:
raise ValueError(f"Bank {bank_id} not found")
await index.insert_documents(documents)
async def query_documents(
self,
bank_id: str,
query: InterleavedTextMedia,
params: Optional[Dict[str, Any]] = None,
) -> QueryDocumentsResponse:
index = await self._get_and_cache_bank_index(bank_id)
if not index:
raise ValueError(f"Bank {bank_id} not found")
return await index.query_documents(query, params)

View file

@ -50,7 +50,9 @@ class WeaviateIndex(EmbeddingIndex):
# TODO: make this async friendly
collection.data.insert_many(data_objects)
async def query(self, embedding: NDArray, k: int) -> QueryDocumentsResponse:
async def query(
self, embedding: NDArray, k: int, score_threshold: float
) -> QueryDocumentsResponse:
collection = self.client.collections.get(self.collection_name)
results = collection.query.near_vector(

View file

@ -75,4 +75,13 @@ def available_providers() -> List[ProviderSpec]:
config_class="llama_stack.providers.adapters.memory.sample.SampleConfig",
),
),
remote_provider_spec(
Api.memory,
AdapterSpec(
adapter_type="qdrant",
pip_packages=EMBEDDING_DEPS + ["qdrant-client"],
module="llama_stack.providers.adapters.memory.qdrant",
config_class="llama_stack.providers.adapters.memory.qdrant.QdrantConfig",
),
),
]

View file

@ -15,6 +15,11 @@ providers:
- provider_id: test-weaviate
provider_type: remote::weaviate
config: {}
- provider_id: test-qdrant
provider_type: remote::qdrant
config:
host: localhost
port: 6333
# if a provider needs private keys from the client, they use the
# "get_request_provider_data" function (see distribution/request_headers.py)
# this is a place to provide such data.

View file

@ -144,10 +144,11 @@ async def test_query_documents(memory_settings, sample_documents):
# Test case 5: Query with threshold on similarity score
query5 = "quantum computing" # Not directly related to any document
params5 = {"score_threshold": 0.5}
params5 = {"score_threshold": 0.2}
response5 = await memory_impl.query_documents("test_bank", query5, params5)
assert_valid_response(response5)
assert all(score >= 0.5 for score in response5.scores)
print("The scores are:", response5.scores)
assert all(score >= 0.2 for score in response5.scores)
def assert_valid_response(response: QueryDocumentsResponse):

View file

@ -140,7 +140,9 @@ class EmbeddingIndex(ABC):
raise NotImplementedError()
@abstractmethod
async def query(self, embedding: NDArray, k: int) -> QueryDocumentsResponse:
async def query(
self, embedding: NDArray, k: int, score_threshold: float
) -> QueryDocumentsResponse:
raise NotImplementedError()
@ -177,6 +179,7 @@ class BankWithIndex:
if params is None:
params = {}
k = params.get("max_chunks", 3)
score_threshold = params.get("score_threshold", 0.0)
def _process(c) -> str:
if isinstance(c, str):
@ -191,4 +194,4 @@ class BankWithIndex:
model = get_embedding_model(self.bank.embedding_model)
query_vector = model.encode([query_str])[0].astype(np.float32)
return await self.index.query(query_vector, k)
return await self.index.query(query_vector, k, score_threshold)