meta reference inference fixes (#797)

Miscellaneous fixes for meta reference inference

Tests for log probs dont pass because meta reference does not support
top_k > 1
This commit is contained in:
Ashwin Bharambe 2025-01-16 18:17:46 -08:00 committed by GitHub
parent cb41848a2a
commit 9f14382d82
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
5 changed files with 20 additions and 12 deletions

View file

@ -193,14 +193,14 @@ class MetaReferenceInferenceImpl(
]
yield CompletionResponseStreamChunk(
delta=TextDelta(text=text),
delta=text,
stop_reason=stop_reason,
logprobs=logprobs if request.logprobs else None,
)
if stop_reason is None:
yield CompletionResponseStreamChunk(
delta=TextDelta(text=""),
delta="",
stop_reason=StopReason.out_of_tokens,
)
@ -223,10 +223,10 @@ class MetaReferenceInferenceImpl(
tokenizer = self.generator.formatter.tokenizer
for token_result in self.generator.completion(request):
tokens.append(token_result.token)
if token_result.token in tokenizer.stop_tokens:
# not quite right semantically
if token_result.text == "<|eot_id|>":
stop_reason = StopReason.end_of_turn
elif token_result.text == "<|eom_id|>":
stop_reason = StopReason.end_of_message
if request.logprobs:
assert len(token_result.logprobs) == 1
@ -243,6 +243,10 @@ class MetaReferenceInferenceImpl(
stop_reason = StopReason.out_of_tokens
content = self.generator.formatter.tokenizer.decode(tokens)
if content.endswith("<|eot_id|>"):
content = content[: -len("<|eot_id|>")]
elif content.endswith("<|eom_id|>"):
content = content[: -len("<|eom_id|>")]
return CompletionResponse(
content=content,
stop_reason=stop_reason,