forked from phoenix-oss/llama-stack-mirror
		
	
		
			
				
	
	
		
			363 lines
		
	
	
	
		
			11 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			363 lines
		
	
	
	
		
			11 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
| # Copyright (c) Meta Platforms, Inc. and affiliates.
 | |
| # All rights reserved.
 | |
| #
 | |
| # This source code is licensed under the terms described in the LICENSE file in
 | |
| # the root directory of this source tree.
 | |
| 
 | |
| # Copyright (c) Meta Platforms, IAny, nc. and affiliates.
 | |
| # All rights reserved.
 | |
| #
 | |
| # This source code is licensed under the terms described in the LICENSE file in
 | |
| # the root directory of this source tree.
 | |
| 
 | |
| import json
 | |
| import logging
 | |
| import multiprocessing
 | |
| import os
 | |
| import tempfile
 | |
| import time
 | |
| import uuid
 | |
| from enum import Enum
 | |
| from typing import Callable, Generator, Literal, Optional, Union
 | |
| 
 | |
| import torch
 | |
| import zmq
 | |
| from fairscale.nn.model_parallel.initialize import (
 | |
|     get_model_parallel_group,
 | |
|     get_model_parallel_rank,
 | |
|     get_model_parallel_src_rank,
 | |
| )
 | |
| from pydantic import BaseModel, Field
 | |
| from torch.distributed.launcher.api import LaunchConfig, elastic_launch
 | |
| from typing_extensions import Annotated
 | |
| 
 | |
| from llama_stack.providers.utils.inference.prompt_adapter import (
 | |
|     ChatCompletionRequestWithRawContent,
 | |
|     CompletionRequestWithRawContent,
 | |
| )
 | |
| 
 | |
| from .common import TokenResult
 | |
| 
 | |
| log = logging.getLogger(__name__)
 | |
| 
 | |
| 
 | |
| class ProcessingMessageName(str, Enum):
 | |
|     ready_request = "ready_request"
 | |
|     ready_response = "ready_response"
 | |
|     end_sentinel = "end_sentinel"
 | |
|     cancel_sentinel = "cancel_sentinel"
 | |
|     task_request = "task_request"
 | |
|     task_response = "task_response"
 | |
|     exception_response = "exception_response"
 | |
| 
 | |
| 
 | |
| class ReadyRequest(BaseModel):
 | |
|     type: Literal[ProcessingMessageName.ready_request] = ProcessingMessageName.ready_request
 | |
| 
 | |
| 
 | |
| class ReadyResponse(BaseModel):
 | |
|     type: Literal[ProcessingMessageName.ready_response] = ProcessingMessageName.ready_response
 | |
| 
 | |
| 
 | |
| class EndSentinel(BaseModel):
 | |
|     type: Literal[ProcessingMessageName.end_sentinel] = ProcessingMessageName.end_sentinel
 | |
| 
 | |
| 
 | |
| class CancelSentinel(BaseModel):
 | |
|     type: Literal[ProcessingMessageName.cancel_sentinel] = ProcessingMessageName.cancel_sentinel
 | |
| 
 | |
| 
 | |
| class TaskRequest(BaseModel):
 | |
|     type: Literal[ProcessingMessageName.task_request] = ProcessingMessageName.task_request
 | |
|     task: Union[CompletionRequestWithRawContent, ChatCompletionRequestWithRawContent]
 | |
| 
 | |
| 
 | |
| class TaskResponse(BaseModel):
 | |
|     type: Literal[ProcessingMessageName.task_response] = ProcessingMessageName.task_response
 | |
|     result: TokenResult
 | |
| 
 | |
| 
 | |
| class ExceptionResponse(BaseModel):
 | |
|     type: Literal[ProcessingMessageName.exception_response] = ProcessingMessageName.exception_response
 | |
|     error: str
 | |
| 
 | |
| 
 | |
| ProcessingMessage = Union[
 | |
|     ReadyRequest,
 | |
|     ReadyResponse,
 | |
|     EndSentinel,
 | |
|     CancelSentinel,
 | |
|     TaskRequest,
 | |
|     TaskResponse,
 | |
|     ExceptionResponse,
 | |
| ]
 | |
| 
 | |
| 
 | |
| class ProcessingMessageWrapper(BaseModel):
 | |
|     payload: Annotated[
 | |
|         ProcessingMessage,
 | |
|         Field(discriminator="type"),
 | |
|     ]
 | |
| 
 | |
| 
 | |
| def mp_rank_0() -> bool:
 | |
|     return get_model_parallel_rank() == 0
 | |
| 
 | |
| 
 | |
| def encode_msg(msg: ProcessingMessage) -> bytes:
 | |
|     return ProcessingMessageWrapper(payload=msg).model_dump_json().encode("utf-8")
 | |
| 
 | |
| 
 | |
| def retrieve_requests(reply_socket_url: str):
 | |
|     if mp_rank_0():
 | |
|         context = zmq.Context()
 | |
|         reply_socket = context.socket(zmq.ROUTER)
 | |
|         reply_socket.connect(reply_socket_url)
 | |
| 
 | |
|         while True:
 | |
|             client_id, obj = maybe_get_work(reply_socket)
 | |
|             if obj is None:
 | |
|                 time.sleep(0.01)
 | |
|                 continue
 | |
| 
 | |
|             ready_response = ReadyResponse()
 | |
|             reply_socket.send_multipart([client_id, encode_msg(ready_response)])
 | |
|             break
 | |
| 
 | |
|     def send_obj(obj: ProcessingMessage):
 | |
|         reply_socket.send_multipart([client_id, encode_msg(obj)])
 | |
| 
 | |
|     while True:
 | |
|         tasks = [None]
 | |
|         if mp_rank_0():
 | |
|             client_id, maybe_task_json = maybe_get_work(reply_socket)
 | |
|             if maybe_task_json is not None:
 | |
|                 task = maybe_parse_message(maybe_task_json)
 | |
|                 # there is still an unknown unclean GeneratorExit happening resulting in a
 | |
|                 # cancel sentinel getting queued _after_ we have finished sending everything :/
 | |
|                 # kind of a hack this is :/
 | |
|                 if task is not None and not isinstance(task, CancelSentinel):
 | |
|                     tasks = [task]
 | |
| 
 | |
|         torch.distributed.broadcast_object_list(
 | |
|             tasks,
 | |
|             src=get_model_parallel_src_rank(),
 | |
|             group=get_model_parallel_group(),
 | |
|         )
 | |
| 
 | |
|         task = tasks[0]
 | |
|         if task is None:
 | |
|             time.sleep(0.1)
 | |
|         else:
 | |
|             try:
 | |
|                 out = yield task
 | |
|                 if out is None:
 | |
|                     break
 | |
| 
 | |
|                 for obj in out:
 | |
|                     updates = [None]
 | |
|                     if mp_rank_0():
 | |
|                         _, update_json = maybe_get_work(reply_socket)
 | |
|                         update = maybe_parse_message(update_json)
 | |
|                         if isinstance(update, CancelSentinel):
 | |
|                             updates = [update]
 | |
|                         else:
 | |
|                             # only send the update if it's not cancelled otherwise the object sits in the socket
 | |
|                             # and gets pulled in the next request lol
 | |
|                             send_obj(TaskResponse(result=obj))
 | |
| 
 | |
|                     torch.distributed.broadcast_object_list(
 | |
|                         updates,
 | |
|                         src=get_model_parallel_src_rank(),
 | |
|                         group=get_model_parallel_group(),
 | |
|                     )
 | |
|                     if isinstance(updates[0], CancelSentinel):
 | |
|                         log.info("quitting generation loop because request was cancelled")
 | |
|                         break
 | |
| 
 | |
|                 if mp_rank_0():
 | |
|                     send_obj(EndSentinel())
 | |
|             except Exception as e:
 | |
|                 log.exception("exception in generation loop")
 | |
| 
 | |
|                 if mp_rank_0():
 | |
|                     send_obj(ExceptionResponse(error=str(e)))
 | |
| 
 | |
|     if mp_rank_0():
 | |
|         send_obj(EndSentinel())
 | |
| 
 | |
| 
 | |
| def maybe_get_work(sock: zmq.Socket):
 | |
|     message = None
 | |
|     client_id = None
 | |
|     try:
 | |
|         client_id, obj = sock.recv_multipart(zmq.NOBLOCK)
 | |
|         message = obj.decode("utf-8")
 | |
|     except zmq.ZMQError as e:
 | |
|         if e.errno != zmq.EAGAIN:
 | |
|             raise e
 | |
| 
 | |
|     return client_id, message
 | |
| 
 | |
| 
 | |
| def maybe_parse_message(maybe_json: Optional[str]) -> Optional[ProcessingMessage]:
 | |
|     if maybe_json is None:
 | |
|         return None
 | |
|     try:
 | |
|         return parse_message(maybe_json)
 | |
|     except json.JSONDecodeError:
 | |
|         return None
 | |
|     except ValueError:
 | |
|         return None
 | |
| 
 | |
| 
 | |
| def parse_message(json_str: str) -> ProcessingMessage:
 | |
|     data = json.loads(json_str)
 | |
|     return ProcessingMessageWrapper(**data).payload
 | |
| 
 | |
| 
 | |
| def worker_process_entrypoint(
 | |
|     reply_socket_url: str,
 | |
|     init_model_cb: Callable,
 | |
| ) -> None:
 | |
|     model = init_model_cb()
 | |
|     torch.distributed.barrier()
 | |
|     time.sleep(1)
 | |
| 
 | |
|     # run the requests co-routine which retrieves requests from the socket
 | |
|     # and sends responses (we provide) back to the caller
 | |
|     req_gen = retrieve_requests(reply_socket_url)
 | |
|     result = None
 | |
|     while True:
 | |
|         try:
 | |
|             task = req_gen.send(result)
 | |
|             if isinstance(task, str) and task == EndSentinel():
 | |
|                 break
 | |
| 
 | |
|             assert isinstance(task, TaskRequest)
 | |
|             result = model(task.task)
 | |
|         except StopIteration:
 | |
|             break
 | |
| 
 | |
|     log.info("[debug] worker process done")
 | |
| 
 | |
| 
 | |
| def launch_dist_group(
 | |
|     reply_socket_url: str,
 | |
|     model_parallel_size: int,
 | |
|     init_model_cb: Callable,
 | |
|     **kwargs,
 | |
| ) -> None:
 | |
|     with tempfile.TemporaryDirectory() as tmpdir:
 | |
|         # TODO: track workers and if they terminate, tell parent process about it so cleanup can happen
 | |
|         launch_config = LaunchConfig(
 | |
|             max_nodes=1,
 | |
|             min_nodes=1,
 | |
|             nproc_per_node=model_parallel_size,
 | |
|             start_method="fork",
 | |
|             rdzv_backend="c10d",
 | |
|             rdzv_endpoint=os.path.join(tmpdir, "rdzv"),
 | |
|             rdzv_configs={"store_type": "file", "timeout": 90},
 | |
|             max_restarts=0,
 | |
|             monitor_interval=1,
 | |
|             run_id=str(uuid.uuid4()),
 | |
|         )
 | |
|         elastic_launch(launch_config, entrypoint=worker_process_entrypoint)(
 | |
|             reply_socket_url,
 | |
|             init_model_cb,
 | |
|         )
 | |
| 
 | |
| 
 | |
| def start_model_parallel_process(
 | |
|     model_parallel_size: int,
 | |
|     init_model_cb: Callable,
 | |
|     **kwargs,
 | |
| ):
 | |
|     context = zmq.Context()
 | |
|     request_socket = context.socket(zmq.DEALER)
 | |
| 
 | |
|     # Binding the request socket to a random port
 | |
|     request_socket.bind("tcp://127.0.0.1:0")
 | |
| 
 | |
|     main_process_url = request_socket.getsockopt_string(zmq.LAST_ENDPOINT)
 | |
| 
 | |
|     ctx = multiprocessing.get_context("spawn")
 | |
|     process = ctx.Process(
 | |
|         target=launch_dist_group,
 | |
|         args=(
 | |
|             main_process_url,
 | |
|             model_parallel_size,
 | |
|             init_model_cb,
 | |
|         ),
 | |
|         kwargs=kwargs,
 | |
|     )
 | |
|     process.start()
 | |
| 
 | |
|     # wait until the model is loaded; rank 0 will send a message to indicate it's ready
 | |
| 
 | |
|     request_socket.send(encode_msg(ReadyRequest()))
 | |
|     _response = request_socket.recv()
 | |
|     log.info("Loaded model...")
 | |
| 
 | |
|     return request_socket, process
 | |
| 
 | |
| 
 | |
| class ModelParallelProcessGroup:
 | |
|     def __init__(
 | |
|         self,
 | |
|         model_parallel_size: int,
 | |
|         init_model_cb: Callable,
 | |
|         **kwargs,
 | |
|     ):
 | |
|         self.model_parallel_size = model_parallel_size
 | |
|         self.init_model_cb = init_model_cb
 | |
|         self.started = False
 | |
|         self.running = False
 | |
| 
 | |
|     def start(self):
 | |
|         assert not self.started, "process group already started"
 | |
|         self.request_socket, self.process = start_model_parallel_process(
 | |
|             self.model_parallel_size,
 | |
|             self.init_model_cb,
 | |
|         )
 | |
|         self.started = True
 | |
| 
 | |
|     def stop(self):
 | |
|         assert self.started, "process group not started"
 | |
|         if self.process.is_alive():
 | |
|             self.request_socket.send(encode_msg(EndSentinel()), zmq.NOBLOCK)
 | |
|             self.process.join()
 | |
|         self.started = False
 | |
| 
 | |
|     def run_inference(
 | |
|         self,
 | |
|         req: Union[CompletionRequestWithRawContent, ChatCompletionRequestWithRawContent],
 | |
|     ) -> Generator:
 | |
|         assert not self.running, "inference already running"
 | |
| 
 | |
|         self.running = True
 | |
|         try:
 | |
|             self.request_socket.send(encode_msg(TaskRequest(task=req)))
 | |
|             while True:
 | |
|                 obj_json = self.request_socket.recv()
 | |
|                 obj = parse_message(obj_json)
 | |
| 
 | |
|                 if isinstance(obj, EndSentinel):
 | |
|                     break
 | |
| 
 | |
|                 if isinstance(obj, ExceptionResponse):
 | |
|                     log.error(f"[debug] got exception {obj.error}")
 | |
|                     raise Exception(obj.error)
 | |
| 
 | |
|                 if isinstance(obj, TaskResponse):
 | |
|                     yield obj.result
 | |
| 
 | |
|         except GeneratorExit:
 | |
|             self.request_socket.send(encode_msg(CancelSentinel()))
 | |
|             while True:
 | |
|                 obj_json = self.request_socket.send()
 | |
|                 obj = parse_message(obj_json)
 | |
|                 if isinstance(obj, EndSentinel):
 | |
|                     break
 | |
|         finally:
 | |
|             self.running = False
 |