Make all API methods async def again

This commit is contained in:
Ashwin Bharambe 2024-10-18 16:50:57 -07:00
parent 95a96afe34
commit 627edaf407
17 changed files with 120 additions and 145 deletions

View file

@ -23,11 +23,6 @@ from fairscale.nn.model_parallel.initialize import (
)
from llama_models.llama3.api.args import ModelArgs
from llama_models.llama3.api.chat_format import ChatFormat, ModelInput
from llama_models.llama3.api.datatypes import (
InterleavedTextMedia,
Message,
ToolPromptFormat,
)
from llama_models.llama3.api.tokenizer import Tokenizer
from llama_models.llama3.reference_impl.model import Transformer
from llama_models.llama3.reference_impl.multimodal.model import (
@ -38,7 +33,11 @@ from llama_models.sku_list import resolve_model
from pydantic import BaseModel
from termcolor import cprint
from llama_stack.apis.inference import * # noqa: F403
from llama_stack.distribution.utils.model_utils import model_local_dir
from llama_stack.providers.utils.inference.prompt_adapter import (
chat_completion_request_to_messages,
)
from .config import MetaReferenceInferenceConfig, MetaReferenceQuantizedInferenceConfig
@ -297,15 +296,11 @@ class Llama:
if all(eos_reached):
break
def text_completion(
def completion(
self,
content: InterleavedTextMedia,
temperature: float = 0.6,
top_p: float = 0.9,
max_gen_len: Optional[int] = None,
logprobs: bool = False,
echo: bool = False,
request: CompletionRequest,
) -> Generator:
sampling_params = request.sampling_params
if (
max_gen_len is None
or max_gen_len == 0
@ -313,26 +308,24 @@ class Llama:
):
max_gen_len = self.model.params.max_seq_len - 1
model_input = self.formatter.encode_content(content)
model_input = self.formatter.encode_content(request.content)
yield from self.generate(
model_input=model_input,
max_gen_len=max_gen_len,
temperature=temperature,
top_p=top_p,
logprobs=logprobs,
echo=echo,
temperature=sampling_params.temperature,
top_p=sampling_params.top_p,
logprobs=bool(request.logprobs),
echo=False,
)
def chat_completion(
self,
messages: List[Message],
temperature: float = 0.6,
top_p: float = 0.9,
max_gen_len: Optional[int] = None,
logprobs: bool = False,
tool_prompt_format: ToolPromptFormat = ToolPromptFormat.json,
request: ChatCompletionRequest,
) -> Generator:
messages = chat_completion_request_to_messages(request)
sampling_params = request.sampling_params
max_gen_len = sampling_params.max_tokens
if (
max_gen_len is None
or max_gen_len == 0
@ -343,12 +336,12 @@ class Llama:
yield from self.generate(
model_input=self.formatter.encode_dialog_prompt(
messages,
tool_prompt_format,
request.tool_prompt_format,
),
max_gen_len=max_gen_len,
temperature=temperature,
top_p=top_p,
logprobs=logprobs,
temperature=sampling_params.temperature,
top_p=sampling_params.top_p,
logprobs=bool(request.logprobs),
include_stop_token=True,
)

View file

@ -13,9 +13,6 @@ from llama_models.sku_list import resolve_model
from llama_models.llama3.api.datatypes import * # noqa: F403
from llama_stack.apis.inference import * # noqa: F403
from llama_stack.providers.datatypes import ModelDef, ModelsProtocolPrivate
from llama_stack.providers.utils.inference.prompt_adapter import (
chat_completion_request_to_messages,
)
from .config import MetaReferenceInferenceConfig
from .generation import Llama
@ -58,7 +55,18 @@ class MetaReferenceInferenceImpl(Inference, ModelsProtocolPrivate):
if self.config.create_distributed_process_group:
self.generator.stop()
def completion(
def check_model(self, request) -> None:
model = resolve_model(request.model)
if model is None:
raise RuntimeError(
f"Unknown model: {request.model}, Run `llama model list`"
)
elif model.descriptor() != self.model.descriptor():
raise RuntimeError(
f"Model mismatch: {request.model} != {self.model.descriptor()}"
)
async def completion(
self,
model: str,
content: InterleavedTextMedia,
@ -66,9 +74,19 @@ class MetaReferenceInferenceImpl(Inference, ModelsProtocolPrivate):
stream: Optional[bool] = False,
logprobs: Optional[LogProbConfig] = None,
) -> Union[CompletionResponse, CompletionResponseStreamChunk]:
raise NotImplementedError()
if logprobs:
assert logprobs.top_k == 1, f"Unexpected top_k={logprobs.top_k}"
def chat_completion(
request = CompletionRequest(
model=model,
content=content,
sampling_params=sampling_params,
stream=stream,
logprobs=logprobs,
)
self.check_model(request)
async def chat_completion(
self,
model: str,
messages: List[Message],
@ -93,16 +111,7 @@ class MetaReferenceInferenceImpl(Inference, ModelsProtocolPrivate):
stream=stream,
logprobs=logprobs,
)
model = resolve_model(request.model)
if model is None:
raise RuntimeError(
f"Unknown model: {request.model}, Run `llama model list`"
)
elif model.descriptor() != self.model.descriptor():
raise RuntimeError(
f"Model mismatch: {request.model} != {self.model.descriptor()}"
)
self.check_model(request)
if self.config.create_distributed_process_group:
if SEMAPHORE.locked():
@ -111,26 +120,17 @@ class MetaReferenceInferenceImpl(Inference, ModelsProtocolPrivate):
if request.stream:
return self._stream_chat_completion(request)
else:
return self._nonstream_chat_completion(request)
return await self._nonstream_chat_completion(request)
async def _nonstream_chat_completion(
self, request: ChatCompletionRequest
) -> ChatCompletionResponse:
def impl():
messages = chat_completion_request_to_messages(request)
tokens = []
logprobs = []
stop_reason = None
for token_result in self.generator.chat_completion(
messages=messages,
temperature=request.sampling_params.temperature,
top_p=request.sampling_params.top_p,
max_gen_len=request.sampling_params.max_tokens,
logprobs=request.logprobs,
tool_prompt_format=request.tool_prompt_format,
):
for token_result in self.generator.chat_completion(request):
tokens.append(token_result.token)
if token_result.text == "<|eot_id|>":
@ -170,8 +170,6 @@ class MetaReferenceInferenceImpl(Inference, ModelsProtocolPrivate):
self, request: ChatCompletionRequest
) -> AsyncGenerator:
def impl():
messages = chat_completion_request_to_messages(request)
yield ChatCompletionResponseStreamChunk(
event=ChatCompletionResponseEvent(
event_type=ChatCompletionResponseEventType.start,
@ -184,14 +182,7 @@ class MetaReferenceInferenceImpl(Inference, ModelsProtocolPrivate):
stop_reason = None
ipython = False
for token_result in self.generator.chat_completion(
messages=messages,
temperature=request.sampling_params.temperature,
top_p=request.sampling_params.top_p,
max_gen_len=request.sampling_params.max_tokens,
logprobs=request.logprobs,
tool_prompt_format=request.tool_prompt_format,
):
for token_result in self.generator.chat_completion(request):
tokens.append(token_result.token)
if not ipython and token_result.text.startswith("<|python_tag|>"):

View file

@ -7,16 +7,17 @@
import os
from copy import deepcopy
from functools import partial
from typing import Generator, List, Optional
from typing import Any, Generator
from llama_models.llama3.api.chat_format import ChatFormat
from llama_models.llama3.api.datatypes import Message, ToolPromptFormat
from llama_models.llama3.api.tokenizer import Tokenizer
from llama_models.sku_list import resolve_model
from llama_stack.apis.inference import ChatCompletionRequest, CompletionRequest
from .config import MetaReferenceInferenceConfig
from .generation import Llama, model_checkpoint_dir
from .parallel_utils import InferenceArgs, ModelParallelProcessGroup
from .parallel_utils import ModelParallelProcessGroup
class ModelRunner:
@ -24,15 +25,13 @@ class ModelRunner:
self.llama = llama
# the `task` object is the same that is sent to `ModelParallelProcessGroup.run_inference()`
def __call__(self, task: InferenceArgs):
return self.llama.chat_completion(
task.messages,
task.temperature,
task.top_p,
task.max_gen_len,
task.logprobs,
task.tool_prompt_format,
)
def __call__(self, req: Any):
if isinstance(req, ChatCompletionRequest):
return self.llama.chat_completion(req)
elif isinstance(req, CompletionRequest):
return self.llama.completion(req)
else:
raise ValueError(f"Unexpected task type {type(req)}")
def init_model_cb(config: MetaReferenceInferenceConfig):
@ -77,23 +76,18 @@ class LlamaModelParallelGenerator:
def __exit__(self, exc_type, exc_value, exc_traceback):
self.group.stop()
def chat_completion(
def completion(
self,
messages: List[Message],
temperature: float = 0.6,
top_p: float = 0.9,
max_gen_len: Optional[int] = None,
logprobs: bool = False,
tool_prompt_format: ToolPromptFormat = ToolPromptFormat.json,
request: CompletionRequest,
) -> Generator:
req_obj = InferenceArgs(
messages=deepcopy(messages),
temperature=temperature,
top_p=top_p,
max_gen_len=max_gen_len,
logprobs=logprobs or False,
tool_prompt_format=tool_prompt_format,
)
req_obj = deepcopy(request)
gen = self.group.run_inference(req_obj)
yield from gen
def chat_completion(
self,
request: ChatCompletionRequest,
) -> Generator:
req_obj = deepcopy(request)
gen = self.group.run_inference(req_obj)
yield from gen

View file

@ -4,6 +4,12 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
# Copyright (c) Meta Platforms, IAny, nc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import json
import multiprocessing
import os
@ -11,10 +17,9 @@ import tempfile
import time
import uuid
from enum import Enum
from typing import Callable, Generator, List, Literal, Optional, Union
from typing import Callable, Generator, Literal, Optional, Union
import torch
import zmq
from fairscale.nn.model_parallel.initialize import (
@ -23,25 +28,16 @@ from fairscale.nn.model_parallel.initialize import (
get_model_parallel_src_rank,
)
from llama_models.llama3.api.datatypes import Message, ToolPromptFormat
from pydantic import BaseModel, Field
from torch.distributed.launcher.api import elastic_launch, LaunchConfig
from typing_extensions import Annotated
from llama_stack.apis.inference import ChatCompletionRequest, CompletionRequest
from .generation import TokenResult
class InferenceArgs(BaseModel):
messages: List[Message]
temperature: float
top_p: float
max_gen_len: int
logprobs: bool
tool_prompt_format: ToolPromptFormat
class ProcessingMessageName(str, Enum):
ready_request = "ready_request"
ready_response = "ready_response"
@ -80,7 +76,7 @@ class TaskRequest(BaseModel):
type: Literal[ProcessingMessageName.task_request] = (
ProcessingMessageName.task_request
)
task: InferenceArgs
task: Union[CompletionRequest, ChatCompletionRequest]
class TaskResponse(BaseModel):
@ -349,11 +345,13 @@ class ModelParallelProcessGroup:
self.process.join()
self.started = False
def run_inference(self, inference_args: InferenceArgs) -> Generator:
def run_inference(
self, req: Union[CompletionRequest, ChatCompletionRequest]
) -> Generator:
assert not self.running, "inference already running"
self.running = True
self.request_socket.send(encode_msg(TaskRequest(task=inference_args)))
self.request_socket.send(encode_msg(TaskRequest(task=req)))
try:
while True:
obj_json = self.request_socket.recv()