diff --git a/llama_stack/providers/adapters/memory/chroma/chroma.py b/llama_stack/providers/adapters/memory/chroma/chroma.py index 954acc09b..7c206d531 100644 --- a/llama_stack/providers/adapters/memory/chroma/chroma.py +++ b/llama_stack/providers/adapters/memory/chroma/chroma.py @@ -38,7 +38,9 @@ class ChromaIndex(EmbeddingIndex): ids=[f"{c.document_id}:chunk-{i}" for i, c in enumerate(chunks)], ) - async def query(self, embedding: NDArray, k: int) -> QueryDocumentsResponse: + async def query( + self, embedding: NDArray, k: int, score_threshold: float + ) -> QueryDocumentsResponse: results = await self.collection.query( query_embeddings=[embedding.tolist()], n_results=k, diff --git a/llama_stack/providers/adapters/memory/pgvector/pgvector.py b/llama_stack/providers/adapters/memory/pgvector/pgvector.py index 251402b46..87d6dbdab 100644 --- a/llama_stack/providers/adapters/memory/pgvector/pgvector.py +++ b/llama_stack/providers/adapters/memory/pgvector/pgvector.py @@ -91,7 +91,9 @@ class PGVectorIndex(EmbeddingIndex): ) execute_values(self.cursor, query, values, template="(%s, %s, %s::vector)") - async def query(self, embedding: NDArray, k: int) -> QueryDocumentsResponse: + async def query( + self, embedding: NDArray, k: int, score_threshold: float + ) -> QueryDocumentsResponse: self.cursor.execute( f""" SELECT document, embedding <-> %s::vector AS distance diff --git a/llama_stack/providers/adapters/memory/qdrant/qdrant.py b/llama_stack/providers/adapters/memory/qdrant/qdrant.py index 313292993..45a8024ac 100644 --- a/llama_stack/providers/adapters/memory/qdrant/qdrant.py +++ b/llama_stack/providers/adapters/memory/qdrant/qdrant.py @@ -68,13 +68,16 @@ class QdrantIndex(EmbeddingIndex): await self.client.upsert(collection_name=self.collection_name, points=points) - async def query(self, embedding: NDArray, k: int) -> QueryDocumentsResponse: + async def query( + self, embedding: NDArray, k: int, score_threshold: float + ) -> QueryDocumentsResponse: results = ( await self.client.query_points( collection_name=self.collection_name, query=embedding.tolist(), limit=k, with_payload=True, + score_threshold=score_threshold, ) ).points diff --git a/llama_stack/providers/adapters/memory/weaviate/weaviate.py b/llama_stack/providers/adapters/memory/weaviate/weaviate.py index 3580b95f8..16fa03679 100644 --- a/llama_stack/providers/adapters/memory/weaviate/weaviate.py +++ b/llama_stack/providers/adapters/memory/weaviate/weaviate.py @@ -50,7 +50,9 @@ class WeaviateIndex(EmbeddingIndex): # TODO: make this async friendly collection.data.insert_many(data_objects) - async def query(self, embedding: NDArray, k: int) -> QueryDocumentsResponse: + async def query( + self, embedding: NDArray, k: int, score_threshold: float + ) -> QueryDocumentsResponse: collection = self.client.collections.get(self.collection_name) results = collection.query.near_vector( diff --git a/llama_stack/providers/tests/memory/test_memory.py b/llama_stack/providers/tests/memory/test_memory.py index a77516154..b26bf75a7 100644 --- a/llama_stack/providers/tests/memory/test_memory.py +++ b/llama_stack/providers/tests/memory/test_memory.py @@ -142,14 +142,13 @@ async def test_query_documents(memory_settings, sample_documents): assert_valid_response(response4) assert len(response4.chunks) <= 2 - # Score threshold is not implemented in vector memory banks # Test case 5: Query with threshold on similarity score - # query5 = "quantum computing" # Not directly related to any document - # params5 = {"score_threshold": 0.5} - # response5 = await memory_impl.query_documents("test_bank", query5, params5) - # assert_valid_response(response5) - # print("The scores are:", response5.scores) - # assert all(score >= 0.5 for score in response5.scores) + query5 = "quantum computing" # Not directly related to any document + params5 = {"score_threshold": 0.2} + response5 = await memory_impl.query_documents("test_bank", query5, params5) + assert_valid_response(response5) + print("The scores are:", response5.scores) + assert all(score >= 0.2 for score in response5.scores) def assert_valid_response(response: QueryDocumentsResponse): diff --git a/llama_stack/providers/utils/memory/vector_store.py b/llama_stack/providers/utils/memory/vector_store.py index d0a7aed54..8e2a1550d 100644 --- a/llama_stack/providers/utils/memory/vector_store.py +++ b/llama_stack/providers/utils/memory/vector_store.py @@ -140,7 +140,9 @@ class EmbeddingIndex(ABC): raise NotImplementedError() @abstractmethod - async def query(self, embedding: NDArray, k: int) -> QueryDocumentsResponse: + async def query( + self, embedding: NDArray, k: int, score_threshold: float + ) -> QueryDocumentsResponse: raise NotImplementedError() @@ -177,6 +179,7 @@ class BankWithIndex: if params is None: params = {} k = params.get("max_chunks", 3) + score_threshold = params.get("score_threshold", 0.0) def _process(c) -> str: if isinstance(c, str): @@ -191,4 +194,4 @@ class BankWithIndex: model = get_embedding_model(self.bank.embedding_model) query_vector = model.encode([query_str])[0].astype(np.float32) - return await self.index.query(query_vector, k) + return await self.index.query(query_vector, k, score_threshold)