diff --git a/llama_stack/providers/impls/meta_reference/inference/generation.py b/llama_stack/providers/impls/meta_reference/inference/generation.py index 27e086e0f..37aef5ede 100644 --- a/llama_stack/providers/impls/meta_reference/inference/generation.py +++ b/llama_stack/providers/impls/meta_reference/inference/generation.py @@ -35,12 +35,14 @@ from llama_models.llama3.reference_impl.multimodal.model import ( CrossAttentionTransformer, ) from llama_models.sku_list import resolve_model -from termcolor import cprint from llama_stack.apis.inference import QuantizationType from llama_stack.distribution.utils.model_utils import model_local_dir +from pydantic import BaseModel +from termcolor import cprint + from .config import MetaReferenceImplConfig @@ -58,8 +60,7 @@ def model_checkpoint_dir(model) -> str: return str(checkpoint_dir) -@dataclass -class TokenResult: +class TokenResult(BaseModel): token: int text: str logprobs: Optional[List[float]] = None diff --git a/llama_stack/providers/impls/meta_reference/inference/model_parallel.py b/llama_stack/providers/impls/meta_reference/inference/model_parallel.py index 833f99efd..798fadcbe 100644 --- a/llama_stack/providers/impls/meta_reference/inference/model_parallel.py +++ b/llama_stack/providers/impls/meta_reference/inference/model_parallel.py @@ -17,17 +17,7 @@ from llama_models.sku_list import resolve_model from .config import MetaReferenceImplConfig from .generation import Llama, model_checkpoint_dir -from .parallel_utils import ModelParallelProcessGroup - - -@dataclass -class InferenceArgs: - messages: List[Message] - temperature: float - top_p: float - max_gen_len: int - logprobs: bool - tool_prompt_format: ToolPromptFormat +from .parallel_utils import InferenceArgs, ModelParallelProcessGroup class ModelRunner: @@ -102,7 +92,7 @@ class LlamaModelParallelGenerator: temperature=temperature, top_p=top_p, max_gen_len=max_gen_len, - logprobs=logprobs, + logprobs=logprobs or False, tool_prompt_format=tool_prompt_format, ) diff --git a/llama_stack/providers/impls/meta_reference/inference/parallel_utils.py b/llama_stack/providers/impls/meta_reference/inference/parallel_utils.py index 180f7de1f..c6eacc73c 100644 --- a/llama_stack/providers/impls/meta_reference/inference/parallel_utils.py +++ b/llama_stack/providers/impls/meta_reference/inference/parallel_utils.py @@ -4,14 +4,14 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. +import json import multiprocessing import os -import pickle import tempfile import time import uuid - -from typing import Callable, Generator +from enum import Enum +from typing import Any, Callable, Generator, List, Literal, Optional, Union import torch @@ -23,17 +23,106 @@ from fairscale.nn.model_parallel.initialize import ( get_model_parallel_src_rank, ) +from llama_models.llama3.api.datatypes import Message, ToolPromptFormat + +from pydantic import BaseModel, Field + from torch.distributed.launcher.api import elastic_launch, LaunchConfig +from typing_extensions import Annotated + +from .generation import TokenResult -_END_SENTINEL = "__end_sentinel__" -_CANCEL_SENTINEL = "__cancel_sentinel__" +class InferenceArgs(BaseModel): + messages: List[Message] + temperature: float + top_p: float + max_gen_len: int + logprobs: bool + tool_prompt_format: ToolPromptFormat + + +class ProcessingMessageName(str, Enum): + ready_request = "ready_request" + ready_response = "ready_response" + end_sentinel = "end_sentinel" + cancel_sentinel = "cancel_sentinel" + task_request = "task_request" + task_response = "task_response" + exception_response = "exception_response" + + +class ReadyRequest(BaseModel): + type: Literal[ProcessingMessageName.ready_request] = ( + ProcessingMessageName.ready_request + ) + + +class ReadyResponse(BaseModel): + type: Literal[ProcessingMessageName.ready_response] = ( + ProcessingMessageName.ready_response + ) + + +class EndSentinel(BaseModel): + type: Literal[ProcessingMessageName.end_sentinel] = ( + ProcessingMessageName.end_sentinel + ) + + +class CancelSentinel(BaseModel): + type: Literal[ProcessingMessageName.cancel_sentinel] = ( + ProcessingMessageName.cancel_sentinel + ) + + +class TaskRequest(BaseModel): + type: Literal[ProcessingMessageName.task_request] = ( + ProcessingMessageName.task_request + ) + task: InferenceArgs + + +class TaskResponse(BaseModel): + type: Literal[ProcessingMessageName.task_response] = ( + ProcessingMessageName.task_response + ) + result: TokenResult + + +class ExceptionResponse(BaseModel): + type: Literal[ProcessingMessageName.exception_response] = ( + ProcessingMessageName.exception_response + ) + error: str + + +ProcessingMessage = Union[ + ReadyRequest, + ReadyResponse, + EndSentinel, + CancelSentinel, + TaskRequest, + TaskResponse, + ExceptionResponse, +] + + +class ProcessingMessageWrapper(BaseModel): + payload: Annotated[ + ProcessingMessage, + Field(discriminator="type"), + ] def mp_rank_0() -> bool: return get_model_parallel_rank() == 0 +def encode_msg(msg: ProcessingMessage) -> bytes: + return ProcessingMessageWrapper(payload=msg).model_dump_json().encode("utf-8") + + def retrieve_requests(reply_socket_url: str): if mp_rank_0(): context = zmq.Context() @@ -46,21 +135,24 @@ def retrieve_requests(reply_socket_url: str): time.sleep(0.01) continue - reply_socket.send_multipart([client_id, pickle.dumps("YES READY")]) + ready_response = ReadyResponse() + reply_socket.send_multipart([client_id, encode_msg(ready_response)]) break - def send_obj(obj): - reply_socket.send_multipart([client_id, pickle.dumps(obj)]) + def send_obj(obj: ProcessingMessage): + reply_socket.send_multipart([client_id, encode_msg(obj)]) while True: tasks = [None] if mp_rank_0(): - client_id, task = maybe_get_work(reply_socket) - # there is still an unknown unclean GeneratorExit happening resulting in a - # cancel sentinel getting queued _after_ we have finished sending everything :/ - # kind of a hack this is :/ - if task != _CANCEL_SENTINEL: - tasks = [task] + client_id, maybe_task_json = maybe_get_work(reply_socket) + if maybe_task_json is not None: + task = maybe_parse_message(maybe_task_json) + # there is still an unknown unclean GeneratorExit happening resulting in a + # cancel sentinel getting queued _after_ we have finished sending everything :/ + # kind of a hack this is :/ + if task is not None and not isinstance(task, CancelSentinel): + tasks = [task] torch.distributed.broadcast_object_list( tasks, @@ -80,35 +172,36 @@ def retrieve_requests(reply_socket_url: str): for obj in out: updates = [None] if mp_rank_0(): - _, update = maybe_get_work(reply_socket) - if update == _CANCEL_SENTINEL: + _, update_json = maybe_get_work(reply_socket) + update = maybe_parse_message(update_json) + if isinstance(update, CancelSentinel): updates = [update] else: # only send the update if it's not cancelled otherwise the object sits in the socket # and gets pulled in the next request lol - send_obj(obj) + send_obj(TaskResponse(result=obj)) torch.distributed.broadcast_object_list( updates, src=get_model_parallel_src_rank(), group=get_model_parallel_group(), ) - if updates[0] == _CANCEL_SENTINEL: + if isinstance(updates[0], CancelSentinel): print("quitting generation loop because request was cancelled") break if mp_rank_0(): - send_obj(_END_SENTINEL) + send_obj(EndSentinel()) except Exception as e: print(f"[debug] got exception {e}") import traceback traceback.print_exc() if mp_rank_0(): - send_obj(e) + send_obj(ExceptionResponse(error=str(e))) if mp_rank_0(): - send_obj("DONE") + send_obj(EndSentinel()) def maybe_get_work(sock: zmq.Socket): @@ -116,7 +209,7 @@ def maybe_get_work(sock: zmq.Socket): client_id = None try: client_id, obj = sock.recv_multipart(zmq.NOBLOCK) - message = pickle.loads(obj) + message = obj.decode("utf-8") except zmq.ZMQError as e: if e.errno != zmq.EAGAIN: raise e @@ -124,6 +217,22 @@ def maybe_get_work(sock: zmq.Socket): return client_id, message +def maybe_parse_message(maybe_json: Optional[str]) -> Optional[ProcessingMessage]: + if maybe_json is None: + return None + try: + return parse_message(maybe_json) + except json.JSONDecodeError: + return None + except ValueError as e: + return None + + +def parse_message(json_str: str) -> ProcessingMessage: + data = json.loads(json_str) + return ProcessingMessageWrapper(**data).payload + + def worker_process_entrypoint( reply_socket_url: str, init_model_cb: Callable, @@ -142,7 +251,8 @@ def worker_process_entrypoint( if isinstance(task, str) and task == _END_SENTINEL: break - result = model(task) + assert isinstance(task, TaskRequest) + result = model(task.task) except StopIteration: break @@ -205,8 +315,8 @@ def start_model_parallel_process( # wait until the model is loaded; rank 0 will send a message to indicate it's ready - request_socket.send_pyobj("READY?") - response = request_socket.recv_pyobj() + request_socket.send(encode_msg(ReadyRequest())) + response = request_socket.recv() print(f"Finished model load {response}") return request_socket, process @@ -235,31 +345,36 @@ class ModelParallelProcessGroup: def stop(self): assert self.started, "process group not started" if self.process.is_alive(): - self.request_socket.send_pyobj(_END_SENTINEL, zmq.NOBLOCK) + self.request_socket.send(encode_msg(EndSentinel()), zmq.NOBLOCK) self.process.join() self.started = False - def run_inference(self, request) -> Generator: + def run_inference(self, inference_args: InferenceArgs) -> Generator: assert not self.running, "inference already running" self.running = True - self.request_socket.send_pyobj(request) + self.request_socket.send(encode_msg(TaskRequest(task=inference_args))) try: while True: - obj = self.request_socket.recv_pyobj() - if obj == _END_SENTINEL: + obj_json = self.request_socket.recv() + obj = parse_message(obj_json) + + if isinstance(obj, EndSentinel): break - if isinstance(obj, Exception): - print(f"[debug] got exception {obj}") - raise obj + if isinstance(obj, ExceptionResponse): + print(f"[debug] got exception {obj.error}") + raise Exception(obj.error) + + if isinstance(obj, TaskResponse): + yield obj.result - yield obj except GeneratorExit as e: - self.request_socket.send_pyobj(_CANCEL_SENTINEL) + self.request_socket.send(encode_msg(CancelSentinel())) while True: - obj = self.request_socket.recv_pyobj() - if obj == _END_SENTINEL: + obj_json = self.request_socket.send() + obj = parse_message(obj_json) + if isinstance(obj, EndSentinel): break finally: self.running = False