mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-08-12 04:50:39 +00:00
refactor: Remove double filtering based on score threshold (#3019)
# What does this PR do? Remove score_threshold based check from `OpenAIVectorStoreMixin` Closes: https://github.com/meta-llama/llama-stack/issues/3018 <!-- If resolving an issue, uncomment and update the line below --> <!-- Closes #[issue-number] --> ## Test Plan <!-- Describe the tests you ran to verify your changes with result summaries. *Provide clear instructions so the plan can be easily re-executed.* -->
This commit is contained in:
parent
1e3b5aa9b8
commit
3c2aee610d
4 changed files with 13 additions and 7 deletions
|
@ -160,8 +160,11 @@ class FaissIndex(EmbeddingIndex):
|
||||||
for d, i in zip(distances[0], indices[0], strict=False):
|
for d, i in zip(distances[0], indices[0], strict=False):
|
||||||
if i < 0:
|
if i < 0:
|
||||||
continue
|
continue
|
||||||
|
score = 1.0 / float(d) if d != 0 else float("inf")
|
||||||
|
if score < score_threshold:
|
||||||
|
continue
|
||||||
chunks.append(self.chunk_by_index[int(i)])
|
chunks.append(self.chunk_by_index[int(i)])
|
||||||
scores.append(1.0 / float(d) if d != 0 else float("inf"))
|
scores.append(score)
|
||||||
|
|
||||||
return QueryChunksResponse(chunks=chunks, scores=scores)
|
return QueryChunksResponse(chunks=chunks, scores=scores)
|
||||||
|
|
||||||
|
|
|
@ -132,8 +132,11 @@ class PGVectorIndex(EmbeddingIndex):
|
||||||
chunks = []
|
chunks = []
|
||||||
scores = []
|
scores = []
|
||||||
for doc, dist in results:
|
for doc, dist in results:
|
||||||
|
score = 1.0 / float(dist) if dist != 0 else float("inf")
|
||||||
|
if score < score_threshold:
|
||||||
|
continue
|
||||||
chunks.append(Chunk(**doc))
|
chunks.append(Chunk(**doc))
|
||||||
scores.append(1.0 / float(dist) if dist != 0 else float("inf"))
|
scores.append(score)
|
||||||
|
|
||||||
return QueryChunksResponse(chunks=chunks, scores=scores)
|
return QueryChunksResponse(chunks=chunks, scores=scores)
|
||||||
|
|
||||||
|
|
|
@ -105,8 +105,12 @@ class WeaviateIndex(EmbeddingIndex):
|
||||||
log.exception(f"Failed to parse document: {chunk_json}")
|
log.exception(f"Failed to parse document: {chunk_json}")
|
||||||
continue
|
continue
|
||||||
|
|
||||||
|
score = 1.0 / doc.metadata.distance if doc.metadata.distance != 0 else float("inf")
|
||||||
|
if score < score_threshold:
|
||||||
|
continue
|
||||||
|
|
||||||
chunks.append(chunk)
|
chunks.append(chunk)
|
||||||
scores.append(1.0 / doc.metadata.distance if doc.metadata.distance != 0 else float("inf"))
|
scores.append(score)
|
||||||
|
|
||||||
return QueryChunksResponse(chunks=chunks, scores=scores)
|
return QueryChunksResponse(chunks=chunks, scores=scores)
|
||||||
|
|
||||||
|
|
|
@ -433,10 +433,6 @@ class OpenAIVectorStoreMixin(ABC):
|
||||||
# Convert response to OpenAI format
|
# Convert response to OpenAI format
|
||||||
data = []
|
data = []
|
||||||
for chunk, score in zip(response.chunks, response.scores, strict=False):
|
for chunk, score in zip(response.chunks, response.scores, strict=False):
|
||||||
# Apply score based filtering
|
|
||||||
if score < score_threshold:
|
|
||||||
continue
|
|
||||||
|
|
||||||
# Apply filters if provided
|
# Apply filters if provided
|
||||||
if filters:
|
if filters:
|
||||||
# Simple metadata filtering
|
# Simple metadata filtering
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue