llama-stack/llama_stack/providers/inline/inference/meta_reference/parallel_utils.py
Ashwin Bharambe 530d4bdfe1
refactor: move all llama code to models/llama out of meta reference (#1887)
# What does this PR do?

Move around bits. This makes the copies from llama-models _much_ easier
to maintain and ensures we don't entangle meta-reference specific
tidbits into llama-models code even by accident.

Also, kills the meta-reference-quantized-gpu distro and rolls
quantization deps into meta-reference-gpu.

## Test Plan

```
LLAMA_MODELS_DEBUG=1 \
  with-proxy llama stack run meta-reference-gpu \
  --env INFERENCE_MODEL=meta-llama/Llama-4-Scout-17B-16E-Instruct \
   --env INFERENCE_CHECKPOINT_DIR=<DIR> \
   --env MODEL_PARALLEL_SIZE=4 \
   --env QUANTIZATION_TYPE=fp8_mixed
```

Start a server with and without quantization. Point integration tests to
it using:

```
pytest -s -v  tests/integration/inference/test_text_inference.py \
   --stack-config http://localhost:8321 --text-model meta-llama/Llama-4-Scout-17B-16E-Instruct
```
2025-04-07 15:03:58 -07:00

363 lines
11 KiB
Python

# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
# Copyright (c) Meta Platforms, IAny, nc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import copy
import json
import logging
import multiprocessing
import os
import tempfile
import time
import uuid
from enum import Enum
from typing import Callable, Generator, Literal, Optional, Union
import torch
import zmq
from fairscale.nn.model_parallel.initialize import (
get_model_parallel_group,
get_model_parallel_rank,
get_model_parallel_src_rank,
)
from pydantic import BaseModel, Field
from torch.distributed.launcher.api import LaunchConfig, elastic_launch
from typing_extensions import Annotated
from llama_stack.models.llama.datatypes import GenerationResult
from llama_stack.providers.utils.inference.prompt_adapter import (
ChatCompletionRequestWithRawContent,
CompletionRequestWithRawContent,
)
log = logging.getLogger(__name__)
class ProcessingMessageName(str, Enum):
ready_request = "ready_request"
ready_response = "ready_response"
end_sentinel = "end_sentinel"
cancel_sentinel = "cancel_sentinel"
task_request = "task_request"
task_response = "task_response"
exception_response = "exception_response"
class ReadyRequest(BaseModel):
type: Literal[ProcessingMessageName.ready_request] = ProcessingMessageName.ready_request
class ReadyResponse(BaseModel):
type: Literal[ProcessingMessageName.ready_response] = ProcessingMessageName.ready_response
class EndSentinel(BaseModel):
type: Literal[ProcessingMessageName.end_sentinel] = ProcessingMessageName.end_sentinel
class CancelSentinel(BaseModel):
type: Literal[ProcessingMessageName.cancel_sentinel] = ProcessingMessageName.cancel_sentinel
class TaskRequest(BaseModel):
type: Literal[ProcessingMessageName.task_request] = ProcessingMessageName.task_request
task: Union[CompletionRequestWithRawContent, ChatCompletionRequestWithRawContent]
class TaskResponse(BaseModel):
type: Literal[ProcessingMessageName.task_response] = ProcessingMessageName.task_response
result: GenerationResult
class ExceptionResponse(BaseModel):
type: Literal[ProcessingMessageName.exception_response] = ProcessingMessageName.exception_response
error: str
ProcessingMessage = Union[
ReadyRequest,
ReadyResponse,
EndSentinel,
CancelSentinel,
TaskRequest,
TaskResponse,
ExceptionResponse,
]
class ProcessingMessageWrapper(BaseModel):
payload: Annotated[
ProcessingMessage,
Field(discriminator="type"),
]
def mp_rank_0() -> bool:
return get_model_parallel_rank() == 0
def encode_msg(msg: ProcessingMessage) -> bytes:
return ProcessingMessageWrapper(payload=msg).model_dump_json().encode("utf-8")
def retrieve_requests(reply_socket_url: str):
if mp_rank_0():
context = zmq.Context()
reply_socket = context.socket(zmq.ROUTER)
reply_socket.connect(reply_socket_url)
while True:
client_id, obj = maybe_get_work(reply_socket)
if obj is None:
time.sleep(0.01)
continue
ready_response = ReadyResponse()
reply_socket.send_multipart([client_id, encode_msg(ready_response)])
break
def send_obj(obj: ProcessingMessage):
reply_socket.send_multipart([client_id, encode_msg(obj)])
while True:
tasks = [None]
if mp_rank_0():
client_id, maybe_task_json = maybe_get_work(reply_socket)
if maybe_task_json is not None:
task = maybe_parse_message(maybe_task_json)
# there is still an unknown unclean GeneratorExit happening resulting in a
# cancel sentinel getting queued _after_ we have finished sending everything :/
# kind of a hack this is :/
if task is not None and not isinstance(task, CancelSentinel):
tasks = [task]
torch.distributed.broadcast_object_list(
tasks,
src=get_model_parallel_src_rank(),
group=get_model_parallel_group(),
)
task = tasks[0]
if task is None:
time.sleep(0.1)
else:
try:
out = yield task
if out is None:
break
for obj in out:
updates = [None]
if mp_rank_0():
_, update_json = maybe_get_work(reply_socket)
update = maybe_parse_message(update_json)
if isinstance(update, CancelSentinel):
updates = [update]
else:
# only send the update if it's not cancelled otherwise the object sits in the socket
# and gets pulled in the next request lol
send_obj(TaskResponse(result=obj))
torch.distributed.broadcast_object_list(
updates,
src=get_model_parallel_src_rank(),
group=get_model_parallel_group(),
)
if isinstance(updates[0], CancelSentinel):
log.info("quitting generation loop because request was cancelled")
break
if mp_rank_0():
send_obj(EndSentinel())
except Exception as e:
log.exception("exception in generation loop")
if mp_rank_0():
send_obj(ExceptionResponse(error=str(e)))
if mp_rank_0():
send_obj(EndSentinel())
def maybe_get_work(sock: zmq.Socket):
message = None
client_id = None
try:
client_id, obj = sock.recv_multipart(zmq.NOBLOCK)
message = obj.decode("utf-8")
except zmq.ZMQError as e:
if e.errno != zmq.EAGAIN:
raise e
return client_id, message
def maybe_parse_message(maybe_json: Optional[str]) -> Optional[ProcessingMessage]:
if maybe_json is None:
return None
try:
return parse_message(maybe_json)
except json.JSONDecodeError:
return None
except ValueError:
return None
def parse_message(json_str: str) -> ProcessingMessage:
data = json.loads(json_str)
return copy.deepcopy(ProcessingMessageWrapper(**data).payload)
def worker_process_entrypoint(
reply_socket_url: str,
init_model_cb: Callable,
) -> None:
model = init_model_cb()
torch.distributed.barrier()
time.sleep(1)
# run the requests co-routine which retrieves requests from the socket
# and sends responses (we provide) back to the caller
req_gen = retrieve_requests(reply_socket_url)
result = None
while True:
try:
task = req_gen.send(result)
if isinstance(task, str) and task == EndSentinel():
break
assert isinstance(task, TaskRequest)
result = model(task.task)
except StopIteration:
break
log.info("[debug] worker process done")
def launch_dist_group(
reply_socket_url: str,
model_parallel_size: int,
init_model_cb: Callable,
**kwargs,
) -> None:
with tempfile.TemporaryDirectory() as tmpdir:
# TODO: track workers and if they terminate, tell parent process about it so cleanup can happen
launch_config = LaunchConfig(
max_nodes=1,
min_nodes=1,
nproc_per_node=model_parallel_size,
start_method="fork",
rdzv_backend="c10d",
rdzv_endpoint=os.path.join(tmpdir, "rdzv"),
rdzv_configs={"store_type": "file", "timeout": 90},
max_restarts=0,
monitor_interval=1,
run_id=str(uuid.uuid4()),
)
elastic_launch(launch_config, entrypoint=worker_process_entrypoint)(
reply_socket_url,
init_model_cb,
)
def start_model_parallel_process(
model_parallel_size: int,
init_model_cb: Callable,
**kwargs,
):
context = zmq.Context()
request_socket = context.socket(zmq.DEALER)
# Binding the request socket to a random port
request_socket.bind("tcp://127.0.0.1:0")
main_process_url = request_socket.getsockopt_string(zmq.LAST_ENDPOINT)
ctx = multiprocessing.get_context("spawn")
process = ctx.Process(
target=launch_dist_group,
args=(
main_process_url,
model_parallel_size,
init_model_cb,
),
kwargs=kwargs,
)
process.start()
# wait until the model is loaded; rank 0 will send a message to indicate it's ready
request_socket.send(encode_msg(ReadyRequest()))
_response = request_socket.recv()
log.info("Loaded model...")
return request_socket, process
class ModelParallelProcessGroup:
def __init__(
self,
model_parallel_size: int,
init_model_cb: Callable,
**kwargs,
):
self.model_parallel_size = model_parallel_size
self.init_model_cb = init_model_cb
self.started = False
self.running = False
def start(self):
assert not self.started, "process group already started"
self.request_socket, self.process = start_model_parallel_process(
self.model_parallel_size,
self.init_model_cb,
)
self.started = True
def stop(self):
assert self.started, "process group not started"
if self.process.is_alive():
self.request_socket.send(encode_msg(EndSentinel()), zmq.NOBLOCK)
self.process.join()
self.started = False
def run_inference(
self,
req: Union[CompletionRequestWithRawContent, ChatCompletionRequestWithRawContent],
) -> Generator:
assert not self.running, "inference already running"
self.running = True
try:
self.request_socket.send(encode_msg(TaskRequest(task=req)))
while True:
obj_json = self.request_socket.recv()
obj = parse_message(obj_json)
if isinstance(obj, EndSentinel):
break
if isinstance(obj, ExceptionResponse):
log.error(f"[debug] got exception {obj.error}")
raise Exception(obj.error)
if isinstance(obj, TaskResponse):
yield obj.result
except GeneratorExit:
self.request_socket.send(encode_msg(CancelSentinel()))
while True:
obj_json = self.request_socket.send()
obj = parse_message(obj_json)
if isinstance(obj, EndSentinel):
break
finally:
self.running = False