diff --git a/llama_stack/providers/impls/meta_reference/inference/inference.py b/llama_stack/providers/impls/meta_reference/inference/inference.py index e155ffd34..34053343e 100644 --- a/llama_stack/providers/impls/meta_reference/inference/inference.py +++ b/llama_stack/providers/impls/meta_reference/inference/inference.py @@ -95,7 +95,7 @@ class MetaReferenceInferenceImpl(Inference, ModelsProtocolPrivate): def impl(): stop_reason = None - for token_result in self.generator.chat_completion(request): + for token_result in self.generator.completion(request): if token_result.text == "<|eot_id|>": stop_reason = StopReason.end_of_turn text = "" diff --git a/llama_stack/providers/tests/inference/test_inference.py b/llama_stack/providers/tests/inference/test_inference.py index 60ebe1766..09d6a69db 100644 --- a/llama_stack/providers/tests/inference/test_inference.py +++ b/llama_stack/providers/tests/inference/test_inference.py @@ -159,7 +159,10 @@ async def test_completion(inference_settings): ) ] - print(chunks) + assert all(isinstance(chunk, CompletionResponseStreamChunk) for chunk in chunks) + assert len(chunks) == 51 + last = chunks[-1] + assert last.stop_reason == StopReason.out_of_tokens @pytest.mark.asyncio