# Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. # # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. import logging import uuid from typing import Any, Dict, List from numpy.typing import NDArray from qdrant_client import AsyncQdrantClient, models from qdrant_client.models import PointStruct from llama_stack.apis.memory_banks import * # noqa: F403 from llama_stack.providers.datatypes import MemoryBanksProtocolPrivate from llama_stack.apis.memory import * # noqa: F403 from llama_stack.providers.remote.memory.qdrant.config import QdrantConfig from llama_stack.providers.utils.memory.vector_store import ( BankWithIndex, EmbeddingIndex, ) log = logging.getLogger(__name__) CHUNK_ID_KEY = "_chunk_id" def convert_id(_id: str) -> str: """ Converts any string into a UUID string based on a seed. Qdrant accepts UUID strings and unsigned integers as point ID. We use a seed to convert each string into a UUID string deterministically. This allows us to overwrite the same point with the original ID. """ return str(uuid.uuid5(uuid.NAMESPACE_DNS, _id)) class QdrantIndex(EmbeddingIndex): def __init__(self, client: AsyncQdrantClient, collection_name: str): self.client = client self.collection_name = collection_name async def add_chunks(self, chunks: List[Chunk], embeddings: NDArray): assert len(chunks) == len( embeddings ), f"Chunk length {len(chunks)} does not match embedding length {len(embeddings)}" if not await self.client.collection_exists(self.collection_name): await self.client.create_collection( self.collection_name, vectors_config=models.VectorParams( size=len(embeddings[0]), distance=models.Distance.COSINE ), ) points = [] for i, (chunk, embedding) in enumerate(zip(chunks, embeddings)): chunk_id = f"{chunk.document_id}:chunk-{i}" points.append( PointStruct( id=convert_id(chunk_id), vector=embedding, payload={"chunk_content": chunk.model_dump()} | {CHUNK_ID_KEY: chunk_id}, ) ) await self.client.upsert(collection_name=self.collection_name, points=points) async def query( self, embedding: NDArray, k: int, score_threshold: float ) -> QueryDocumentsResponse: results = ( await self.client.query_points( collection_name=self.collection_name, query=embedding.tolist(), limit=k, with_payload=True, score_threshold=score_threshold, ) ).points chunks, scores = [], [] for point in results: assert isinstance(point, models.ScoredPoint) assert point.payload is not None try: chunk = Chunk(**point.payload["chunk_content"]) except Exception: log.exception("Failed to parse chunk") continue chunks.append(chunk) scores.append(point.score) return QueryDocumentsResponse(chunks=chunks, scores=scores) class QdrantVectorMemoryAdapter(Memory, MemoryBanksProtocolPrivate): def __init__(self, config: QdrantConfig, inference_api: Api.inference) -> None: self.config = config self.client = AsyncQdrantClient(**self.config.model_dump(exclude_none=True)) self.cache = {} self.inference_api = inference_api async def initialize(self) -> None: pass async def shutdown(self) -> None: self.client.close() async def register_memory_bank( self, memory_bank: MemoryBank, ) -> None: assert ( memory_bank.memory_bank_type == MemoryBankType.vector ), f"Only vector banks are supported {memory_bank.memory_bank_type}" index = BankWithIndex( bank=memory_bank, index=QdrantIndex(self.client, memory_bank.identifier), inference_api=self.inference_api, ) self.cache[memory_bank.identifier] = index async def _get_and_cache_bank_index(self, bank_id: str) -> Optional[BankWithIndex]: if bank_id in self.cache: return self.cache[bank_id] bank = await self.memory_bank_store.get_memory_bank(bank_id) if not bank: raise ValueError(f"Bank {bank_id} not found") index = BankWithIndex( bank=bank, index=QdrantIndex(client=self.client, collection_name=bank_id), inference_api=self.inference_api, ) self.cache[bank_id] = index return index async def insert_documents( self, bank_id: str, documents: List[MemoryBankDocument], ttl_seconds: Optional[int] = None, ) -> None: index = await self._get_and_cache_bank_index(bank_id) if not index: raise ValueError(f"Bank {bank_id} not found") await index.insert_documents(documents) async def query_documents( self, bank_id: str, query: InterleavedTextMedia, params: Optional[Dict[str, Any]] = None, ) -> QueryDocumentsResponse: index = await self._get_and_cache_bank_index(bank_id) if not index: raise ValueError(f"Bank {bank_id} not found") return await index.query_documents(query, params)