diff --git a/llama_stack/providers/inline/inference/meta_reference/inference.py b/llama_stack/providers/inline/inference/meta_reference/inference.py index 0e69c2e7e..1bc098fab 100644 --- a/llama_stack/providers/inline/inference/meta_reference/inference.py +++ b/llama_stack/providers/inline/inference/meta_reference/inference.py @@ -253,7 +253,8 @@ class MetaReferenceInferenceImpl( def impl(): stop_reason = None - for token_result in self.generator.completion(request): + for token_results in self.generator.completion([request]): + token_result = token_results[0] if token_result.token == tokenizer.eot_id: stop_reason = StopReason.end_of_turn text = "" diff --git a/llama_stack/providers/inline/inference/meta_reference/parallel_utils.py b/llama_stack/providers/inline/inference/meta_reference/parallel_utils.py index 9ffcf99fe..8c0ffc632 100644 --- a/llama_stack/providers/inline/inference/meta_reference/parallel_utils.py +++ b/llama_stack/providers/inline/inference/meta_reference/parallel_utils.py @@ -69,7 +69,10 @@ class CancelSentinel(BaseModel): class TaskRequest(BaseModel): type: Literal[ProcessingMessageName.task_request] = ProcessingMessageName.task_request - task: Tuple[str, List[CompletionRequestWithRawContent] | List[ChatCompletionRequestWithRawContent]] + task: Tuple[ + str, + List[CompletionRequestWithRawContent] | List[ChatCompletionRequestWithRawContent], + ] class TaskResponse(BaseModel): @@ -234,7 +237,7 @@ def worker_process_entrypoint( if isinstance(task, EndSentinel): break - assert isinstance(task, TaskRequest) + assert isinstance(task, TaskRequest), task result = model(task.task) except StopIteration: break @@ -331,7 +334,10 @@ class ModelParallelProcessGroup: def run_inference( self, - req: Tuple[str, List[CompletionRequestWithRawContent] | List[ChatCompletionRequestWithRawContent]], + req: Tuple[ + str, + List[CompletionRequestWithRawContent] | List[ChatCompletionRequestWithRawContent], + ], ) -> Generator: assert not self.running, "inference already running"