forked from phoenix-oss/llama-stack-mirror
# What does this PR do? Move around bits. This makes the copies from llama-models _much_ easier to maintain and ensures we don't entangle meta-reference specific tidbits into llama-models code even by accident. Also, kills the meta-reference-quantized-gpu distro and rolls quantization deps into meta-reference-gpu. ## Test Plan ``` LLAMA_MODELS_DEBUG=1 \ with-proxy llama stack run meta-reference-gpu \ --env INFERENCE_MODEL=meta-llama/Llama-4-Scout-17B-16E-Instruct \ --env INFERENCE_CHECKPOINT_DIR=<DIR> \ --env MODEL_PARALLEL_SIZE=4 \ --env QUANTIZATION_TYPE=fp8_mixed ``` Start a server with and without quantization. Point integration tests to it using: ``` pytest -s -v tests/integration/inference/test_text_inference.py \ --stack-config http://localhost:8321 --text-model meta-llama/Llama-4-Scout-17B-16E-Instruct ```
363 lines
11 KiB
Python
363 lines
11 KiB
Python
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
# All rights reserved.
|
|
#
|
|
# This source code is licensed under the terms described in the LICENSE file in
|
|
# the root directory of this source tree.
|
|
|
|
# Copyright (c) Meta Platforms, IAny, nc. and affiliates.
|
|
# All rights reserved.
|
|
#
|
|
# This source code is licensed under the terms described in the LICENSE file in
|
|
# the root directory of this source tree.
|
|
|
|
import copy
|
|
import json
|
|
import logging
|
|
import multiprocessing
|
|
import os
|
|
import tempfile
|
|
import time
|
|
import uuid
|
|
from enum import Enum
|
|
from typing import Callable, Generator, Literal, Optional, Union
|
|
|
|
import torch
|
|
import zmq
|
|
from fairscale.nn.model_parallel.initialize import (
|
|
get_model_parallel_group,
|
|
get_model_parallel_rank,
|
|
get_model_parallel_src_rank,
|
|
)
|
|
from pydantic import BaseModel, Field
|
|
from torch.distributed.launcher.api import LaunchConfig, elastic_launch
|
|
from typing_extensions import Annotated
|
|
|
|
from llama_stack.models.llama.datatypes import GenerationResult
|
|
from llama_stack.providers.utils.inference.prompt_adapter import (
|
|
ChatCompletionRequestWithRawContent,
|
|
CompletionRequestWithRawContent,
|
|
)
|
|
|
|
log = logging.getLogger(__name__)
|
|
|
|
|
|
class ProcessingMessageName(str, Enum):
|
|
ready_request = "ready_request"
|
|
ready_response = "ready_response"
|
|
end_sentinel = "end_sentinel"
|
|
cancel_sentinel = "cancel_sentinel"
|
|
task_request = "task_request"
|
|
task_response = "task_response"
|
|
exception_response = "exception_response"
|
|
|
|
|
|
class ReadyRequest(BaseModel):
|
|
type: Literal[ProcessingMessageName.ready_request] = ProcessingMessageName.ready_request
|
|
|
|
|
|
class ReadyResponse(BaseModel):
|
|
type: Literal[ProcessingMessageName.ready_response] = ProcessingMessageName.ready_response
|
|
|
|
|
|
class EndSentinel(BaseModel):
|
|
type: Literal[ProcessingMessageName.end_sentinel] = ProcessingMessageName.end_sentinel
|
|
|
|
|
|
class CancelSentinel(BaseModel):
|
|
type: Literal[ProcessingMessageName.cancel_sentinel] = ProcessingMessageName.cancel_sentinel
|
|
|
|
|
|
class TaskRequest(BaseModel):
|
|
type: Literal[ProcessingMessageName.task_request] = ProcessingMessageName.task_request
|
|
task: Union[CompletionRequestWithRawContent, ChatCompletionRequestWithRawContent]
|
|
|
|
|
|
class TaskResponse(BaseModel):
|
|
type: Literal[ProcessingMessageName.task_response] = ProcessingMessageName.task_response
|
|
result: GenerationResult
|
|
|
|
|
|
class ExceptionResponse(BaseModel):
|
|
type: Literal[ProcessingMessageName.exception_response] = ProcessingMessageName.exception_response
|
|
error: str
|
|
|
|
|
|
ProcessingMessage = Union[
|
|
ReadyRequest,
|
|
ReadyResponse,
|
|
EndSentinel,
|
|
CancelSentinel,
|
|
TaskRequest,
|
|
TaskResponse,
|
|
ExceptionResponse,
|
|
]
|
|
|
|
|
|
class ProcessingMessageWrapper(BaseModel):
|
|
payload: Annotated[
|
|
ProcessingMessage,
|
|
Field(discriminator="type"),
|
|
]
|
|
|
|
|
|
def mp_rank_0() -> bool:
|
|
return get_model_parallel_rank() == 0
|
|
|
|
|
|
def encode_msg(msg: ProcessingMessage) -> bytes:
|
|
return ProcessingMessageWrapper(payload=msg).model_dump_json().encode("utf-8")
|
|
|
|
|
|
def retrieve_requests(reply_socket_url: str):
|
|
if mp_rank_0():
|
|
context = zmq.Context()
|
|
reply_socket = context.socket(zmq.ROUTER)
|
|
reply_socket.connect(reply_socket_url)
|
|
|
|
while True:
|
|
client_id, obj = maybe_get_work(reply_socket)
|
|
if obj is None:
|
|
time.sleep(0.01)
|
|
continue
|
|
|
|
ready_response = ReadyResponse()
|
|
reply_socket.send_multipart([client_id, encode_msg(ready_response)])
|
|
break
|
|
|
|
def send_obj(obj: ProcessingMessage):
|
|
reply_socket.send_multipart([client_id, encode_msg(obj)])
|
|
|
|
while True:
|
|
tasks = [None]
|
|
if mp_rank_0():
|
|
client_id, maybe_task_json = maybe_get_work(reply_socket)
|
|
if maybe_task_json is not None:
|
|
task = maybe_parse_message(maybe_task_json)
|
|
# there is still an unknown unclean GeneratorExit happening resulting in a
|
|
# cancel sentinel getting queued _after_ we have finished sending everything :/
|
|
# kind of a hack this is :/
|
|
if task is not None and not isinstance(task, CancelSentinel):
|
|
tasks = [task]
|
|
|
|
torch.distributed.broadcast_object_list(
|
|
tasks,
|
|
src=get_model_parallel_src_rank(),
|
|
group=get_model_parallel_group(),
|
|
)
|
|
|
|
task = tasks[0]
|
|
if task is None:
|
|
time.sleep(0.1)
|
|
else:
|
|
try:
|
|
out = yield task
|
|
if out is None:
|
|
break
|
|
|
|
for obj in out:
|
|
updates = [None]
|
|
if mp_rank_0():
|
|
_, update_json = maybe_get_work(reply_socket)
|
|
update = maybe_parse_message(update_json)
|
|
if isinstance(update, CancelSentinel):
|
|
updates = [update]
|
|
else:
|
|
# only send the update if it's not cancelled otherwise the object sits in the socket
|
|
# and gets pulled in the next request lol
|
|
send_obj(TaskResponse(result=obj))
|
|
|
|
torch.distributed.broadcast_object_list(
|
|
updates,
|
|
src=get_model_parallel_src_rank(),
|
|
group=get_model_parallel_group(),
|
|
)
|
|
if isinstance(updates[0], CancelSentinel):
|
|
log.info("quitting generation loop because request was cancelled")
|
|
break
|
|
|
|
if mp_rank_0():
|
|
send_obj(EndSentinel())
|
|
except Exception as e:
|
|
log.exception("exception in generation loop")
|
|
|
|
if mp_rank_0():
|
|
send_obj(ExceptionResponse(error=str(e)))
|
|
|
|
if mp_rank_0():
|
|
send_obj(EndSentinel())
|
|
|
|
|
|
def maybe_get_work(sock: zmq.Socket):
|
|
message = None
|
|
client_id = None
|
|
try:
|
|
client_id, obj = sock.recv_multipart(zmq.NOBLOCK)
|
|
message = obj.decode("utf-8")
|
|
except zmq.ZMQError as e:
|
|
if e.errno != zmq.EAGAIN:
|
|
raise e
|
|
|
|
return client_id, message
|
|
|
|
|
|
def maybe_parse_message(maybe_json: Optional[str]) -> Optional[ProcessingMessage]:
|
|
if maybe_json is None:
|
|
return None
|
|
try:
|
|
return parse_message(maybe_json)
|
|
except json.JSONDecodeError:
|
|
return None
|
|
except ValueError:
|
|
return None
|
|
|
|
|
|
def parse_message(json_str: str) -> ProcessingMessage:
|
|
data = json.loads(json_str)
|
|
return copy.deepcopy(ProcessingMessageWrapper(**data).payload)
|
|
|
|
|
|
def worker_process_entrypoint(
|
|
reply_socket_url: str,
|
|
init_model_cb: Callable,
|
|
) -> None:
|
|
model = init_model_cb()
|
|
torch.distributed.barrier()
|
|
time.sleep(1)
|
|
|
|
# run the requests co-routine which retrieves requests from the socket
|
|
# and sends responses (we provide) back to the caller
|
|
req_gen = retrieve_requests(reply_socket_url)
|
|
result = None
|
|
while True:
|
|
try:
|
|
task = req_gen.send(result)
|
|
if isinstance(task, str) and task == EndSentinel():
|
|
break
|
|
|
|
assert isinstance(task, TaskRequest)
|
|
result = model(task.task)
|
|
except StopIteration:
|
|
break
|
|
|
|
log.info("[debug] worker process done")
|
|
|
|
|
|
def launch_dist_group(
|
|
reply_socket_url: str,
|
|
model_parallel_size: int,
|
|
init_model_cb: Callable,
|
|
**kwargs,
|
|
) -> None:
|
|
with tempfile.TemporaryDirectory() as tmpdir:
|
|
# TODO: track workers and if they terminate, tell parent process about it so cleanup can happen
|
|
launch_config = LaunchConfig(
|
|
max_nodes=1,
|
|
min_nodes=1,
|
|
nproc_per_node=model_parallel_size,
|
|
start_method="fork",
|
|
rdzv_backend="c10d",
|
|
rdzv_endpoint=os.path.join(tmpdir, "rdzv"),
|
|
rdzv_configs={"store_type": "file", "timeout": 90},
|
|
max_restarts=0,
|
|
monitor_interval=1,
|
|
run_id=str(uuid.uuid4()),
|
|
)
|
|
elastic_launch(launch_config, entrypoint=worker_process_entrypoint)(
|
|
reply_socket_url,
|
|
init_model_cb,
|
|
)
|
|
|
|
|
|
def start_model_parallel_process(
|
|
model_parallel_size: int,
|
|
init_model_cb: Callable,
|
|
**kwargs,
|
|
):
|
|
context = zmq.Context()
|
|
request_socket = context.socket(zmq.DEALER)
|
|
|
|
# Binding the request socket to a random port
|
|
request_socket.bind("tcp://127.0.0.1:0")
|
|
|
|
main_process_url = request_socket.getsockopt_string(zmq.LAST_ENDPOINT)
|
|
|
|
ctx = multiprocessing.get_context("spawn")
|
|
process = ctx.Process(
|
|
target=launch_dist_group,
|
|
args=(
|
|
main_process_url,
|
|
model_parallel_size,
|
|
init_model_cb,
|
|
),
|
|
kwargs=kwargs,
|
|
)
|
|
process.start()
|
|
|
|
# wait until the model is loaded; rank 0 will send a message to indicate it's ready
|
|
|
|
request_socket.send(encode_msg(ReadyRequest()))
|
|
_response = request_socket.recv()
|
|
log.info("Loaded model...")
|
|
|
|
return request_socket, process
|
|
|
|
|
|
class ModelParallelProcessGroup:
|
|
def __init__(
|
|
self,
|
|
model_parallel_size: int,
|
|
init_model_cb: Callable,
|
|
**kwargs,
|
|
):
|
|
self.model_parallel_size = model_parallel_size
|
|
self.init_model_cb = init_model_cb
|
|
self.started = False
|
|
self.running = False
|
|
|
|
def start(self):
|
|
assert not self.started, "process group already started"
|
|
self.request_socket, self.process = start_model_parallel_process(
|
|
self.model_parallel_size,
|
|
self.init_model_cb,
|
|
)
|
|
self.started = True
|
|
|
|
def stop(self):
|
|
assert self.started, "process group not started"
|
|
if self.process.is_alive():
|
|
self.request_socket.send(encode_msg(EndSentinel()), zmq.NOBLOCK)
|
|
self.process.join()
|
|
self.started = False
|
|
|
|
def run_inference(
|
|
self,
|
|
req: Union[CompletionRequestWithRawContent, ChatCompletionRequestWithRawContent],
|
|
) -> Generator:
|
|
assert not self.running, "inference already running"
|
|
|
|
self.running = True
|
|
try:
|
|
self.request_socket.send(encode_msg(TaskRequest(task=req)))
|
|
while True:
|
|
obj_json = self.request_socket.recv()
|
|
obj = parse_message(obj_json)
|
|
|
|
if isinstance(obj, EndSentinel):
|
|
break
|
|
|
|
if isinstance(obj, ExceptionResponse):
|
|
log.error(f"[debug] got exception {obj.error}")
|
|
raise Exception(obj.error)
|
|
|
|
if isinstance(obj, TaskResponse):
|
|
yield obj.result
|
|
|
|
except GeneratorExit:
|
|
self.request_socket.send(encode_msg(CancelSentinel()))
|
|
while True:
|
|
obj_json = self.request_socket.send()
|
|
obj = parse_message(obj_json)
|
|
if isinstance(obj, EndSentinel):
|
|
break
|
|
finally:
|
|
self.running = False
|