Merge remote-tracking branch 'upstream/main' into qdrant

This commit is contained in:
Anush008 2024-10-09 01:37:29 +05:30
commit d9531d17de
No known key found for this signature in database
12 changed files with 76 additions and 18 deletions

View file

@ -673,7 +673,7 @@ class ChatAgent(ShieldRunnerMixin):
async def _retrieve_context(
self, session_id: str, messages: List[Message], attachments: List[Attachment]
) -> Tuple[List[str], List[int]]: # (rag_context, bank_ids)
) -> Tuple[Optional[List[str]], Optional[List[int]]]: # (rag_context, bank_ids)
bank_ids = []
memory = self._memory_tool_definition()
@ -722,12 +722,13 @@ class ChatAgent(ShieldRunnerMixin):
chunks = [c for r in results for c in r.chunks]
scores = [s for r in results for s in r.scores]
if not chunks:
return None, bank_ids
# sort by score
chunks, scores = zip(
*sorted(zip(chunks, scores), key=lambda x: x[1], reverse=True)
)
if not chunks:
return None, bank_ids
tokens = 0
picked = []

View file

@ -297,7 +297,7 @@ class Llama:
token=next_token[0].item(),
text=self.tokenizer.decode(next_token.tolist()),
logprobs=(
token_logprobs[:, prev_pos + 1 : cur_pos + 1][0].tolist()
token_logprobs[:, cur_pos : cur_pos + 1][0].tolist()
if logprobs
else None
),

View file

@ -132,7 +132,20 @@ class MetaReferenceInferenceImpl(Inference, RoutableProvider):
if not request.stream:
if request.logprobs:
logprobs.append(token_result.logprob)
assert (
len(token_result.logprobs) == 1
), "Expected logprob to contain 1 result for the current token"
assert (
request.logprobs.top_k == 1
), "Only top_k=1 is supported for LogProbConfig"
logprobs.append(
TokenLogProbs(
logprobs_by_token={
token_result.text: token_result.logprobs[0]
}
)
)
continue

View file

@ -59,7 +59,7 @@ def available_providers() -> List[ProviderSpec]:
remote_provider_spec(
Api.memory,
AdapterSpec(
adapter_id="weaviate",
adapter_type="weaviate",
pip_packages=EMBEDDING_DEPS + ["weaviate-client"],
module="llama_stack.providers.adapters.memory.weaviate",
provider_data_validator="llama_stack.providers.adapters.memory.weaviate.WeaviateRequestProviderData",