From f19eb8eee34f9c7caedbc8fd28fd2b0726064fd3 Mon Sep 17 00:00:00 2001 From: Ashwin Bharambe Date: Thu, 19 Dec 2024 13:58:20 -0800 Subject: [PATCH] Update types in parallel_utils for meta-refernece-gpu impl --- .../inference/meta_reference/parallel_utils.py | 17 ++++++++++------- 1 file changed, 10 insertions(+), 7 deletions(-) diff --git a/llama_stack/providers/inline/inference/meta_reference/parallel_utils.py b/llama_stack/providers/inline/inference/meta_reference/parallel_utils.py index 830160578..36720612c 100644 --- a/llama_stack/providers/inline/inference/meta_reference/parallel_utils.py +++ b/llama_stack/providers/inline/inference/meta_reference/parallel_utils.py @@ -34,7 +34,10 @@ from pydantic import BaseModel, Field from torch.distributed.launcher.api import elastic_launch, LaunchConfig from typing_extensions import Annotated -from llama_stack.apis.inference import ChatCompletionRequest, CompletionRequest +from llama_stack.providers.utils.inference.prompt_adapter import ( + ChatCompletionRequestWithRawContent, + CompletionRequestWithRawContent, +) from .generation import TokenResult @@ -79,7 +82,7 @@ class TaskRequest(BaseModel): type: Literal[ProcessingMessageName.task_request] = ( ProcessingMessageName.task_request ) - task: Union[CompletionRequest, ChatCompletionRequest] + task: Union[CompletionRequestWithRawContent, ChatCompletionRequestWithRawContent] class TaskResponse(BaseModel): @@ -264,9 +267,6 @@ def launch_dist_group( init_model_cb: Callable, **kwargs, ) -> None: - id = uuid.uuid4().hex - dist_url = f"file:///tmp/llama3_{id}_{time.time()}" - with tempfile.TemporaryDirectory() as tmpdir: # TODO: track workers and if they terminate, tell parent process about it so cleanup can happen launch_config = LaunchConfig( @@ -315,7 +315,7 @@ def start_model_parallel_process( # wait until the model is loaded; rank 0 will send a message to indicate it's ready request_socket.send(encode_msg(ReadyRequest())) - response = request_socket.recv() + _response = request_socket.recv() log.info("Loaded model...") return request_socket, process @@ -349,7 +349,10 @@ class ModelParallelProcessGroup: self.started = False def run_inference( - self, req: Union[CompletionRequest, ChatCompletionRequest] + self, + req: Union[ + CompletionRequestWithRawContent, ChatCompletionRequestWithRawContent + ], ) -> Generator: assert not self.running, "inference already running"