[bugfix] Fix logprobs on meta-reference impl (#213)

* fix log probs

* add back LogProbsConfig

* error handling

* bugfix
This commit is contained in:
Xi Yan 2024-10-07 19:42:39 -07:00 committed by GitHub
parent e4ae09d090
commit 4d5f7459aa
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
3 changed files with 36 additions and 7 deletions

View file

@ -6,7 +6,6 @@
import asyncio import asyncio
import json import json
import sys
from typing import Any, AsyncGenerator, List, Optional from typing import Any, AsyncGenerator, List, Optional
import fire import fire
@ -101,7 +100,9 @@ class InferenceClient(Inference):
print(f"Error with parsing or validation: {e}") print(f"Error with parsing or validation: {e}")
async def run_main(host: str, port: int, stream: bool, model: Optional[str]): async def run_main(
host: str, port: int, stream: bool, model: Optional[str], logprobs: bool
):
client = InferenceClient(f"http://{host}:{port}") client = InferenceClient(f"http://{host}:{port}")
if not model: if not model:
@ -111,13 +112,27 @@ async def run_main(host: str, port: int, stream: bool, model: Optional[str]):
content="hello world, write me a 2 sentence poem about the moon" content="hello world, write me a 2 sentence poem about the moon"
) )
cprint(f"User>{message.content}", "green") cprint(f"User>{message.content}", "green")
if logprobs:
logprobs_config = LogProbConfig(
top_k=1,
)
else:
logprobs_config = None
iterator = client.chat_completion( iterator = client.chat_completion(
model=model, model=model,
messages=[message], messages=[message],
stream=stream, stream=stream,
logprobs=logprobs_config,
) )
async for log in EventLogger().log(iterator):
log.print() if logprobs:
async for chunk in iterator:
cprint(f"Response: {chunk}", "red")
else:
async for log in EventLogger().log(iterator):
log.print()
async def run_mm_main( async def run_mm_main(
@ -149,13 +164,14 @@ def main(
port: int, port: int,
stream: bool = True, stream: bool = True,
mm: bool = False, mm: bool = False,
logprobs: bool = False,
file: Optional[str] = None, file: Optional[str] = None,
model: Optional[str] = None, model: Optional[str] = None,
): ):
if mm: if mm:
asyncio.run(run_mm_main(host, port, stream, file, model)) asyncio.run(run_mm_main(host, port, stream, file, model))
else: else:
asyncio.run(run_main(host, port, stream, model)) asyncio.run(run_main(host, port, stream, model, logprobs))
if __name__ == "__main__": if __name__ == "__main__":

View file

@ -297,7 +297,7 @@ class Llama:
token=next_token[0].item(), token=next_token[0].item(),
text=self.tokenizer.decode(next_token.tolist()), text=self.tokenizer.decode(next_token.tolist()),
logprobs=( logprobs=(
token_logprobs[:, prev_pos + 1 : cur_pos + 1][0].tolist() token_logprobs[:, cur_pos : cur_pos + 1][0].tolist()
if logprobs if logprobs
else None else None
), ),

View file

@ -132,7 +132,20 @@ class MetaReferenceInferenceImpl(Inference, RoutableProvider):
if not request.stream: if not request.stream:
if request.logprobs: if request.logprobs:
logprobs.append(token_result.logprob) assert (
len(token_result.logprobs) == 1
), "Expected logprob to contain 1 result for the current token"
assert (
request.logprobs.top_k == 1
), "Only top_k=1 is supported for LogProbConfig"
logprobs.append(
TokenLogProbs(
logprobs_by_token={
token_result.text: token_result.logprobs[0]
}
)
)
continue continue