mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-07-29 07:14:20 +00:00
address feedback
This commit is contained in:
parent
d6924327ce
commit
3adf1dc20e
1 changed files with 56 additions and 38 deletions
|
@ -10,7 +10,8 @@ import os
|
|||
import tempfile
|
||||
import time
|
||||
import uuid
|
||||
from typing import Any, Callable, Generator, List, Optional, Union
|
||||
from enum import Enum
|
||||
from typing import Any, Callable, Generator, List, Literal, Optional, Union
|
||||
|
||||
import torch
|
||||
|
||||
|
@ -24,9 +25,10 @@ from fairscale.nn.model_parallel.initialize import (
|
|||
|
||||
from llama_models.llama3.api.datatypes import Message, ToolPromptFormat
|
||||
|
||||
from pydantic import BaseModel
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from torch.distributed.launcher.api import elastic_launch, LaunchConfig
|
||||
from typing_extensions import Annotated
|
||||
|
||||
from .generation import TokenResult
|
||||
|
||||
|
@ -39,40 +41,67 @@ class InferenceArgs(BaseModel):
|
|||
logprobs: bool
|
||||
tool_prompt_format: ToolPromptFormat
|
||||
|
||||
# avoid converting llama_models types to BaseModel for now
|
||||
class Config:
|
||||
arbitrary_types_allowed = True
|
||||
|
||||
class ProcessingMessageName(str, Enum):
|
||||
ready_request = "ready_request"
|
||||
ready_response = "ready_response"
|
||||
end_sentinel = "end_sentinel"
|
||||
cancel_sentinel = "cancel_sentinel"
|
||||
task_request = "task_request"
|
||||
task_response = "task_response"
|
||||
exception_response = "exception_response"
|
||||
|
||||
|
||||
class ReadyRequest(BaseModel):
|
||||
type: Literal[ProcessingMessageName.ready_request] = (
|
||||
ProcessingMessageName.ready_request
|
||||
)
|
||||
message: str = "READY?"
|
||||
|
||||
|
||||
class ReadyResponse(BaseModel):
|
||||
type: Literal[ProcessingMessageName.ready_response] = (
|
||||
ProcessingMessageName.ready_response
|
||||
)
|
||||
message: str = "YES READY"
|
||||
|
||||
|
||||
class EndSentinel(BaseModel):
|
||||
type: Literal[ProcessingMessageName.end_sentinel] = (
|
||||
ProcessingMessageName.end_sentinel
|
||||
)
|
||||
message: str = "__end_sentinel__"
|
||||
|
||||
|
||||
class CancelSentinel(BaseModel):
|
||||
type: Literal[ProcessingMessageName.cancel_sentinel] = (
|
||||
ProcessingMessageName.cancel_sentinel
|
||||
)
|
||||
message: str = "__cancel_sentinel__"
|
||||
|
||||
|
||||
class TaskRequest(BaseModel):
|
||||
type: Literal[ProcessingMessageName.task_request] = (
|
||||
ProcessingMessageName.task_request
|
||||
)
|
||||
task: InferenceArgs
|
||||
|
||||
|
||||
class TaskResponse(BaseModel):
|
||||
type: Literal[ProcessingMessageName.task_response] = (
|
||||
ProcessingMessageName.task_response
|
||||
)
|
||||
result: TokenResult
|
||||
|
||||
|
||||
class ExceptionResponse(BaseModel):
|
||||
type: Literal[ProcessingMessageName.exception_response] = (
|
||||
ProcessingMessageName.exception_response
|
||||
)
|
||||
error: str
|
||||
|
||||
|
||||
Message = Union[
|
||||
ProcessingMessage = Union[
|
||||
ReadyRequest,
|
||||
ReadyResponse,
|
||||
EndSentinel,
|
||||
|
@ -83,10 +112,21 @@ Message = Union[
|
|||
]
|
||||
|
||||
|
||||
class ProcessingMessageWrapper(BaseModel):
|
||||
payload: Annotated[
|
||||
ProcessingMessage,
|
||||
Field(discriminator="type"),
|
||||
]
|
||||
|
||||
|
||||
def mp_rank_0() -> bool:
|
||||
return get_model_parallel_rank() == 0
|
||||
|
||||
|
||||
def encode_msg(msg: ProcessingMessage) -> bytes:
|
||||
return msg.model_dump_json().encode("utf-8")
|
||||
|
||||
|
||||
def retrieve_requests(reply_socket_url: str):
|
||||
if mp_rank_0():
|
||||
context = zmq.Context()
|
||||
|
@ -100,13 +140,11 @@ def retrieve_requests(reply_socket_url: str):
|
|||
continue
|
||||
|
||||
ready_response = ReadyResponse()
|
||||
reply_socket.send_multipart(
|
||||
[client_id, ready_response.model_dump_json().encode("utf-8")]
|
||||
)
|
||||
reply_socket.send_multipart([client_id, encode_msg(ready_response)])
|
||||
break
|
||||
|
||||
def send_obj(obj: Union[Message, BaseModel]):
|
||||
reply_socket.send_multipart([client_id, obj.model_dump_json().encode("utf-8")])
|
||||
def send_obj(obj: Union[ProcessingMessage, BaseModel]):
|
||||
reply_socket.send_multipart([client_id, encode_msg(obj)])
|
||||
|
||||
while True:
|
||||
tasks = [None]
|
||||
|
@ -183,7 +221,7 @@ def maybe_get_work(sock: zmq.Socket):
|
|||
return client_id, message
|
||||
|
||||
|
||||
def maybe_parse_message(maybe_json: Optional[str]) -> Optional[Message]:
|
||||
def maybe_parse_message(maybe_json: Optional[str]) -> Optional[ProcessingMessage]:
|
||||
if maybe_json is None:
|
||||
return None
|
||||
try:
|
||||
|
@ -194,25 +232,9 @@ def maybe_parse_message(maybe_json: Optional[str]) -> Optional[Message]:
|
|||
return None
|
||||
|
||||
|
||||
def parse_message(json_str: str) -> Message:
|
||||
def parse_message(json_str: str) -> ProcessingMessage:
|
||||
data = json.loads(json_str)
|
||||
if "message" in data:
|
||||
if data["message"] == "READY?":
|
||||
return ReadyRequest.model_validate(data)
|
||||
elif data["message"] == "YES READY":
|
||||
return ReadyResponse.model_validate(data)
|
||||
elif data["message"] == "__end_sentinel__":
|
||||
return EndSentinel.model_validate(data)
|
||||
elif data["message"] == "__cancel_sentinel__":
|
||||
return CancelSentinel.model_validate(data)
|
||||
elif "task" in data:
|
||||
return TaskRequest.model_validate(data)
|
||||
elif "result" in data:
|
||||
return TaskResponse.model_validate(data)
|
||||
elif "error" in data:
|
||||
return ExceptionResponse.model_validate(data)
|
||||
else:
|
||||
raise ValueError(f"Unknown message type: {data}")
|
||||
return ProcessingMessageWrapper(**data).payload
|
||||
|
||||
|
||||
def worker_process_entrypoint(
|
||||
|
@ -297,7 +319,7 @@ def start_model_parallel_process(
|
|||
|
||||
# wait until the model is loaded; rank 0 will send a message to indicate it's ready
|
||||
|
||||
request_socket.send(ReadyRequest().model_dump_json().encode("utf-8"))
|
||||
request_socket.send(encode_msg(ReadyRequest()))
|
||||
response = request_socket.recv()
|
||||
print(f"Finished model load {response}")
|
||||
|
||||
|
@ -327,9 +349,7 @@ class ModelParallelProcessGroup:
|
|||
def stop(self):
|
||||
assert self.started, "process group not started"
|
||||
if self.process.is_alive():
|
||||
self.request_socket.send(
|
||||
EndSentinel().model_dump_json().encode("utf-8"), zmq.NOBLOCK
|
||||
)
|
||||
self.request_socket.send(encode_msg(EndSentinel()), zmq.NOBLOCK)
|
||||
self.process.join()
|
||||
self.started = False
|
||||
|
||||
|
@ -337,9 +357,7 @@ class ModelParallelProcessGroup:
|
|||
assert not self.running, "inference already running"
|
||||
|
||||
self.running = True
|
||||
self.request_socket.send(
|
||||
TaskRequest(task=inference_args).model_dump_json().encode("utf-8")
|
||||
)
|
||||
self.request_socket.send(encode_msg(TaskRequest(task=inference_args)))
|
||||
try:
|
||||
while True:
|
||||
obj_json = self.request_socket.recv()
|
||||
|
@ -356,7 +374,7 @@ class ModelParallelProcessGroup:
|
|||
yield obj.result
|
||||
|
||||
except GeneratorExit as e:
|
||||
self.request_socket.send(CancelSentinel().model_dump_json().encode("utf-8"))
|
||||
self.request_socket.send(encode_msg(CancelSentinel()))
|
||||
while True:
|
||||
obj_json = self.request_socket.send()
|
||||
obj = parse_message(obj_json)
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue