Update types in parallel_utils for meta-refernece-gpu impl

This commit is contained in:
Ashwin Bharambe 2024-12-19 13:58:20 -08:00
parent b33086d632
commit f19eb8eee3

View file

@ -34,7 +34,10 @@ from pydantic import BaseModel, Field
from torch.distributed.launcher.api import elastic_launch, LaunchConfig from torch.distributed.launcher.api import elastic_launch, LaunchConfig
from typing_extensions import Annotated from typing_extensions import Annotated
from llama_stack.apis.inference import ChatCompletionRequest, CompletionRequest from llama_stack.providers.utils.inference.prompt_adapter import (
ChatCompletionRequestWithRawContent,
CompletionRequestWithRawContent,
)
from .generation import TokenResult from .generation import TokenResult
@ -79,7 +82,7 @@ class TaskRequest(BaseModel):
type: Literal[ProcessingMessageName.task_request] = ( type: Literal[ProcessingMessageName.task_request] = (
ProcessingMessageName.task_request ProcessingMessageName.task_request
) )
task: Union[CompletionRequest, ChatCompletionRequest] task: Union[CompletionRequestWithRawContent, ChatCompletionRequestWithRawContent]
class TaskResponse(BaseModel): class TaskResponse(BaseModel):
@ -264,9 +267,6 @@ def launch_dist_group(
init_model_cb: Callable, init_model_cb: Callable,
**kwargs, **kwargs,
) -> None: ) -> None:
id = uuid.uuid4().hex
dist_url = f"file:///tmp/llama3_{id}_{time.time()}"
with tempfile.TemporaryDirectory() as tmpdir: with tempfile.TemporaryDirectory() as tmpdir:
# TODO: track workers and if they terminate, tell parent process about it so cleanup can happen # TODO: track workers and if they terminate, tell parent process about it so cleanup can happen
launch_config = LaunchConfig( launch_config = LaunchConfig(
@ -315,7 +315,7 @@ def start_model_parallel_process(
# wait until the model is loaded; rank 0 will send a message to indicate it's ready # wait until the model is loaded; rank 0 will send a message to indicate it's ready
request_socket.send(encode_msg(ReadyRequest())) request_socket.send(encode_msg(ReadyRequest()))
response = request_socket.recv() _response = request_socket.recv()
log.info("Loaded model...") log.info("Loaded model...")
return request_socket, process return request_socket, process
@ -349,7 +349,10 @@ class ModelParallelProcessGroup:
self.started = False self.started = False
def run_inference( def run_inference(
self, req: Union[CompletionRequest, ChatCompletionRequest] self,
req: Union[
CompletionRequestWithRawContent, ChatCompletionRequestWithRawContent
],
) -> Generator: ) -> Generator:
assert not self.running, "inference already running" assert not self.running, "inference already running"