mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-08-15 06:00:48 +00:00
add query_keyword function
This commit is contained in:
parent
554c78ba66
commit
26fb208877
1 changed files with 37 additions and 10 deletions
|
@ -10,7 +10,6 @@ from typing import Any
|
||||||
from urllib.parse import urlparse
|
from urllib.parse import urlparse
|
||||||
|
|
||||||
import chromadb
|
import chromadb
|
||||||
from chromadb.api.models.AsyncCollection import AsyncCollection
|
|
||||||
from numpy.typing import NDArray
|
from numpy.typing import NDArray
|
||||||
|
|
||||||
from llama_stack.apis.files import Files
|
from llama_stack.apis.files import Files
|
||||||
|
@ -109,18 +108,46 @@ class ChromaIndex(EmbeddingIndex):
|
||||||
await maybe_await(self.client.delete_collection(self.collection.name))
|
await maybe_await(self.client.delete_collection(self.collection.name))
|
||||||
|
|
||||||
async def query_keyword(
|
async def query_keyword(
|
||||||
self,
|
self,
|
||||||
query_string: str,
|
query_string: str,
|
||||||
k: int,
|
k: int,
|
||||||
score_threshold: float,
|
score_threshold: float,
|
||||||
) -> QueryChunksResponse:
|
) -> QueryChunksResponse:
|
||||||
raise NotImplementedError("Keyword search is not supported in Chroma")
|
results = await maybe_await(
|
||||||
|
self.collection.query(
|
||||||
|
query_texts=[query_string],
|
||||||
|
where_document={"$contains": query_string},
|
||||||
|
n_results=k,
|
||||||
|
include=["documents", "distances"],
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
distances = results["distances"][0] if results["distances"] else []
|
||||||
|
documents = results["documents"][0] if results["documents"] else []
|
||||||
|
|
||||||
|
chunks = []
|
||||||
|
scores = []
|
||||||
|
|
||||||
|
for dist, doc in zip(distances, documents, strict=False):
|
||||||
|
try:
|
||||||
|
doc_data = json.loads(doc)
|
||||||
|
chunk = Chunk(**doc_data)
|
||||||
|
except Exception:
|
||||||
|
log.exception(f"Failed to parse document: {doc}")
|
||||||
|
continue
|
||||||
|
|
||||||
|
score = 1.0 / (1.0 + float(dist)) if dist is not None else 1.0
|
||||||
|
|
||||||
|
if score < score_threshold:
|
||||||
|
continue
|
||||||
|
|
||||||
|
chunks.append(chunk)
|
||||||
|
scores.append(score)
|
||||||
|
|
||||||
|
return QueryChunksResponse(chunks=chunks, scores=scores)
|
||||||
|
|
||||||
async def delete_chunk(self, chunk_id: str) -> None:
|
async def delete_chunk(self, chunk_id: str) -> None:
|
||||||
if isinstance(self.collection, AsyncCollection):
|
await maybe_await(self.collection.delete([chunk_id]))
|
||||||
await self.collection.delete([chunk_id])
|
|
||||||
else:
|
|
||||||
self.collection.delete([chunk_id])
|
|
||||||
|
|
||||||
async def query_hybrid(
|
async def query_hybrid(
|
||||||
self,
|
self,
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue