mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-06-29 03:14:19 +00:00
JSON serialization for parallel processing queue (#232)
* send/recv pydantic json over socket * fixup * address feedback * bidirectional wrapper * second round of feedback
This commit is contained in:
parent
0f66ae0f61
commit
7a8aa775e5
3 changed files with 158 additions and 52 deletions
|
@ -35,12 +35,14 @@ from llama_models.llama3.reference_impl.multimodal.model import (
|
||||||
CrossAttentionTransformer,
|
CrossAttentionTransformer,
|
||||||
)
|
)
|
||||||
from llama_models.sku_list import resolve_model
|
from llama_models.sku_list import resolve_model
|
||||||
from termcolor import cprint
|
|
||||||
|
|
||||||
from llama_stack.apis.inference import QuantizationType
|
from llama_stack.apis.inference import QuantizationType
|
||||||
|
|
||||||
from llama_stack.distribution.utils.model_utils import model_local_dir
|
from llama_stack.distribution.utils.model_utils import model_local_dir
|
||||||
|
|
||||||
|
from pydantic import BaseModel
|
||||||
|
from termcolor import cprint
|
||||||
|
|
||||||
from .config import MetaReferenceImplConfig
|
from .config import MetaReferenceImplConfig
|
||||||
|
|
||||||
|
|
||||||
|
@ -58,8 +60,7 @@ def model_checkpoint_dir(model) -> str:
|
||||||
return str(checkpoint_dir)
|
return str(checkpoint_dir)
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
class TokenResult(BaseModel):
|
||||||
class TokenResult:
|
|
||||||
token: int
|
token: int
|
||||||
text: str
|
text: str
|
||||||
logprobs: Optional[List[float]] = None
|
logprobs: Optional[List[float]] = None
|
||||||
|
|
|
@ -17,17 +17,7 @@ from llama_models.sku_list import resolve_model
|
||||||
|
|
||||||
from .config import MetaReferenceImplConfig
|
from .config import MetaReferenceImplConfig
|
||||||
from .generation import Llama, model_checkpoint_dir
|
from .generation import Llama, model_checkpoint_dir
|
||||||
from .parallel_utils import ModelParallelProcessGroup
|
from .parallel_utils import InferenceArgs, ModelParallelProcessGroup
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class InferenceArgs:
|
|
||||||
messages: List[Message]
|
|
||||||
temperature: float
|
|
||||||
top_p: float
|
|
||||||
max_gen_len: int
|
|
||||||
logprobs: bool
|
|
||||||
tool_prompt_format: ToolPromptFormat
|
|
||||||
|
|
||||||
|
|
||||||
class ModelRunner:
|
class ModelRunner:
|
||||||
|
@ -102,7 +92,7 @@ class LlamaModelParallelGenerator:
|
||||||
temperature=temperature,
|
temperature=temperature,
|
||||||
top_p=top_p,
|
top_p=top_p,
|
||||||
max_gen_len=max_gen_len,
|
max_gen_len=max_gen_len,
|
||||||
logprobs=logprobs,
|
logprobs=logprobs or False,
|
||||||
tool_prompt_format=tool_prompt_format,
|
tool_prompt_format=tool_prompt_format,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
@ -4,14 +4,14 @@
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
import json
|
||||||
import multiprocessing
|
import multiprocessing
|
||||||
import os
|
import os
|
||||||
import pickle
|
|
||||||
import tempfile
|
import tempfile
|
||||||
import time
|
import time
|
||||||
import uuid
|
import uuid
|
||||||
|
from enum import Enum
|
||||||
from typing import Callable, Generator
|
from typing import Any, Callable, Generator, List, Literal, Optional, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
|
@ -23,17 +23,106 @@ from fairscale.nn.model_parallel.initialize import (
|
||||||
get_model_parallel_src_rank,
|
get_model_parallel_src_rank,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
from llama_models.llama3.api.datatypes import Message, ToolPromptFormat
|
||||||
|
|
||||||
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
from torch.distributed.launcher.api import elastic_launch, LaunchConfig
|
from torch.distributed.launcher.api import elastic_launch, LaunchConfig
|
||||||
|
from typing_extensions import Annotated
|
||||||
|
|
||||||
|
from .generation import TokenResult
|
||||||
|
|
||||||
|
|
||||||
_END_SENTINEL = "__end_sentinel__"
|
class InferenceArgs(BaseModel):
|
||||||
_CANCEL_SENTINEL = "__cancel_sentinel__"
|
messages: List[Message]
|
||||||
|
temperature: float
|
||||||
|
top_p: float
|
||||||
|
max_gen_len: int
|
||||||
|
logprobs: bool
|
||||||
|
tool_prompt_format: ToolPromptFormat
|
||||||
|
|
||||||
|
|
||||||
|
class ProcessingMessageName(str, Enum):
|
||||||
|
ready_request = "ready_request"
|
||||||
|
ready_response = "ready_response"
|
||||||
|
end_sentinel = "end_sentinel"
|
||||||
|
cancel_sentinel = "cancel_sentinel"
|
||||||
|
task_request = "task_request"
|
||||||
|
task_response = "task_response"
|
||||||
|
exception_response = "exception_response"
|
||||||
|
|
||||||
|
|
||||||
|
class ReadyRequest(BaseModel):
|
||||||
|
type: Literal[ProcessingMessageName.ready_request] = (
|
||||||
|
ProcessingMessageName.ready_request
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class ReadyResponse(BaseModel):
|
||||||
|
type: Literal[ProcessingMessageName.ready_response] = (
|
||||||
|
ProcessingMessageName.ready_response
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class EndSentinel(BaseModel):
|
||||||
|
type: Literal[ProcessingMessageName.end_sentinel] = (
|
||||||
|
ProcessingMessageName.end_sentinel
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class CancelSentinel(BaseModel):
|
||||||
|
type: Literal[ProcessingMessageName.cancel_sentinel] = (
|
||||||
|
ProcessingMessageName.cancel_sentinel
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class TaskRequest(BaseModel):
|
||||||
|
type: Literal[ProcessingMessageName.task_request] = (
|
||||||
|
ProcessingMessageName.task_request
|
||||||
|
)
|
||||||
|
task: InferenceArgs
|
||||||
|
|
||||||
|
|
||||||
|
class TaskResponse(BaseModel):
|
||||||
|
type: Literal[ProcessingMessageName.task_response] = (
|
||||||
|
ProcessingMessageName.task_response
|
||||||
|
)
|
||||||
|
result: TokenResult
|
||||||
|
|
||||||
|
|
||||||
|
class ExceptionResponse(BaseModel):
|
||||||
|
type: Literal[ProcessingMessageName.exception_response] = (
|
||||||
|
ProcessingMessageName.exception_response
|
||||||
|
)
|
||||||
|
error: str
|
||||||
|
|
||||||
|
|
||||||
|
ProcessingMessage = Union[
|
||||||
|
ReadyRequest,
|
||||||
|
ReadyResponse,
|
||||||
|
EndSentinel,
|
||||||
|
CancelSentinel,
|
||||||
|
TaskRequest,
|
||||||
|
TaskResponse,
|
||||||
|
ExceptionResponse,
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
class ProcessingMessageWrapper(BaseModel):
|
||||||
|
payload: Annotated[
|
||||||
|
ProcessingMessage,
|
||||||
|
Field(discriminator="type"),
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
def mp_rank_0() -> bool:
|
def mp_rank_0() -> bool:
|
||||||
return get_model_parallel_rank() == 0
|
return get_model_parallel_rank() == 0
|
||||||
|
|
||||||
|
|
||||||
|
def encode_msg(msg: ProcessingMessage) -> bytes:
|
||||||
|
return ProcessingMessageWrapper(payload=msg).model_dump_json().encode("utf-8")
|
||||||
|
|
||||||
|
|
||||||
def retrieve_requests(reply_socket_url: str):
|
def retrieve_requests(reply_socket_url: str):
|
||||||
if mp_rank_0():
|
if mp_rank_0():
|
||||||
context = zmq.Context()
|
context = zmq.Context()
|
||||||
|
@ -46,20 +135,23 @@ def retrieve_requests(reply_socket_url: str):
|
||||||
time.sleep(0.01)
|
time.sleep(0.01)
|
||||||
continue
|
continue
|
||||||
|
|
||||||
reply_socket.send_multipart([client_id, pickle.dumps("YES READY")])
|
ready_response = ReadyResponse()
|
||||||
|
reply_socket.send_multipart([client_id, encode_msg(ready_response)])
|
||||||
break
|
break
|
||||||
|
|
||||||
def send_obj(obj):
|
def send_obj(obj: ProcessingMessage):
|
||||||
reply_socket.send_multipart([client_id, pickle.dumps(obj)])
|
reply_socket.send_multipart([client_id, encode_msg(obj)])
|
||||||
|
|
||||||
while True:
|
while True:
|
||||||
tasks = [None]
|
tasks = [None]
|
||||||
if mp_rank_0():
|
if mp_rank_0():
|
||||||
client_id, task = maybe_get_work(reply_socket)
|
client_id, maybe_task_json = maybe_get_work(reply_socket)
|
||||||
|
if maybe_task_json is not None:
|
||||||
|
task = maybe_parse_message(maybe_task_json)
|
||||||
# there is still an unknown unclean GeneratorExit happening resulting in a
|
# there is still an unknown unclean GeneratorExit happening resulting in a
|
||||||
# cancel sentinel getting queued _after_ we have finished sending everything :/
|
# cancel sentinel getting queued _after_ we have finished sending everything :/
|
||||||
# kind of a hack this is :/
|
# kind of a hack this is :/
|
||||||
if task != _CANCEL_SENTINEL:
|
if task is not None and not isinstance(task, CancelSentinel):
|
||||||
tasks = [task]
|
tasks = [task]
|
||||||
|
|
||||||
torch.distributed.broadcast_object_list(
|
torch.distributed.broadcast_object_list(
|
||||||
|
@ -80,35 +172,36 @@ def retrieve_requests(reply_socket_url: str):
|
||||||
for obj in out:
|
for obj in out:
|
||||||
updates = [None]
|
updates = [None]
|
||||||
if mp_rank_0():
|
if mp_rank_0():
|
||||||
_, update = maybe_get_work(reply_socket)
|
_, update_json = maybe_get_work(reply_socket)
|
||||||
if update == _CANCEL_SENTINEL:
|
update = maybe_parse_message(update_json)
|
||||||
|
if isinstance(update, CancelSentinel):
|
||||||
updates = [update]
|
updates = [update]
|
||||||
else:
|
else:
|
||||||
# only send the update if it's not cancelled otherwise the object sits in the socket
|
# only send the update if it's not cancelled otherwise the object sits in the socket
|
||||||
# and gets pulled in the next request lol
|
# and gets pulled in the next request lol
|
||||||
send_obj(obj)
|
send_obj(TaskResponse(result=obj))
|
||||||
|
|
||||||
torch.distributed.broadcast_object_list(
|
torch.distributed.broadcast_object_list(
|
||||||
updates,
|
updates,
|
||||||
src=get_model_parallel_src_rank(),
|
src=get_model_parallel_src_rank(),
|
||||||
group=get_model_parallel_group(),
|
group=get_model_parallel_group(),
|
||||||
)
|
)
|
||||||
if updates[0] == _CANCEL_SENTINEL:
|
if isinstance(updates[0], CancelSentinel):
|
||||||
print("quitting generation loop because request was cancelled")
|
print("quitting generation loop because request was cancelled")
|
||||||
break
|
break
|
||||||
|
|
||||||
if mp_rank_0():
|
if mp_rank_0():
|
||||||
send_obj(_END_SENTINEL)
|
send_obj(EndSentinel())
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"[debug] got exception {e}")
|
print(f"[debug] got exception {e}")
|
||||||
import traceback
|
import traceback
|
||||||
|
|
||||||
traceback.print_exc()
|
traceback.print_exc()
|
||||||
if mp_rank_0():
|
if mp_rank_0():
|
||||||
send_obj(e)
|
send_obj(ExceptionResponse(error=str(e)))
|
||||||
|
|
||||||
if mp_rank_0():
|
if mp_rank_0():
|
||||||
send_obj("DONE")
|
send_obj(EndSentinel())
|
||||||
|
|
||||||
|
|
||||||
def maybe_get_work(sock: zmq.Socket):
|
def maybe_get_work(sock: zmq.Socket):
|
||||||
|
@ -116,7 +209,7 @@ def maybe_get_work(sock: zmq.Socket):
|
||||||
client_id = None
|
client_id = None
|
||||||
try:
|
try:
|
||||||
client_id, obj = sock.recv_multipart(zmq.NOBLOCK)
|
client_id, obj = sock.recv_multipart(zmq.NOBLOCK)
|
||||||
message = pickle.loads(obj)
|
message = obj.decode("utf-8")
|
||||||
except zmq.ZMQError as e:
|
except zmq.ZMQError as e:
|
||||||
if e.errno != zmq.EAGAIN:
|
if e.errno != zmq.EAGAIN:
|
||||||
raise e
|
raise e
|
||||||
|
@ -124,6 +217,22 @@ def maybe_get_work(sock: zmq.Socket):
|
||||||
return client_id, message
|
return client_id, message
|
||||||
|
|
||||||
|
|
||||||
|
def maybe_parse_message(maybe_json: Optional[str]) -> Optional[ProcessingMessage]:
|
||||||
|
if maybe_json is None:
|
||||||
|
return None
|
||||||
|
try:
|
||||||
|
return parse_message(maybe_json)
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
return None
|
||||||
|
except ValueError as e:
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def parse_message(json_str: str) -> ProcessingMessage:
|
||||||
|
data = json.loads(json_str)
|
||||||
|
return ProcessingMessageWrapper(**data).payload
|
||||||
|
|
||||||
|
|
||||||
def worker_process_entrypoint(
|
def worker_process_entrypoint(
|
||||||
reply_socket_url: str,
|
reply_socket_url: str,
|
||||||
init_model_cb: Callable,
|
init_model_cb: Callable,
|
||||||
|
@ -142,7 +251,8 @@ def worker_process_entrypoint(
|
||||||
if isinstance(task, str) and task == _END_SENTINEL:
|
if isinstance(task, str) and task == _END_SENTINEL:
|
||||||
break
|
break
|
||||||
|
|
||||||
result = model(task)
|
assert isinstance(task, TaskRequest)
|
||||||
|
result = model(task.task)
|
||||||
except StopIteration:
|
except StopIteration:
|
||||||
break
|
break
|
||||||
|
|
||||||
|
@ -205,8 +315,8 @@ def start_model_parallel_process(
|
||||||
|
|
||||||
# wait until the model is loaded; rank 0 will send a message to indicate it's ready
|
# wait until the model is loaded; rank 0 will send a message to indicate it's ready
|
||||||
|
|
||||||
request_socket.send_pyobj("READY?")
|
request_socket.send(encode_msg(ReadyRequest()))
|
||||||
response = request_socket.recv_pyobj()
|
response = request_socket.recv()
|
||||||
print(f"Finished model load {response}")
|
print(f"Finished model load {response}")
|
||||||
|
|
||||||
return request_socket, process
|
return request_socket, process
|
||||||
|
@ -235,31 +345,36 @@ class ModelParallelProcessGroup:
|
||||||
def stop(self):
|
def stop(self):
|
||||||
assert self.started, "process group not started"
|
assert self.started, "process group not started"
|
||||||
if self.process.is_alive():
|
if self.process.is_alive():
|
||||||
self.request_socket.send_pyobj(_END_SENTINEL, zmq.NOBLOCK)
|
self.request_socket.send(encode_msg(EndSentinel()), zmq.NOBLOCK)
|
||||||
self.process.join()
|
self.process.join()
|
||||||
self.started = False
|
self.started = False
|
||||||
|
|
||||||
def run_inference(self, request) -> Generator:
|
def run_inference(self, inference_args: InferenceArgs) -> Generator:
|
||||||
assert not self.running, "inference already running"
|
assert not self.running, "inference already running"
|
||||||
|
|
||||||
self.running = True
|
self.running = True
|
||||||
self.request_socket.send_pyobj(request)
|
self.request_socket.send(encode_msg(TaskRequest(task=inference_args)))
|
||||||
try:
|
try:
|
||||||
while True:
|
while True:
|
||||||
obj = self.request_socket.recv_pyobj()
|
obj_json = self.request_socket.recv()
|
||||||
if obj == _END_SENTINEL:
|
obj = parse_message(obj_json)
|
||||||
|
|
||||||
|
if isinstance(obj, EndSentinel):
|
||||||
break
|
break
|
||||||
|
|
||||||
if isinstance(obj, Exception):
|
if isinstance(obj, ExceptionResponse):
|
||||||
print(f"[debug] got exception {obj}")
|
print(f"[debug] got exception {obj.error}")
|
||||||
raise obj
|
raise Exception(obj.error)
|
||||||
|
|
||||||
|
if isinstance(obj, TaskResponse):
|
||||||
|
yield obj.result
|
||||||
|
|
||||||
yield obj
|
|
||||||
except GeneratorExit as e:
|
except GeneratorExit as e:
|
||||||
self.request_socket.send_pyobj(_CANCEL_SENTINEL)
|
self.request_socket.send(encode_msg(CancelSentinel()))
|
||||||
while True:
|
while True:
|
||||||
obj = self.request_socket.recv_pyobj()
|
obj_json = self.request_socket.send()
|
||||||
if obj == _END_SENTINEL:
|
obj = parse_message(obj_json)
|
||||||
|
if isinstance(obj, EndSentinel):
|
||||||
break
|
break
|
||||||
finally:
|
finally:
|
||||||
self.running = False
|
self.running = False
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue