forked from phoenix-oss/llama-stack-mirror
## What does this PR do?
In this PR, we refactor the meta reference inference logic to support
- load the model during registering model instead of during spinning up
server
- support inference finetuned model checkpoint on top of native llama
model
## Why need these changes
To solve the existing pain points that
- user cannot lazy load the model and hot switch the inference
checkpoint after spinning up the server
- this blocks us doing inference and eval on the same sever for a
finetuned checkpoint after post training
- user cannot do inference on a finetuned checkpoint on top of native
llama models
## Expect user experience change
- The inference model won't be loaded when spinning up server. Instead,
it will be loaded during register model. If user add the model as models
resource in run.yaml, it will be registered and loaded automatically
when starting server. There is an optional flag 'skip_initialize' in
model metadata to skip model loading during registration.
- There is an optional flag 'llama_model' in model metadata to identify
the base model of the Model class for validation and initialize model
arch. model identifier no longer needs to be a native llama model
- the default inference model name updates from
'meta-llama/Llama-3.2-3B-Instruct' to 'Llama3.2-3B-Instruct'
- It aligns with the checkpoint folder name after running 'llama model
download'
- It aligns with the descriptor name defined in llama-models SKU list
bf5b0c4fe7/models/datatypes.py (L95)
## test
run python llama_stack/scripts/distro_codegen.py
**run unit test**
- torchrun $CONDA_PREFIX/bin/pytest -v -s -k "meta_reference"
--inference-model="Llama3.1-8B-Instruct"
./llama_stack/providers/tests/inference/test_text_inference.py
- torchrun $CONDA_PREFIX/bin/pytest -v -s -k "meta_reference"
--inference-model="Llama3.1-8B-Instruct"
./llama_stack/providers/tests/inference/test_model_registration.py
**test post training experience**
on server side run: llama stack run
llama_stack/templates/experimental-post-training/run.yaml
server is spinning up without model loaded
<img width="812" alt="Screenshot 2024-12-17 at 1 24 50 PM"
src="https://github.com/user-attachments/assets/ce1f606b-3b6f-452f-b48e-b3761ffd90f3"
/>
on client side, run: llama-stack-client --endpoint
http://devgpu018.nha2.facebook.com:5000 models register
Llama3.2-3B-Instruct
register model successfully and the model is loaded
<img width="1111" alt="Screenshot 2024-12-17 at 1 26 30 PM"
src="https://github.com/user-attachments/assets/56e02131-cf7d-4de5-8f63-fbdcb8c55c26"
/>
<img width="1541" alt="Screenshot 2024-12-17 at 1 26 09 PM"
src="https://github.com/user-attachments/assets/a83255a1-20f5-40a2-af51-55641410a115"
/>
if add "skip_initialize" in metadata, model is registered but isn't
loaded
on client side, run: llama-stack-client --endpoint
http://devgpu018.nha2.facebook.com:5000 inference chat-completion
--message "hello, what model are you?"
Inference the model succesfully
<img width="1121" alt="Screenshot 2024-12-17 at 1 27 33 PM"
src="https://github.com/user-attachments/assets/8e708545-3fe7-4a73-8754-1470fa5f1e75"
/>
**test inference experience**
run: llama stack run llama_stack/templates/meta-reference-gpu/run.yaml
model is loaded since the model is in resouce list in run.yaml
<img width="1537" alt="Screenshot 2024-12-17 at 1 30 19 PM"
src="https://github.com/user-attachments/assets/5c8af817-66eb-43f8-bf4c-f5e24b0a12c6"
/>
on client side, run: llama-stack-client --endpoint
http://devgpu018.nha2.facebook.com:5000 inference chat-completion
--message "hello, what model are you?"
inference successfully
<img width="1123" alt="Screenshot 2024-12-17 at 1 31 08 PM"
src="https://github.com/user-attachments/assets/471809aa-c65e-46dc-a37e-7094fb857f97"
/>
## inference on a finetuned model
**register a finetuned model that finetuned by post training api
(torchtune)**
- the model is registered and loaded successfully
- the model is shown up in the model list
<img width="974" alt="Screenshot 2024-12-18 at 3 56 33 PM"
src="https://github.com/user-attachments/assets/2994b4f5-4fa9-40c6-acc6-4b971479f3e2"
/>
**run inference**
<img width="977" alt="Screenshot 2024-12-18 at 3 57 59 PM"
src="https://github.com/user-attachments/assets/d117abbc-b2a0-41d8-a028-1a13128787b2"
/>
381 lines
11 KiB
Python
381 lines
11 KiB
Python
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
# All rights reserved.
|
|
#
|
|
# This source code is licensed under the terms described in the LICENSE file in
|
|
# the root directory of this source tree.
|
|
|
|
# Copyright (c) Meta Platforms, IAny, nc. and affiliates.
|
|
# All rights reserved.
|
|
#
|
|
# This source code is licensed under the terms described in the LICENSE file in
|
|
# the root directory of this source tree.
|
|
|
|
import json
|
|
import logging
|
|
import multiprocessing
|
|
import os
|
|
import tempfile
|
|
import time
|
|
import uuid
|
|
from enum import Enum
|
|
from typing import Callable, Generator, Literal, Optional, Union
|
|
|
|
import torch
|
|
import zmq
|
|
|
|
from fairscale.nn.model_parallel.initialize import (
|
|
get_model_parallel_group,
|
|
get_model_parallel_rank,
|
|
get_model_parallel_src_rank,
|
|
)
|
|
|
|
from pydantic import BaseModel, Field
|
|
|
|
from torch.distributed.launcher.api import elastic_launch, LaunchConfig
|
|
from typing_extensions import Annotated
|
|
|
|
from llama_stack.apis.inference import ChatCompletionRequest, CompletionRequest
|
|
|
|
from .generation import TokenResult
|
|
|
|
log = logging.getLogger(__name__)
|
|
|
|
|
|
class ProcessingMessageName(str, Enum):
|
|
ready_request = "ready_request"
|
|
ready_response = "ready_response"
|
|
end_sentinel = "end_sentinel"
|
|
cancel_sentinel = "cancel_sentinel"
|
|
task_request = "task_request"
|
|
task_response = "task_response"
|
|
exception_response = "exception_response"
|
|
|
|
|
|
class ReadyRequest(BaseModel):
|
|
type: Literal[ProcessingMessageName.ready_request] = (
|
|
ProcessingMessageName.ready_request
|
|
)
|
|
|
|
|
|
class ReadyResponse(BaseModel):
|
|
type: Literal[ProcessingMessageName.ready_response] = (
|
|
ProcessingMessageName.ready_response
|
|
)
|
|
|
|
|
|
class EndSentinel(BaseModel):
|
|
type: Literal[ProcessingMessageName.end_sentinel] = (
|
|
ProcessingMessageName.end_sentinel
|
|
)
|
|
|
|
|
|
class CancelSentinel(BaseModel):
|
|
type: Literal[ProcessingMessageName.cancel_sentinel] = (
|
|
ProcessingMessageName.cancel_sentinel
|
|
)
|
|
|
|
|
|
class TaskRequest(BaseModel):
|
|
type: Literal[ProcessingMessageName.task_request] = (
|
|
ProcessingMessageName.task_request
|
|
)
|
|
task: Union[CompletionRequest, ChatCompletionRequest]
|
|
|
|
|
|
class TaskResponse(BaseModel):
|
|
type: Literal[ProcessingMessageName.task_response] = (
|
|
ProcessingMessageName.task_response
|
|
)
|
|
result: TokenResult
|
|
|
|
|
|
class ExceptionResponse(BaseModel):
|
|
type: Literal[ProcessingMessageName.exception_response] = (
|
|
ProcessingMessageName.exception_response
|
|
)
|
|
error: str
|
|
|
|
|
|
ProcessingMessage = Union[
|
|
ReadyRequest,
|
|
ReadyResponse,
|
|
EndSentinel,
|
|
CancelSentinel,
|
|
TaskRequest,
|
|
TaskResponse,
|
|
ExceptionResponse,
|
|
]
|
|
|
|
|
|
class ProcessingMessageWrapper(BaseModel):
|
|
payload: Annotated[
|
|
ProcessingMessage,
|
|
Field(discriminator="type"),
|
|
]
|
|
|
|
|
|
def mp_rank_0() -> bool:
|
|
return get_model_parallel_rank() == 0
|
|
|
|
|
|
def encode_msg(msg: ProcessingMessage) -> bytes:
|
|
return ProcessingMessageWrapper(payload=msg).model_dump_json().encode("utf-8")
|
|
|
|
|
|
def retrieve_requests(reply_socket_url: str):
|
|
if mp_rank_0():
|
|
context = zmq.Context()
|
|
reply_socket = context.socket(zmq.ROUTER)
|
|
reply_socket.connect(reply_socket_url)
|
|
|
|
while True:
|
|
client_id, obj = maybe_get_work(reply_socket)
|
|
if obj is None:
|
|
time.sleep(0.01)
|
|
continue
|
|
|
|
ready_response = ReadyResponse()
|
|
reply_socket.send_multipart([client_id, encode_msg(ready_response)])
|
|
break
|
|
|
|
def send_obj(obj: ProcessingMessage):
|
|
reply_socket.send_multipart([client_id, encode_msg(obj)])
|
|
|
|
while True:
|
|
tasks = [None]
|
|
if mp_rank_0():
|
|
client_id, maybe_task_json = maybe_get_work(reply_socket)
|
|
if maybe_task_json is not None:
|
|
task = maybe_parse_message(maybe_task_json)
|
|
# there is still an unknown unclean GeneratorExit happening resulting in a
|
|
# cancel sentinel getting queued _after_ we have finished sending everything :/
|
|
# kind of a hack this is :/
|
|
if task is not None and not isinstance(task, CancelSentinel):
|
|
tasks = [task]
|
|
|
|
torch.distributed.broadcast_object_list(
|
|
tasks,
|
|
src=get_model_parallel_src_rank(),
|
|
group=get_model_parallel_group(),
|
|
)
|
|
|
|
task = tasks[0]
|
|
if task is None:
|
|
time.sleep(0.1)
|
|
else:
|
|
try:
|
|
out = yield task
|
|
if out is None:
|
|
break
|
|
|
|
for obj in out:
|
|
updates = [None]
|
|
if mp_rank_0():
|
|
_, update_json = maybe_get_work(reply_socket)
|
|
update = maybe_parse_message(update_json)
|
|
if isinstance(update, CancelSentinel):
|
|
updates = [update]
|
|
else:
|
|
# only send the update if it's not cancelled otherwise the object sits in the socket
|
|
# and gets pulled in the next request lol
|
|
send_obj(TaskResponse(result=obj))
|
|
|
|
torch.distributed.broadcast_object_list(
|
|
updates,
|
|
src=get_model_parallel_src_rank(),
|
|
group=get_model_parallel_group(),
|
|
)
|
|
if isinstance(updates[0], CancelSentinel):
|
|
log.info(
|
|
"quitting generation loop because request was cancelled"
|
|
)
|
|
break
|
|
|
|
if mp_rank_0():
|
|
send_obj(EndSentinel())
|
|
except Exception as e:
|
|
log.exception("exception in generation loop")
|
|
|
|
if mp_rank_0():
|
|
send_obj(ExceptionResponse(error=str(e)))
|
|
|
|
if mp_rank_0():
|
|
send_obj(EndSentinel())
|
|
|
|
|
|
def maybe_get_work(sock: zmq.Socket):
|
|
message = None
|
|
client_id = None
|
|
try:
|
|
client_id, obj = sock.recv_multipart(zmq.NOBLOCK)
|
|
message = obj.decode("utf-8")
|
|
except zmq.ZMQError as e:
|
|
if e.errno != zmq.EAGAIN:
|
|
raise e
|
|
|
|
return client_id, message
|
|
|
|
|
|
def maybe_parse_message(maybe_json: Optional[str]) -> Optional[ProcessingMessage]:
|
|
if maybe_json is None:
|
|
return None
|
|
try:
|
|
return parse_message(maybe_json)
|
|
except json.JSONDecodeError:
|
|
return None
|
|
except ValueError as e:
|
|
return None
|
|
|
|
|
|
def parse_message(json_str: str) -> ProcessingMessage:
|
|
data = json.loads(json_str)
|
|
return ProcessingMessageWrapper(**data).payload
|
|
|
|
|
|
def worker_process_entrypoint(
|
|
reply_socket_url: str,
|
|
init_model_cb: Callable,
|
|
) -> None:
|
|
model = init_model_cb()
|
|
torch.distributed.barrier()
|
|
time.sleep(1)
|
|
|
|
# run the requests co-routine which retrieves requests from the socket
|
|
# and sends responses (we provide) back to the caller
|
|
req_gen = retrieve_requests(reply_socket_url)
|
|
result = None
|
|
while True:
|
|
try:
|
|
task = req_gen.send(result)
|
|
if isinstance(task, str) and task == _END_SENTINEL:
|
|
break
|
|
|
|
assert isinstance(task, TaskRequest)
|
|
result = model(task.task)
|
|
except StopIteration:
|
|
break
|
|
|
|
log.info("[debug] worker process done")
|
|
|
|
|
|
def launch_dist_group(
|
|
reply_socket_url: str,
|
|
model_parallel_size: int,
|
|
init_model_cb: Callable,
|
|
**kwargs,
|
|
) -> None:
|
|
id = uuid.uuid4().hex
|
|
dist_url = f"file:///tmp/llama3_{id}_{time.time()}"
|
|
|
|
with tempfile.TemporaryDirectory() as tmpdir:
|
|
# TODO: track workers and if they terminate, tell parent process about it so cleanup can happen
|
|
launch_config = LaunchConfig(
|
|
max_nodes=1,
|
|
min_nodes=1,
|
|
nproc_per_node=model_parallel_size,
|
|
start_method="fork",
|
|
rdzv_backend="c10d",
|
|
rdzv_endpoint=os.path.join(tmpdir, "rdzv"),
|
|
rdzv_configs={"store_type": "file", "timeout": 90},
|
|
max_restarts=0,
|
|
monitor_interval=1,
|
|
run_id=str(uuid.uuid4()),
|
|
)
|
|
elastic_launch(launch_config, entrypoint=worker_process_entrypoint)(
|
|
reply_socket_url,
|
|
init_model_cb,
|
|
)
|
|
|
|
|
|
def start_model_parallel_process(
|
|
model_parallel_size: int,
|
|
init_model_cb: Callable,
|
|
**kwargs,
|
|
):
|
|
context = zmq.Context()
|
|
request_socket = context.socket(zmq.DEALER)
|
|
|
|
# Binding the request socket to a random port
|
|
request_socket.bind("tcp://127.0.0.1:0")
|
|
|
|
main_process_url = request_socket.getsockopt_string(zmq.LAST_ENDPOINT)
|
|
|
|
ctx = multiprocessing.get_context("spawn")
|
|
process = ctx.Process(
|
|
target=launch_dist_group,
|
|
args=(
|
|
main_process_url,
|
|
model_parallel_size,
|
|
init_model_cb,
|
|
),
|
|
kwargs=kwargs,
|
|
)
|
|
process.start()
|
|
|
|
# wait until the model is loaded; rank 0 will send a message to indicate it's ready
|
|
|
|
request_socket.send(encode_msg(ReadyRequest()))
|
|
response = request_socket.recv()
|
|
log.info("Loaded model...")
|
|
|
|
return request_socket, process
|
|
|
|
|
|
class ModelParallelProcessGroup:
|
|
def __init__(
|
|
self,
|
|
model_parallel_size: int,
|
|
init_model_cb: Callable,
|
|
**kwargs,
|
|
):
|
|
self.model_parallel_size = model_parallel_size
|
|
self.init_model_cb = init_model_cb
|
|
self.started = False
|
|
self.running = False
|
|
|
|
def start(self):
|
|
assert not self.started, "process group already started"
|
|
self.request_socket, self.process = start_model_parallel_process(
|
|
self.model_parallel_size,
|
|
self.init_model_cb,
|
|
)
|
|
self.started = True
|
|
|
|
def stop(self):
|
|
assert self.started, "process group not started"
|
|
if self.process.is_alive():
|
|
self.request_socket.send(encode_msg(EndSentinel()), zmq.NOBLOCK)
|
|
self.process.join()
|
|
self.started = False
|
|
|
|
def run_inference(
|
|
self, req: Union[CompletionRequest, ChatCompletionRequest]
|
|
) -> Generator:
|
|
assert not self.running, "inference already running"
|
|
|
|
self.running = True
|
|
self.request_socket.send(encode_msg(TaskRequest(task=req)))
|
|
try:
|
|
while True:
|
|
obj_json = self.request_socket.recv()
|
|
obj = parse_message(obj_json)
|
|
|
|
if isinstance(obj, EndSentinel):
|
|
break
|
|
|
|
if isinstance(obj, ExceptionResponse):
|
|
log.error(f"[debug] got exception {obj.error}")
|
|
raise Exception(obj.error)
|
|
|
|
if isinstance(obj, TaskResponse):
|
|
yield obj.result
|
|
|
|
except GeneratorExit as e:
|
|
self.request_socket.send(encode_msg(CancelSentinel()))
|
|
while True:
|
|
obj_json = self.request_socket.send()
|
|
obj = parse_message(obj_json)
|
|
if isinstance(obj, EndSentinel):
|
|
break
|
|
finally:
|
|
self.running = False
|