address feedback

This commit is contained in:
dltn 2024-10-09 11:14:02 -07:00
parent d6924327ce
commit 3adf1dc20e

View file

@ -10,7 +10,8 @@ import os
import tempfile import tempfile
import time import time
import uuid import uuid
from typing import Any, Callable, Generator, List, Optional, Union from enum import Enum
from typing import Any, Callable, Generator, List, Literal, Optional, Union
import torch import torch
@ -24,9 +25,10 @@ from fairscale.nn.model_parallel.initialize import (
from llama_models.llama3.api.datatypes import Message, ToolPromptFormat from llama_models.llama3.api.datatypes import Message, ToolPromptFormat
from pydantic import BaseModel from pydantic import BaseModel, Field
from torch.distributed.launcher.api import elastic_launch, LaunchConfig from torch.distributed.launcher.api import elastic_launch, LaunchConfig
from typing_extensions import Annotated
from .generation import TokenResult from .generation import TokenResult
@ -39,40 +41,67 @@ class InferenceArgs(BaseModel):
logprobs: bool logprobs: bool
tool_prompt_format: ToolPromptFormat tool_prompt_format: ToolPromptFormat
# avoid converting llama_models types to BaseModel for now
class Config: class ProcessingMessageName(str, Enum):
arbitrary_types_allowed = True ready_request = "ready_request"
ready_response = "ready_response"
end_sentinel = "end_sentinel"
cancel_sentinel = "cancel_sentinel"
task_request = "task_request"
task_response = "task_response"
exception_response = "exception_response"
class ReadyRequest(BaseModel): class ReadyRequest(BaseModel):
type: Literal[ProcessingMessageName.ready_request] = (
ProcessingMessageName.ready_request
)
message: str = "READY?" message: str = "READY?"
class ReadyResponse(BaseModel): class ReadyResponse(BaseModel):
type: Literal[ProcessingMessageName.ready_response] = (
ProcessingMessageName.ready_response
)
message: str = "YES READY" message: str = "YES READY"
class EndSentinel(BaseModel): class EndSentinel(BaseModel):
type: Literal[ProcessingMessageName.end_sentinel] = (
ProcessingMessageName.end_sentinel
)
message: str = "__end_sentinel__" message: str = "__end_sentinel__"
class CancelSentinel(BaseModel): class CancelSentinel(BaseModel):
type: Literal[ProcessingMessageName.cancel_sentinel] = (
ProcessingMessageName.cancel_sentinel
)
message: str = "__cancel_sentinel__" message: str = "__cancel_sentinel__"
class TaskRequest(BaseModel): class TaskRequest(BaseModel):
type: Literal[ProcessingMessageName.task_request] = (
ProcessingMessageName.task_request
)
task: InferenceArgs task: InferenceArgs
class TaskResponse(BaseModel): class TaskResponse(BaseModel):
type: Literal[ProcessingMessageName.task_response] = (
ProcessingMessageName.task_response
)
result: TokenResult result: TokenResult
class ExceptionResponse(BaseModel): class ExceptionResponse(BaseModel):
type: Literal[ProcessingMessageName.exception_response] = (
ProcessingMessageName.exception_response
)
error: str error: str
Message = Union[ ProcessingMessage = Union[
ReadyRequest, ReadyRequest,
ReadyResponse, ReadyResponse,
EndSentinel, EndSentinel,
@ -83,10 +112,21 @@ Message = Union[
] ]
class ProcessingMessageWrapper(BaseModel):
payload: Annotated[
ProcessingMessage,
Field(discriminator="type"),
]
def mp_rank_0() -> bool: def mp_rank_0() -> bool:
return get_model_parallel_rank() == 0 return get_model_parallel_rank() == 0
def encode_msg(msg: ProcessingMessage) -> bytes:
return msg.model_dump_json().encode("utf-8")
def retrieve_requests(reply_socket_url: str): def retrieve_requests(reply_socket_url: str):
if mp_rank_0(): if mp_rank_0():
context = zmq.Context() context = zmq.Context()
@ -100,13 +140,11 @@ def retrieve_requests(reply_socket_url: str):
continue continue
ready_response = ReadyResponse() ready_response = ReadyResponse()
reply_socket.send_multipart( reply_socket.send_multipart([client_id, encode_msg(ready_response)])
[client_id, ready_response.model_dump_json().encode("utf-8")]
)
break break
def send_obj(obj: Union[Message, BaseModel]): def send_obj(obj: Union[ProcessingMessage, BaseModel]):
reply_socket.send_multipart([client_id, obj.model_dump_json().encode("utf-8")]) reply_socket.send_multipart([client_id, encode_msg(obj)])
while True: while True:
tasks = [None] tasks = [None]
@ -183,7 +221,7 @@ def maybe_get_work(sock: zmq.Socket):
return client_id, message return client_id, message
def maybe_parse_message(maybe_json: Optional[str]) -> Optional[Message]: def maybe_parse_message(maybe_json: Optional[str]) -> Optional[ProcessingMessage]:
if maybe_json is None: if maybe_json is None:
return None return None
try: try:
@ -194,25 +232,9 @@ def maybe_parse_message(maybe_json: Optional[str]) -> Optional[Message]:
return None return None
def parse_message(json_str: str) -> Message: def parse_message(json_str: str) -> ProcessingMessage:
data = json.loads(json_str) data = json.loads(json_str)
if "message" in data: return ProcessingMessageWrapper(**data).payload
if data["message"] == "READY?":
return ReadyRequest.model_validate(data)
elif data["message"] == "YES READY":
return ReadyResponse.model_validate(data)
elif data["message"] == "__end_sentinel__":
return EndSentinel.model_validate(data)
elif data["message"] == "__cancel_sentinel__":
return CancelSentinel.model_validate(data)
elif "task" in data:
return TaskRequest.model_validate(data)
elif "result" in data:
return TaskResponse.model_validate(data)
elif "error" in data:
return ExceptionResponse.model_validate(data)
else:
raise ValueError(f"Unknown message type: {data}")
def worker_process_entrypoint( def worker_process_entrypoint(
@ -297,7 +319,7 @@ def start_model_parallel_process(
# wait until the model is loaded; rank 0 will send a message to indicate it's ready # wait until the model is loaded; rank 0 will send a message to indicate it's ready
request_socket.send(ReadyRequest().model_dump_json().encode("utf-8")) request_socket.send(encode_msg(ReadyRequest()))
response = request_socket.recv() response = request_socket.recv()
print(f"Finished model load {response}") print(f"Finished model load {response}")
@ -327,9 +349,7 @@ class ModelParallelProcessGroup:
def stop(self): def stop(self):
assert self.started, "process group not started" assert self.started, "process group not started"
if self.process.is_alive(): if self.process.is_alive():
self.request_socket.send( self.request_socket.send(encode_msg(EndSentinel()), zmq.NOBLOCK)
EndSentinel().model_dump_json().encode("utf-8"), zmq.NOBLOCK
)
self.process.join() self.process.join()
self.started = False self.started = False
@ -337,9 +357,7 @@ class ModelParallelProcessGroup:
assert not self.running, "inference already running" assert not self.running, "inference already running"
self.running = True self.running = True
self.request_socket.send( self.request_socket.send(encode_msg(TaskRequest(task=inference_args)))
TaskRequest(task=inference_args).model_dump_json().encode("utf-8")
)
try: try:
while True: while True:
obj_json = self.request_socket.recv() obj_json = self.request_socket.recv()
@ -356,7 +374,7 @@ class ModelParallelProcessGroup:
yield obj.result yield obj.result
except GeneratorExit as e: except GeneratorExit as e:
self.request_socket.send(CancelSentinel().model_dump_json().encode("utf-8")) self.request_socket.send(encode_msg(CancelSentinel()))
while True: while True:
obj_json = self.request_socket.send() obj_json = self.request_socket.send()
obj = parse_message(obj_json) obj = parse_message(obj_json)