mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-06-28 02:53:30 +00:00
Update types in parallel_utils for meta-refernece-gpu impl
This commit is contained in:
parent
b33086d632
commit
f19eb8eee3
1 changed files with 10 additions and 7 deletions
|
@ -34,7 +34,10 @@ from pydantic import BaseModel, Field
|
|||
from torch.distributed.launcher.api import elastic_launch, LaunchConfig
|
||||
from typing_extensions import Annotated
|
||||
|
||||
from llama_stack.apis.inference import ChatCompletionRequest, CompletionRequest
|
||||
from llama_stack.providers.utils.inference.prompt_adapter import (
|
||||
ChatCompletionRequestWithRawContent,
|
||||
CompletionRequestWithRawContent,
|
||||
)
|
||||
|
||||
from .generation import TokenResult
|
||||
|
||||
|
@ -79,7 +82,7 @@ class TaskRequest(BaseModel):
|
|||
type: Literal[ProcessingMessageName.task_request] = (
|
||||
ProcessingMessageName.task_request
|
||||
)
|
||||
task: Union[CompletionRequest, ChatCompletionRequest]
|
||||
task: Union[CompletionRequestWithRawContent, ChatCompletionRequestWithRawContent]
|
||||
|
||||
|
||||
class TaskResponse(BaseModel):
|
||||
|
@ -264,9 +267,6 @@ def launch_dist_group(
|
|||
init_model_cb: Callable,
|
||||
**kwargs,
|
||||
) -> None:
|
||||
id = uuid.uuid4().hex
|
||||
dist_url = f"file:///tmp/llama3_{id}_{time.time()}"
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
# TODO: track workers and if they terminate, tell parent process about it so cleanup can happen
|
||||
launch_config = LaunchConfig(
|
||||
|
@ -315,7 +315,7 @@ def start_model_parallel_process(
|
|||
# wait until the model is loaded; rank 0 will send a message to indicate it's ready
|
||||
|
||||
request_socket.send(encode_msg(ReadyRequest()))
|
||||
response = request_socket.recv()
|
||||
_response = request_socket.recv()
|
||||
log.info("Loaded model...")
|
||||
|
||||
return request_socket, process
|
||||
|
@ -349,7 +349,10 @@ class ModelParallelProcessGroup:
|
|||
self.started = False
|
||||
|
||||
def run_inference(
|
||||
self, req: Union[CompletionRequest, ChatCompletionRequest]
|
||||
self,
|
||||
req: Union[
|
||||
CompletionRequestWithRawContent, ChatCompletionRequestWithRawContent
|
||||
],
|
||||
) -> Generator:
|
||||
assert not self.running, "inference already running"
|
||||
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue