address feedback

This commit is contained in:
dltn 2024-10-09 11:14:02 -07:00
parent d6924327ce
commit 3adf1dc20e

View file

@ -10,7 +10,8 @@ import os
import tempfile
import time
import uuid
from typing import Any, Callable, Generator, List, Optional, Union
from enum import Enum
from typing import Any, Callable, Generator, List, Literal, Optional, Union
import torch
@ -24,9 +25,10 @@ from fairscale.nn.model_parallel.initialize import (
from llama_models.llama3.api.datatypes import Message, ToolPromptFormat
from pydantic import BaseModel
from pydantic import BaseModel, Field
from torch.distributed.launcher.api import elastic_launch, LaunchConfig
from typing_extensions import Annotated
from .generation import TokenResult
@ -39,40 +41,67 @@ class InferenceArgs(BaseModel):
logprobs: bool
tool_prompt_format: ToolPromptFormat
# avoid converting llama_models types to BaseModel for now
class Config:
arbitrary_types_allowed = True
class ProcessingMessageName(str, Enum):
ready_request = "ready_request"
ready_response = "ready_response"
end_sentinel = "end_sentinel"
cancel_sentinel = "cancel_sentinel"
task_request = "task_request"
task_response = "task_response"
exception_response = "exception_response"
class ReadyRequest(BaseModel):
type: Literal[ProcessingMessageName.ready_request] = (
ProcessingMessageName.ready_request
)
message: str = "READY?"
class ReadyResponse(BaseModel):
type: Literal[ProcessingMessageName.ready_response] = (
ProcessingMessageName.ready_response
)
message: str = "YES READY"
class EndSentinel(BaseModel):
type: Literal[ProcessingMessageName.end_sentinel] = (
ProcessingMessageName.end_sentinel
)
message: str = "__end_sentinel__"
class CancelSentinel(BaseModel):
type: Literal[ProcessingMessageName.cancel_sentinel] = (
ProcessingMessageName.cancel_sentinel
)
message: str = "__cancel_sentinel__"
class TaskRequest(BaseModel):
type: Literal[ProcessingMessageName.task_request] = (
ProcessingMessageName.task_request
)
task: InferenceArgs
class TaskResponse(BaseModel):
type: Literal[ProcessingMessageName.task_response] = (
ProcessingMessageName.task_response
)
result: TokenResult
class ExceptionResponse(BaseModel):
type: Literal[ProcessingMessageName.exception_response] = (
ProcessingMessageName.exception_response
)
error: str
Message = Union[
ProcessingMessage = Union[
ReadyRequest,
ReadyResponse,
EndSentinel,
@ -83,10 +112,21 @@ Message = Union[
]
class ProcessingMessageWrapper(BaseModel):
payload: Annotated[
ProcessingMessage,
Field(discriminator="type"),
]
def mp_rank_0() -> bool:
return get_model_parallel_rank() == 0
def encode_msg(msg: ProcessingMessage) -> bytes:
return msg.model_dump_json().encode("utf-8")
def retrieve_requests(reply_socket_url: str):
if mp_rank_0():
context = zmq.Context()
@ -100,13 +140,11 @@ def retrieve_requests(reply_socket_url: str):
continue
ready_response = ReadyResponse()
reply_socket.send_multipart(
[client_id, ready_response.model_dump_json().encode("utf-8")]
)
reply_socket.send_multipart([client_id, encode_msg(ready_response)])
break
def send_obj(obj: Union[Message, BaseModel]):
reply_socket.send_multipart([client_id, obj.model_dump_json().encode("utf-8")])
def send_obj(obj: Union[ProcessingMessage, BaseModel]):
reply_socket.send_multipart([client_id, encode_msg(obj)])
while True:
tasks = [None]
@ -183,7 +221,7 @@ def maybe_get_work(sock: zmq.Socket):
return client_id, message
def maybe_parse_message(maybe_json: Optional[str]) -> Optional[Message]:
def maybe_parse_message(maybe_json: Optional[str]) -> Optional[ProcessingMessage]:
if maybe_json is None:
return None
try:
@ -194,25 +232,9 @@ def maybe_parse_message(maybe_json: Optional[str]) -> Optional[Message]:
return None
def parse_message(json_str: str) -> Message:
def parse_message(json_str: str) -> ProcessingMessage:
data = json.loads(json_str)
if "message" in data:
if data["message"] == "READY?":
return ReadyRequest.model_validate(data)
elif data["message"] == "YES READY":
return ReadyResponse.model_validate(data)
elif data["message"] == "__end_sentinel__":
return EndSentinel.model_validate(data)
elif data["message"] == "__cancel_sentinel__":
return CancelSentinel.model_validate(data)
elif "task" in data:
return TaskRequest.model_validate(data)
elif "result" in data:
return TaskResponse.model_validate(data)
elif "error" in data:
return ExceptionResponse.model_validate(data)
else:
raise ValueError(f"Unknown message type: {data}")
return ProcessingMessageWrapper(**data).payload
def worker_process_entrypoint(
@ -297,7 +319,7 @@ def start_model_parallel_process(
# wait until the model is loaded; rank 0 will send a message to indicate it's ready
request_socket.send(ReadyRequest().model_dump_json().encode("utf-8"))
request_socket.send(encode_msg(ReadyRequest()))
response = request_socket.recv()
print(f"Finished model load {response}")
@ -327,9 +349,7 @@ class ModelParallelProcessGroup:
def stop(self):
assert self.started, "process group not started"
if self.process.is_alive():
self.request_socket.send(
EndSentinel().model_dump_json().encode("utf-8"), zmq.NOBLOCK
)
self.request_socket.send(encode_msg(EndSentinel()), zmq.NOBLOCK)
self.process.join()
self.started = False
@ -337,9 +357,7 @@ class ModelParallelProcessGroup:
assert not self.running, "inference already running"
self.running = True
self.request_socket.send(
TaskRequest(task=inference_args).model_dump_json().encode("utf-8")
)
self.request_socket.send(encode_msg(TaskRequest(task=inference_args)))
try:
while True:
obj_json = self.request_socket.recv()
@ -356,7 +374,7 @@ class ModelParallelProcessGroup:
yield obj.result
except GeneratorExit as e:
self.request_socket.send(CancelSentinel().model_dump_json().encode("utf-8"))
self.request_socket.send(encode_msg(CancelSentinel()))
while True:
obj_json = self.request_socket.send()
obj = parse_message(obj_json)