From fedc11b72674a277873fac2a0baaf49fc07b4e21 Mon Sep 17 00:00:00 2001 From: Ashwin Bharambe Date: Fri, 18 Oct 2024 20:46:48 -0700 Subject: [PATCH] Fix --- .../providers/impls/meta_reference/inference/inference.py | 2 +- llama_stack/providers/tests/inference/test_inference.py | 5 ++++- 2 files changed, 5 insertions(+), 2 deletions(-) diff --git a/llama_stack/providers/impls/meta_reference/inference/inference.py b/llama_stack/providers/impls/meta_reference/inference/inference.py index e155ffd34..34053343e 100644 --- a/llama_stack/providers/impls/meta_reference/inference/inference.py +++ b/llama_stack/providers/impls/meta_reference/inference/inference.py @@ -95,7 +95,7 @@ class MetaReferenceInferenceImpl(Inference, ModelsProtocolPrivate): def impl(): stop_reason = None - for token_result in self.generator.chat_completion(request): + for token_result in self.generator.completion(request): if token_result.text == "<|eot_id|>": stop_reason = StopReason.end_of_turn text = "" diff --git a/llama_stack/providers/tests/inference/test_inference.py b/llama_stack/providers/tests/inference/test_inference.py index 60ebe1766..09d6a69db 100644 --- a/llama_stack/providers/tests/inference/test_inference.py +++ b/llama_stack/providers/tests/inference/test_inference.py @@ -159,7 +159,10 @@ async def test_completion(inference_settings): ) ] - print(chunks) + assert all(isinstance(chunk, CompletionResponseStreamChunk) for chunk in chunks) + assert len(chunks) == 51 + last = chunks[-1] + assert last.stop_reason == StopReason.out_of_tokens @pytest.mark.asyncio