JSON serialization for parallel processing queue (#232)

* send/recv pydantic json over socket

* fixup

* address feedback

* bidirectional wrapper

* second round of feedback
This commit is contained in:
Dalton Flanagan 2024-10-09 17:24:12 -04:00 committed by GitHub
parent 0f66ae0f61
commit 7a8aa775e5
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
3 changed files with 158 additions and 52 deletions

View file

@ -35,12 +35,14 @@ from llama_models.llama3.reference_impl.multimodal.model import (
CrossAttentionTransformer, CrossAttentionTransformer,
) )
from llama_models.sku_list import resolve_model from llama_models.sku_list import resolve_model
from termcolor import cprint
from llama_stack.apis.inference import QuantizationType from llama_stack.apis.inference import QuantizationType
from llama_stack.distribution.utils.model_utils import model_local_dir from llama_stack.distribution.utils.model_utils import model_local_dir
from pydantic import BaseModel
from termcolor import cprint
from .config import MetaReferenceImplConfig from .config import MetaReferenceImplConfig
@ -58,8 +60,7 @@ def model_checkpoint_dir(model) -> str:
return str(checkpoint_dir) return str(checkpoint_dir)
@dataclass class TokenResult(BaseModel):
class TokenResult:
token: int token: int
text: str text: str
logprobs: Optional[List[float]] = None logprobs: Optional[List[float]] = None

View file

@ -17,17 +17,7 @@ from llama_models.sku_list import resolve_model
from .config import MetaReferenceImplConfig from .config import MetaReferenceImplConfig
from .generation import Llama, model_checkpoint_dir from .generation import Llama, model_checkpoint_dir
from .parallel_utils import ModelParallelProcessGroup from .parallel_utils import InferenceArgs, ModelParallelProcessGroup
@dataclass
class InferenceArgs:
messages: List[Message]
temperature: float
top_p: float
max_gen_len: int
logprobs: bool
tool_prompt_format: ToolPromptFormat
class ModelRunner: class ModelRunner:
@ -102,7 +92,7 @@ class LlamaModelParallelGenerator:
temperature=temperature, temperature=temperature,
top_p=top_p, top_p=top_p,
max_gen_len=max_gen_len, max_gen_len=max_gen_len,
logprobs=logprobs, logprobs=logprobs or False,
tool_prompt_format=tool_prompt_format, tool_prompt_format=tool_prompt_format,
) )

View file

@ -4,14 +4,14 @@
# This source code is licensed under the terms described in the LICENSE file in # This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree. # the root directory of this source tree.
import json
import multiprocessing import multiprocessing
import os import os
import pickle
import tempfile import tempfile
import time import time
import uuid import uuid
from enum import Enum
from typing import Callable, Generator from typing import Any, Callable, Generator, List, Literal, Optional, Union
import torch import torch
@ -23,17 +23,106 @@ from fairscale.nn.model_parallel.initialize import (
get_model_parallel_src_rank, get_model_parallel_src_rank,
) )
from llama_models.llama3.api.datatypes import Message, ToolPromptFormat
from pydantic import BaseModel, Field
from torch.distributed.launcher.api import elastic_launch, LaunchConfig from torch.distributed.launcher.api import elastic_launch, LaunchConfig
from typing_extensions import Annotated
from .generation import TokenResult
_END_SENTINEL = "__end_sentinel__" class InferenceArgs(BaseModel):
_CANCEL_SENTINEL = "__cancel_sentinel__" messages: List[Message]
temperature: float
top_p: float
max_gen_len: int
logprobs: bool
tool_prompt_format: ToolPromptFormat
class ProcessingMessageName(str, Enum):
ready_request = "ready_request"
ready_response = "ready_response"
end_sentinel = "end_sentinel"
cancel_sentinel = "cancel_sentinel"
task_request = "task_request"
task_response = "task_response"
exception_response = "exception_response"
class ReadyRequest(BaseModel):
type: Literal[ProcessingMessageName.ready_request] = (
ProcessingMessageName.ready_request
)
class ReadyResponse(BaseModel):
type: Literal[ProcessingMessageName.ready_response] = (
ProcessingMessageName.ready_response
)
class EndSentinel(BaseModel):
type: Literal[ProcessingMessageName.end_sentinel] = (
ProcessingMessageName.end_sentinel
)
class CancelSentinel(BaseModel):
type: Literal[ProcessingMessageName.cancel_sentinel] = (
ProcessingMessageName.cancel_sentinel
)
class TaskRequest(BaseModel):
type: Literal[ProcessingMessageName.task_request] = (
ProcessingMessageName.task_request
)
task: InferenceArgs
class TaskResponse(BaseModel):
type: Literal[ProcessingMessageName.task_response] = (
ProcessingMessageName.task_response
)
result: TokenResult
class ExceptionResponse(BaseModel):
type: Literal[ProcessingMessageName.exception_response] = (
ProcessingMessageName.exception_response
)
error: str
ProcessingMessage = Union[
ReadyRequest,
ReadyResponse,
EndSentinel,
CancelSentinel,
TaskRequest,
TaskResponse,
ExceptionResponse,
]
class ProcessingMessageWrapper(BaseModel):
payload: Annotated[
ProcessingMessage,
Field(discriminator="type"),
]
def mp_rank_0() -> bool: def mp_rank_0() -> bool:
return get_model_parallel_rank() == 0 return get_model_parallel_rank() == 0
def encode_msg(msg: ProcessingMessage) -> bytes:
return ProcessingMessageWrapper(payload=msg).model_dump_json().encode("utf-8")
def retrieve_requests(reply_socket_url: str): def retrieve_requests(reply_socket_url: str):
if mp_rank_0(): if mp_rank_0():
context = zmq.Context() context = zmq.Context()
@ -46,21 +135,24 @@ def retrieve_requests(reply_socket_url: str):
time.sleep(0.01) time.sleep(0.01)
continue continue
reply_socket.send_multipart([client_id, pickle.dumps("YES READY")]) ready_response = ReadyResponse()
reply_socket.send_multipart([client_id, encode_msg(ready_response)])
break break
def send_obj(obj): def send_obj(obj: ProcessingMessage):
reply_socket.send_multipart([client_id, pickle.dumps(obj)]) reply_socket.send_multipart([client_id, encode_msg(obj)])
while True: while True:
tasks = [None] tasks = [None]
if mp_rank_0(): if mp_rank_0():
client_id, task = maybe_get_work(reply_socket) client_id, maybe_task_json = maybe_get_work(reply_socket)
# there is still an unknown unclean GeneratorExit happening resulting in a if maybe_task_json is not None:
# cancel sentinel getting queued _after_ we have finished sending everything :/ task = maybe_parse_message(maybe_task_json)
# kind of a hack this is :/ # there is still an unknown unclean GeneratorExit happening resulting in a
if task != _CANCEL_SENTINEL: # cancel sentinel getting queued _after_ we have finished sending everything :/
tasks = [task] # kind of a hack this is :/
if task is not None and not isinstance(task, CancelSentinel):
tasks = [task]
torch.distributed.broadcast_object_list( torch.distributed.broadcast_object_list(
tasks, tasks,
@ -80,35 +172,36 @@ def retrieve_requests(reply_socket_url: str):
for obj in out: for obj in out:
updates = [None] updates = [None]
if mp_rank_0(): if mp_rank_0():
_, update = maybe_get_work(reply_socket) _, update_json = maybe_get_work(reply_socket)
if update == _CANCEL_SENTINEL: update = maybe_parse_message(update_json)
if isinstance(update, CancelSentinel):
updates = [update] updates = [update]
else: else:
# only send the update if it's not cancelled otherwise the object sits in the socket # only send the update if it's not cancelled otherwise the object sits in the socket
# and gets pulled in the next request lol # and gets pulled in the next request lol
send_obj(obj) send_obj(TaskResponse(result=obj))
torch.distributed.broadcast_object_list( torch.distributed.broadcast_object_list(
updates, updates,
src=get_model_parallel_src_rank(), src=get_model_parallel_src_rank(),
group=get_model_parallel_group(), group=get_model_parallel_group(),
) )
if updates[0] == _CANCEL_SENTINEL: if isinstance(updates[0], CancelSentinel):
print("quitting generation loop because request was cancelled") print("quitting generation loop because request was cancelled")
break break
if mp_rank_0(): if mp_rank_0():
send_obj(_END_SENTINEL) send_obj(EndSentinel())
except Exception as e: except Exception as e:
print(f"[debug] got exception {e}") print(f"[debug] got exception {e}")
import traceback import traceback
traceback.print_exc() traceback.print_exc()
if mp_rank_0(): if mp_rank_0():
send_obj(e) send_obj(ExceptionResponse(error=str(e)))
if mp_rank_0(): if mp_rank_0():
send_obj("DONE") send_obj(EndSentinel())
def maybe_get_work(sock: zmq.Socket): def maybe_get_work(sock: zmq.Socket):
@ -116,7 +209,7 @@ def maybe_get_work(sock: zmq.Socket):
client_id = None client_id = None
try: try:
client_id, obj = sock.recv_multipart(zmq.NOBLOCK) client_id, obj = sock.recv_multipart(zmq.NOBLOCK)
message = pickle.loads(obj) message = obj.decode("utf-8")
except zmq.ZMQError as e: except zmq.ZMQError as e:
if e.errno != zmq.EAGAIN: if e.errno != zmq.EAGAIN:
raise e raise e
@ -124,6 +217,22 @@ def maybe_get_work(sock: zmq.Socket):
return client_id, message return client_id, message
def maybe_parse_message(maybe_json: Optional[str]) -> Optional[ProcessingMessage]:
if maybe_json is None:
return None
try:
return parse_message(maybe_json)
except json.JSONDecodeError:
return None
except ValueError as e:
return None
def parse_message(json_str: str) -> ProcessingMessage:
data = json.loads(json_str)
return ProcessingMessageWrapper(**data).payload
def worker_process_entrypoint( def worker_process_entrypoint(
reply_socket_url: str, reply_socket_url: str,
init_model_cb: Callable, init_model_cb: Callable,
@ -142,7 +251,8 @@ def worker_process_entrypoint(
if isinstance(task, str) and task == _END_SENTINEL: if isinstance(task, str) and task == _END_SENTINEL:
break break
result = model(task) assert isinstance(task, TaskRequest)
result = model(task.task)
except StopIteration: except StopIteration:
break break
@ -205,8 +315,8 @@ def start_model_parallel_process(
# wait until the model is loaded; rank 0 will send a message to indicate it's ready # wait until the model is loaded; rank 0 will send a message to indicate it's ready
request_socket.send_pyobj("READY?") request_socket.send(encode_msg(ReadyRequest()))
response = request_socket.recv_pyobj() response = request_socket.recv()
print(f"Finished model load {response}") print(f"Finished model load {response}")
return request_socket, process return request_socket, process
@ -235,31 +345,36 @@ class ModelParallelProcessGroup:
def stop(self): def stop(self):
assert self.started, "process group not started" assert self.started, "process group not started"
if self.process.is_alive(): if self.process.is_alive():
self.request_socket.send_pyobj(_END_SENTINEL, zmq.NOBLOCK) self.request_socket.send(encode_msg(EndSentinel()), zmq.NOBLOCK)
self.process.join() self.process.join()
self.started = False self.started = False
def run_inference(self, request) -> Generator: def run_inference(self, inference_args: InferenceArgs) -> Generator:
assert not self.running, "inference already running" assert not self.running, "inference already running"
self.running = True self.running = True
self.request_socket.send_pyobj(request) self.request_socket.send(encode_msg(TaskRequest(task=inference_args)))
try: try:
while True: while True:
obj = self.request_socket.recv_pyobj() obj_json = self.request_socket.recv()
if obj == _END_SENTINEL: obj = parse_message(obj_json)
if isinstance(obj, EndSentinel):
break break
if isinstance(obj, Exception): if isinstance(obj, ExceptionResponse):
print(f"[debug] got exception {obj}") print(f"[debug] got exception {obj.error}")
raise obj raise Exception(obj.error)
if isinstance(obj, TaskResponse):
yield obj.result
yield obj
except GeneratorExit as e: except GeneratorExit as e:
self.request_socket.send_pyobj(_CANCEL_SENTINEL) self.request_socket.send(encode_msg(CancelSentinel()))
while True: while True:
obj = self.request_socket.recv_pyobj() obj_json = self.request_socket.send()
if obj == _END_SENTINEL: obj = parse_message(obj_json)
if isinstance(obj, EndSentinel):
break break
finally: finally:
self.running = False self.running = False