From 7ed137e96310984e9a518beac21007e56bef881b Mon Sep 17 00:00:00 2001 From: ehhuang Date: Thu, 24 Apr 2025 13:03:35 -0700 Subject: [PATCH] fix: meta ref inference (#2022) MAX_BATCH_SIZE=10 LLAMA_MODELS_DEBUG=1 LLAMA_STACK_PORT=5002 LLAMA_STACK_LOGGING='all=info' llama stack run meta-reference-gpu --env INFERENCE_MODEL=meta-llama/Llama-4-Scout-17B-16E-Instruct --env INFERENCE_CHECKPOINT_DIR=... LLAMA_STACK_CONFIG=http://localhost:5002/ pytest -s -v tests/integration/inference --safety-shield meta-llama/Llama-Guard-3-8B --vision-model meta-llama/Llama-4-Scout-17B-16E-Instruct --text-model meta-llama/Llama-4-Scout-17B-16E-Instruct Co-authored-by: Eric Huang --- .../inline/inference/meta_reference/inference.py | 3 ++- .../inference/meta_reference/parallel_utils.py | 12 +++++++++--- 2 files changed, 11 insertions(+), 4 deletions(-) diff --git a/llama_stack/providers/inline/inference/meta_reference/inference.py b/llama_stack/providers/inline/inference/meta_reference/inference.py index 0e69c2e7e..1bc098fab 100644 --- a/llama_stack/providers/inline/inference/meta_reference/inference.py +++ b/llama_stack/providers/inline/inference/meta_reference/inference.py @@ -253,7 +253,8 @@ class MetaReferenceInferenceImpl( def impl(): stop_reason = None - for token_result in self.generator.completion(request): + for token_results in self.generator.completion([request]): + token_result = token_results[0] if token_result.token == tokenizer.eot_id: stop_reason = StopReason.end_of_turn text = "" diff --git a/llama_stack/providers/inline/inference/meta_reference/parallel_utils.py b/llama_stack/providers/inline/inference/meta_reference/parallel_utils.py index 9ffcf99fe..8c0ffc632 100644 --- a/llama_stack/providers/inline/inference/meta_reference/parallel_utils.py +++ b/llama_stack/providers/inline/inference/meta_reference/parallel_utils.py @@ -69,7 +69,10 @@ class CancelSentinel(BaseModel): class TaskRequest(BaseModel): type: Literal[ProcessingMessageName.task_request] = ProcessingMessageName.task_request - task: Tuple[str, List[CompletionRequestWithRawContent] | List[ChatCompletionRequestWithRawContent]] + task: Tuple[ + str, + List[CompletionRequestWithRawContent] | List[ChatCompletionRequestWithRawContent], + ] class TaskResponse(BaseModel): @@ -234,7 +237,7 @@ def worker_process_entrypoint( if isinstance(task, EndSentinel): break - assert isinstance(task, TaskRequest) + assert isinstance(task, TaskRequest), task result = model(task.task) except StopIteration: break @@ -331,7 +334,10 @@ class ModelParallelProcessGroup: def run_inference( self, - req: Tuple[str, List[CompletionRequestWithRawContent] | List[ChatCompletionRequestWithRawContent]], + req: Tuple[ + str, + List[CompletionRequestWithRawContent] | List[ChatCompletionRequestWithRawContent], + ], ) -> Generator: assert not self.running, "inference already running"