diff --git a/.gitignore b/.gitignore index f6ef5d9ca..a6c204131 100644 --- a/.gitignore +++ b/.gitignore @@ -13,5 +13,6 @@ xcuserdata/ Package.resolved *.pte *.ipynb_checkpoints* +.venv/ .idea _build diff --git a/llama_stack/providers/adapters/memory/chroma/chroma.py b/llama_stack/providers/adapters/memory/chroma/chroma.py index 954acc09b..7c206d531 100644 --- a/llama_stack/providers/adapters/memory/chroma/chroma.py +++ b/llama_stack/providers/adapters/memory/chroma/chroma.py @@ -38,7 +38,9 @@ class ChromaIndex(EmbeddingIndex): ids=[f"{c.document_id}:chunk-{i}" for i, c in enumerate(chunks)], ) - async def query(self, embedding: NDArray, k: int) -> QueryDocumentsResponse: + async def query( + self, embedding: NDArray, k: int, score_threshold: float + ) -> QueryDocumentsResponse: results = await self.collection.query( query_embeddings=[embedding.tolist()], n_results=k, diff --git a/llama_stack/providers/adapters/memory/pgvector/pgvector.py b/llama_stack/providers/adapters/memory/pgvector/pgvector.py index 251402b46..87d6dbdab 100644 --- a/llama_stack/providers/adapters/memory/pgvector/pgvector.py +++ b/llama_stack/providers/adapters/memory/pgvector/pgvector.py @@ -91,7 +91,9 @@ class PGVectorIndex(EmbeddingIndex): ) execute_values(self.cursor, query, values, template="(%s, %s, %s::vector)") - async def query(self, embedding: NDArray, k: int) -> QueryDocumentsResponse: + async def query( + self, embedding: NDArray, k: int, score_threshold: float + ) -> QueryDocumentsResponse: self.cursor.execute( f""" SELECT document, embedding <-> %s::vector AS distance diff --git a/llama_stack/providers/adapters/memory/qdrant/__init__.py b/llama_stack/providers/adapters/memory/qdrant/__init__.py new file mode 100644 index 000000000..9f54babad --- /dev/null +++ b/llama_stack/providers/adapters/memory/qdrant/__init__.py @@ -0,0 +1,15 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +from .config import QdrantConfig + + +async def get_adapter_impl(config: QdrantConfig, _deps): + from .qdrant import QdrantVectorMemoryAdapter + + impl = QdrantVectorMemoryAdapter(config) + await impl.initialize() + return impl diff --git a/llama_stack/providers/adapters/memory/qdrant/config.py b/llama_stack/providers/adapters/memory/qdrant/config.py new file mode 100644 index 000000000..a6a5a6ff6 --- /dev/null +++ b/llama_stack/providers/adapters/memory/qdrant/config.py @@ -0,0 +1,25 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +from typing import Optional + +from llama_models.schema_utils import json_schema_type +from pydantic import BaseModel + + +@json_schema_type +class QdrantConfig(BaseModel): + location: Optional[str] = None + url: Optional[str] = None + port: Optional[int] = 6333 + grpc_port: int = 6334 + prefer_grpc: bool = False + https: Optional[bool] = None + api_key: Optional[str] = None + prefix: Optional[str] = None + timeout: Optional[int] = None + host: Optional[str] = None + path: Optional[str] = None diff --git a/llama_stack/providers/adapters/memory/qdrant/qdrant.py b/llama_stack/providers/adapters/memory/qdrant/qdrant.py new file mode 100644 index 000000000..45a8024ac --- /dev/null +++ b/llama_stack/providers/adapters/memory/qdrant/qdrant.py @@ -0,0 +1,170 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +import traceback +import uuid +from typing import Any, Dict, List + +from numpy.typing import NDArray +from qdrant_client import AsyncQdrantClient, models +from qdrant_client.models import PointStruct + +from llama_stack.providers.datatypes import MemoryBanksProtocolPrivate + +from llama_stack.apis.memory import * # noqa: F403 + +from llama_stack.providers.adapters.memory.qdrant.config import QdrantConfig +from llama_stack.providers.utils.memory.vector_store import ( + BankWithIndex, + EmbeddingIndex, +) + +CHUNK_ID_KEY = "_chunk_id" + + +def convert_id(_id: str) -> str: + """ + Converts any string into a UUID string based on a seed. + + Qdrant accepts UUID strings and unsigned integers as point ID. + We use a seed to convert each string into a UUID string deterministically. + This allows us to overwrite the same point with the original ID. + """ + return str(uuid.uuid5(uuid.NAMESPACE_DNS, _id)) + + +class QdrantIndex(EmbeddingIndex): + def __init__(self, client: AsyncQdrantClient, collection_name: str): + self.client = client + self.collection_name = collection_name + + async def add_chunks(self, chunks: List[Chunk], embeddings: NDArray): + assert len(chunks) == len( + embeddings + ), f"Chunk length {len(chunks)} does not match embedding length {len(embeddings)}" + + if not await self.client.collection_exists(self.collection_name): + await self.client.create_collection( + self.collection_name, + vectors_config=models.VectorParams( + size=len(embeddings[0]), distance=models.Distance.COSINE + ), + ) + + points = [] + for i, (chunk, embedding) in enumerate(zip(chunks, embeddings)): + chunk_id = f"{chunk.document_id}:chunk-{i}" + points.append( + PointStruct( + id=convert_id(chunk_id), + vector=embedding, + payload={"chunk_content": chunk.model_dump()} + | {CHUNK_ID_KEY: chunk_id}, + ) + ) + + await self.client.upsert(collection_name=self.collection_name, points=points) + + async def query( + self, embedding: NDArray, k: int, score_threshold: float + ) -> QueryDocumentsResponse: + results = ( + await self.client.query_points( + collection_name=self.collection_name, + query=embedding.tolist(), + limit=k, + with_payload=True, + score_threshold=score_threshold, + ) + ).points + + chunks, scores = [], [] + for point in results: + assert isinstance(point, models.ScoredPoint) + assert point.payload is not None + + try: + chunk = Chunk(**point.payload["chunk_content"]) + except Exception: + traceback.print_exc() + continue + + chunks.append(chunk) + scores.append(point.score) + + return QueryDocumentsResponse(chunks=chunks, scores=scores) + + +class QdrantVectorMemoryAdapter(Memory, MemoryBanksProtocolPrivate): + def __init__(self, config: QdrantConfig) -> None: + self.config = config + self.client = AsyncQdrantClient(**self.config.model_dump(exclude_none=True)) + self.cache = {} + + async def initialize(self) -> None: + pass + + async def shutdown(self) -> None: + self.client.close() + + async def register_memory_bank( + self, + memory_bank: MemoryBankDef, + ) -> None: + assert ( + memory_bank.type == MemoryBankType.vector.value + ), f"Only vector banks are supported {memory_bank.type}" + + index = BankWithIndex( + bank=memory_bank, + index=QdrantIndex(self.client, memory_bank.identifier), + ) + + self.cache[memory_bank.identifier] = index + + async def list_memory_banks(self) -> List[MemoryBankDef]: + # Qdrant doesn't have collection level metadata to store the bank properties + # So we only return from the cache value + return [i.bank for i in self.cache.values()] + + async def _get_and_cache_bank_index(self, bank_id: str) -> Optional[BankWithIndex]: + if bank_id in self.cache: + return self.cache[bank_id] + + bank = await self.memory_bank_store.get_memory_bank(bank_id) + if not bank: + raise ValueError(f"Bank {bank_id} not found") + + index = BankWithIndex( + bank=bank, + index=QdrantIndex(client=self.client, collection_name=bank_id), + ) + self.cache[bank_id] = index + return index + + async def insert_documents( + self, + bank_id: str, + documents: List[MemoryBankDocument], + ttl_seconds: Optional[int] = None, + ) -> None: + index = await self._get_and_cache_bank_index(bank_id) + if not index: + raise ValueError(f"Bank {bank_id} not found") + + await index.insert_documents(documents) + + async def query_documents( + self, + bank_id: str, + query: InterleavedTextMedia, + params: Optional[Dict[str, Any]] = None, + ) -> QueryDocumentsResponse: + index = await self._get_and_cache_bank_index(bank_id) + if not index: + raise ValueError(f"Bank {bank_id} not found") + + return await index.query_documents(query, params) diff --git a/llama_stack/providers/adapters/memory/weaviate/weaviate.py b/llama_stack/providers/adapters/memory/weaviate/weaviate.py index 3580b95f8..16fa03679 100644 --- a/llama_stack/providers/adapters/memory/weaviate/weaviate.py +++ b/llama_stack/providers/adapters/memory/weaviate/weaviate.py @@ -50,7 +50,9 @@ class WeaviateIndex(EmbeddingIndex): # TODO: make this async friendly collection.data.insert_many(data_objects) - async def query(self, embedding: NDArray, k: int) -> QueryDocumentsResponse: + async def query( + self, embedding: NDArray, k: int, score_threshold: float + ) -> QueryDocumentsResponse: collection = self.client.collections.get(self.collection_name) results = collection.query.near_vector( diff --git a/llama_stack/providers/registry/memory.py b/llama_stack/providers/registry/memory.py index a8d776c3f..a0fbf1636 100644 --- a/llama_stack/providers/registry/memory.py +++ b/llama_stack/providers/registry/memory.py @@ -75,4 +75,13 @@ def available_providers() -> List[ProviderSpec]: config_class="llama_stack.providers.adapters.memory.sample.SampleConfig", ), ), + remote_provider_spec( + Api.memory, + AdapterSpec( + adapter_type="qdrant", + pip_packages=EMBEDDING_DEPS + ["qdrant-client"], + module="llama_stack.providers.adapters.memory.qdrant", + config_class="llama_stack.providers.adapters.memory.qdrant.QdrantConfig", + ), + ), ] diff --git a/llama_stack/providers/tests/memory/provider_config_example.yaml b/llama_stack/providers/tests/memory/provider_config_example.yaml index 5b5440f8d..13575a598 100644 --- a/llama_stack/providers/tests/memory/provider_config_example.yaml +++ b/llama_stack/providers/tests/memory/provider_config_example.yaml @@ -15,6 +15,11 @@ providers: - provider_id: test-weaviate provider_type: remote::weaviate config: {} + - provider_id: test-qdrant + provider_type: remote::qdrant + config: + host: localhost + port: 6333 # if a provider needs private keys from the client, they use the # "get_request_provider_data" function (see distribution/request_headers.py) # this is a place to provide such data. diff --git a/llama_stack/providers/tests/memory/test_memory.py b/llama_stack/providers/tests/memory/test_memory.py index d92feaba8..b26bf75a7 100644 --- a/llama_stack/providers/tests/memory/test_memory.py +++ b/llama_stack/providers/tests/memory/test_memory.py @@ -144,10 +144,11 @@ async def test_query_documents(memory_settings, sample_documents): # Test case 5: Query with threshold on similarity score query5 = "quantum computing" # Not directly related to any document - params5 = {"score_threshold": 0.5} + params5 = {"score_threshold": 0.2} response5 = await memory_impl.query_documents("test_bank", query5, params5) assert_valid_response(response5) - assert all(score >= 0.5 for score in response5.scores) + print("The scores are:", response5.scores) + assert all(score >= 0.2 for score in response5.scores) def assert_valid_response(response: QueryDocumentsResponse): diff --git a/llama_stack/providers/utils/memory/vector_store.py b/llama_stack/providers/utils/memory/vector_store.py index d0a7aed54..8e2a1550d 100644 --- a/llama_stack/providers/utils/memory/vector_store.py +++ b/llama_stack/providers/utils/memory/vector_store.py @@ -140,7 +140,9 @@ class EmbeddingIndex(ABC): raise NotImplementedError() @abstractmethod - async def query(self, embedding: NDArray, k: int) -> QueryDocumentsResponse: + async def query( + self, embedding: NDArray, k: int, score_threshold: float + ) -> QueryDocumentsResponse: raise NotImplementedError() @@ -177,6 +179,7 @@ class BankWithIndex: if params is None: params = {} k = params.get("max_chunks", 3) + score_threshold = params.get("score_threshold", 0.0) def _process(c) -> str: if isinstance(c, str): @@ -191,4 +194,4 @@ class BankWithIndex: model = get_embedding_model(self.bank.embedding_model) query_vector = model.encode([query_str])[0].astype(np.float32) - return await self.index.query(query_vector, k) + return await self.index.query(query_vector, k, score_threshold)