mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-07-29 15:23:51 +00:00
address feedback
This commit is contained in:
parent
d6924327ce
commit
3adf1dc20e
1 changed files with 56 additions and 38 deletions
|
@ -10,7 +10,8 @@ import os
|
||||||
import tempfile
|
import tempfile
|
||||||
import time
|
import time
|
||||||
import uuid
|
import uuid
|
||||||
from typing import Any, Callable, Generator, List, Optional, Union
|
from enum import Enum
|
||||||
|
from typing import Any, Callable, Generator, List, Literal, Optional, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
|
@ -24,9 +25,10 @@ from fairscale.nn.model_parallel.initialize import (
|
||||||
|
|
||||||
from llama_models.llama3.api.datatypes import Message, ToolPromptFormat
|
from llama_models.llama3.api.datatypes import Message, ToolPromptFormat
|
||||||
|
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
from torch.distributed.launcher.api import elastic_launch, LaunchConfig
|
from torch.distributed.launcher.api import elastic_launch, LaunchConfig
|
||||||
|
from typing_extensions import Annotated
|
||||||
|
|
||||||
from .generation import TokenResult
|
from .generation import TokenResult
|
||||||
|
|
||||||
|
@ -39,40 +41,67 @@ class InferenceArgs(BaseModel):
|
||||||
logprobs: bool
|
logprobs: bool
|
||||||
tool_prompt_format: ToolPromptFormat
|
tool_prompt_format: ToolPromptFormat
|
||||||
|
|
||||||
# avoid converting llama_models types to BaseModel for now
|
|
||||||
class Config:
|
class ProcessingMessageName(str, Enum):
|
||||||
arbitrary_types_allowed = True
|
ready_request = "ready_request"
|
||||||
|
ready_response = "ready_response"
|
||||||
|
end_sentinel = "end_sentinel"
|
||||||
|
cancel_sentinel = "cancel_sentinel"
|
||||||
|
task_request = "task_request"
|
||||||
|
task_response = "task_response"
|
||||||
|
exception_response = "exception_response"
|
||||||
|
|
||||||
|
|
||||||
class ReadyRequest(BaseModel):
|
class ReadyRequest(BaseModel):
|
||||||
|
type: Literal[ProcessingMessageName.ready_request] = (
|
||||||
|
ProcessingMessageName.ready_request
|
||||||
|
)
|
||||||
message: str = "READY?"
|
message: str = "READY?"
|
||||||
|
|
||||||
|
|
||||||
class ReadyResponse(BaseModel):
|
class ReadyResponse(BaseModel):
|
||||||
|
type: Literal[ProcessingMessageName.ready_response] = (
|
||||||
|
ProcessingMessageName.ready_response
|
||||||
|
)
|
||||||
message: str = "YES READY"
|
message: str = "YES READY"
|
||||||
|
|
||||||
|
|
||||||
class EndSentinel(BaseModel):
|
class EndSentinel(BaseModel):
|
||||||
|
type: Literal[ProcessingMessageName.end_sentinel] = (
|
||||||
|
ProcessingMessageName.end_sentinel
|
||||||
|
)
|
||||||
message: str = "__end_sentinel__"
|
message: str = "__end_sentinel__"
|
||||||
|
|
||||||
|
|
||||||
class CancelSentinel(BaseModel):
|
class CancelSentinel(BaseModel):
|
||||||
|
type: Literal[ProcessingMessageName.cancel_sentinel] = (
|
||||||
|
ProcessingMessageName.cancel_sentinel
|
||||||
|
)
|
||||||
message: str = "__cancel_sentinel__"
|
message: str = "__cancel_sentinel__"
|
||||||
|
|
||||||
|
|
||||||
class TaskRequest(BaseModel):
|
class TaskRequest(BaseModel):
|
||||||
|
type: Literal[ProcessingMessageName.task_request] = (
|
||||||
|
ProcessingMessageName.task_request
|
||||||
|
)
|
||||||
task: InferenceArgs
|
task: InferenceArgs
|
||||||
|
|
||||||
|
|
||||||
class TaskResponse(BaseModel):
|
class TaskResponse(BaseModel):
|
||||||
|
type: Literal[ProcessingMessageName.task_response] = (
|
||||||
|
ProcessingMessageName.task_response
|
||||||
|
)
|
||||||
result: TokenResult
|
result: TokenResult
|
||||||
|
|
||||||
|
|
||||||
class ExceptionResponse(BaseModel):
|
class ExceptionResponse(BaseModel):
|
||||||
|
type: Literal[ProcessingMessageName.exception_response] = (
|
||||||
|
ProcessingMessageName.exception_response
|
||||||
|
)
|
||||||
error: str
|
error: str
|
||||||
|
|
||||||
|
|
||||||
Message = Union[
|
ProcessingMessage = Union[
|
||||||
ReadyRequest,
|
ReadyRequest,
|
||||||
ReadyResponse,
|
ReadyResponse,
|
||||||
EndSentinel,
|
EndSentinel,
|
||||||
|
@ -83,10 +112,21 @@ Message = Union[
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
|
class ProcessingMessageWrapper(BaseModel):
|
||||||
|
payload: Annotated[
|
||||||
|
ProcessingMessage,
|
||||||
|
Field(discriminator="type"),
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
def mp_rank_0() -> bool:
|
def mp_rank_0() -> bool:
|
||||||
return get_model_parallel_rank() == 0
|
return get_model_parallel_rank() == 0
|
||||||
|
|
||||||
|
|
||||||
|
def encode_msg(msg: ProcessingMessage) -> bytes:
|
||||||
|
return msg.model_dump_json().encode("utf-8")
|
||||||
|
|
||||||
|
|
||||||
def retrieve_requests(reply_socket_url: str):
|
def retrieve_requests(reply_socket_url: str):
|
||||||
if mp_rank_0():
|
if mp_rank_0():
|
||||||
context = zmq.Context()
|
context = zmq.Context()
|
||||||
|
@ -100,13 +140,11 @@ def retrieve_requests(reply_socket_url: str):
|
||||||
continue
|
continue
|
||||||
|
|
||||||
ready_response = ReadyResponse()
|
ready_response = ReadyResponse()
|
||||||
reply_socket.send_multipart(
|
reply_socket.send_multipart([client_id, encode_msg(ready_response)])
|
||||||
[client_id, ready_response.model_dump_json().encode("utf-8")]
|
|
||||||
)
|
|
||||||
break
|
break
|
||||||
|
|
||||||
def send_obj(obj: Union[Message, BaseModel]):
|
def send_obj(obj: Union[ProcessingMessage, BaseModel]):
|
||||||
reply_socket.send_multipart([client_id, obj.model_dump_json().encode("utf-8")])
|
reply_socket.send_multipart([client_id, encode_msg(obj)])
|
||||||
|
|
||||||
while True:
|
while True:
|
||||||
tasks = [None]
|
tasks = [None]
|
||||||
|
@ -183,7 +221,7 @@ def maybe_get_work(sock: zmq.Socket):
|
||||||
return client_id, message
|
return client_id, message
|
||||||
|
|
||||||
|
|
||||||
def maybe_parse_message(maybe_json: Optional[str]) -> Optional[Message]:
|
def maybe_parse_message(maybe_json: Optional[str]) -> Optional[ProcessingMessage]:
|
||||||
if maybe_json is None:
|
if maybe_json is None:
|
||||||
return None
|
return None
|
||||||
try:
|
try:
|
||||||
|
@ -194,25 +232,9 @@ def maybe_parse_message(maybe_json: Optional[str]) -> Optional[Message]:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
def parse_message(json_str: str) -> Message:
|
def parse_message(json_str: str) -> ProcessingMessage:
|
||||||
data = json.loads(json_str)
|
data = json.loads(json_str)
|
||||||
if "message" in data:
|
return ProcessingMessageWrapper(**data).payload
|
||||||
if data["message"] == "READY?":
|
|
||||||
return ReadyRequest.model_validate(data)
|
|
||||||
elif data["message"] == "YES READY":
|
|
||||||
return ReadyResponse.model_validate(data)
|
|
||||||
elif data["message"] == "__end_sentinel__":
|
|
||||||
return EndSentinel.model_validate(data)
|
|
||||||
elif data["message"] == "__cancel_sentinel__":
|
|
||||||
return CancelSentinel.model_validate(data)
|
|
||||||
elif "task" in data:
|
|
||||||
return TaskRequest.model_validate(data)
|
|
||||||
elif "result" in data:
|
|
||||||
return TaskResponse.model_validate(data)
|
|
||||||
elif "error" in data:
|
|
||||||
return ExceptionResponse.model_validate(data)
|
|
||||||
else:
|
|
||||||
raise ValueError(f"Unknown message type: {data}")
|
|
||||||
|
|
||||||
|
|
||||||
def worker_process_entrypoint(
|
def worker_process_entrypoint(
|
||||||
|
@ -297,7 +319,7 @@ def start_model_parallel_process(
|
||||||
|
|
||||||
# wait until the model is loaded; rank 0 will send a message to indicate it's ready
|
# wait until the model is loaded; rank 0 will send a message to indicate it's ready
|
||||||
|
|
||||||
request_socket.send(ReadyRequest().model_dump_json().encode("utf-8"))
|
request_socket.send(encode_msg(ReadyRequest()))
|
||||||
response = request_socket.recv()
|
response = request_socket.recv()
|
||||||
print(f"Finished model load {response}")
|
print(f"Finished model load {response}")
|
||||||
|
|
||||||
|
@ -327,9 +349,7 @@ class ModelParallelProcessGroup:
|
||||||
def stop(self):
|
def stop(self):
|
||||||
assert self.started, "process group not started"
|
assert self.started, "process group not started"
|
||||||
if self.process.is_alive():
|
if self.process.is_alive():
|
||||||
self.request_socket.send(
|
self.request_socket.send(encode_msg(EndSentinel()), zmq.NOBLOCK)
|
||||||
EndSentinel().model_dump_json().encode("utf-8"), zmq.NOBLOCK
|
|
||||||
)
|
|
||||||
self.process.join()
|
self.process.join()
|
||||||
self.started = False
|
self.started = False
|
||||||
|
|
||||||
|
@ -337,9 +357,7 @@ class ModelParallelProcessGroup:
|
||||||
assert not self.running, "inference already running"
|
assert not self.running, "inference already running"
|
||||||
|
|
||||||
self.running = True
|
self.running = True
|
||||||
self.request_socket.send(
|
self.request_socket.send(encode_msg(TaskRequest(task=inference_args)))
|
||||||
TaskRequest(task=inference_args).model_dump_json().encode("utf-8")
|
|
||||||
)
|
|
||||||
try:
|
try:
|
||||||
while True:
|
while True:
|
||||||
obj_json = self.request_socket.recv()
|
obj_json = self.request_socket.recv()
|
||||||
|
@ -356,7 +374,7 @@ class ModelParallelProcessGroup:
|
||||||
yield obj.result
|
yield obj.result
|
||||||
|
|
||||||
except GeneratorExit as e:
|
except GeneratorExit as e:
|
||||||
self.request_socket.send(CancelSentinel().model_dump_json().encode("utf-8"))
|
self.request_socket.send(encode_msg(CancelSentinel()))
|
||||||
while True:
|
while True:
|
||||||
obj_json = self.request_socket.send()
|
obj_json = self.request_socket.send()
|
||||||
obj = parse_message(obj_json)
|
obj = parse_message(obj_json)
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue