diff --git a/.gitignore b/.gitignore index 2465d2d4e..7b8321844 100644 --- a/.gitignore +++ b/.gitignore @@ -13,3 +13,4 @@ xcuserdata/ Package.resolved *.pte *.ipynb_checkpoints* +.venv/ diff --git a/llama_stack/providers/adapters/memory/qdrant/__init__.py b/llama_stack/providers/adapters/memory/qdrant/__init__.py new file mode 100644 index 000000000..9f54babad --- /dev/null +++ b/llama_stack/providers/adapters/memory/qdrant/__init__.py @@ -0,0 +1,15 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +from .config import QdrantConfig + + +async def get_adapter_impl(config: QdrantConfig, _deps): + from .qdrant import QdrantVectorMemoryAdapter + + impl = QdrantVectorMemoryAdapter(config) + await impl.initialize() + return impl diff --git a/llama_stack/providers/adapters/memory/qdrant/config.py b/llama_stack/providers/adapters/memory/qdrant/config.py new file mode 100644 index 000000000..a6a5a6ff6 --- /dev/null +++ b/llama_stack/providers/adapters/memory/qdrant/config.py @@ -0,0 +1,25 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +from typing import Optional + +from llama_models.schema_utils import json_schema_type +from pydantic import BaseModel + + +@json_schema_type +class QdrantConfig(BaseModel): + location: Optional[str] = None + url: Optional[str] = None + port: Optional[int] = 6333 + grpc_port: int = 6334 + prefer_grpc: bool = False + https: Optional[bool] = None + api_key: Optional[str] = None + prefix: Optional[str] = None + timeout: Optional[int] = None + host: Optional[str] = None + path: Optional[str] = None diff --git a/llama_stack/providers/adapters/memory/qdrant/qdrant.py b/llama_stack/providers/adapters/memory/qdrant/qdrant.py new file mode 100644 index 000000000..a7aa0b15e --- /dev/null +++ b/llama_stack/providers/adapters/memory/qdrant/qdrant.py @@ -0,0 +1,199 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +import traceback +import uuid +from typing import List + +from numpy.typing import NDArray +from qdrant_client import AsyncQdrantClient, models +from qdrant_client.models import PointStruct + +from llama_stack.apis.memory import * # noqa: F403 +from llama_stack.distribution.datatypes import RoutableProvider + +from llama_stack.providers.adapters.memory.qdrant.config import QdrantConfig +from llama_stack.providers.utils.memory.vector_store import ( + BankWithIndex, + EmbeddingIndex, +) + +CHUNK_ID_KEY = "_chunk_id" +METADATA_COLLECTION_NAME = "metadata_store" + + +def convert_id(_id: str) -> str: + """ + Converts any string into a UUID string based on a seed. + + Qdrant accepts UUID strings and unsigned integers as point ID. + We use a seed to convert each string into a UUID string deterministically. + This allows us to overwrite the same point with the original ID. + """ + return str(uuid.uuid5(uuid.NAMESPACE_DNS, _id)) + + +class QdrantIndex(EmbeddingIndex): + def __init__(self, client: AsyncQdrantClient, bank: MemoryBank): + self.client = client + self.collection_name = bank.name + + async def add_chunks(self, chunks: List[Chunk], embeddings: NDArray): + assert len(chunks) == len( + embeddings + ), f"Chunk length {len(chunks)} does not match embedding length {len(embeddings)}" + + if not await self.client.collection_exists(self.collection_name): + await self.client.create_collection( + self.collection_name, + vectors_config=models.VectorParams( + size=len(embeddings[0]), distance=models.Distance.COSINE + ), + ) + + points = [] + for i, (chunk, embedding) in enumerate(zip(chunks, embeddings)): + chunk_id = f"{chunk.document_id}:chunk-{i}" + points.append( + PointStruct( + id=convert_id(chunk_id), + vector=embedding, + payload=chunk.model_dump() | {CHUNK_ID_KEY: chunk_id}, + ) + ) + + await self.client.upsert(collection_name=self.collection_name, points=points) + + async def query(self, embedding: NDArray, k: int) -> QueryDocumentsResponse: + results = ( + await self.client.query_points( + collection_name=self.collection_name, query=embedding.tolist(), limit=k + ) + ).points + + chunks, scores = [], [] + for point in results: + assert isinstance(point, models.ScoredPoint) + assert point.payload is not None + + try: + point.payload.pop(CHUNK_ID_KEY, None) + chunk = Chunk(**point.payload) + except Exception: + traceback.print_exc() + continue + + chunks.append(chunk) + scores.append(point.score) + + return QueryDocumentsResponse(chunks=chunks, scores=scores) + + +class QdrantVectorMemoryAdapter(Memory, RoutableProvider): + def __init__(self, config: QdrantConfig) -> None: + self.config = config + self.client = None + self.cache = {} + + async def initialize(self) -> None: + try: + self.client = AsyncQdrantClient(**self.config.model_dump(exclude_none=True)) + + if not await self.client.collection_exists(METADATA_COLLECTION_NAME): + await self.client.create_collection( + METADATA_COLLECTION_NAME, vectors_config={} + ) + except Exception as e: + import traceback + + traceback.print_exc() + raise RuntimeError(f"Could not connect to Qdrant: {e}") from e + + async def shutdown(self) -> None: + pass + + async def validate_routing_keys(self, routing_keys: List[str]) -> None: + print(f"[qdrant] Registering memory bank routing keys: {routing_keys}") + pass + + async def create_memory_bank( + self, + name: str, + config: MemoryBankConfig, + url: Optional[URL] = None, + ) -> MemoryBank: + bank_id = str(uuid.uuid4()) + bank = MemoryBank( + bank_id=bank_id, + name=name, + config=config, + url=url, + ) + + await self.client.upsert( + METADATA_COLLECTION_NAME, + points=[ + PointStruct( + id=convert_id(bank_id), vector={}, payload=bank.model_dump() + ) + ], + ) + + index = BankWithIndex( + bank=bank, + index=QdrantIndex(self.client, bank), + ) + self.cache[bank_id] = index + return bank + + async def get_memory_bank(self, bank_id: str) -> Optional[MemoryBank]: + bank_index = await self._get_and_cache_bank_index(bank_id) + if bank_index is None: + return None + return bank_index.bank + + async def _get_and_cache_bank_index(self, bank_id: str) -> Optional[BankWithIndex]: + if bank_id in self.cache: + return self.cache[bank_id] + + bank_point = await self.client.retrieve( + METADATA_COLLECTION_NAME, ids=[convert_id(bank_id)], with_payload=True + ) + + if not bank_point: + return None + + bank = MemoryBank(**bank_point[0].payload) + index = BankWithIndex( + bank=bank, + index=QdrantIndex(self.client, bank), + ) + self.cache[bank_id] = index + return index + + async def insert_documents( + self, + bank_id: str, + documents: List[MemoryBankDocument], + ttl_seconds: Optional[int] = None, + ) -> None: + index = await self._get_and_cache_bank_index(bank_id) + if not index: + raise ValueError(f"Bank {bank_id} not found") + + await index.insert_documents(documents) + + async def query_documents( + self, + bank_id: str, + query: InterleavedTextMedia, + params: Optional[Dict[str, Any]] = None, + ) -> QueryDocumentsResponse: + index = await self._get_and_cache_bank_index(bank_id) + if not index: + raise ValueError(f"Bank {bank_id} not found") + + return await index.query_documents(query, params)