mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-06-27 18:50:41 +00:00
[bugfix] Fix logprobs on meta-reference impl (#213)
* fix log probs * add back LogProbsConfig * error handling * bugfix
This commit is contained in:
parent
e4ae09d090
commit
4d5f7459aa
3 changed files with 36 additions and 7 deletions
|
@ -6,7 +6,6 @@
|
||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
import json
|
import json
|
||||||
import sys
|
|
||||||
from typing import Any, AsyncGenerator, List, Optional
|
from typing import Any, AsyncGenerator, List, Optional
|
||||||
|
|
||||||
import fire
|
import fire
|
||||||
|
@ -101,7 +100,9 @@ class InferenceClient(Inference):
|
||||||
print(f"Error with parsing or validation: {e}")
|
print(f"Error with parsing or validation: {e}")
|
||||||
|
|
||||||
|
|
||||||
async def run_main(host: str, port: int, stream: bool, model: Optional[str]):
|
async def run_main(
|
||||||
|
host: str, port: int, stream: bool, model: Optional[str], logprobs: bool
|
||||||
|
):
|
||||||
client = InferenceClient(f"http://{host}:{port}")
|
client = InferenceClient(f"http://{host}:{port}")
|
||||||
|
|
||||||
if not model:
|
if not model:
|
||||||
|
@ -111,13 +112,27 @@ async def run_main(host: str, port: int, stream: bool, model: Optional[str]):
|
||||||
content="hello world, write me a 2 sentence poem about the moon"
|
content="hello world, write me a 2 sentence poem about the moon"
|
||||||
)
|
)
|
||||||
cprint(f"User>{message.content}", "green")
|
cprint(f"User>{message.content}", "green")
|
||||||
|
|
||||||
|
if logprobs:
|
||||||
|
logprobs_config = LogProbConfig(
|
||||||
|
top_k=1,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
logprobs_config = None
|
||||||
|
|
||||||
iterator = client.chat_completion(
|
iterator = client.chat_completion(
|
||||||
model=model,
|
model=model,
|
||||||
messages=[message],
|
messages=[message],
|
||||||
stream=stream,
|
stream=stream,
|
||||||
|
logprobs=logprobs_config,
|
||||||
)
|
)
|
||||||
async for log in EventLogger().log(iterator):
|
|
||||||
log.print()
|
if logprobs:
|
||||||
|
async for chunk in iterator:
|
||||||
|
cprint(f"Response: {chunk}", "red")
|
||||||
|
else:
|
||||||
|
async for log in EventLogger().log(iterator):
|
||||||
|
log.print()
|
||||||
|
|
||||||
|
|
||||||
async def run_mm_main(
|
async def run_mm_main(
|
||||||
|
@ -149,13 +164,14 @@ def main(
|
||||||
port: int,
|
port: int,
|
||||||
stream: bool = True,
|
stream: bool = True,
|
||||||
mm: bool = False,
|
mm: bool = False,
|
||||||
|
logprobs: bool = False,
|
||||||
file: Optional[str] = None,
|
file: Optional[str] = None,
|
||||||
model: Optional[str] = None,
|
model: Optional[str] = None,
|
||||||
):
|
):
|
||||||
if mm:
|
if mm:
|
||||||
asyncio.run(run_mm_main(host, port, stream, file, model))
|
asyncio.run(run_mm_main(host, port, stream, file, model))
|
||||||
else:
|
else:
|
||||||
asyncio.run(run_main(host, port, stream, model))
|
asyncio.run(run_main(host, port, stream, model, logprobs))
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
|
|
@ -297,7 +297,7 @@ class Llama:
|
||||||
token=next_token[0].item(),
|
token=next_token[0].item(),
|
||||||
text=self.tokenizer.decode(next_token.tolist()),
|
text=self.tokenizer.decode(next_token.tolist()),
|
||||||
logprobs=(
|
logprobs=(
|
||||||
token_logprobs[:, prev_pos + 1 : cur_pos + 1][0].tolist()
|
token_logprobs[:, cur_pos : cur_pos + 1][0].tolist()
|
||||||
if logprobs
|
if logprobs
|
||||||
else None
|
else None
|
||||||
),
|
),
|
||||||
|
|
|
@ -132,7 +132,20 @@ class MetaReferenceInferenceImpl(Inference, RoutableProvider):
|
||||||
|
|
||||||
if not request.stream:
|
if not request.stream:
|
||||||
if request.logprobs:
|
if request.logprobs:
|
||||||
logprobs.append(token_result.logprob)
|
assert (
|
||||||
|
len(token_result.logprobs) == 1
|
||||||
|
), "Expected logprob to contain 1 result for the current token"
|
||||||
|
assert (
|
||||||
|
request.logprobs.top_k == 1
|
||||||
|
), "Only top_k=1 is supported for LogProbConfig"
|
||||||
|
|
||||||
|
logprobs.append(
|
||||||
|
TokenLogProbs(
|
||||||
|
logprobs_by_token={
|
||||||
|
token_result.text: token_result.logprobs[0]
|
||||||
|
}
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
continue
|
continue
|
||||||
|
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue