forked from phoenix-oss/llama-stack-mirror
[bugfix] Fix logprobs on meta-reference impl (#213)
* fix log probs * add back LogProbsConfig * error handling * bugfix
This commit is contained in:
parent
e4ae09d090
commit
4d5f7459aa
3 changed files with 36 additions and 7 deletions
|
@ -6,7 +6,6 @@
|
|||
|
||||
import asyncio
|
||||
import json
|
||||
import sys
|
||||
from typing import Any, AsyncGenerator, List, Optional
|
||||
|
||||
import fire
|
||||
|
@ -101,7 +100,9 @@ class InferenceClient(Inference):
|
|||
print(f"Error with parsing or validation: {e}")
|
||||
|
||||
|
||||
async def run_main(host: str, port: int, stream: bool, model: Optional[str]):
|
||||
async def run_main(
|
||||
host: str, port: int, stream: bool, model: Optional[str], logprobs: bool
|
||||
):
|
||||
client = InferenceClient(f"http://{host}:{port}")
|
||||
|
||||
if not model:
|
||||
|
@ -111,13 +112,27 @@ async def run_main(host: str, port: int, stream: bool, model: Optional[str]):
|
|||
content="hello world, write me a 2 sentence poem about the moon"
|
||||
)
|
||||
cprint(f"User>{message.content}", "green")
|
||||
|
||||
if logprobs:
|
||||
logprobs_config = LogProbConfig(
|
||||
top_k=1,
|
||||
)
|
||||
else:
|
||||
logprobs_config = None
|
||||
|
||||
iterator = client.chat_completion(
|
||||
model=model,
|
||||
messages=[message],
|
||||
stream=stream,
|
||||
logprobs=logprobs_config,
|
||||
)
|
||||
async for log in EventLogger().log(iterator):
|
||||
log.print()
|
||||
|
||||
if logprobs:
|
||||
async for chunk in iterator:
|
||||
cprint(f"Response: {chunk}", "red")
|
||||
else:
|
||||
async for log in EventLogger().log(iterator):
|
||||
log.print()
|
||||
|
||||
|
||||
async def run_mm_main(
|
||||
|
@ -149,13 +164,14 @@ def main(
|
|||
port: int,
|
||||
stream: bool = True,
|
||||
mm: bool = False,
|
||||
logprobs: bool = False,
|
||||
file: Optional[str] = None,
|
||||
model: Optional[str] = None,
|
||||
):
|
||||
if mm:
|
||||
asyncio.run(run_mm_main(host, port, stream, file, model))
|
||||
else:
|
||||
asyncio.run(run_main(host, port, stream, model))
|
||||
asyncio.run(run_main(host, port, stream, model, logprobs))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue