diff --git a/llama_stack/apis/inference/client.py b/llama_stack/apis/inference/client.py index 5cfae633c..fffcf4692 100644 --- a/llama_stack/apis/inference/client.py +++ b/llama_stack/apis/inference/client.py @@ -6,7 +6,6 @@ import asyncio import json -import sys from typing import Any, AsyncGenerator, List, Optional import fire @@ -101,7 +100,9 @@ class InferenceClient(Inference): print(f"Error with parsing or validation: {e}") -async def run_main(host: str, port: int, stream: bool, model: Optional[str]): +async def run_main( + host: str, port: int, stream: bool, model: Optional[str], logprobs: bool +): client = InferenceClient(f"http://{host}:{port}") if not model: @@ -111,13 +112,27 @@ async def run_main(host: str, port: int, stream: bool, model: Optional[str]): content="hello world, write me a 2 sentence poem about the moon" ) cprint(f"User>{message.content}", "green") + + if logprobs: + logprobs_config = LogProbConfig( + top_k=1, + ) + else: + logprobs_config = None + iterator = client.chat_completion( model=model, messages=[message], stream=stream, + logprobs=logprobs_config, ) - async for log in EventLogger().log(iterator): - log.print() + + if logprobs: + async for chunk in iterator: + cprint(f"Response: {chunk}", "red") + else: + async for log in EventLogger().log(iterator): + log.print() async def run_mm_main( @@ -149,13 +164,14 @@ def main( port: int, stream: bool = True, mm: bool = False, + logprobs: bool = False, file: Optional[str] = None, model: Optional[str] = None, ): if mm: asyncio.run(run_mm_main(host, port, stream, file, model)) else: - asyncio.run(run_main(host, port, stream, model)) + asyncio.run(run_main(host, port, stream, model, logprobs)) if __name__ == "__main__": diff --git a/llama_stack/providers/impls/meta_reference/inference/generation.py b/llama_stack/providers/impls/meta_reference/inference/generation.py index 4351a3d56..27e086e0f 100644 --- a/llama_stack/providers/impls/meta_reference/inference/generation.py +++ b/llama_stack/providers/impls/meta_reference/inference/generation.py @@ -297,7 +297,7 @@ class Llama: token=next_token[0].item(), text=self.tokenizer.decode(next_token.tolist()), logprobs=( - token_logprobs[:, prev_pos + 1 : cur_pos + 1][0].tolist() + token_logprobs[:, cur_pos : cur_pos + 1][0].tolist() if logprobs else None ), diff --git a/llama_stack/providers/impls/meta_reference/inference/inference.py b/llama_stack/providers/impls/meta_reference/inference/inference.py index e89d8ec4c..dca4ea6fb 100644 --- a/llama_stack/providers/impls/meta_reference/inference/inference.py +++ b/llama_stack/providers/impls/meta_reference/inference/inference.py @@ -132,7 +132,20 @@ class MetaReferenceInferenceImpl(Inference, RoutableProvider): if not request.stream: if request.logprobs: - logprobs.append(token_result.logprob) + assert ( + len(token_result.logprobs) == 1 + ), "Expected logprob to contain 1 result for the current token" + assert ( + request.logprobs.top_k == 1 + ), "Only top_k=1 is supported for LogProbConfig" + + logprobs.append( + TokenLogProbs( + logprobs_by_token={ + token_result.text: token_result.logprobs[0] + } + ) + ) continue