Make each inference provider into its own subdirectory

This commit is contained in:
Ashwin Bharambe 2024-08-05 15:13:52 -07:00
parent f64668319c
commit 0de5a807c7
42 changed files with 123 additions and 103 deletions

View file

@ -0,0 +1,8 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from .config import MetaReferenceImplConfig # noqa
from .inference import get_provider_impl # noqa

View file

@ -0,0 +1,65 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from enum import Enum
from typing import Literal, Optional, Union
from llama_models.llama3_1.api.datatypes import CheckpointQuantizationFormat
from pydantic import BaseModel, Field
from strong_typing.schema import json_schema_type
from typing_extensions import Annotated
from llama_toolchain.inference.api import QuantizationConfig
@json_schema_type
class CheckpointType(Enum):
pytorch = "pytorch"
huggingface = "huggingface"
@json_schema_type
class PytorchCheckpoint(BaseModel):
checkpoint_type: Literal[CheckpointType.pytorch.value] = (
CheckpointType.pytorch.value
)
checkpoint_dir: str
tokenizer_path: str
model_parallel_size: int
quantization_format: CheckpointQuantizationFormat = (
CheckpointQuantizationFormat.bf16
)
@json_schema_type
class HuggingFaceCheckpoint(BaseModel):
checkpoint_type: Literal[CheckpointType.huggingface.value] = (
CheckpointType.huggingface.value
)
repo_id: str # or model_name ?
model_parallel_size: int
quantization_format: CheckpointQuantizationFormat = (
CheckpointQuantizationFormat.bf16
)
@json_schema_type
class ModelCheckpointConfig(BaseModel):
checkpoint: Annotated[
Union[PytorchCheckpoint, HuggingFaceCheckpoint],
Field(discriminator="checkpoint_type"),
]
@json_schema_type
class MetaReferenceImplConfig(BaseModel):
model: str
checkpoint_config: ModelCheckpointConfig
quantization: Optional[QuantizationConfig] = None
torch_seed: Optional[int] = None
max_seq_len: int
max_batch_size: int = 1

View file

@ -0,0 +1,320 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
# Copyright (c) Meta Platforms, Inc. and affiliates.
# This software may be used and distributed in accordance with the terms of the Llama 3 Community License Agreement.
import json
import os
import sys
import time
from dataclasses import dataclass
from pathlib import Path
from typing import Generator, List, Optional
import torch
import torch.nn.functional as F
from fairscale.nn.model_parallel.initialize import (
get_model_parallel_rank,
initialize_model_parallel,
model_parallel_is_initialized,
)
from llama_models.llama3_1.api.args import ModelArgs
from llama_models.llama3_1.api.chat_format import ChatFormat, ModelInput
from llama_models.llama3_1.api.datatypes import Message
from llama_models.llama3_1.api.model import Transformer
from llama_models.llama3_1.api.tokenizer import Tokenizer
from termcolor import cprint
from llama_toolchain.inference.api import QuantizationType
from .config import CheckpointType, MetaReferenceImplConfig
@dataclass
class TokenResult:
token: int
text: str
logprobs: Optional[List[float]] = None
class Llama:
@staticmethod
def build(config: MetaReferenceImplConfig):
"""
Build a Llama instance by initializing and loading a model checkpoint.
Note:
This method initializes the distributed process group, sets the device to CUDA,
and loads the pre-trained model and tokenizer.
"""
checkpoint = config.checkpoint_config.checkpoint
if checkpoint.checkpoint_type != CheckpointType.pytorch.value:
raise NotImplementedError("HuggingFace checkpoints not supported yet")
if (
config.quantization
and config.quantization.type == QuantizationType.fp8.value
):
from .quantization.loader import is_fbgemm_available
if not is_fbgemm_available():
raise ImportError("fbgemm-gpu is required for FP8 quantization")
if not torch.distributed.is_initialized():
torch.distributed.init_process_group("nccl")
model_parallel_size = checkpoint.model_parallel_size
if not model_parallel_is_initialized():
initialize_model_parallel(model_parallel_size)
local_rank = int(os.environ.get("LOCAL_RANK", 0))
torch.cuda.set_device(local_rank)
# seed must be the same in all processes
if config.torch_seed is not None:
torch.manual_seed(config.torch_seed)
if local_rank > 0:
sys.stdout = open(os.devnull, "w")
start_time = time.time()
ckpt_dir = checkpoint.checkpoint_dir
checkpoints = sorted(Path(ckpt_dir).glob("*.pth"))
assert len(checkpoints) > 0, f"no checkpoint files found in {ckpt_dir}"
assert model_parallel_size == len(
checkpoints
), f"Loading a checkpoint for MP={len(checkpoints)} but world size is {model_parallel_size}"
ckpt_path = checkpoints[get_model_parallel_rank()]
state_dict = torch.load(ckpt_path, map_location="cpu", weights_only=True)
with open(Path(ckpt_dir) / "params.json", "r") as f:
params = json.loads(f.read())
# TODO(ashwin): this block is so we can load internal checkpoints without additional
# fuss. the final code should _not_ have this blurb
if "model" in params:
params = params["model"]
model_args: ModelArgs = ModelArgs(
max_seq_len=config.max_seq_len,
max_batch_size=config.max_batch_size,
**params,
)
tokenizer = Tokenizer(model_path=checkpoint.tokenizer_path)
assert (
model_args.vocab_size == tokenizer.n_words
), f"model_args vocab = {model_args.vocab_size} but tokenizer vocab = {tokenizer.n_words}"
fp8 = (
config.quantization
and config.quantization.type == QuantizationType.fp8.value
)
if fp8:
# load on CPU in bf16 so that fp8 conversion does not find an
# unexpected (fp32, e.g.) datatype
torch.set_default_tensor_type(torch.BFloat16Tensor)
model = Transformer(model_args)
if fp8:
# load on CPU first since if we are doing fp8, we probably don't
# have enough memory on GPU for bf16
model.load_state_dict(state_dict, strict=False)
if torch.cuda.is_bf16_supported():
torch.set_default_tensor_type(torch.cuda.BFloat16Tensor)
else:
torch.set_default_tensor_type(torch.cuda.HalfTensor)
if not fp8:
model.load_state_dict(state_dict, strict=False)
if config.quantization:
from .quantization.loader import convert_to_quantized_model
model = convert_to_quantized_model(model, config)
else:
model = model.to("cuda")
print(f"Loaded in {time.time() - start_time:.2f} seconds")
return Llama(model, tokenizer, model_args)
def __init__(self, model: Transformer, tokenizer: Tokenizer, args: ModelArgs):
self.args = args
self.model = model
self.tokenizer = tokenizer
self.formatter = ChatFormat(tokenizer)
@torch.inference_mode()
def generate(
self,
model_input: ModelInput,
max_gen_len: int,
temperature: float = 0.6,
top_p: float = 0.9,
logprobs: bool = False,
echo: bool = False,
include_stop_token: bool = False,
) -> Generator:
params = self.model.params
# cprint("Input to model -> " + self.tokenizer.decode(model_input.tokens), "red")
prompt_tokens = [model_input.tokens]
bsz = 1
assert bsz <= params.max_batch_size, (bsz, params.max_batch_size)
min_prompt_len = min(len(t) for t in prompt_tokens)
max_prompt_len = max(len(t) for t in prompt_tokens)
if max_prompt_len >= params.max_seq_len:
cprint(
f"Out of token budget {max_prompt_len} vs {params.max_seq_len}", "red"
)
return
total_len = min(max_gen_len + max_prompt_len, params.max_seq_len)
pad_id = self.tokenizer.pad_id
tokens = torch.full((bsz, total_len), pad_id, dtype=torch.long, device="cuda")
for k, t in enumerate(prompt_tokens):
tokens[k, : len(t)] = torch.tensor(t, dtype=torch.long, device="cuda")
if logprobs:
token_logprobs = torch.zeros_like(tokens, dtype=torch.float)
prev_pos = 0
eos_reached = torch.tensor([False] * bsz, device="cuda")
input_text_mask = tokens != pad_id
if min_prompt_len == total_len:
# TODO(ashwin): unify this branch with the one below and figure out multimodal crap
logits = self.model.forward(tokens, prev_pos)
token_logprobs = -F.cross_entropy(
input=logits.transpose(1, 2),
target=tokens,
reduction="none",
ignore_index=pad_id,
)
stop_tokens = torch.tensor(self.tokenizer.stop_tokens)
for cur_pos in range(min_prompt_len, total_len):
logits = self.model.forward(tokens[:, prev_pos:cur_pos], prev_pos)
if temperature > 0:
probs = torch.softmax(logits[:, -1] / temperature, dim=-1)
next_token = sample_top_p(probs, top_p)
else:
next_token = torch.argmax(logits[:, -1], dim=-1)
next_token = next_token.reshape(-1)
# only replace token if prompt has already been generated
next_token = torch.where(
input_text_mask[:, cur_pos], tokens[:, cur_pos], next_token
)
tokens[:, cur_pos] = next_token
target = tokens[:, prev_pos + 1 : cur_pos + 1]
if logprobs:
token_logprobs[:, prev_pos + 1 : cur_pos + 1] = -F.cross_entropy(
input=logits.transpose(1, 2),
target=tokens[:, prev_pos + 1 : cur_pos + 1],
reduction="none",
ignore_index=pad_id,
)
eos_reached |= (~input_text_mask[:, cur_pos]) & (
torch.isin(next_token, stop_tokens)
)
yield TokenResult(
token=next_token[0].item(),
text=self.tokenizer.decode(next_token.tolist()),
logprobs=(
token_logprobs[:, prev_pos + 1 : cur_pos + 1][0].tolist()
if logprobs
else None
),
)
prev_pos = cur_pos
if all(eos_reached):
break
def text_completion(
self,
prompt: str,
temperature: float = 0.6,
top_p: float = 0.9,
max_gen_len: Optional[int] = None,
logprobs: bool = False,
echo: bool = False,
) -> Generator:
if (
max_gen_len is None
or max_gen_len == 0
or max_gen_len >= self.model.params.max_seq_len
):
max_gen_len = self.model.params.max_seq_len - 1
prompt_tokens = self.tokenizer.encode(x, bos=True, eos=False)
yield from self.generate(
model_input=ModelInput(tokens=prompt_tokens),
max_gen_len=max_gen_len,
temperature=temperature,
top_p=top_p,
logprobs=logprobs,
echo=echo,
)
def chat_completion(
self,
messages: List[Message],
temperature: float = 0.6,
top_p: float = 0.9,
max_gen_len: Optional[int] = None,
logprobs: bool = False,
) -> Generator:
if (
max_gen_len is None
or max_gen_len == 0
or max_gen_len >= self.model.params.max_seq_len
):
max_gen_len = self.model.params.max_seq_len - 1
yield from self.generate(
model_input=self.formatter.encode_dialog_prompt(messages),
max_gen_len=max_gen_len,
temperature=temperature,
top_p=top_p,
logprobs=logprobs,
include_stop_token=True,
)
def sample_top_p(probs, p):
"""
Perform top-p (nucleus) sampling on a probability distribution.
Args:
probs (torch.Tensor): Probability distribution tensor.
p (float): Probability threshold for top-p sampling.
Returns:
torch.Tensor: Sampled token indices.
Note:
Top-p sampling selects the smallest set of tokens whose cumulative probability mass
exceeds the threshold p. The distribution is renormalized based on the selected tokens.
"""
probs_sort, probs_idx = torch.sort(probs, dim=-1, descending=True)
probs_sum = torch.cumsum(probs_sort, dim=-1)
mask = probs_sum - probs_sort > p
probs_sort[mask] = 0.0
probs_sort.div_(probs_sort.sum(dim=-1, keepdim=True))
next_token = torch.multinomial(probs_sort, num_samples=1)
next_token = torch.gather(probs_idx, -1, next_token)
return next_token

View file

@ -0,0 +1,204 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import asyncio
from typing import AsyncIterator, Dict, Union
from llama_models.llama3_1.api.datatypes import StopReason
from llama_models.sku_list import resolve_model
from llama_toolchain.distribution.datatypes import Api, ProviderSpec
from llama_toolchain.inference.api import (
ChatCompletionRequest,
ChatCompletionResponse,
ChatCompletionResponseEvent,
ChatCompletionResponseEventType,
ChatCompletionResponseStreamChunk,
Inference,
ToolCallDelta,
ToolCallParseStatus,
)
from .config import MetaReferenceImplConfig
from .model_parallel import LlamaModelParallelGenerator
async def get_provider_impl(
config: MetaReferenceImplConfig, _deps: Dict[Api, ProviderSpec]
):
assert isinstance(
config, MetaReferenceImplConfig
), f"Unexpected config type: {type(config)}"
impl = MetaReferenceInferenceImpl(config)
await impl.initialize()
return impl
# there's a single model parallel process running serving the model. for now,
# we don't support multiple concurrent requests to this process.
SEMAPHORE = asyncio.Semaphore(1)
class MetaReferenceInferenceImpl(Inference):
def __init__(self, config: MetaReferenceImplConfig) -> None:
self.config = config
model = resolve_model(config.model)
if model is None:
raise RuntimeError(f"Unknown model: {config.model}, Run `llama model list`")
self.model = model
# verify that the checkpoint actually is for this model lol
async def initialize(self) -> None:
self.generator = LlamaModelParallelGenerator(self.config)
self.generator.start()
async def shutdown(self) -> None:
self.generator.stop()
# hm, when stream=False, we should not be doing SSE :/ which is what the
# top-level server is going to do. make the typing more specific here
async def chat_completion(
self, request: ChatCompletionRequest
) -> AsyncIterator[
Union[ChatCompletionResponseStreamChunk, ChatCompletionResponse]
]:
model = resolve_model(request.model)
if model is None:
raise RuntimeError(
f"Unknown model: {request.model}, Run `llama model list`"
)
elif model.descriptor() != self.model.descriptor():
raise RuntimeError(
f"Model mismatch: {request.model} != {self.model.descriptor()}"
)
if SEMAPHORE.locked():
raise RuntimeError("Only one concurrent request is supported")
async with SEMAPHORE:
if request.stream:
yield ChatCompletionResponseStreamChunk(
event=ChatCompletionResponseEvent(
event_type=ChatCompletionResponseEventType.start,
delta="",
)
)
tokens = []
logprobs = []
stop_reason = None
buffer = ""
ipython = False
for token_result in self.generator.chat_completion(
messages=request.messages,
temperature=request.sampling_params.temperature,
top_p=request.sampling_params.top_p,
max_gen_len=request.sampling_params.max_tokens,
logprobs=request.logprobs,
):
buffer += token_result.text
tokens.append(token_result.token)
if not ipython and buffer.startswith("<|python_tag|>"):
ipython = True
yield ChatCompletionResponseStreamChunk(
event=ChatCompletionResponseEvent(
event_type=ChatCompletionResponseEventType.progress,
delta=ToolCallDelta(
content="",
parse_status=ToolCallParseStatus.started,
),
)
)
buffer = buffer[len("<|python_tag|>") :]
continue
if not request.stream:
if request.logprobs:
logprobs.append(token_result.logprob)
continue
if token_result.text == "<|eot_id|>":
stop_reason = StopReason.end_of_turn
text = ""
elif token_result.text == "<|eom_id|>":
stop_reason = StopReason.end_of_message
text = ""
else:
text = token_result.text
if ipython:
delta = ToolCallDelta(
content=text,
parse_status=ToolCallParseStatus.in_progress,
)
else:
delta = text
if stop_reason is None:
yield ChatCompletionResponseStreamChunk(
event=ChatCompletionResponseEvent(
event_type=ChatCompletionResponseEventType.progress,
delta=delta,
stop_reason=stop_reason,
)
)
if stop_reason is None:
stop_reason = StopReason.out_of_tokens
# TODO(ashwin): parse tool calls separately here and report errors?
# if someone breaks the iteration before coming here we are toast
message = self.generator.formatter.decode_assistant_message(
tokens, stop_reason
)
if request.stream:
parsed_tool_calls = len(message.tool_calls) > 0
if ipython and not parsed_tool_calls:
yield ChatCompletionResponseStreamChunk(
event=ChatCompletionResponseEvent(
event_type=ChatCompletionResponseEventType.progress,
delta=ToolCallDelta(
content="",
parse_status=ToolCallParseStatus.failure,
),
stop_reason=stop_reason,
)
)
for tool_call in message.tool_calls:
yield ChatCompletionResponseStreamChunk(
event=ChatCompletionResponseEvent(
event_type=ChatCompletionResponseEventType.progress,
delta=ToolCallDelta(
content=tool_call,
parse_status=ToolCallParseStatus.success,
),
stop_reason=stop_reason,
)
)
yield ChatCompletionResponseStreamChunk(
event=ChatCompletionResponseEvent(
event_type=ChatCompletionResponseEventType.complete,
delta="",
stop_reason=stop_reason,
)
)
# TODO(ashwin): what else do we need to send out here when everything finishes?
else:
yield ChatCompletionResponse(
completion_message=message,
logprobs=logprobs if request.logprobs else None,
)

View file

@ -0,0 +1,104 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from copy import deepcopy
from dataclasses import dataclass
from functools import partial
from typing import Generator, List, Optional
from llama_models.llama3_1.api.chat_format import ChatFormat
from llama_models.llama3_1.api.datatypes import Message
from llama_models.llama3_1.api.tokenizer import Tokenizer
from .config import MetaReferenceImplConfig
from .generation import Llama
from .parallel_utils import ModelParallelProcessGroup
@dataclass
class InferenceArgs:
messages: List[Message]
temperature: float
top_p: float
max_gen_len: int
logprobs: bool
class ModelRunner:
def __init__(self, llama):
self.llama = llama
# the `task` object is the same that is sent to `ModelParallelProcessGroup.run_inference()`
def __call__(self, task: InferenceArgs):
return self.llama.chat_completion(
task.messages,
task.temperature,
task.top_p,
task.max_gen_len,
task.logprobs,
)
def init_model_cb(config: MetaReferenceImplConfig):
llama = Llama.build(config)
return ModelRunner(llama)
class LlamaModelParallelGenerator:
"""
This abstraction exists so
- we can run model parallel code without needing to run the CLIs via torchrun
- this also enables use model parallel code within a notebook context.
A Context Manager is used to ensure that the model parallel process is started and stopped
correctly. This does make the ergonomics a little awkward, because it isn't immediately
clear at the callsite why we need to use a context manager.
"""
def __init__(self, config: MetaReferenceImplConfig):
self.config = config
# this is a hack because Agent's loop uses this to tokenize and check if input is too long
# while the tool-use loop is going
checkpoint = self.config.checkpoint_config.checkpoint
self.formatter = ChatFormat(Tokenizer(checkpoint.tokenizer_path))
def start(self):
self.__enter__()
def stop(self):
self.__exit__(None, None, None)
def __enter__(self):
checkpoint = self.config.checkpoint_config.checkpoint
self.group = ModelParallelProcessGroup(
checkpoint.model_parallel_size,
init_model_cb=partial(init_model_cb, self.config),
)
self.group.start()
return self
def __exit__(self, exc_type, exc_value, exc_traceback):
self.group.stop()
def chat_completion(
self,
messages: List[Message],
temperature: float = 0.6,
top_p: float = 0.9,
max_gen_len: Optional[int] = None,
logprobs: bool = False,
) -> Generator:
req_obj = InferenceArgs(
messages=deepcopy(messages),
temperature=temperature,
top_p=top_p,
max_gen_len=max_gen_len,
logprobs=logprobs,
)
gen = self.group.run_inference(req_obj)
yield from gen

View file

@ -0,0 +1,265 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import multiprocessing
import os
import pickle
import tempfile
import time
import uuid
from typing import Callable, Generator
import torch
import zmq
from fairscale.nn.model_parallel.initialize import (
get_model_parallel_group,
get_model_parallel_rank,
get_model_parallel_src_rank,
)
from torch.distributed.launcher.api import elastic_launch, LaunchConfig
_END_SENTINEL = "__end_sentinel__"
_CANCEL_SENTINEL = "__cancel_sentinel__"
def mp_rank_0() -> bool:
return get_model_parallel_rank() == 0
def retrieve_requests(reply_socket_url: str):
if mp_rank_0():
context = zmq.Context()
reply_socket = context.socket(zmq.ROUTER)
reply_socket.connect(reply_socket_url)
while True:
client_id, obj = maybe_get_work(reply_socket)
if obj is None:
time.sleep(0.01)
continue
reply_socket.send_multipart([client_id, pickle.dumps("YES READY")])
break
def send_obj(obj):
reply_socket.send_multipart([client_id, pickle.dumps(obj)])
while True:
tasks = [None]
if mp_rank_0():
client_id, task = maybe_get_work(reply_socket)
# there is still an unknown unclean GeneratorExit happening resulting in a
# cancel sentinel getting queued _after_ we have finished sending everything :/
# kind of a hack this is :/
if task != _CANCEL_SENTINEL:
tasks = [task]
torch.distributed.broadcast_object_list(
tasks,
src=get_model_parallel_src_rank(),
group=get_model_parallel_group(),
)
task = tasks[0]
if task is None:
time.sleep(0.1)
else:
try:
out = yield task
if out is None:
break
for obj in out:
updates = [None]
if mp_rank_0():
_, update = maybe_get_work(reply_socket)
if update == _CANCEL_SENTINEL:
updates = [update]
else:
# only send the update if it's not cancelled otherwise the object sits in the socket
# and gets pulled in the next request lol
send_obj(obj)
torch.distributed.broadcast_object_list(
updates,
src=get_model_parallel_src_rank(),
group=get_model_parallel_group(),
)
if updates[0] == _CANCEL_SENTINEL:
print("quitting generation loop because request was cancelled")
break
if mp_rank_0():
send_obj(_END_SENTINEL)
except Exception as e:
print(f"[debug] got exception {e}")
import traceback
traceback.print_exc()
if mp_rank_0():
send_obj(e)
if mp_rank_0():
send_obj("DONE")
def maybe_get_work(sock: zmq.Socket):
message = None
client_id = None
try:
client_id, obj = sock.recv_multipart(zmq.NOBLOCK)
message = pickle.loads(obj)
except zmq.ZMQError as e:
if e.errno != zmq.EAGAIN:
raise e
return client_id, message
def worker_process_entrypoint(
reply_socket_url: str,
init_model_cb: Callable,
) -> None:
model = init_model_cb()
torch.distributed.barrier()
time.sleep(1)
# run the requests co-routine which retrieves requests from the socket
# and sends responses (we provide) back to the caller
req_gen = retrieve_requests(reply_socket_url)
result = None
while True:
try:
task = req_gen.send(result)
if isinstance(task, str) and task == _END_SENTINEL:
break
result = model(task)
except StopIteration:
break
print("[debug] worker process done")
def launch_dist_group(
reply_socket_url: str,
model_parallel_size: int,
init_model_cb: Callable,
**kwargs,
) -> None:
id = uuid.uuid4().hex
dist_url = f"file:///tmp/llama3_{id}_{time.time()}"
with tempfile.TemporaryDirectory() as tmpdir:
# TODO: track workers and if they terminate, tell parent process about it so cleanup can happen
launch_config = LaunchConfig(
max_nodes=1,
min_nodes=1,
nproc_per_node=model_parallel_size,
start_method="fork",
rdzv_backend="c10d",
rdzv_endpoint=os.path.join(tmpdir, "rdzv"),
rdzv_configs={"store_type": "file", "timeout": 90},
max_restarts=0,
monitor_interval=1,
run_id=str(uuid.uuid4()),
)
elastic_launch(launch_config, entrypoint=worker_process_entrypoint)(
reply_socket_url,
init_model_cb,
)
def start_model_parallel_process(
model_parallel_size: int,
init_model_cb: Callable,
**kwargs,
):
context = zmq.Context()
request_socket = context.socket(zmq.DEALER)
# Binding the request socket to a random port
request_socket.bind("tcp://127.0.0.1:0")
main_process_url = request_socket.getsockopt_string(zmq.LAST_ENDPOINT)
ctx = multiprocessing.get_context("fork")
process = ctx.Process(
target=launch_dist_group,
args=(
main_process_url,
model_parallel_size,
init_model_cb,
),
kwargs=kwargs,
)
process.start()
# wait until the model is loaded; rank 0 will send a message to indicate it's ready
request_socket.send_pyobj("READY?")
response = request_socket.recv_pyobj()
print(f"Finished model load {response}")
return request_socket, process
class ModelParallelProcessGroup:
def __init__(
self,
model_parallel_size: int,
init_model_cb: Callable,
**kwargs,
):
self.model_parallel_size = model_parallel_size
self.init_model_cb = init_model_cb
self.started = False
self.running = False
def start(self):
assert not self.started, "process group already started"
self.request_socket, self.process = start_model_parallel_process(
self.model_parallel_size,
self.init_model_cb,
)
self.started = True
def stop(self):
assert self.started, "process group not started"
if self.process.is_alive():
self.request_socket.send_pyobj(_END_SENTINEL, zmq.NOBLOCK)
self.process.join()
self.started = False
def run_inference(self, request) -> Generator:
assert not self.running, "inference already running"
self.running = True
self.request_socket.send_pyobj(request)
try:
while True:
obj = self.request_socket.recv_pyobj()
if obj == _END_SENTINEL:
break
if isinstance(obj, Exception):
print(f"[debug] got exception {obj}")
raise obj
yield obj
except GeneratorExit as e:
self.request_socket.send_pyobj(_CANCEL_SENTINEL)
while True:
obj = self.request_socket.recv_pyobj()
if obj == _END_SENTINEL:
break
finally:
self.running = False