mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-07-29 23:29:43 +00:00
chore: review update with score_threshold
This commit is contained in:
parent
d1bef44e2b
commit
bc27046f36
6 changed files with 24 additions and 13 deletions
|
@ -38,7 +38,9 @@ class ChromaIndex(EmbeddingIndex):
|
||||||
ids=[f"{c.document_id}:chunk-{i}" for i, c in enumerate(chunks)],
|
ids=[f"{c.document_id}:chunk-{i}" for i, c in enumerate(chunks)],
|
||||||
)
|
)
|
||||||
|
|
||||||
async def query(self, embedding: NDArray, k: int) -> QueryDocumentsResponse:
|
async def query(
|
||||||
|
self, embedding: NDArray, k: int, score_threshold: float
|
||||||
|
) -> QueryDocumentsResponse:
|
||||||
results = await self.collection.query(
|
results = await self.collection.query(
|
||||||
query_embeddings=[embedding.tolist()],
|
query_embeddings=[embedding.tolist()],
|
||||||
n_results=k,
|
n_results=k,
|
||||||
|
|
|
@ -91,7 +91,9 @@ class PGVectorIndex(EmbeddingIndex):
|
||||||
)
|
)
|
||||||
execute_values(self.cursor, query, values, template="(%s, %s, %s::vector)")
|
execute_values(self.cursor, query, values, template="(%s, %s, %s::vector)")
|
||||||
|
|
||||||
async def query(self, embedding: NDArray, k: int) -> QueryDocumentsResponse:
|
async def query(
|
||||||
|
self, embedding: NDArray, k: int, score_threshold: float
|
||||||
|
) -> QueryDocumentsResponse:
|
||||||
self.cursor.execute(
|
self.cursor.execute(
|
||||||
f"""
|
f"""
|
||||||
SELECT document, embedding <-> %s::vector AS distance
|
SELECT document, embedding <-> %s::vector AS distance
|
||||||
|
|
|
@ -68,13 +68,16 @@ class QdrantIndex(EmbeddingIndex):
|
||||||
|
|
||||||
await self.client.upsert(collection_name=self.collection_name, points=points)
|
await self.client.upsert(collection_name=self.collection_name, points=points)
|
||||||
|
|
||||||
async def query(self, embedding: NDArray, k: int) -> QueryDocumentsResponse:
|
async def query(
|
||||||
|
self, embedding: NDArray, k: int, score_threshold: float
|
||||||
|
) -> QueryDocumentsResponse:
|
||||||
results = (
|
results = (
|
||||||
await self.client.query_points(
|
await self.client.query_points(
|
||||||
collection_name=self.collection_name,
|
collection_name=self.collection_name,
|
||||||
query=embedding.tolist(),
|
query=embedding.tolist(),
|
||||||
limit=k,
|
limit=k,
|
||||||
with_payload=True,
|
with_payload=True,
|
||||||
|
score_threshold=score_threshold,
|
||||||
)
|
)
|
||||||
).points
|
).points
|
||||||
|
|
||||||
|
|
|
@ -50,7 +50,9 @@ class WeaviateIndex(EmbeddingIndex):
|
||||||
# TODO: make this async friendly
|
# TODO: make this async friendly
|
||||||
collection.data.insert_many(data_objects)
|
collection.data.insert_many(data_objects)
|
||||||
|
|
||||||
async def query(self, embedding: NDArray, k: int) -> QueryDocumentsResponse:
|
async def query(
|
||||||
|
self, embedding: NDArray, k: int, score_threshold: float
|
||||||
|
) -> QueryDocumentsResponse:
|
||||||
collection = self.client.collections.get(self.collection_name)
|
collection = self.client.collections.get(self.collection_name)
|
||||||
|
|
||||||
results = collection.query.near_vector(
|
results = collection.query.near_vector(
|
||||||
|
|
|
@ -142,14 +142,13 @@ async def test_query_documents(memory_settings, sample_documents):
|
||||||
assert_valid_response(response4)
|
assert_valid_response(response4)
|
||||||
assert len(response4.chunks) <= 2
|
assert len(response4.chunks) <= 2
|
||||||
|
|
||||||
# Score threshold is not implemented in vector memory banks
|
|
||||||
# Test case 5: Query with threshold on similarity score
|
# Test case 5: Query with threshold on similarity score
|
||||||
# query5 = "quantum computing" # Not directly related to any document
|
query5 = "quantum computing" # Not directly related to any document
|
||||||
# params5 = {"score_threshold": 0.5}
|
params5 = {"score_threshold": 0.2}
|
||||||
# response5 = await memory_impl.query_documents("test_bank", query5, params5)
|
response5 = await memory_impl.query_documents("test_bank", query5, params5)
|
||||||
# assert_valid_response(response5)
|
assert_valid_response(response5)
|
||||||
# print("The scores are:", response5.scores)
|
print("The scores are:", response5.scores)
|
||||||
# assert all(score >= 0.5 for score in response5.scores)
|
assert all(score >= 0.2 for score in response5.scores)
|
||||||
|
|
||||||
|
|
||||||
def assert_valid_response(response: QueryDocumentsResponse):
|
def assert_valid_response(response: QueryDocumentsResponse):
|
||||||
|
|
|
@ -140,7 +140,9 @@ class EmbeddingIndex(ABC):
|
||||||
raise NotImplementedError()
|
raise NotImplementedError()
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
async def query(self, embedding: NDArray, k: int) -> QueryDocumentsResponse:
|
async def query(
|
||||||
|
self, embedding: NDArray, k: int, score_threshold: float
|
||||||
|
) -> QueryDocumentsResponse:
|
||||||
raise NotImplementedError()
|
raise NotImplementedError()
|
||||||
|
|
||||||
|
|
||||||
|
@ -177,6 +179,7 @@ class BankWithIndex:
|
||||||
if params is None:
|
if params is None:
|
||||||
params = {}
|
params = {}
|
||||||
k = params.get("max_chunks", 3)
|
k = params.get("max_chunks", 3)
|
||||||
|
score_threshold = params.get("score_threshold", 0.0)
|
||||||
|
|
||||||
def _process(c) -> str:
|
def _process(c) -> str:
|
||||||
if isinstance(c, str):
|
if isinstance(c, str):
|
||||||
|
@ -191,4 +194,4 @@ class BankWithIndex:
|
||||||
|
|
||||||
model = get_embedding_model(self.bank.embedding_model)
|
model = get_embedding_model(self.bank.embedding_model)
|
||||||
query_vector = model.encode([query_str])[0].astype(np.float32)
|
query_vector = model.encode([query_str])[0].astype(np.float32)
|
||||||
return await self.index.query(query_vector, k)
|
return await self.index.query(query_vector, k, score_threshold)
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue