JSON serialization for parallel processing queue (#232)

* send/recv pydantic json over socket

* fixup

* address feedback

* bidirectional wrapper

* second round of feedback
This commit is contained in:
Dalton Flanagan 2024-10-09 17:24:12 -04:00 committed by GitHub
parent 0f66ae0f61
commit 7a8aa775e5
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
3 changed files with 158 additions and 52 deletions

View file

@ -35,12 +35,14 @@ from llama_models.llama3.reference_impl.multimodal.model import (
CrossAttentionTransformer,
)
from llama_models.sku_list import resolve_model
from termcolor import cprint
from llama_stack.apis.inference import QuantizationType
from llama_stack.distribution.utils.model_utils import model_local_dir
from pydantic import BaseModel
from termcolor import cprint
from .config import MetaReferenceImplConfig
@ -58,8 +60,7 @@ def model_checkpoint_dir(model) -> str:
return str(checkpoint_dir)
@dataclass
class TokenResult:
class TokenResult(BaseModel):
token: int
text: str
logprobs: Optional[List[float]] = None

View file

@ -17,17 +17,7 @@ from llama_models.sku_list import resolve_model
from .config import MetaReferenceImplConfig
from .generation import Llama, model_checkpoint_dir
from .parallel_utils import ModelParallelProcessGroup
@dataclass
class InferenceArgs:
messages: List[Message]
temperature: float
top_p: float
max_gen_len: int
logprobs: bool
tool_prompt_format: ToolPromptFormat
from .parallel_utils import InferenceArgs, ModelParallelProcessGroup
class ModelRunner:
@ -102,7 +92,7 @@ class LlamaModelParallelGenerator:
temperature=temperature,
top_p=top_p,
max_gen_len=max_gen_len,
logprobs=logprobs,
logprobs=logprobs or False,
tool_prompt_format=tool_prompt_format,
)

View file

@ -4,14 +4,14 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import json
import multiprocessing
import os
import pickle
import tempfile
import time
import uuid
from typing import Callable, Generator
from enum import Enum
from typing import Any, Callable, Generator, List, Literal, Optional, Union
import torch
@ -23,17 +23,106 @@ from fairscale.nn.model_parallel.initialize import (
get_model_parallel_src_rank,
)
from llama_models.llama3.api.datatypes import Message, ToolPromptFormat
from pydantic import BaseModel, Field
from torch.distributed.launcher.api import elastic_launch, LaunchConfig
from typing_extensions import Annotated
from .generation import TokenResult
_END_SENTINEL = "__end_sentinel__"
_CANCEL_SENTINEL = "__cancel_sentinel__"
class InferenceArgs(BaseModel):
messages: List[Message]
temperature: float
top_p: float
max_gen_len: int
logprobs: bool
tool_prompt_format: ToolPromptFormat
class ProcessingMessageName(str, Enum):
ready_request = "ready_request"
ready_response = "ready_response"
end_sentinel = "end_sentinel"
cancel_sentinel = "cancel_sentinel"
task_request = "task_request"
task_response = "task_response"
exception_response = "exception_response"
class ReadyRequest(BaseModel):
type: Literal[ProcessingMessageName.ready_request] = (
ProcessingMessageName.ready_request
)
class ReadyResponse(BaseModel):
type: Literal[ProcessingMessageName.ready_response] = (
ProcessingMessageName.ready_response
)
class EndSentinel(BaseModel):
type: Literal[ProcessingMessageName.end_sentinel] = (
ProcessingMessageName.end_sentinel
)
class CancelSentinel(BaseModel):
type: Literal[ProcessingMessageName.cancel_sentinel] = (
ProcessingMessageName.cancel_sentinel
)
class TaskRequest(BaseModel):
type: Literal[ProcessingMessageName.task_request] = (
ProcessingMessageName.task_request
)
task: InferenceArgs
class TaskResponse(BaseModel):
type: Literal[ProcessingMessageName.task_response] = (
ProcessingMessageName.task_response
)
result: TokenResult
class ExceptionResponse(BaseModel):
type: Literal[ProcessingMessageName.exception_response] = (
ProcessingMessageName.exception_response
)
error: str
ProcessingMessage = Union[
ReadyRequest,
ReadyResponse,
EndSentinel,
CancelSentinel,
TaskRequest,
TaskResponse,
ExceptionResponse,
]
class ProcessingMessageWrapper(BaseModel):
payload: Annotated[
ProcessingMessage,
Field(discriminator="type"),
]
def mp_rank_0() -> bool:
return get_model_parallel_rank() == 0
def encode_msg(msg: ProcessingMessage) -> bytes:
return ProcessingMessageWrapper(payload=msg).model_dump_json().encode("utf-8")
def retrieve_requests(reply_socket_url: str):
if mp_rank_0():
context = zmq.Context()
@ -46,21 +135,24 @@ def retrieve_requests(reply_socket_url: str):
time.sleep(0.01)
continue
reply_socket.send_multipart([client_id, pickle.dumps("YES READY")])
ready_response = ReadyResponse()
reply_socket.send_multipart([client_id, encode_msg(ready_response)])
break
def send_obj(obj):
reply_socket.send_multipart([client_id, pickle.dumps(obj)])
def send_obj(obj: ProcessingMessage):
reply_socket.send_multipart([client_id, encode_msg(obj)])
while True:
tasks = [None]
if mp_rank_0():
client_id, task = maybe_get_work(reply_socket)
# there is still an unknown unclean GeneratorExit happening resulting in a
# cancel sentinel getting queued _after_ we have finished sending everything :/
# kind of a hack this is :/
if task != _CANCEL_SENTINEL:
tasks = [task]
client_id, maybe_task_json = maybe_get_work(reply_socket)
if maybe_task_json is not None:
task = maybe_parse_message(maybe_task_json)
# there is still an unknown unclean GeneratorExit happening resulting in a
# cancel sentinel getting queued _after_ we have finished sending everything :/
# kind of a hack this is :/
if task is not None and not isinstance(task, CancelSentinel):
tasks = [task]
torch.distributed.broadcast_object_list(
tasks,
@ -80,35 +172,36 @@ def retrieve_requests(reply_socket_url: str):
for obj in out:
updates = [None]
if mp_rank_0():
_, update = maybe_get_work(reply_socket)
if update == _CANCEL_SENTINEL:
_, update_json = maybe_get_work(reply_socket)
update = maybe_parse_message(update_json)
if isinstance(update, CancelSentinel):
updates = [update]
else:
# only send the update if it's not cancelled otherwise the object sits in the socket
# and gets pulled in the next request lol
send_obj(obj)
send_obj(TaskResponse(result=obj))
torch.distributed.broadcast_object_list(
updates,
src=get_model_parallel_src_rank(),
group=get_model_parallel_group(),
)
if updates[0] == _CANCEL_SENTINEL:
if isinstance(updates[0], CancelSentinel):
print("quitting generation loop because request was cancelled")
break
if mp_rank_0():
send_obj(_END_SENTINEL)
send_obj(EndSentinel())
except Exception as e:
print(f"[debug] got exception {e}")
import traceback
traceback.print_exc()
if mp_rank_0():
send_obj(e)
send_obj(ExceptionResponse(error=str(e)))
if mp_rank_0():
send_obj("DONE")
send_obj(EndSentinel())
def maybe_get_work(sock: zmq.Socket):
@ -116,7 +209,7 @@ def maybe_get_work(sock: zmq.Socket):
client_id = None
try:
client_id, obj = sock.recv_multipart(zmq.NOBLOCK)
message = pickle.loads(obj)
message = obj.decode("utf-8")
except zmq.ZMQError as e:
if e.errno != zmq.EAGAIN:
raise e
@ -124,6 +217,22 @@ def maybe_get_work(sock: zmq.Socket):
return client_id, message
def maybe_parse_message(maybe_json: Optional[str]) -> Optional[ProcessingMessage]:
if maybe_json is None:
return None
try:
return parse_message(maybe_json)
except json.JSONDecodeError:
return None
except ValueError as e:
return None
def parse_message(json_str: str) -> ProcessingMessage:
data = json.loads(json_str)
return ProcessingMessageWrapper(**data).payload
def worker_process_entrypoint(
reply_socket_url: str,
init_model_cb: Callable,
@ -142,7 +251,8 @@ def worker_process_entrypoint(
if isinstance(task, str) and task == _END_SENTINEL:
break
result = model(task)
assert isinstance(task, TaskRequest)
result = model(task.task)
except StopIteration:
break
@ -205,8 +315,8 @@ def start_model_parallel_process(
# wait until the model is loaded; rank 0 will send a message to indicate it's ready
request_socket.send_pyobj("READY?")
response = request_socket.recv_pyobj()
request_socket.send(encode_msg(ReadyRequest()))
response = request_socket.recv()
print(f"Finished model load {response}")
return request_socket, process
@ -235,31 +345,36 @@ class ModelParallelProcessGroup:
def stop(self):
assert self.started, "process group not started"
if self.process.is_alive():
self.request_socket.send_pyobj(_END_SENTINEL, zmq.NOBLOCK)
self.request_socket.send(encode_msg(EndSentinel()), zmq.NOBLOCK)
self.process.join()
self.started = False
def run_inference(self, request) -> Generator:
def run_inference(self, inference_args: InferenceArgs) -> Generator:
assert not self.running, "inference already running"
self.running = True
self.request_socket.send_pyobj(request)
self.request_socket.send(encode_msg(TaskRequest(task=inference_args)))
try:
while True:
obj = self.request_socket.recv_pyobj()
if obj == _END_SENTINEL:
obj_json = self.request_socket.recv()
obj = parse_message(obj_json)
if isinstance(obj, EndSentinel):
break
if isinstance(obj, Exception):
print(f"[debug] got exception {obj}")
raise obj
if isinstance(obj, ExceptionResponse):
print(f"[debug] got exception {obj.error}")
raise Exception(obj.error)
if isinstance(obj, TaskResponse):
yield obj.result
yield obj
except GeneratorExit as e:
self.request_socket.send_pyobj(_CANCEL_SENTINEL)
self.request_socket.send(encode_msg(CancelSentinel()))
while True:
obj = self.request_socket.recv_pyobj()
if obj == _END_SENTINEL:
obj_json = self.request_socket.send()
obj = parse_message(obj_json)
if isinstance(obj, EndSentinel):
break
finally:
self.running = False