[bugfix] Fix logprobs on meta-reference impl (#213)

* fix log probs

* add back LogProbsConfig

* error handling

* bugfix
This commit is contained in:
Xi Yan 2024-10-07 19:42:39 -07:00 committed by GitHub
parent e4ae09d090
commit 4d5f7459aa
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
3 changed files with 36 additions and 7 deletions

View file

@ -132,7 +132,20 @@ class MetaReferenceInferenceImpl(Inference, RoutableProvider):
if not request.stream:
if request.logprobs:
logprobs.append(token_result.logprob)
assert (
len(token_result.logprobs) == 1
), "Expected logprob to contain 1 result for the current token"
assert (
request.logprobs.top_k == 1
), "Only top_k=1 is supported for LogProbConfig"
logprobs.append(
TokenLogProbs(
logprobs_by_token={
token_result.text: token_result.logprobs[0]
}
)
)
continue