[bugfix] Fix logprobs on meta-reference impl (#213)

* fix log probs

* add back LogProbsConfig

* error handling

* bugfix
This commit is contained in:
Xi Yan 2024-10-07 19:42:39 -07:00 committed by GitHub
parent e4ae09d090
commit 4d5f7459aa
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
3 changed files with 36 additions and 7 deletions

View file

@ -297,7 +297,7 @@ class Llama:
token=next_token[0].item(),
text=self.tokenizer.decode(next_token.tolist()),
logprobs=(
token_logprobs[:, prev_pos + 1 : cur_pos + 1][0].tolist()
token_logprobs[:, cur_pos : cur_pos + 1][0].tolist()
if logprobs
else None
),