mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-06-27 18:50:41 +00:00
fix: ChromaDB provider (#2413)
fixes the remote::chromaDB provider for vector_io by updating the method definition appropriately. Fixed impl to use score_threshold properly. ### Test Plan ``` # Start Chroma Docker docker run --rm \ --name chromadb \ -p 8800:8000 \ -v ~/chroma:/chroma/chroma \ -e IS_PERSISTENT=TRUE \ -e ANONYMIZED_TELEMETRY=FALSE \ chromadb/chroma:latest # run pytest CHROMADB_URL="http://localhost:8800" pytest -sv tests/integration/vector_io/test_vector_io.py --stack-config vector_io=remote::chromadb,inference=fireworks --embedding-model nomic-ai/nomic-embed-text-v1.5 ```
This commit is contained in:
parent
0d0b8d2be1
commit
1f48577a02
1 changed files with 6 additions and 2 deletions
|
@ -55,7 +55,7 @@ class ChromaIndex(EmbeddingIndex):
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
async def query(self, embedding: NDArray, k: int, score_threshold: float) -> QueryChunksResponse:
|
async def query_vector(self, embedding: NDArray, k: int, score_threshold: float) -> QueryChunksResponse:
|
||||||
results = await maybe_await(
|
results = await maybe_await(
|
||||||
self.collection.query(
|
self.collection.query(
|
||||||
query_embeddings=[embedding.tolist()],
|
query_embeddings=[embedding.tolist()],
|
||||||
|
@ -76,8 +76,12 @@ class ChromaIndex(EmbeddingIndex):
|
||||||
log.exception(f"Failed to parse document: {doc}")
|
log.exception(f"Failed to parse document: {doc}")
|
||||||
continue
|
continue
|
||||||
|
|
||||||
|
score = 1.0 / float(dist) if dist != 0 else float("inf")
|
||||||
|
if score < score_threshold:
|
||||||
|
continue
|
||||||
|
|
||||||
chunks.append(chunk)
|
chunks.append(chunk)
|
||||||
scores.append(1.0 / float(dist))
|
scores.append(score)
|
||||||
|
|
||||||
return QueryChunksResponse(chunks=chunks, scores=scores)
|
return QueryChunksResponse(chunks=chunks, scores=scores)
|
||||||
|
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue