chore: review update with score_threshold

This commit is contained in:
Anush008 2024-10-23 00:54:47 +05:30
parent d1bef44e2b
commit bc27046f36
No known key found for this signature in database
6 changed files with 24 additions and 13 deletions

View file

@ -38,7 +38,9 @@ class ChromaIndex(EmbeddingIndex):
ids=[f"{c.document_id}:chunk-{i}" for i, c in enumerate(chunks)],
)
async def query(self, embedding: NDArray, k: int) -> QueryDocumentsResponse:
async def query(
self, embedding: NDArray, k: int, score_threshold: float
) -> QueryDocumentsResponse:
results = await self.collection.query(
query_embeddings=[embedding.tolist()],
n_results=k,

View file

@ -91,7 +91,9 @@ class PGVectorIndex(EmbeddingIndex):
)
execute_values(self.cursor, query, values, template="(%s, %s, %s::vector)")
async def query(self, embedding: NDArray, k: int) -> QueryDocumentsResponse:
async def query(
self, embedding: NDArray, k: int, score_threshold: float
) -> QueryDocumentsResponse:
self.cursor.execute(
f"""
SELECT document, embedding <-> %s::vector AS distance

View file

@ -68,13 +68,16 @@ class QdrantIndex(EmbeddingIndex):
await self.client.upsert(collection_name=self.collection_name, points=points)
async def query(self, embedding: NDArray, k: int) -> QueryDocumentsResponse:
async def query(
self, embedding: NDArray, k: int, score_threshold: float
) -> QueryDocumentsResponse:
results = (
await self.client.query_points(
collection_name=self.collection_name,
query=embedding.tolist(),
limit=k,
with_payload=True,
score_threshold=score_threshold,
)
).points

View file

@ -50,7 +50,9 @@ class WeaviateIndex(EmbeddingIndex):
# TODO: make this async friendly
collection.data.insert_many(data_objects)
async def query(self, embedding: NDArray, k: int) -> QueryDocumentsResponse:
async def query(
self, embedding: NDArray, k: int, score_threshold: float
) -> QueryDocumentsResponse:
collection = self.client.collections.get(self.collection_name)
results = collection.query.near_vector(