forked from phoenix-oss/llama-stack-mirror
[bugfix] Fix logprobs on meta-reference impl (#213)
* fix log probs * add back LogProbsConfig * error handling * bugfix
This commit is contained in:
parent
e4ae09d090
commit
4d5f7459aa
3 changed files with 36 additions and 7 deletions
|
@ -297,7 +297,7 @@ class Llama:
|
|||
token=next_token[0].item(),
|
||||
text=self.tokenizer.decode(next_token.tolist()),
|
||||
logprobs=(
|
||||
token_logprobs[:, prev_pos + 1 : cur_pos + 1][0].tolist()
|
||||
token_logprobs[:, cur_pos : cur_pos + 1][0].tolist()
|
||||
if logprobs
|
||||
else None
|
||||
),
|
||||
|
|
|
@ -132,7 +132,20 @@ class MetaReferenceInferenceImpl(Inference, RoutableProvider):
|
|||
|
||||
if not request.stream:
|
||||
if request.logprobs:
|
||||
logprobs.append(token_result.logprob)
|
||||
assert (
|
||||
len(token_result.logprobs) == 1
|
||||
), "Expected logprob to contain 1 result for the current token"
|
||||
assert (
|
||||
request.logprobs.top_k == 1
|
||||
), "Only top_k=1 is supported for LogProbConfig"
|
||||
|
||||
logprobs.append(
|
||||
TokenLogProbs(
|
||||
logprobs_by_token={
|
||||
token_result.text: token_result.logprobs[0]
|
||||
}
|
||||
)
|
||||
)
|
||||
|
||||
continue
|
||||
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue