diff --git a/llama_stack/providers/impls/meta_reference/inference/parallel_utils.py b/llama_stack/providers/impls/meta_reference/inference/parallel_utils.py index e83d58526..5b5372240 100644 --- a/llama_stack/providers/impls/meta_reference/inference/parallel_utils.py +++ b/llama_stack/providers/impls/meta_reference/inference/parallel_utils.py @@ -10,7 +10,8 @@ import os import tempfile import time import uuid -from typing import Any, Callable, Generator, List, Optional, Union +from enum import Enum +from typing import Any, Callable, Generator, List, Literal, Optional, Union import torch @@ -24,9 +25,10 @@ from fairscale.nn.model_parallel.initialize import ( from llama_models.llama3.api.datatypes import Message, ToolPromptFormat -from pydantic import BaseModel +from pydantic import BaseModel, Field from torch.distributed.launcher.api import elastic_launch, LaunchConfig +from typing_extensions import Annotated from .generation import TokenResult @@ -39,40 +41,67 @@ class InferenceArgs(BaseModel): logprobs: bool tool_prompt_format: ToolPromptFormat - # avoid converting llama_models types to BaseModel for now - class Config: - arbitrary_types_allowed = True + +class ProcessingMessageName(str, Enum): + ready_request = "ready_request" + ready_response = "ready_response" + end_sentinel = "end_sentinel" + cancel_sentinel = "cancel_sentinel" + task_request = "task_request" + task_response = "task_response" + exception_response = "exception_response" class ReadyRequest(BaseModel): + type: Literal[ProcessingMessageName.ready_request] = ( + ProcessingMessageName.ready_request + ) message: str = "READY?" class ReadyResponse(BaseModel): + type: Literal[ProcessingMessageName.ready_response] = ( + ProcessingMessageName.ready_response + ) message: str = "YES READY" class EndSentinel(BaseModel): + type: Literal[ProcessingMessageName.end_sentinel] = ( + ProcessingMessageName.end_sentinel + ) message: str = "__end_sentinel__" class CancelSentinel(BaseModel): + type: Literal[ProcessingMessageName.cancel_sentinel] = ( + ProcessingMessageName.cancel_sentinel + ) message: str = "__cancel_sentinel__" class TaskRequest(BaseModel): + type: Literal[ProcessingMessageName.task_request] = ( + ProcessingMessageName.task_request + ) task: InferenceArgs class TaskResponse(BaseModel): + type: Literal[ProcessingMessageName.task_response] = ( + ProcessingMessageName.task_response + ) result: TokenResult class ExceptionResponse(BaseModel): + type: Literal[ProcessingMessageName.exception_response] = ( + ProcessingMessageName.exception_response + ) error: str -Message = Union[ +ProcessingMessage = Union[ ReadyRequest, ReadyResponse, EndSentinel, @@ -83,10 +112,21 @@ Message = Union[ ] +class ProcessingMessageWrapper(BaseModel): + payload: Annotated[ + ProcessingMessage, + Field(discriminator="type"), + ] + + def mp_rank_0() -> bool: return get_model_parallel_rank() == 0 +def encode_msg(msg: ProcessingMessage) -> bytes: + return msg.model_dump_json().encode("utf-8") + + def retrieve_requests(reply_socket_url: str): if mp_rank_0(): context = zmq.Context() @@ -100,13 +140,11 @@ def retrieve_requests(reply_socket_url: str): continue ready_response = ReadyResponse() - reply_socket.send_multipart( - [client_id, ready_response.model_dump_json().encode("utf-8")] - ) + reply_socket.send_multipart([client_id, encode_msg(ready_response)]) break - def send_obj(obj: Union[Message, BaseModel]): - reply_socket.send_multipart([client_id, obj.model_dump_json().encode("utf-8")]) + def send_obj(obj: Union[ProcessingMessage, BaseModel]): + reply_socket.send_multipart([client_id, encode_msg(obj)]) while True: tasks = [None] @@ -183,7 +221,7 @@ def maybe_get_work(sock: zmq.Socket): return client_id, message -def maybe_parse_message(maybe_json: Optional[str]) -> Optional[Message]: +def maybe_parse_message(maybe_json: Optional[str]) -> Optional[ProcessingMessage]: if maybe_json is None: return None try: @@ -194,25 +232,9 @@ def maybe_parse_message(maybe_json: Optional[str]) -> Optional[Message]: return None -def parse_message(json_str: str) -> Message: +def parse_message(json_str: str) -> ProcessingMessage: data = json.loads(json_str) - if "message" in data: - if data["message"] == "READY?": - return ReadyRequest.model_validate(data) - elif data["message"] == "YES READY": - return ReadyResponse.model_validate(data) - elif data["message"] == "__end_sentinel__": - return EndSentinel.model_validate(data) - elif data["message"] == "__cancel_sentinel__": - return CancelSentinel.model_validate(data) - elif "task" in data: - return TaskRequest.model_validate(data) - elif "result" in data: - return TaskResponse.model_validate(data) - elif "error" in data: - return ExceptionResponse.model_validate(data) - else: - raise ValueError(f"Unknown message type: {data}") + return ProcessingMessageWrapper(**data).payload def worker_process_entrypoint( @@ -297,7 +319,7 @@ def start_model_parallel_process( # wait until the model is loaded; rank 0 will send a message to indicate it's ready - request_socket.send(ReadyRequest().model_dump_json().encode("utf-8")) + request_socket.send(encode_msg(ReadyRequest())) response = request_socket.recv() print(f"Finished model load {response}") @@ -327,9 +349,7 @@ class ModelParallelProcessGroup: def stop(self): assert self.started, "process group not started" if self.process.is_alive(): - self.request_socket.send( - EndSentinel().model_dump_json().encode("utf-8"), zmq.NOBLOCK - ) + self.request_socket.send(encode_msg(EndSentinel()), zmq.NOBLOCK) self.process.join() self.started = False @@ -337,9 +357,7 @@ class ModelParallelProcessGroup: assert not self.running, "inference already running" self.running = True - self.request_socket.send( - TaskRequest(task=inference_args).model_dump_json().encode("utf-8") - ) + self.request_socket.send(encode_msg(TaskRequest(task=inference_args))) try: while True: obj_json = self.request_socket.recv() @@ -356,7 +374,7 @@ class ModelParallelProcessGroup: yield obj.result except GeneratorExit as e: - self.request_socket.send(CancelSentinel().model_dump_json().encode("utf-8")) + self.request_socket.send(encode_msg(CancelSentinel())) while True: obj_json = self.request_socket.send() obj = parse_message(obj_json)