diff --git a/llama_stack/providers/impls/meta_reference/inference/generation.py b/llama_stack/providers/impls/meta_reference/inference/generation.py index 27e086e0f..37aef5ede 100644 --- a/llama_stack/providers/impls/meta_reference/inference/generation.py +++ b/llama_stack/providers/impls/meta_reference/inference/generation.py @@ -35,12 +35,14 @@ from llama_models.llama3.reference_impl.multimodal.model import ( CrossAttentionTransformer, ) from llama_models.sku_list import resolve_model -from termcolor import cprint from llama_stack.apis.inference import QuantizationType from llama_stack.distribution.utils.model_utils import model_local_dir +from pydantic import BaseModel +from termcolor import cprint + from .config import MetaReferenceImplConfig @@ -58,8 +60,7 @@ def model_checkpoint_dir(model) -> str: return str(checkpoint_dir) -@dataclass -class TokenResult: +class TokenResult(BaseModel): token: int text: str logprobs: Optional[List[float]] = None diff --git a/llama_stack/providers/impls/meta_reference/inference/model_parallel.py b/llama_stack/providers/impls/meta_reference/inference/model_parallel.py index 833f99efd..46ac3778c 100644 --- a/llama_stack/providers/impls/meta_reference/inference/model_parallel.py +++ b/llama_stack/providers/impls/meta_reference/inference/model_parallel.py @@ -17,17 +17,7 @@ from llama_models.sku_list import resolve_model from .config import MetaReferenceImplConfig from .generation import Llama, model_checkpoint_dir -from .parallel_utils import ModelParallelProcessGroup - - -@dataclass -class InferenceArgs: - messages: List[Message] - temperature: float - top_p: float - max_gen_len: int - logprobs: bool - tool_prompt_format: ToolPromptFormat +from .parallel_utils import InferenceArgs, ModelParallelProcessGroup class ModelRunner: @@ -97,12 +87,13 @@ class LlamaModelParallelGenerator: logprobs: bool = False, tool_prompt_format: ToolPromptFormat = ToolPromptFormat.json, ) -> Generator: + print("logprobs", logprobs) req_obj = InferenceArgs( messages=deepcopy(messages), temperature=temperature, top_p=top_p, max_gen_len=max_gen_len, - logprobs=logprobs, + logprobs=logprobs or False, tool_prompt_format=tool_prompt_format, ) diff --git a/llama_stack/providers/impls/meta_reference/inference/parallel_utils.py b/llama_stack/providers/impls/meta_reference/inference/parallel_utils.py index 180f7de1f..e83d58526 100644 --- a/llama_stack/providers/impls/meta_reference/inference/parallel_utils.py +++ b/llama_stack/providers/impls/meta_reference/inference/parallel_utils.py @@ -4,14 +4,13 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. +import json import multiprocessing import os -import pickle import tempfile import time import uuid - -from typing import Callable, Generator +from typing import Any, Callable, Generator, List, Optional, Union import torch @@ -23,11 +22,65 @@ from fairscale.nn.model_parallel.initialize import ( get_model_parallel_src_rank, ) +from llama_models.llama3.api.datatypes import Message, ToolPromptFormat + +from pydantic import BaseModel + from torch.distributed.launcher.api import elastic_launch, LaunchConfig +from .generation import TokenResult -_END_SENTINEL = "__end_sentinel__" -_CANCEL_SENTINEL = "__cancel_sentinel__" + +class InferenceArgs(BaseModel): + messages: List[Message] + temperature: float + top_p: float + max_gen_len: int + logprobs: bool + tool_prompt_format: ToolPromptFormat + + # avoid converting llama_models types to BaseModel for now + class Config: + arbitrary_types_allowed = True + + +class ReadyRequest(BaseModel): + message: str = "READY?" + + +class ReadyResponse(BaseModel): + message: str = "YES READY" + + +class EndSentinel(BaseModel): + message: str = "__end_sentinel__" + + +class CancelSentinel(BaseModel): + message: str = "__cancel_sentinel__" + + +class TaskRequest(BaseModel): + task: InferenceArgs + + +class TaskResponse(BaseModel): + result: TokenResult + + +class ExceptionResponse(BaseModel): + error: str + + +Message = Union[ + ReadyRequest, + ReadyResponse, + EndSentinel, + CancelSentinel, + TaskRequest, + TaskResponse, + ExceptionResponse, +] def mp_rank_0() -> bool: @@ -46,21 +99,26 @@ def retrieve_requests(reply_socket_url: str): time.sleep(0.01) continue - reply_socket.send_multipart([client_id, pickle.dumps("YES READY")]) + ready_response = ReadyResponse() + reply_socket.send_multipart( + [client_id, ready_response.model_dump_json().encode("utf-8")] + ) break - def send_obj(obj): - reply_socket.send_multipart([client_id, pickle.dumps(obj)]) + def send_obj(obj: Union[Message, BaseModel]): + reply_socket.send_multipart([client_id, obj.model_dump_json().encode("utf-8")]) while True: tasks = [None] if mp_rank_0(): - client_id, task = maybe_get_work(reply_socket) - # there is still an unknown unclean GeneratorExit happening resulting in a - # cancel sentinel getting queued _after_ we have finished sending everything :/ - # kind of a hack this is :/ - if task != _CANCEL_SENTINEL: - tasks = [task] + client_id, maybe_task_json = maybe_get_work(reply_socket) + if maybe_task_json is not None: + task = maybe_parse_message(maybe_task_json) + # there is still an unknown unclean GeneratorExit happening resulting in a + # cancel sentinel getting queued _after_ we have finished sending everything :/ + # kind of a hack this is :/ + if task is not None and not isinstance(task, CancelSentinel): + tasks = [task] torch.distributed.broadcast_object_list( tasks, @@ -80,35 +138,36 @@ def retrieve_requests(reply_socket_url: str): for obj in out: updates = [None] if mp_rank_0(): - _, update = maybe_get_work(reply_socket) - if update == _CANCEL_SENTINEL: + _, update_json = maybe_get_work(reply_socket) + update = maybe_parse_message(update_json) + if isinstance(update, CancelSentinel): updates = [update] else: # only send the update if it's not cancelled otherwise the object sits in the socket # and gets pulled in the next request lol - send_obj(obj) + send_obj(TaskResponse(result=obj)) torch.distributed.broadcast_object_list( updates, src=get_model_parallel_src_rank(), group=get_model_parallel_group(), ) - if updates[0] == _CANCEL_SENTINEL: + if isinstance(updates[0], CancelSentinel): print("quitting generation loop because request was cancelled") break if mp_rank_0(): - send_obj(_END_SENTINEL) + send_obj(EndSentinel()) except Exception as e: print(f"[debug] got exception {e}") import traceback traceback.print_exc() if mp_rank_0(): - send_obj(e) + send_obj(ExceptionResponse(error=str(e))) if mp_rank_0(): - send_obj("DONE") + send_obj(EndSentinel()) def maybe_get_work(sock: zmq.Socket): @@ -116,7 +175,7 @@ def maybe_get_work(sock: zmq.Socket): client_id = None try: client_id, obj = sock.recv_multipart(zmq.NOBLOCK) - message = pickle.loads(obj) + message = obj.decode("utf-8") except zmq.ZMQError as e: if e.errno != zmq.EAGAIN: raise e @@ -124,6 +183,38 @@ def maybe_get_work(sock: zmq.Socket): return client_id, message +def maybe_parse_message(maybe_json: Optional[str]) -> Optional[Message]: + if maybe_json is None: + return None + try: + return parse_message(maybe_json) + except json.JSONDecodeError: + return None + except ValueError as e: + return None + + +def parse_message(json_str: str) -> Message: + data = json.loads(json_str) + if "message" in data: + if data["message"] == "READY?": + return ReadyRequest.model_validate(data) + elif data["message"] == "YES READY": + return ReadyResponse.model_validate(data) + elif data["message"] == "__end_sentinel__": + return EndSentinel.model_validate(data) + elif data["message"] == "__cancel_sentinel__": + return CancelSentinel.model_validate(data) + elif "task" in data: + return TaskRequest.model_validate(data) + elif "result" in data: + return TaskResponse.model_validate(data) + elif "error" in data: + return ExceptionResponse.model_validate(data) + else: + raise ValueError(f"Unknown message type: {data}") + + def worker_process_entrypoint( reply_socket_url: str, init_model_cb: Callable, @@ -142,7 +233,8 @@ def worker_process_entrypoint( if isinstance(task, str) and task == _END_SENTINEL: break - result = model(task) + assert isinstance(task, TaskRequest) + result = model(task.task) except StopIteration: break @@ -205,8 +297,8 @@ def start_model_parallel_process( # wait until the model is loaded; rank 0 will send a message to indicate it's ready - request_socket.send_pyobj("READY?") - response = request_socket.recv_pyobj() + request_socket.send(ReadyRequest().model_dump_json().encode("utf-8")) + response = request_socket.recv() print(f"Finished model load {response}") return request_socket, process @@ -235,31 +327,40 @@ class ModelParallelProcessGroup: def stop(self): assert self.started, "process group not started" if self.process.is_alive(): - self.request_socket.send_pyobj(_END_SENTINEL, zmq.NOBLOCK) + self.request_socket.send( + EndSentinel().model_dump_json().encode("utf-8"), zmq.NOBLOCK + ) self.process.join() self.started = False - def run_inference(self, request) -> Generator: + def run_inference(self, inference_args: InferenceArgs) -> Generator: assert not self.running, "inference already running" self.running = True - self.request_socket.send_pyobj(request) + self.request_socket.send( + TaskRequest(task=inference_args).model_dump_json().encode("utf-8") + ) try: while True: - obj = self.request_socket.recv_pyobj() - if obj == _END_SENTINEL: + obj_json = self.request_socket.recv() + obj = parse_message(obj_json) + + if isinstance(obj, EndSentinel): break - if isinstance(obj, Exception): - print(f"[debug] got exception {obj}") - raise obj + if isinstance(obj, ExceptionResponse): + print(f"[debug] got exception {obj.error}") + raise Exception(obj.error) + + if isinstance(obj, TaskResponse): + yield obj.result - yield obj except GeneratorExit as e: - self.request_socket.send_pyobj(_CANCEL_SENTINEL) + self.request_socket.send(CancelSentinel().model_dump_json().encode("utf-8")) while True: - obj = self.request_socket.recv_pyobj() - if obj == _END_SENTINEL: + obj_json = self.request_socket.send() + obj = parse_message(obj_json) + if isinstance(obj, EndSentinel): break finally: self.running = False