mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-06-28 02:53:30 +00:00
meta reference inference fixes (#797)
Miscellaneous fixes for meta reference inference Tests for log probs dont pass because meta reference does not support top_k > 1
This commit is contained in:
parent
cb41848a2a
commit
9f14382d82
5 changed files with 20 additions and 12 deletions
|
@ -193,14 +193,14 @@ class MetaReferenceInferenceImpl(
|
|||
]
|
||||
|
||||
yield CompletionResponseStreamChunk(
|
||||
delta=TextDelta(text=text),
|
||||
delta=text,
|
||||
stop_reason=stop_reason,
|
||||
logprobs=logprobs if request.logprobs else None,
|
||||
)
|
||||
|
||||
if stop_reason is None:
|
||||
yield CompletionResponseStreamChunk(
|
||||
delta=TextDelta(text=""),
|
||||
delta="",
|
||||
stop_reason=StopReason.out_of_tokens,
|
||||
)
|
||||
|
||||
|
@ -223,10 +223,10 @@ class MetaReferenceInferenceImpl(
|
|||
tokenizer = self.generator.formatter.tokenizer
|
||||
for token_result in self.generator.completion(request):
|
||||
tokens.append(token_result.token)
|
||||
|
||||
if token_result.token in tokenizer.stop_tokens:
|
||||
# not quite right semantically
|
||||
if token_result.text == "<|eot_id|>":
|
||||
stop_reason = StopReason.end_of_turn
|
||||
elif token_result.text == "<|eom_id|>":
|
||||
stop_reason = StopReason.end_of_message
|
||||
|
||||
if request.logprobs:
|
||||
assert len(token_result.logprobs) == 1
|
||||
|
@ -243,6 +243,10 @@ class MetaReferenceInferenceImpl(
|
|||
stop_reason = StopReason.out_of_tokens
|
||||
|
||||
content = self.generator.formatter.tokenizer.decode(tokens)
|
||||
if content.endswith("<|eot_id|>"):
|
||||
content = content[: -len("<|eot_id|>")]
|
||||
elif content.endswith("<|eom_id|>"):
|
||||
content = content[: -len("<|eom_id|>")]
|
||||
return CompletionResponse(
|
||||
content=content,
|
||||
stop_reason=stop_reason,
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue