Make all methods async def again; add completion() for meta-reference (#270)

PR #201 had made several changes while trying to fix issues with getting the stream=False branches of inference and agents API working. As part of this, it made a change which was slightly gratuitous. Namely, making chat_completion() and brethren "def" instead of "async def".

The rationale was that this allowed the user (within llama-stack) of this to use it as:

```
async for chunk in api.chat_completion(params)
```

However, it causes unnecessary confusion for several folks. Given that clients (e.g., llama-stack-apps) anyway use the SDK methods (which are completely isolated) this choice was not ideal. Let's revert back so the call now looks like:

```
async for chunk in await api.chat_completion(params)
```

Bonus: Added a completion() implementation for the meta-reference provider. Technically should have been another PR :)
This commit is contained in:
Ashwin Bharambe 2024-10-18 20:50:59 -07:00 committed by GitHub
parent 95a96afe34
commit 2089427d60
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
23 changed files with 330 additions and 213 deletions

View file

@ -4,6 +4,12 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
# Copyright (c) Meta Platforms, IAny, nc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import json
import multiprocessing
import os
@ -11,10 +17,9 @@ import tempfile
import time
import uuid
from enum import Enum
from typing import Callable, Generator, List, Literal, Optional, Union
from typing import Callable, Generator, Literal, Optional, Union
import torch
import zmq
from fairscale.nn.model_parallel.initialize import (
@ -23,25 +28,16 @@ from fairscale.nn.model_parallel.initialize import (
get_model_parallel_src_rank,
)
from llama_models.llama3.api.datatypes import Message, ToolPromptFormat
from pydantic import BaseModel, Field
from torch.distributed.launcher.api import elastic_launch, LaunchConfig
from typing_extensions import Annotated
from llama_stack.apis.inference import ChatCompletionRequest, CompletionRequest
from .generation import TokenResult
class InferenceArgs(BaseModel):
messages: List[Message]
temperature: float
top_p: float
max_gen_len: int
logprobs: bool
tool_prompt_format: ToolPromptFormat
class ProcessingMessageName(str, Enum):
ready_request = "ready_request"
ready_response = "ready_response"
@ -80,7 +76,7 @@ class TaskRequest(BaseModel):
type: Literal[ProcessingMessageName.task_request] = (
ProcessingMessageName.task_request
)
task: InferenceArgs
task: Union[CompletionRequest, ChatCompletionRequest]
class TaskResponse(BaseModel):
@ -349,11 +345,13 @@ class ModelParallelProcessGroup:
self.process.join()
self.started = False
def run_inference(self, inference_args: InferenceArgs) -> Generator:
def run_inference(
self, req: Union[CompletionRequest, ChatCompletionRequest]
) -> Generator:
assert not self.running, "inference already running"
self.running = True
self.request_socket.send(encode_msg(TaskRequest(task=inference_args)))
self.request_socket.send(encode_msg(TaskRequest(task=req)))
try:
while True:
obj_json = self.request_socket.recv()