mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-06-28 10:54:19 +00:00
Update types in parallel_utils for meta-refernece-gpu impl
This commit is contained in:
parent
b33086d632
commit
f19eb8eee3
1 changed files with 10 additions and 7 deletions
|
@ -34,7 +34,10 @@ from pydantic import BaseModel, Field
|
||||||
from torch.distributed.launcher.api import elastic_launch, LaunchConfig
|
from torch.distributed.launcher.api import elastic_launch, LaunchConfig
|
||||||
from typing_extensions import Annotated
|
from typing_extensions import Annotated
|
||||||
|
|
||||||
from llama_stack.apis.inference import ChatCompletionRequest, CompletionRequest
|
from llama_stack.providers.utils.inference.prompt_adapter import (
|
||||||
|
ChatCompletionRequestWithRawContent,
|
||||||
|
CompletionRequestWithRawContent,
|
||||||
|
)
|
||||||
|
|
||||||
from .generation import TokenResult
|
from .generation import TokenResult
|
||||||
|
|
||||||
|
@ -79,7 +82,7 @@ class TaskRequest(BaseModel):
|
||||||
type: Literal[ProcessingMessageName.task_request] = (
|
type: Literal[ProcessingMessageName.task_request] = (
|
||||||
ProcessingMessageName.task_request
|
ProcessingMessageName.task_request
|
||||||
)
|
)
|
||||||
task: Union[CompletionRequest, ChatCompletionRequest]
|
task: Union[CompletionRequestWithRawContent, ChatCompletionRequestWithRawContent]
|
||||||
|
|
||||||
|
|
||||||
class TaskResponse(BaseModel):
|
class TaskResponse(BaseModel):
|
||||||
|
@ -264,9 +267,6 @@ def launch_dist_group(
|
||||||
init_model_cb: Callable,
|
init_model_cb: Callable,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> None:
|
) -> None:
|
||||||
id = uuid.uuid4().hex
|
|
||||||
dist_url = f"file:///tmp/llama3_{id}_{time.time()}"
|
|
||||||
|
|
||||||
with tempfile.TemporaryDirectory() as tmpdir:
|
with tempfile.TemporaryDirectory() as tmpdir:
|
||||||
# TODO: track workers and if they terminate, tell parent process about it so cleanup can happen
|
# TODO: track workers and if they terminate, tell parent process about it so cleanup can happen
|
||||||
launch_config = LaunchConfig(
|
launch_config = LaunchConfig(
|
||||||
|
@ -315,7 +315,7 @@ def start_model_parallel_process(
|
||||||
# wait until the model is loaded; rank 0 will send a message to indicate it's ready
|
# wait until the model is loaded; rank 0 will send a message to indicate it's ready
|
||||||
|
|
||||||
request_socket.send(encode_msg(ReadyRequest()))
|
request_socket.send(encode_msg(ReadyRequest()))
|
||||||
response = request_socket.recv()
|
_response = request_socket.recv()
|
||||||
log.info("Loaded model...")
|
log.info("Loaded model...")
|
||||||
|
|
||||||
return request_socket, process
|
return request_socket, process
|
||||||
|
@ -349,7 +349,10 @@ class ModelParallelProcessGroup:
|
||||||
self.started = False
|
self.started = False
|
||||||
|
|
||||||
def run_inference(
|
def run_inference(
|
||||||
self, req: Union[CompletionRequest, ChatCompletionRequest]
|
self,
|
||||||
|
req: Union[
|
||||||
|
CompletionRequestWithRawContent, ChatCompletionRequestWithRawContent
|
||||||
|
],
|
||||||
) -> Generator:
|
) -> Generator:
|
||||||
assert not self.running, "inference already running"
|
assert not self.running, "inference already running"
|
||||||
|
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue