chore: review update with score_threshold

This commit is contained in:
Anush008 2024-10-23 00:54:47 +05:30
parent d1bef44e2b
commit bc27046f36
No known key found for this signature in database
6 changed files with 24 additions and 13 deletions

View file

@ -38,7 +38,9 @@ class ChromaIndex(EmbeddingIndex):
ids=[f"{c.document_id}:chunk-{i}" for i, c in enumerate(chunks)], ids=[f"{c.document_id}:chunk-{i}" for i, c in enumerate(chunks)],
) )
async def query(self, embedding: NDArray, k: int) -> QueryDocumentsResponse: async def query(
self, embedding: NDArray, k: int, score_threshold: float
) -> QueryDocumentsResponse:
results = await self.collection.query( results = await self.collection.query(
query_embeddings=[embedding.tolist()], query_embeddings=[embedding.tolist()],
n_results=k, n_results=k,

View file

@ -91,7 +91,9 @@ class PGVectorIndex(EmbeddingIndex):
) )
execute_values(self.cursor, query, values, template="(%s, %s, %s::vector)") execute_values(self.cursor, query, values, template="(%s, %s, %s::vector)")
async def query(self, embedding: NDArray, k: int) -> QueryDocumentsResponse: async def query(
self, embedding: NDArray, k: int, score_threshold: float
) -> QueryDocumentsResponse:
self.cursor.execute( self.cursor.execute(
f""" f"""
SELECT document, embedding <-> %s::vector AS distance SELECT document, embedding <-> %s::vector AS distance

View file

@ -68,13 +68,16 @@ class QdrantIndex(EmbeddingIndex):
await self.client.upsert(collection_name=self.collection_name, points=points) await self.client.upsert(collection_name=self.collection_name, points=points)
async def query(self, embedding: NDArray, k: int) -> QueryDocumentsResponse: async def query(
self, embedding: NDArray, k: int, score_threshold: float
) -> QueryDocumentsResponse:
results = ( results = (
await self.client.query_points( await self.client.query_points(
collection_name=self.collection_name, collection_name=self.collection_name,
query=embedding.tolist(), query=embedding.tolist(),
limit=k, limit=k,
with_payload=True, with_payload=True,
score_threshold=score_threshold,
) )
).points ).points

View file

@ -50,7 +50,9 @@ class WeaviateIndex(EmbeddingIndex):
# TODO: make this async friendly # TODO: make this async friendly
collection.data.insert_many(data_objects) collection.data.insert_many(data_objects)
async def query(self, embedding: NDArray, k: int) -> QueryDocumentsResponse: async def query(
self, embedding: NDArray, k: int, score_threshold: float
) -> QueryDocumentsResponse:
collection = self.client.collections.get(self.collection_name) collection = self.client.collections.get(self.collection_name)
results = collection.query.near_vector( results = collection.query.near_vector(

View file

@ -142,14 +142,13 @@ async def test_query_documents(memory_settings, sample_documents):
assert_valid_response(response4) assert_valid_response(response4)
assert len(response4.chunks) <= 2 assert len(response4.chunks) <= 2
# Score threshold is not implemented in vector memory banks
# Test case 5: Query with threshold on similarity score # Test case 5: Query with threshold on similarity score
# query5 = "quantum computing" # Not directly related to any document query5 = "quantum computing" # Not directly related to any document
# params5 = {"score_threshold": 0.5} params5 = {"score_threshold": 0.2}
# response5 = await memory_impl.query_documents("test_bank", query5, params5) response5 = await memory_impl.query_documents("test_bank", query5, params5)
# assert_valid_response(response5) assert_valid_response(response5)
# print("The scores are:", response5.scores) print("The scores are:", response5.scores)
# assert all(score >= 0.5 for score in response5.scores) assert all(score >= 0.2 for score in response5.scores)
def assert_valid_response(response: QueryDocumentsResponse): def assert_valid_response(response: QueryDocumentsResponse):

View file

@ -140,7 +140,9 @@ class EmbeddingIndex(ABC):
raise NotImplementedError() raise NotImplementedError()
@abstractmethod @abstractmethod
async def query(self, embedding: NDArray, k: int) -> QueryDocumentsResponse: async def query(
self, embedding: NDArray, k: int, score_threshold: float
) -> QueryDocumentsResponse:
raise NotImplementedError() raise NotImplementedError()
@ -177,6 +179,7 @@ class BankWithIndex:
if params is None: if params is None:
params = {} params = {}
k = params.get("max_chunks", 3) k = params.get("max_chunks", 3)
score_threshold = params.get("score_threshold", 0.0)
def _process(c) -> str: def _process(c) -> str:
if isinstance(c, str): if isinstance(c, str):
@ -191,4 +194,4 @@ class BankWithIndex:
model = get_embedding_model(self.bank.embedding_model) model = get_embedding_model(self.bank.embedding_model)
query_vector = model.encode([query_str])[0].astype(np.float32) query_vector = model.encode([query_str])[0].astype(np.float32)
return await self.index.query(query_vector, k) return await self.index.query(query_vector, k, score_threshold)