mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-12-10 03:30:58 +00:00
Merge remote-tracking branch 'upstream/main' into qdrant
This commit is contained in:
commit
d9531d17de
12 changed files with 76 additions and 18 deletions
|
|
@ -6,7 +6,6 @@
|
|||
|
||||
import asyncio
|
||||
import json
|
||||
import sys
|
||||
from typing import Any, AsyncGenerator, List, Optional
|
||||
|
||||
import fire
|
||||
|
|
@ -101,7 +100,9 @@ class InferenceClient(Inference):
|
|||
print(f"Error with parsing or validation: {e}")
|
||||
|
||||
|
||||
async def run_main(host: str, port: int, stream: bool, model: Optional[str]):
|
||||
async def run_main(
|
||||
host: str, port: int, stream: bool, model: Optional[str], logprobs: bool
|
||||
):
|
||||
client = InferenceClient(f"http://{host}:{port}")
|
||||
|
||||
if not model:
|
||||
|
|
@ -111,13 +112,27 @@ async def run_main(host: str, port: int, stream: bool, model: Optional[str]):
|
|||
content="hello world, write me a 2 sentence poem about the moon"
|
||||
)
|
||||
cprint(f"User>{message.content}", "green")
|
||||
|
||||
if logprobs:
|
||||
logprobs_config = LogProbConfig(
|
||||
top_k=1,
|
||||
)
|
||||
else:
|
||||
logprobs_config = None
|
||||
|
||||
iterator = client.chat_completion(
|
||||
model=model,
|
||||
messages=[message],
|
||||
stream=stream,
|
||||
logprobs=logprobs_config,
|
||||
)
|
||||
async for log in EventLogger().log(iterator):
|
||||
log.print()
|
||||
|
||||
if logprobs:
|
||||
async for chunk in iterator:
|
||||
cprint(f"Response: {chunk}", "red")
|
||||
else:
|
||||
async for log in EventLogger().log(iterator):
|
||||
log.print()
|
||||
|
||||
|
||||
async def run_mm_main(
|
||||
|
|
@ -149,13 +164,14 @@ def main(
|
|||
port: int,
|
||||
stream: bool = True,
|
||||
mm: bool = False,
|
||||
logprobs: bool = False,
|
||||
file: Optional[str] = None,
|
||||
model: Optional[str] = None,
|
||||
):
|
||||
if mm:
|
||||
asyncio.run(run_mm_main(host, port, stream, file, model))
|
||||
else:
|
||||
asyncio.run(run_main(host, port, stream, model))
|
||||
asyncio.run(run_main(host, port, stream, model, logprobs))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
|
|
|||
|
|
@ -169,7 +169,7 @@ def run_download_cmd(args: argparse.Namespace, parser: argparse.ArgumentParser):
|
|||
meta_url = args.meta_url
|
||||
if not meta_url:
|
||||
meta_url = input(
|
||||
"Please provide the signed URL you received via email (e.g., https://llama3-1.llamameta.net/*?Policy...): "
|
||||
"Please provide the signed URL you received via email after visiting https://www.llama.com/llama-downloads/ (e.g., https://llama3-1.llamameta.net/*?Policy...): "
|
||||
)
|
||||
assert meta_url is not None and "llamameta.net" in meta_url
|
||||
_meta_download(model, meta_url, info)
|
||||
|
|
|
|||
|
|
@ -673,7 +673,7 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
|
||||
async def _retrieve_context(
|
||||
self, session_id: str, messages: List[Message], attachments: List[Attachment]
|
||||
) -> Tuple[List[str], List[int]]: # (rag_context, bank_ids)
|
||||
) -> Tuple[Optional[List[str]], Optional[List[int]]]: # (rag_context, bank_ids)
|
||||
bank_ids = []
|
||||
|
||||
memory = self._memory_tool_definition()
|
||||
|
|
@ -722,12 +722,13 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
chunks = [c for r in results for c in r.chunks]
|
||||
scores = [s for r in results for s in r.scores]
|
||||
|
||||
if not chunks:
|
||||
return None, bank_ids
|
||||
|
||||
# sort by score
|
||||
chunks, scores = zip(
|
||||
*sorted(zip(chunks, scores), key=lambda x: x[1], reverse=True)
|
||||
)
|
||||
if not chunks:
|
||||
return None, bank_ids
|
||||
|
||||
tokens = 0
|
||||
picked = []
|
||||
|
|
|
|||
|
|
@ -297,7 +297,7 @@ class Llama:
|
|||
token=next_token[0].item(),
|
||||
text=self.tokenizer.decode(next_token.tolist()),
|
||||
logprobs=(
|
||||
token_logprobs[:, prev_pos + 1 : cur_pos + 1][0].tolist()
|
||||
token_logprobs[:, cur_pos : cur_pos + 1][0].tolist()
|
||||
if logprobs
|
||||
else None
|
||||
),
|
||||
|
|
|
|||
|
|
@ -132,7 +132,20 @@ class MetaReferenceInferenceImpl(Inference, RoutableProvider):
|
|||
|
||||
if not request.stream:
|
||||
if request.logprobs:
|
||||
logprobs.append(token_result.logprob)
|
||||
assert (
|
||||
len(token_result.logprobs) == 1
|
||||
), "Expected logprob to contain 1 result for the current token"
|
||||
assert (
|
||||
request.logprobs.top_k == 1
|
||||
), "Only top_k=1 is supported for LogProbConfig"
|
||||
|
||||
logprobs.append(
|
||||
TokenLogProbs(
|
||||
logprobs_by_token={
|
||||
token_result.text: token_result.logprobs[0]
|
||||
}
|
||||
)
|
||||
)
|
||||
|
||||
continue
|
||||
|
||||
|
|
|
|||
|
|
@ -59,7 +59,7 @@ def available_providers() -> List[ProviderSpec]:
|
|||
remote_provider_spec(
|
||||
Api.memory,
|
||||
AdapterSpec(
|
||||
adapter_id="weaviate",
|
||||
adapter_type="weaviate",
|
||||
pip_packages=EMBEDDING_DEPS + ["weaviate-client"],
|
||||
module="llama_stack.providers.adapters.memory.weaviate",
|
||||
provider_data_validator="llama_stack.providers.adapters.memory.weaviate.WeaviateRequestProviderData",
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue