mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-12-16 07:22:37 +00:00
chore: review update with score_threshold
This commit is contained in:
parent
d1bef44e2b
commit
bc27046f36
6 changed files with 24 additions and 13 deletions
|
|
@ -140,7 +140,9 @@ class EmbeddingIndex(ABC):
|
|||
raise NotImplementedError()
|
||||
|
||||
@abstractmethod
|
||||
async def query(self, embedding: NDArray, k: int) -> QueryDocumentsResponse:
|
||||
async def query(
|
||||
self, embedding: NDArray, k: int, score_threshold: float
|
||||
) -> QueryDocumentsResponse:
|
||||
raise NotImplementedError()
|
||||
|
||||
|
||||
|
|
@ -177,6 +179,7 @@ class BankWithIndex:
|
|||
if params is None:
|
||||
params = {}
|
||||
k = params.get("max_chunks", 3)
|
||||
score_threshold = params.get("score_threshold", 0.0)
|
||||
|
||||
def _process(c) -> str:
|
||||
if isinstance(c, str):
|
||||
|
|
@ -191,4 +194,4 @@ class BankWithIndex:
|
|||
|
||||
model = get_embedding_model(self.bank.embedding_model)
|
||||
query_vector = model.encode([query_str])[0].astype(np.float32)
|
||||
return await self.index.query(query_vector, k)
|
||||
return await self.index.query(query_vector, k, score_threshold)
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue