Add toolchain from agentic system here

This commit is contained in:
Ashwin Bharambe 2024-07-19 12:30:35 -07:00
parent f6b2b2fb39
commit 95781ec85d
71 changed files with 11899 additions and 0 deletions

View file

View file

@ -0,0 +1,2 @@
from .datatypes import * # noqa: F401 F403
from .endpoints import * # noqa: F401 F403

View file

@ -0,0 +1,146 @@
from dataclasses import dataclass
from enum import Enum
from typing import Literal, Optional, Union
from hydra.core.config_store import ConfigStore
from pydantic import BaseModel, Field
from typing_extensions import Annotated
@dataclass
class GeneratorArgs:
ckpt_dir: str
tokenizer_path: str
model_parallel_size: Optional[int] = None
max_seq_len: int = 2048
max_batch_size: int = 4
class ImplType(Enum):
inline = "inline"
remote = "remote"
class CheckpointType(Enum):
pytorch = "pytorch"
huggingface = "huggingface"
class PytorchCheckpoint(BaseModel):
checkpoint_type: Literal[CheckpointType.pytorch.value] = (
CheckpointType.pytorch.value
)
checkpoint_dir: str
tokenizer_path: str
model_parallel_size: int
class HuggingFaceCheckpoint(BaseModel):
checkpoint_type: Literal[CheckpointType.huggingface.value] = (
CheckpointType.huggingface.value
)
repo_id: str # or model_name ?
model_parallel_size: int
class ModelCheckpointConfig(BaseModel):
checkpoint: Annotated[
Union[PytorchCheckpoint, HuggingFaceCheckpoint],
Field(discriminator="checkpoint_type"),
]
# NOTE: this same config will be used when instantiating an inference server naturally
class InlineImplConfig(BaseModel):
impl_type: Literal[ImplType.inline.value] = ImplType.inline.value
checkpoint_config: ModelCheckpointConfig
max_seq_len: int
max_batch_size: int = 1
class RemoteImplConfig(BaseModel):
impl_type: Literal[ImplType.remote.value] = ImplType.remote.value
url: str = Field(..., description="The URL of the remote module")
class ModelInferenceConfig(BaseModel):
impl_config: Annotated[
Union[InlineImplConfig, RemoteImplConfig],
Field(discriminator="impl_type"),
]
# Hydra does not like unions of containers and
# Pydantic does not like Literals
# Adding a simple dataclass with custom coversion
# to config classes
@dataclass
class InlineImplHydraConfig:
checkpoint_type: str # "pytorch" / "HF"
# pytorch checkpoint required args
checkpoint_dir: str
tokenizer_path: str
model_parallel_size: int
max_seq_len: int
max_batch_size: int = 1
# TODO: huggingface checkpoint required args
def convert_to_inline_impl_config(self):
if self.checkpoint_type == "pytorch":
return InlineImplConfig(
checkpoint_config=ModelCheckpointConfig(
checkpoint=PytorchCheckpoint(
checkpoint_type=CheckpointType.pytorch.value,
checkpoint_dir=self.checkpoint_dir,
tokenizer_path=self.tokenizer_path,
model_parallel_size=self.model_parallel_size,
)
),
max_seq_len=self.max_seq_len,
max_batch_size=self.max_batch_size,
)
else:
raise NotImplementedError("HF Checkpoint not supported yet")
@dataclass
class RemoteImplHydraConfig:
url: str
def convert_to_remote_impl_config(self):
return RemoteImplConfig(
url=self.url,
)
@dataclass
class ModelInferenceHydraConfig:
impl_type: str
inline_config: Optional[InlineImplHydraConfig] = None
remote_config: Optional[RemoteImplHydraConfig] = None
def __post_init__(self):
assert self.impl_type in ["inline", "remote"]
if self.impl_type == "inline":
assert self.inline_config is not None
if self.impl_type == "remote":
assert self.remote_config is not None
def convert_to_model_inferene_config(self):
if self.impl_type == "inline":
inline_config = InlineImplHydraConfig(**self.inline_config)
return ModelInferenceConfig(
impl_config=inline_config.convert_to_inline_impl_config()
)
elif self.impl_type == "remote":
remote_config = RemoteImplHydraConfig(**self.remote_config)
return ModelInferenceConfig(
impl_config=remote_config.convert_to_remote_impl_config()
)
cs = ConfigStore.instance()
cs.store(name="model_inference_config", node=ModelInferenceHydraConfig)

View file

@ -0,0 +1,68 @@
from enum import Enum
from typing import List, Literal, Optional, Union
from pydantic import BaseModel, Field
from strong_typing.schema import json_schema_type
from typing_extensions import Annotated
from models.llama3.datatypes import * # noqa: F403
class LogProbConfig(BaseModel):
top_k: Optional[int] = 0
@json_schema_type
class QuantizationType(Enum):
bf16 = "bf16"
fp8 = "fp8"
@json_schema_type
class Fp8QuantizationConfig(BaseModel):
quantization_type: Literal[QuantizationType.fp8.value] = QuantizationType.fp8.value
@json_schema_type
class Bf16QuantizationConfig(BaseModel):
quantization_type: Literal[QuantizationType.bf16.value] = (
QuantizationType.bf16.value
)
QuantizationConfig = Annotated[
Union[Bf16QuantizationConfig, Fp8QuantizationConfig],
Field(discriminator="quantization_type"),
]
@json_schema_type
class ChatCompletionResponseEventType(Enum):
start = "start"
complete = "complete"
progress = "progress"
@json_schema_type
class ToolCallParseStatus(Enum):
started = "started"
in_progress = "in_progress"
failure = "failure"
success = "success"
@json_schema_type
class ToolCallDelta(BaseModel):
content: Union[str, ToolCall]
parse_status: ToolCallParseStatus
@json_schema_type
class ChatCompletionResponseEvent(BaseModel):
"""Chat completion response event."""
event_type: ChatCompletionResponseEventType
delta: Union[str, ToolCallDelta]
logprobs: Optional[List[TokenLogProbs]] = None
stop_reason: Optional[StopReason] = None

View file

@ -0,0 +1,117 @@
from .datatypes import * # noqa: F403
from typing import Optional, Protocol
# this dependency is annoying and we need a forked up version anyway
from pyopenapi import webmethod
@json_schema_type
class CompletionRequest(BaseModel):
model: PretrainedModel
content: InterleavedTextAttachment
sampling_params: Optional[SamplingParams] = SamplingParams()
stream: Optional[bool] = False
logprobs: Optional[LogProbConfig] = None
quantization_config: Optional[QuantizationConfig] = None
@json_schema_type
class CompletionResponse(BaseModel):
completion_message: CompletionMessage
logprobs: Optional[List[TokenLogProbs]] = None
@json_schema_type
class CompletionResponseStreamChunk(BaseModel):
"""streamed completion response."""
delta: str
stop_reason: Optional[StopReason] = None
logprobs: Optional[List[TokenLogProbs]] = None
@json_schema_type
class BatchCompletionRequest(BaseModel):
model: PretrainedModel
content_batch: List[InterleavedTextAttachment]
sampling_params: Optional[SamplingParams] = SamplingParams()
logprobs: Optional[LogProbConfig] = None
quantization_config: Optional[QuantizationConfig] = None
@json_schema_type
class BatchCompletionResponse(BaseModel):
completion_message_batch: List[CompletionMessage]
@json_schema_type
class ChatCompletionRequest(BaseModel):
model: InstructModel
messages: List[Message]
sampling_params: Optional[SamplingParams] = SamplingParams()
# zero-shot tool definitions as input to the model
available_tools: Optional[List[ToolDefinition]] = Field(default_factory=list)
stream: Optional[bool] = False
logprobs: Optional[LogProbConfig] = None
quantization_config: Optional[QuantizationConfig] = None
@json_schema_type
class ChatCompletionResponseStreamChunk(BaseModel):
"""SSE-stream of these events."""
event: ChatCompletionResponseEvent
@json_schema_type
class ChatCompletionResponse(BaseModel):
completion_message: CompletionMessage
logprobs: Optional[List[TokenLogProbs]] = None
@json_schema_type
class BatchChatCompletionRequest(BaseModel):
model: InstructModel
messages_batch: List[List[Message]]
sampling_params: Optional[SamplingParams] = SamplingParams()
# zero-shot tool definitions as input to the model
available_tools: Optional[List[ToolDefinition]] = Field(default_factory=list)
logprobs: Optional[LogProbConfig] = None
quantization_config: Optional[QuantizationConfig] = None
@json_schema_type
class BatchChatCompletionResponse(BaseModel):
completion_message_batch: List[CompletionMessage]
class ModelInference(Protocol):
@webmethod(route="/inference/completion")
async def completion(
self,
request: CompletionRequest,
) -> Union[CompletionResponse, CompletionResponseStreamChunk]: ...
@webmethod(route="/inference/chat_completion")
async def chat_completion(
self,
request: ChatCompletionRequest,
) -> Union[ChatCompletionResponse, ChatCompletionResponseStreamChunk]: ...
@webmethod(route="/inference/batch_completion")
async def batch_completion(
self,
request: BatchCompletionRequest,
) -> List[CompletionResponse]: ...
@webmethod(route="/inference/batch_chat_completion")
async def batch_chat_completion(
self,
request: BatchChatCompletionRequest,
) -> List[ChatCompletionResponse]: ...

View file

@ -0,0 +1,12 @@
from .api.config import ImplType, ModelInferenceConfig
async def get_inference_api_instance(config: ModelInferenceConfig):
if config.impl_config.impl_type == ImplType.inline.value:
from .inference import ModelInferenceImpl
return ModelInferenceImpl(config.impl_config)
from .client import ModelInferenceClient
return ModelInferenceClient(config.impl_config.url)

View file

@ -0,0 +1,73 @@
import asyncio
import json
from typing import AsyncGenerator
import fire
import httpx
from .api.endpoints import (
ChatCompletionRequest,
ChatCompletionResponseStreamChunk,
CompletionRequest,
InstructModel,
ModelInference,
)
class ModelInferenceClient(ModelInference):
def __init__(self, base_url: str):
self.base_url = base_url
async def initialize(self) -> None:
pass
async def shutdown(self) -> None:
pass
async def completion(self, request: CompletionRequest) -> AsyncGenerator:
raise NotImplementedError()
async def chat_completion(self, request: ChatCompletionRequest) -> AsyncGenerator:
async with httpx.AsyncClient() as client:
async with client.stream(
"POST",
f"{self.base_url}/inference/chat_completion",
data=request.json(),
headers={"Content-Type": "application/json"},
timeout=20,
) as response:
async for line in response.aiter_lines():
if line.startswith("data:"):
data = line[len("data: ") :]
try:
yield ChatCompletionResponseStreamChunk(**json.loads(data))
except Exception as e:
print(data)
print(f"Error with parsing or validation: {e}")
async def run_main(host: str, port: int):
client = ModelInferenceClient(f"http://{host}:{port}")
message = UserMessage(content="hello world, help me out here")
req = ChatCompletionRequest(
model=InstructModel.llama3_70b_chat,
messages=[message],
stream=True,
)
async for event in client.chat_completion(
ChatCompletionRequest(
model=InstructModel.llama3_70b_chat,
messages=[message],
stream=True,
)
):
print(event)
def main(host: str, port: int):
asyncio.run(run_main(host, port))
if __name__ == "__main__":
fire.Fire(main)

View file

@ -0,0 +1,298 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# This software may be used and distributed in accordance with the terms of the Llama 3 Community License Agreement.
import json
import os
import sys
import time
from dataclasses import dataclass
from pathlib import Path
from typing import Generator, List, Optional, TypedDict
import torch
import torch.nn.functional as F
from fairscale.nn.model_parallel.initialize import (
get_model_parallel_rank,
initialize_model_parallel,
model_parallel_is_initialized,
)
from models.llama3.args import ModelArgs
from models.llama3.chat_format import ChatFormat, ModelInput
from models.llama3.datatypes import Message
from models.llama3.model import Transformer
from models.llama3.tokenizer import Tokenizer
from termcolor import cprint
@dataclass
class TokenResult:
token: int
text: str
logprobs: Optional[List[float]] = None
class CompletionPrediction(TypedDict, total=False):
generation: str
tokens: List[str] # not required
logprobs: List[float] # not required
class Llama:
@staticmethod
def build(
ckpt_dir: str,
tokenizer_path: str,
max_seq_len: int,
max_batch_size: int,
model_parallel_size: Optional[int] = None,
seed: int = 1,
) -> "Llama":
"""
Build a Llama instance by initializing and loading a model checkpoint.
Args:
ckpt_dir (str): Path to the directory containing checkpoint files.
tokenizer_path (str): Path to the tokenizer file.
max_seq_len (int): Maximum sequence length for input text.
max_batch_size (int): Maximum batch size for inference.
model_parallel_size (Optional[int], optional): Number of model parallel processes.
If not provided, it's determined from the environment. Defaults to None.
Returns:
Llama: An instance of the Llama class with the loaded model and tokenizer.
Raises:
AssertionError: If there are no checkpoint files in the specified directory,
or if the model parallel size does not match the number of checkpoint files.
Note:
This method initializes the distributed process group, sets the device to CUDA,
and loads the pre-trained model and tokenizer.
"""
if not torch.distributed.is_initialized():
torch.distributed.init_process_group("nccl")
if not model_parallel_is_initialized():
if model_parallel_size is None:
model_parallel_size = int(os.environ.get("WORLD_SIZE", 1))
initialize_model_parallel(model_parallel_size)
local_rank = int(os.environ.get("LOCAL_RANK", 0))
torch.cuda.set_device(local_rank)
# seed must be the same in all processes
torch.manual_seed(seed)
if local_rank > 0:
sys.stdout = open(os.devnull, "w")
start_time = time.time()
checkpoints = sorted(Path(ckpt_dir).glob("*.pth"))
assert len(checkpoints) > 0, f"no checkpoint files found in {ckpt_dir}"
assert model_parallel_size == len(
checkpoints
), f"Loading a checkpoint for MP={len(checkpoints)} but world size is {model_parallel_size}"
ckpt_path = checkpoints[get_model_parallel_rank()]
checkpoint = torch.load(ckpt_path, map_location="cpu")
with open(Path(ckpt_dir) / "params.json", "r") as f:
params = json.loads(f.read())
# TODO(ashwin): this block is so we can load internal checkpoints without additional
# fuss. the final code should _not_ have this blurb
if "model" in params:
params = params["model"]
model_args: ModelArgs = ModelArgs(
max_seq_len=max_seq_len,
max_batch_size=max_batch_size,
**params,
)
tokenizer = Tokenizer(model_path=tokenizer_path)
assert (
model_args.vocab_size == tokenizer.n_words
), f"model_args vocab = {model_args.vocab_size} but tokenizer vocab = {tokenizer.n_words}"
if torch.cuda.is_bf16_supported():
torch.set_default_tensor_type(torch.cuda.BFloat16Tensor)
else:
torch.set_default_tensor_type(torch.cuda.HalfTensor)
model = Transformer(model_args)
model.load_state_dict(checkpoint, strict=False)
print(f"Loaded in {time.time() - start_time:.2f} seconds")
return Llama(model, tokenizer, model_args)
def __init__(self, model: Transformer, tokenizer: Tokenizer, args: ModelArgs):
self.args = args
self.model = model
self.tokenizer = tokenizer
self.formatter = ChatFormat(tokenizer)
@torch.inference_mode()
def generate(
self,
model_input: ModelInput,
max_gen_len: int,
temperature: float = 0.6,
top_p: float = 0.9,
logprobs: bool = False,
echo: bool = False,
include_stop_token: bool = False,
) -> Generator:
params = self.model.params
# cprint("Input to model -> " + self.tokenizer.decode(model_input.tokens), "red")
prompt_tokens = [model_input.tokens]
bsz = 1
assert bsz <= params.max_batch_size, (bsz, params.max_batch_size)
min_prompt_len = min(len(t) for t in prompt_tokens)
max_prompt_len = max(len(t) for t in prompt_tokens)
if max_prompt_len >= params.max_seq_len:
cprint(
f"Out of token budget {max_prompt_len} vs {params.max_seq_len}", "red"
)
return
total_len = min(max_gen_len + max_prompt_len, params.max_seq_len)
pad_id = self.tokenizer.pad_id
tokens = torch.full((bsz, total_len), pad_id, dtype=torch.long, device="cuda")
for k, t in enumerate(prompt_tokens):
tokens[k, : len(t)] = torch.tensor(t, dtype=torch.long, device="cuda")
if logprobs:
token_logprobs = torch.zeros_like(tokens, dtype=torch.float)
prev_pos = 0
eos_reached = torch.tensor([False] * bsz, device="cuda")
input_text_mask = tokens != pad_id
if min_prompt_len == total_len:
# TODO(ashwin): unify this branch with the one below and figure out multimodal crap
logits = self.model.forward(tokens, prev_pos)
token_logprobs = -F.cross_entropy(
input=logits.transpose(1, 2),
target=tokens,
reduction="none",
ignore_index=pad_id,
)
stop_tokens = torch.tensor(self.tokenizer.stop_tokens)
for cur_pos in range(min_prompt_len, total_len):
logits = self.model.forward(tokens[:, prev_pos:cur_pos], prev_pos)
if temperature > 0:
probs = torch.softmax(logits[:, -1] / temperature, dim=-1)
next_token = sample_top_p(probs, top_p)
else:
next_token = torch.argmax(logits[:, -1], dim=-1)
next_token = next_token.reshape(-1)
# only replace token if prompt has already been generated
next_token = torch.where(
input_text_mask[:, cur_pos], tokens[:, cur_pos], next_token
)
tokens[:, cur_pos] = next_token
target = tokens[:, prev_pos + 1 : cur_pos + 1]
if logprobs:
token_logprobs[:, prev_pos + 1 : cur_pos + 1] = -F.cross_entropy(
input=logits.transpose(1, 2),
target=tokens[:, prev_pos + 1 : cur_pos + 1],
reduction="none",
ignore_index=pad_id,
)
eos_reached |= (~input_text_mask[:, cur_pos]) & (
torch.isin(next_token, stop_tokens)
)
yield TokenResult(
token=next_token[0].item(),
text=self.tokenizer.decode(next_token.tolist()),
logprobs=(
token_logprobs[:, prev_pos + 1 : cur_pos + 1][0].tolist()
if logprobs
else None
),
)
prev_pos = cur_pos
if all(eos_reached):
break
def text_completion(
self,
prompt: str,
temperature: float = 0.6,
top_p: float = 0.9,
max_gen_len: Optional[int] = None,
logprobs: bool = False,
echo: bool = False,
) -> Generator:
if (
max_gen_len is None
or max_gen_len == 0
or max_gen_len >= self.model.params.max_seq_len
):
max_gen_len = self.model.params.max_seq_len - 1
prompt_tokens = self.tokenizer.encode(x, bos=True, eos=False)
yield from self.generate(
model_input=ModelInput(tokens=prompt_tokens),
max_gen_len=max_gen_len,
temperature=temperature,
top_p=top_p,
logprobs=logprobs,
echo=echo,
)
def chat_completion(
self,
messages: List[Message],
temperature: float = 0.6,
top_p: float = 0.9,
max_gen_len: Optional[int] = None,
logprobs: bool = False,
) -> Generator:
if (
max_gen_len is None
or max_gen_len == 0
or max_gen_len >= self.model.params.max_seq_len
):
max_gen_len = self.model.params.max_seq_len - 1
yield from self.generate(
model_input=self.formatter.encode_dialog_prompt(messages),
max_gen_len=max_gen_len,
temperature=temperature,
top_p=top_p,
logprobs=logprobs,
include_stop_token=True,
)
def sample_top_p(probs, p):
"""
Perform top-p (nucleus) sampling on a probability distribution.
Args:
probs (torch.Tensor): Probability distribution tensor.
p (float): Probability threshold for top-p sampling.
Returns:
torch.Tensor: Sampled token indices.
Note:
Top-p sampling selects the smallest set of tokens whose cumulative probability mass
exceeds the threshold p. The distribution is renormalized based on the selected tokens.
"""
probs_sort, probs_idx = torch.sort(probs, dim=-1, descending=True)
probs_sum = torch.cumsum(probs_sort, dim=-1)
mask = probs_sum - probs_sort > p
probs_sort[mask] = 0.0
probs_sort.div_(probs_sort.sum(dim=-1, keepdim=True))
next_token = torch.multinomial(probs_sort, num_samples=1)
next_token = torch.gather(probs_idx, -1, next_token)
return next_token

View file

@ -0,0 +1,173 @@
from typing import AsyncGenerator
from models.llama3.datatypes import StopReason
from .api.config import CheckpointType, GeneratorArgs, InlineImplConfig
from .api.datatypes import (
ChatCompletionResponseEvent,
ChatCompletionResponseEventType,
ToolCallDelta,
ToolCallParseStatus,
)
from .api.endpoints import (
ChatCompletionRequest,
ChatCompletionResponseStreamChunk,
CompletionRequest,
ModelInference,
)
from .model_parallel import LlamaModelParallelGenerator
def generator_args_from_config(config: InlineImplConfig) -> GeneratorArgs:
if (
config.checkpoint_config.checkpoint.checkpoint_type
== CheckpointType.pytorch.value
):
pt_checkpoint = config.checkpoint_config.checkpoint
return GeneratorArgs(
ckpt_dir=pt_checkpoint.checkpoint_dir,
tokenizer_path=pt_checkpoint.tokenizer_path,
model_parallel_size=pt_checkpoint.model_parallel_size,
max_seq_len=config.max_seq_len,
max_batch_size=config.max_batch_size,
)
else:
raise NotImplementedError("HF Checkpoint not supported yet")
class ModelInferenceImpl(ModelInference):
def __init__(self, config: InlineImplConfig) -> None:
self.config = config
async def initialize(self) -> None:
generator_args = generator_args_from_config(self.config)
self.generator = LlamaModelParallelGenerator(
args=generator_args,
)
self.generator.start()
async def shutdown(self) -> None:
self.generator.stop()
async def completion(self, request: CompletionRequest) -> AsyncGenerator:
raise NotImplementedError()
async def chat_completion(self, request: ChatCompletionRequest) -> AsyncGenerator:
yield ChatCompletionResponseStreamChunk(
event=ChatCompletionResponseEvent(
event_type=ChatCompletionResponseEventType.start,
delta="",
)
)
tokens = []
logprobs = []
stop_reason = None
buffer = ""
ipython = False
for token_result in self.generator.chat_completion(
messages=request.messages,
temperature=request.sampling_params.temperature,
top_p=request.sampling_params.top_p,
max_gen_len=request.sampling_params.max_tokens,
logprobs=request.logprobs,
):
buffer += token_result.text
tokens.append(token_result.token)
if not ipython and buffer.startswith("<|python_tag|>"):
ipython = True
yield ChatCompletionResponseStreamChunk(
event=ChatCompletionResponseEvent(
event_type=ChatCompletionResponseEventType.progress,
delta=ToolCallDelta(
content="",
parse_status=ToolCallParseStatus.started,
),
)
)
buffer = buffer[len("<|python_tag|>") :]
continue
if not request.stream:
if request.logprobs:
logprobs.append(token_result.logprob)
continue
if token_result.text == "<|eot_id|>":
stop_reason = StopReason.end_of_turn
text = ""
elif token_result.text == "<|eom_id|>":
stop_reason = StopReason.end_of_message
text = ""
else:
text = token_result.text
if ipython:
delta = ToolCallDelta(
content=text,
parse_status=ToolCallParseStatus.in_progress,
)
else:
delta = text
yield ChatCompletionResponseStreamChunk(
event=ChatCompletionResponseEvent(
event_type=ChatCompletionResponseEventType.progress,
delta=delta,
stop_reason=stop_reason,
)
)
if stop_reason is None:
stop_reason = StopReason.out_of_tokens
# TODO(ashwin): parse tool calls separately here and report errors?
# if someone breaks the iteration before coming here we are toast
message = self.generator.formatter.decode_assistant_message(tokens, stop_reason)
if request.stream:
parsed_tool_calls = len(message.tool_calls) > 0
if ipython and not parsed_tool_calls:
yield ChatCompletionResponseStreamChunk(
event=ChatCompletionResponseEvent(
event_type=ChatCompletionResponseEventType.progress,
delta=ToolCallDelta(
content="",
parse_status=ToolCallParseStatus.failure,
),
stop_reason=stop_reason,
)
)
for tool_call in message.tool_calls:
yield ChatCompletionResponseStreamChunk(
event=ChatCompletionResponseEvent(
event_type=ChatCompletionResponseEventType.progress,
delta=ToolCallDelta(
content=tool_call,
parse_status=ToolCallParseStatus.success,
),
stop_reason=stop_reason,
)
)
yield ChatCompletionResponseStreamChunk(
event=ChatCompletionResponseEvent(
event_type=ChatCompletionResponseEventType.complete,
delta="",
stop_reason=stop_reason,
)
)
# TODO(ashwin): what else do we need to send out here when everything finishes?
else:
yield ChatCompletionResponse(
content=message.content,
tool_calls=message.tool_calls,
stop_reason=stop_reason,
logprobs=logprobs if request.logprobs else None,
)

View file

@ -0,0 +1,100 @@
from dataclasses import dataclass
from functools import partial
from typing import Generator, List, Optional
from models.llama3.chat_format import ChatFormat
from models.llama3.datatypes import Message
from models.llama3.tokenizer import Tokenizer
from .api.config import GeneratorArgs
from .generation import Llama
from .parallel_utils import ModelParallelProcessGroup
@dataclass
class InferenceArgs:
messages: List[Message]
temperature: float
top_p: float
max_gen_len: int
logprobs: bool
class ModelRunner:
def __init__(self, llama):
self.llama = llama
# the `task` object is the same that is sent to `ModelParallelProcessGroup.run_inference()`
def __call__(self, task: InferenceArgs):
return self.llama.chat_completion(
task.messages,
task.temperature,
task.top_p,
task.max_gen_len,
task.logprobs,
)
def init_model_cb(args: GeneratorArgs):
llama = Llama.build(
args.ckpt_dir,
args.tokenizer_path,
args.max_seq_len,
args.max_batch_size,
)
return ModelRunner(llama)
class LlamaModelParallelGenerator:
"""
This abstraction exists so
- we can run model parallel code without needing to run the CLIs via torchrun
- this also enables use model parallel code within a notebook context.
A Context Manager is used to ensure that the model parallel process is started and stopped
correctly. This does make the ergonomics a little awkward, because it isn't immediately
clear at the callsite why we need to use a context manager.
"""
def __init__(self, args: GeneratorArgs):
self.args = args
# this is a hack because Agent's loop uses this to tokenize and check if input is too long
# while the tool-use loop is going
self.formatter = ChatFormat(Tokenizer(self.args.tokenizer_path))
def start(self):
self.__enter__()
def stop(self):
self.__exit__(None, None, None)
def __enter__(self):
self.group = ModelParallelProcessGroup(
self.args.model_parallel_size,
init_model_cb=partial(init_model_cb, self.args),
)
self.group.start()
return self
def __exit__(self, exc_type, exc_value, exc_traceback):
self.group.stop()
def chat_completion(
self,
messages: List[Message],
temperature: float = 0.6,
top_p: float = 0.9,
max_gen_len: Optional[int] = None,
logprobs: bool = False,
) -> Generator:
req_obj = InferenceArgs(
messages=messages,
temperature=temperature,
top_p=top_p,
max_gen_len=max_gen_len,
logprobs=logprobs,
)
gen = self.group.run_inference(req_obj)
yield from gen

View file

@ -0,0 +1,259 @@
import multiprocessing
import os
import pickle
import tempfile
import time
import uuid
from typing import Callable, Generator
import torch
import zmq
from fairscale.nn.model_parallel.initialize import (
get_model_parallel_group,
get_model_parallel_rank,
get_model_parallel_src_rank,
)
from torch.distributed.launcher.api import elastic_launch, LaunchConfig
_END_SENTINEL = "__end_sentinel__"
_CANCEL_SENTINEL = "__cancel_sentinel__"
def mp_rank_0() -> bool:
return get_model_parallel_rank() == 0
def retrieve_requests(reply_socket_url: str):
if mp_rank_0():
context = zmq.Context()
reply_socket = context.socket(zmq.ROUTER)
reply_socket.connect(reply_socket_url)
while True:
client_id, obj = maybe_get_work(reply_socket)
if obj is None:
time.sleep(0.01)
continue
reply_socket.send_multipart([client_id, pickle.dumps("YES READY")])
break
def send_obj(obj):
reply_socket.send_multipart([client_id, pickle.dumps(obj)])
while True:
tasks = [None]
if mp_rank_0():
client_id, task = maybe_get_work(reply_socket)
# there is still an unknown unclean GeneratorExit happening resulting in a
# cancel sentinel getting queued _after_ we have finished sending everything :/
# kind of a hack this is :/
if task != _CANCEL_SENTINEL:
tasks = [task]
torch.distributed.broadcast_object_list(
tasks,
src=get_model_parallel_src_rank(),
group=get_model_parallel_group(),
)
task = tasks[0]
if task is None:
time.sleep(0.1)
else:
try:
out = yield task
if out is None:
break
for obj in out:
updates = [None]
if mp_rank_0():
_, update = maybe_get_work(reply_socket)
if update == _CANCEL_SENTINEL:
updates = [update]
else:
# only send the update if it's not cancelled otherwise the object sits in the socket
# and gets pulled in the next request lol
send_obj(obj)
torch.distributed.broadcast_object_list(
updates,
src=get_model_parallel_src_rank(),
group=get_model_parallel_group(),
)
if updates[0] == _CANCEL_SENTINEL:
print("quitting generation loop because request was cancelled")
break
if mp_rank_0():
send_obj(_END_SENTINEL)
except Exception as e:
print(f"[debug] got exception {e}")
import traceback
traceback.print_exc()
if mp_rank_0():
send_obj(e)
if mp_rank_0():
send_obj("DONE")
def maybe_get_work(sock: zmq.Socket):
message = None
client_id = None
try:
client_id, obj = sock.recv_multipart(zmq.NOBLOCK)
message = pickle.loads(obj)
except zmq.ZMQError as e:
if e.errno != zmq.EAGAIN:
raise e
return client_id, message
def worker_process_entrypoint(
reply_socket_url: str,
init_model_cb: Callable,
) -> None:
model = init_model_cb()
torch.distributed.barrier()
time.sleep(1)
# run the requests co-routine which retrieves requests from the socket
# and sends responses (we provide) back to the caller
req_gen = retrieve_requests(reply_socket_url)
result = None
while True:
try:
task = req_gen.send(result)
if isinstance(task, str) and task == _END_SENTINEL:
break
result = model(task)
except StopIteration:
break
print("[debug] worker process done")
def launch_dist_group(
reply_socket_url: str,
model_parallel_size: int,
init_model_cb: Callable,
**kwargs,
) -> None:
id = uuid.uuid4().hex
dist_url = f"file:///tmp/llama3_{id}_{time.time()}"
with tempfile.TemporaryDirectory() as tmpdir:
# TODO: track workers and if they terminate, tell parent process about it so cleanup can happen
launch_config = LaunchConfig(
max_nodes=1,
min_nodes=1,
nproc_per_node=model_parallel_size,
start_method="fork",
rdzv_backend="c10d",
rdzv_endpoint=os.path.join(tmpdir, "rdzv"),
rdzv_configs={"store_type": "file", "timeout": 90},
max_restarts=0,
monitor_interval=1,
run_id=str(uuid.uuid4()),
)
elastic_launch(launch_config, entrypoint=worker_process_entrypoint)(
reply_socket_url,
init_model_cb,
)
def start_model_parallel_process(
model_parallel_size: int,
init_model_cb: Callable,
**kwargs,
):
context = zmq.Context()
request_socket = context.socket(zmq.DEALER)
# Binding the request socket to a random port
request_socket.bind("tcp://127.0.0.1:0")
main_process_url = request_socket.getsockopt_string(zmq.LAST_ENDPOINT)
ctx = multiprocessing.get_context("fork")
process = ctx.Process(
target=launch_dist_group,
args=(
main_process_url,
model_parallel_size,
init_model_cb,
),
kwargs=kwargs,
)
process.start()
# wait until the model is loaded; rank 0 will send a message to indicate it's ready
request_socket.send_pyobj("READY?")
response = request_socket.recv_pyobj()
print(f"Finished model load {response}")
return request_socket, process
class ModelParallelProcessGroup:
def __init__(
self,
model_parallel_size: int,
init_model_cb: Callable,
**kwargs,
):
self.model_parallel_size = model_parallel_size
self.init_model_cb = init_model_cb
self.started = False
self.running = False
def start(self):
assert not self.started, "process group already started"
self.request_socket, self.process = start_model_parallel_process(
self.model_parallel_size,
self.init_model_cb,
)
self.started = True
def stop(self):
assert self.started, "process group not started"
if self.process.is_alive():
self.request_socket.send_pyobj(_END_SENTINEL, zmq.NOBLOCK)
self.process.join()
self.started = False
def run_inference(self, request) -> Generator:
assert not self.running, "inference already running"
self.running = True
self.request_socket.send_pyobj(request)
try:
while True:
obj = self.request_socket.recv_pyobj()
if obj == _END_SENTINEL:
break
if isinstance(obj, Exception):
print(f"[debug] got exception {obj}")
raise obj
yield obj
except GeneratorExit as e:
self.request_socket.send_pyobj(_CANCEL_SENTINEL)
while True:
obj = self.request_socket.recv_pyobj()
if obj == _END_SENTINEL:
break
finally:
self.running = False

View file

@ -0,0 +1,45 @@
#!/bin/bash
if [[ $# -ne 1 ]]; then
echo "Error: Please provide the name of CONDA environment you wish to create"
exit 1
fi
ENV_NAME=$1
set -eu
eval "$(conda shell.bash hook)"
echo "Will build env (or overwrite) named '$ENV_NAME'"
set -x
run_build() {
# Set CUDA 9.0a targets
export CUDA_ARCH_LIST="8.0;9.0a"
export NVCC_GENCODE="-gencode=arch=compute_80,code=sm_80 -gencode=arch=compute_90a,code=sm_90a"
export TORCH_CUDA_ARCH_LIST=$CUDA_ARCH_LIST
# Set up the conda environment
yes | conda remove --name $ENV_NAME --all
yes | conda create -n $ENV_NAME python=3.10
conda activate $ENV_NAME
yes | conda install --channel "nvidia/label/cuda-12.1.0" cuda
yes | conda install cuda-nvtx cuda-nvtx-dev conda-forge::nccl
# ############# Hack to get CUDA path #############
ln -s $CONDA_PREFIX/targets/x86_64-linux/include/* $CONDA_PREFIX/include/ || true
export CUDA_HOME=$CONDA_PREFIX
export CUDA_BIN_PATH=$CUDA_HOME
# #################################################
# PT nightly
pip install --pre torch --index-url https://download.pytorch.org/whl/nightly/cu121
pip install --pre torchvision --index-url https://download.pytorch.org/whl/nightly/cu121
# install dependencies for `llama-agentic-system`
pip install -r fp8_requirements.txt
}
run_build

View file

@ -0,0 +1,165 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# This software may be used and distributed in accordance with the terms of the Llama 3 Community License Agreement.
import collections
from enum import Enum, unique
from typing import Optional, Type
try:
import fbgemm_gpu.experimental.gen_ai # noqa: F401
print("Using efficient FP8 operators in FBGEMM.")
except (ImportError, ModuleNotFoundError):
print("No efficient FP8 operators. Please install FBGEMM in fp8_requirements.txt.")
import torch
from torch import nn, Tensor
@unique
class FfnQuantizeMode(Enum):
FP8_ROWWISE = "fp8_rowwise"
NONE = "none"
def __str__(self) -> str:
return self.value
class Fp8ScaledWeights:
# TODO: Ugly trick so torch allows us to replace parameters
# with our custom Fp8Weights instance. Do this properly.
@property
def __class__(self) -> Type[nn.parameter.Parameter]:
return nn.Parameter
@property
def grad_fn(self) -> None:
return None
# pyre-fixme[4]: Attribute annotation cannot be `Any`.
# pyre-fixme[2]: Parameter annotation cannot be `Any`.
class Fp8RowwiseWeights(
Fp8ScaledWeights,
collections.namedtuple(
"Fp8RowwiseWeights",
["weight", "scale", "shape", "activation_scale_ub"],
),
):
pass
def ffn_swiglu(
x: Tensor,
w1: Fp8RowwiseWeights,
w3: Fp8RowwiseWeights,
w2: Fp8RowwiseWeights,
num_tokens: Optional[Tensor] = None,
is_memory_bounded: bool = False,
) -> Tensor:
if (
isinstance(w1, Fp8ScaledWeights)
and isinstance(w3, Fp8ScaledWeights)
and isinstance(w2, Fp8ScaledWeights)
):
return ffn_swiglu_fp8_dynamic(
x, w1, w3, w2, w1.activation_scale_ub, num_tokens, is_memory_bounded
)
(B, T, D) = x.shape
(HD_L, D_) = w1.shape
assert D_ == D
assert isinstance(w1, Tensor)
assert isinstance(w3, Tensor)
x1 = x.view(B * T, D) @ w1.T
x2 = x.view(B * T, D) @ w3.T
z = torch.nn.functional.silu(x1) * x2
del x1, x2
assert isinstance(w2, Tensor)
return (z @ w2.T).view(B, T, D)
@torch.inference_mode()
def quantize_fp8(
w: Tensor,
fp8_activation_scale_ub: float,
mode: Optional[FfnQuantizeMode] = None,
output_device: Optional[torch.device] = None,
) -> Fp8RowwiseWeights:
"""Quantize [n, k] weight tensor.
Args:
w (Tensor): [n, k] input high precision tensor to quantize.
fp8_activation_scale_ub (float): Upper bound for activation max.
mode (FfnQuantizeMode): Quantization mode.
"""
activation_scale_ub = torch.tensor(
[fp8_activation_scale_ub],
dtype=torch.float,
device="cuda",
)
if mode is not None and mode == FfnQuantizeMode.FP8_ROWWISE: # rowwise
wq, w_scale = torch.ops.fbgemm.quantize_fp8_per_row(w)
del w
return Fp8RowwiseWeights(
weight=wq,
scale=w_scale,
shape=wq.shape,
activation_scale_ub=activation_scale_ub,
)
def fc_fp8_dynamic(
x: Tensor,
w: Fp8RowwiseWeights,
activation_scale_ub: Optional[Tensor] = None,
num_tokens: Optional[Tensor] = None,
is_memory_bounded: bool = False,
) -> Tensor:
"""
Single w8a8 fc layer with dynamic row-wise scaling.
"""
if isinstance(w, Fp8RowwiseWeights):
xq, x_scale = torch.ops.fbgemm.quantize_fp8_per_row(
x, num_tokens, activation_scale_ub
)
y = torch.ops.fbgemm.f8f8bf16_rowwise(
xq, w.weight, x_scale, w.scale, use_fast_accum=True
)
del xq
return y
def ffn_swiglu_fp8_dynamic(
x: Tensor,
w1: Fp8RowwiseWeights,
w3: Fp8RowwiseWeights,
w2: Fp8RowwiseWeights,
activation_scale_ub: Optional[Tensor] = None,
num_tokens: Optional[Tensor] = None,
is_memory_bounded: bool = False,
) -> Tensor:
(B, T, D) = x.shape
HD_L = w1.shape[0]
assert HD_L == w3.shape[0]
x1 = fc_fp8_dynamic(
x.view(B * T, D),
w1,
activation_scale_ub,
num_tokens,
is_memory_bounded,
)
x2 = fc_fp8_dynamic(
x.view(B * T, D),
w3,
activation_scale_ub,
num_tokens,
is_memory_bounded,
)
z = torch.nn.functional.silu(x1) * x2
del x1, x2
z_ = fc_fp8_dynamic(z, w2, activation_scale_ub, num_tokens, is_memory_bounded)
return z_.view(B, T, D)

View file

@ -0,0 +1,5 @@
fairscale
fire
tiktoken
blobfile
fbgemm-gpu==0.8.0rc4

View file

@ -0,0 +1,455 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# This software may be used and distributed in accordance with the terms of the Llama 3 Community License Agreement.
import json
import os
import sys
import time
from pathlib import Path
from typing import List, Optional, Tuple, TypedDict
import torch
import torch.nn.functional as F
from fairscale.nn.model_parallel.initialize import (
get_model_parallel_rank,
initialize_model_parallel,
model_parallel_is_initialized,
)
from fp8.fp8_impls import (
FfnQuantizeMode,
Fp8ScaledWeights,
load_fp8,
ModelLoadMode,
quantize_fp8,
)
from llama.model import ModelArgs, Transformer, TransformerBlock
from llama.tokenizer import ChatFormat, Dialog, Message, ModelInput, Tokenizer
class CompletionPrediction(TypedDict, total=False):
generation: str
tokens: List[str] # not required
logprobs: List[float] # not required
class ChatPrediction(TypedDict, total=False):
generation: Message
tokens: List[str] # not required
logprobs: List[float] # not required
class Llama:
@staticmethod
def build(
ckpt_dir: str,
tokenizer_path: str,
max_seq_len: int,
max_batch_size: int,
model_parallel_size: Optional[int] = None,
ffn_quantize_mode: Optional[FfnQuantizeMode] = FfnQuantizeMode.NONE,
model_load_mode: Optional[ModelLoadMode] = ModelLoadMode.BF16,
fp8_activation_scale_ub: Optional[float] = 1200.0,
seed: int = 1,
) -> "Llama":
"""
Build a Llama instance by initializing and loading a model checkpoint.
Args:
ckpt_dir (str): Path to the directory containing checkpoint files.
tokenizer_path (str): Path to the tokenizer file.
max_seq_len (int): Maximum sequence length for input text.
max_batch_size (int): Maximum batch size for inference.
model_parallel_size (Optional[int], optional): Number of model parallel processes.
If not provided, it's determined from the environment. Defaults to None.
Returns:
Llama: An instance of the Llama class with the loaded model and tokenizer.
Raises:
AssertionError: If there are no checkpoint files in the specified directory,
or if the model parallel size does not match the number of checkpoint files.
Note:
This method initializes the distributed process group, sets the device to CUDA,
and loads the pre-trained model and tokenizer.
"""
if not torch.distributed.is_initialized():
torch.distributed.init_process_group("nccl")
if not model_parallel_is_initialized():
if model_parallel_size is None:
model_parallel_size = int(os.environ.get("WORLD_SIZE", 1))
initialize_model_parallel(model_parallel_size)
local_rank = int(os.environ.get("LOCAL_RANK", 0))
torch.cuda.set_device(local_rank)
# seed must be the same in all processes
torch.manual_seed(seed)
if local_rank > 0:
sys.stdout = open(os.devnull, "w")
start_time = time.time()
checkpoints = sorted(Path(ckpt_dir).glob("*.pth"))
assert len(checkpoints) > 0, f"no checkpoint files found in {ckpt_dir}"
assert model_parallel_size == len(
checkpoints
), f"Loading a checkpoint for MP={len(checkpoints)} but world size is {model_parallel_size}"
ckpt_path = checkpoints[get_model_parallel_rank()]
checkpoint = torch.load(ckpt_path, map_location="cpu", weights_only=True)
with open(Path(ckpt_dir) / "params.json", "r") as f:
params = json.loads(f.read())
model_args: ModelArgs = ModelArgs(
max_seq_len=max_seq_len,
max_batch_size=max_batch_size,
**params,
)
tokenizer = Tokenizer(model_path=tokenizer_path)
assert (
model_args.vocab_size == tokenizer.n_words
), f"model_args vocab = {model_args.vocab_size} but tokenizer vocab = {tokenizer.n_words}"
# load on CPU in bf16 so that fp8 conversion does not find an unexpected (fp32, e.g.) datatype
torch.set_default_tensor_type(torch.BFloat16Tensor)
model = Transformer(model_args)
model.load_state_dict(checkpoint, strict=False)
if torch.cuda.is_bf16_supported():
torch.set_default_tensor_type(torch.cuda.BFloat16Tensor)
else:
torch.set_default_tensor_type(torch.cuda.HalfTensor)
print("ffn_quantize_mode: ", ffn_quantize_mode)
if ffn_quantize_mode == FfnQuantizeMode.FP8_ROWWISE:
# Move weights to GPU with quantization
if model_load_mode == ModelLoadMode.FP8:
fp8_scales_path = os.path.join(
ckpt_dir, f"fp8_scales_{get_model_parallel_rank()}.pt"
)
assert os.path.isfile(
fp8_scales_path
), f"fp8_scales_path not found for rank {get_model_parallel_rank()}"
fp8_scales = torch.load(fp8_scales_path, weights_only=True)
for block in model.layers:
if isinstance(block, TransformerBlock):
if block.layer_id == 0 or block.layer_id == (
model.n_layers - 1
):
continue
block.feed_forward.w1.weight = load_fp8(
block.feed_forward.w1.weight,
fp8_scales[
f"{block.layer_id}_feed_forward.w1_{get_model_parallel_rank()}"
],
fp8_activation_scale_ub,
)
block.feed_forward.w3.weight = load_fp8(
block.feed_forward.w3.weight,
fp8_scales[
f"{block.layer_id}_feed_forward.w3_{get_model_parallel_rank()}"
],
fp8_activation_scale_ub,
)
block.feed_forward.w2.weight = load_fp8(
block.feed_forward.w2.weight,
fp8_scales[
f"{block.layer_id}_feed_forward.w2_{get_model_parallel_rank()}"
],
fp8_activation_scale_ub,
)
else:
for block in model.layers:
if isinstance(block, TransformerBlock):
if block.layer_id == 0 or block.layer_id == (
model.n_layers - 1
):
continue
block.feed_forward.w1.weight = quantize_fp8(
block.feed_forward.w1.weight,
fp8_activation_scale_ub,
ffn_quantize_mode,
output_device=torch.device("cuda"),
)
block.feed_forward.w3.weight = quantize_fp8(
block.feed_forward.w3.weight,
fp8_activation_scale_ub,
ffn_quantize_mode,
output_device=torch.device("cuda"),
)
block.feed_forward.w2.weight = quantize_fp8(
block.feed_forward.w2.weight,
fp8_activation_scale_ub,
ffn_quantize_mode,
output_device=torch.device("cuda"),
)
for _, parameter in model.named_parameters():
if not isinstance(parameter, Fp8ScaledWeights):
parameter.data = parameter.to(device="cuda")
else:
for _, parameter in model.named_parameters():
parameter.data = parameter.to(device="cuda")
print(f"Loaded in {time.time() - start_time:.2f} seconds")
return Llama(model, tokenizer, model_args)
def __init__(self, model: Transformer, tokenizer: Tokenizer, args: ModelArgs):
self.args = args
self.model = model
self.tokenizer = tokenizer
self.formatter = ChatFormat(tokenizer)
@torch.inference_mode()
def generate(
self,
model_inputs: List[ModelInput],
max_gen_len: int,
temperature: float = 0.6,
top_p: float = 0.9,
logprobs: bool = False,
echo: bool = False,
include_stop_token: bool = False,
) -> Tuple[List[List[int]], Optional[List[List[float]]]]:
"""
Generate text sequences based on provided prompts using the language generation model.
Args:
prompt_tokens (List[List[int]]): List of tokenized prompts, where each prompt is represented as a list of integers.
max_gen_len (int): Maximum length of the generated text sequence.
temperature (float, optional): Temperature value for controlling randomness in sampling. Defaults to 0.6.
top_p (float, optional): Top-p probability threshold for nucleus sampling. Defaults to 0.9.
logprobs (bool, optional): Flag indicating whether to compute token log probabilities. Defaults to False.
echo (bool, optional): Flag indicating whether to include prompt tokens in the generated output. Defaults to False.
Returns:
Tuple[List[List[int]], Optional[List[List[float]]]]: A tuple containing generated token sequences and, if logprobs is True, corresponding token log probabilities.
Note:
This method uses the provided prompts as a basis for generating text. It employs nucleus sampling to produce text with controlled randomness.
If logprobs is True, token log probabilities are computed for each generated token.
"""
params = self.model.params
prompt_tokens = [m.tokens for m in model_inputs]
bsz = len(prompt_tokens)
assert bsz <= params.max_batch_size, (bsz, params.max_batch_size)
min_prompt_len = min(len(t) for t in prompt_tokens)
max_prompt_len = max(len(t) for t in prompt_tokens)
assert max_prompt_len <= params.max_seq_len
total_len = min(params.max_seq_len, max_gen_len + max_prompt_len)
pad_id = self.tokenizer.pad_id
tokens = torch.full((bsz, total_len), pad_id, dtype=torch.long, device="cuda")
for k, t in enumerate(prompt_tokens):
tokens[k, : len(t)] = torch.tensor(t, dtype=torch.long, device="cuda")
if logprobs:
token_logprobs = torch.zeros_like(tokens, dtype=torch.float)
prev_pos = 0
eos_reached = torch.tensor([False] * bsz, device="cuda")
input_text_mask = tokens != pad_id
if min_prompt_len == total_len:
logits = self.model.forward(tokens, prev_pos)
token_logprobs = -F.cross_entropy(
input=logits.transpose(1, 2),
target=tokens,
reduction="none",
ignore_index=pad_id,
)
stop_tokens = torch.tensor(list(self.tokenizer.stop_tokens))
for cur_pos in range(min_prompt_len, total_len):
logits = self.model.forward(tokens[:, prev_pos:cur_pos], prev_pos)
if temperature > 0:
probs = torch.softmax(logits[:, -1] / temperature, dim=-1)
next_token = sample_top_p(probs, top_p)
else:
next_token = torch.argmax(logits[:, -1], dim=-1)
next_token = next_token.reshape(-1)
# only replace token if prompt has already been generated
next_token = torch.where(
input_text_mask[:, cur_pos], tokens[:, cur_pos], next_token
)
tokens[:, cur_pos] = next_token
target = tokens[:, prev_pos + 1 : cur_pos + 1]
if logprobs:
token_logprobs[:, prev_pos + 1 : cur_pos + 1] = -F.cross_entropy(
input=logits.transpose(1, 2),
target=tokens[:, prev_pos + 1 : cur_pos + 1],
reduction="none",
ignore_index=pad_id,
)
eos_reached |= (~input_text_mask[:, cur_pos]) & (
torch.isin(next_token, stop_tokens)
)
prev_pos = cur_pos
if all(eos_reached):
break
if logprobs:
token_logprobs = token_logprobs.tolist()
out_tokens, out_logprobs = [], []
for i, toks in enumerate(tokens.tolist()):
# cut to max gen len
start = 0 if echo else len(prompt_tokens[i])
toks = toks[start : len(prompt_tokens[i]) + max_gen_len]
probs = None
if logprobs:
probs = token_logprobs[i][start : len(prompt_tokens[i]) + max_gen_len]
# cut to after eos tok if any
for stop_token in self.tokenizer.stop_tokens:
try:
eos_idx = toks.index(stop_token)
if include_stop_token:
eos_idx += 1
toks = toks[:eos_idx]
probs = probs[:eos_idx] if logprobs else None
except ValueError:
pass
out_tokens.append(toks)
out_logprobs.append(probs)
return (out_tokens, out_logprobs if logprobs else None)
def text_completion(
self,
prompts: List[str],
temperature: float = 0.6,
top_p: float = 0.9,
max_gen_len: Optional[int] = None,
logprobs: bool = False,
echo: bool = False,
) -> List[CompletionPrediction]:
"""
Perform text completion for a list of prompts using the language generation model.
Args:
prompts (List[str]): List of text prompts for completion.
temperature (float, optional): Temperature value for controlling randomness in sampling. Defaults to 0.6.
top_p (float, optional): Top-p probability threshold for nucleus sampling. Defaults to 0.9.
max_gen_len (Optional[int], optional): Maximum length of the generated completion sequence.
If not provided, it's set to the model's maximum sequence length minus 1.
logprobs (bool, optional): Flag indicating whether to compute token log probabilities. Defaults to False.
echo (bool, optional): Flag indicating whether to include prompt tokens in the generated output. Defaults to False.
Returns:
List[CompletionPrediction]: List of completion predictions, each containing the generated text completion.
Note:
This method generates text completions for the provided prompts, employing nucleus sampling to introduce controlled randomness.
If logprobs is True, token log probabilities are computed for each generated token.
"""
if max_gen_len is None:
max_gen_len = self.model.params.max_seq_len - 1
prompt_tokens = [self.tokenizer.encode(x, bos=True, eos=False) for x in prompts]
generation_tokens, generation_logprobs = self.generate(
model_inputs=[ModelInput(tokens=pt) for pt in prompt_tokens],
max_gen_len=max_gen_len,
temperature=temperature,
top_p=top_p,
logprobs=logprobs,
echo=echo,
)
if logprobs:
return [
{
"generation": self.tokenizer.decode(t),
"tokens": [self.tokenizer.decode([x]) for x in t],
"logprobs": logprobs_i,
}
for t, logprobs_i in zip(generation_tokens, generation_logprobs)
]
return [{"generation": self.tokenizer.decode(t)} for t in generation_tokens]
def chat_completion(
self,
dialogs: List[Dialog],
temperature: float = 0.6,
top_p: float = 0.9,
max_gen_len: Optional[int] = None,
logprobs: bool = False,
) -> List[ChatPrediction]:
"""
Generate assistant responses for a list of conversational dialogs using the language generation model.
Args:
dialogs (List[Dialog]): List of conversational dialogs, where each dialog is a list of messages.
temperature (float, optional): Temperature value for controlling randomness in sampling. Defaults to 0.6.
top_p (float, optional): Top-p probability threshold for nucleus sampling. Defaults to 0.9.
max_gen_len (Optional[int], optional): Maximum length of the generated response sequence.
If not provided, it's set to the model's maximum sequence length minus 1.
logprobs (bool, optional): Flag indicating whether to compute token log probabilities. Defaults to False.
Returns:
List[ChatPrediction]: List of chat predictions, each containing the assistant's generated response.
Note:
This method generates assistant responses for the provided conversational dialogs.
It employs nucleus sampling to introduce controlled randomness in text generation.
If logprobs is True, token log probabilities are computed for each generated token.
"""
if max_gen_len is None:
max_gen_len = self.model.params.max_seq_len - 1
model_inputs = [
self.formatter.encode_dialog_prompt(dialog) for dialog in dialogs
]
generation_tokens, generation_logprobs = self.generate(
model_inputs=model_inputs,
max_gen_len=max_gen_len,
temperature=temperature,
top_p=top_p,
logprobs=logprobs,
include_stop_token=True,
)
if logprobs:
return [
{
"generation": self.formatter.decode_assistant_message(t),
"tokens": [self.tokenizer.decode([x]) for x in t],
"logprobs": logprobs_i,
}
for t, logprobs_i in zip(generation_tokens, generation_logprobs)
]
return [
{
"generation": self.formatter.decode_assistant_message(t),
}
for t in generation_tokens
]
def sample_top_p(probs, p):
"""
Perform top-p (nucleus) sampling on a probability distribution.
Args:
probs (torch.Tensor): Probability distribution tensor.
p (float): Probability threshold for top-p sampling.
Returns:
torch.Tensor: Sampled token indices.
Note:
Top-p sampling selects the smallest set of tokens whose cumulative probability mass
exceeds the threshold p. The distribution is renormalized based on the selected tokens.
"""
probs_sort, probs_idx = torch.sort(probs, dim=-1, descending=True)
probs_sum = torch.cumsum(probs_sort, dim=-1)
mask = probs_sum - probs_sort > p
probs_sort[mask] = 0.0
probs_sort.div_(probs_sort.sum(dim=-1, keepdim=True))
next_token = torch.multinomial(probs_sort, num_samples=1)
next_token = torch.gather(probs_idx, -1, next_token)
return next_token

View file

@ -0,0 +1,355 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# This software may be used and distributed in accordance with the terms of the Llama 3 Community License Agreement.
import math
from dataclasses import dataclass
from typing import Optional, Tuple
import fairscale.nn.model_parallel.initialize as fs_init
import torch
import torch.nn.functional as F
from fairscale.nn.model_parallel.layers import (
ColumnParallelLinear,
RowParallelLinear,
VocabParallelEmbedding,
)
from fairscale.nn.model_parallel.mappings import reduce_from_model_parallel_region
from fp8.fp8_impls import ffn_swiglu
from torch import nn
@dataclass
class ModelArgs:
dim: int = 4096
n_layers: int = 32
n_heads: int = 32
n_kv_heads: Optional[int] = None
vocab_size: int = -1
multiple_of: int = 256 # make SwiGLU hidden layer size multiple of large power of 2
ffn_dim_multiplier: Optional[float] = None
norm_eps: float = 1e-5
rope_theta: float = 500000
use_scaled_rope: bool = False
max_batch_size: int = 32
max_seq_len: int = 2048
def __init__(self, **kwargs):
for k, v in kwargs.items():
if hasattr(self, k):
setattr(self, k, v)
if self.n_kv_heads is None:
self.n_kv_heads = self.n_heads
assert self.n_kv_heads <= self.n_heads
assert self.n_heads % self.n_kv_heads == 0
assert self.dim % self.n_heads == 0
class RMSNorm(torch.nn.Module):
def __init__(self, dim: int, eps: float = 1e-6):
super().__init__()
self.eps = eps
self.weight = nn.Parameter(torch.ones(dim))
def _norm(self, x):
return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
def forward(self, x):
output = self._norm(x.float()).type_as(x)
return output * self.weight
def apply_scaling(freqs: torch.Tensor):
# Values obtained from grid search
scale_factor = 8
low_freq_factor = 1
high_freq_factor = 4
old_context_len = 8192 # original llama3 length
low_freq_wavelen = old_context_len / low_freq_factor
high_freq_wavelen = old_context_len / high_freq_factor
new_freqs = []
for freq in freqs:
wavelen = 2 * math.pi / freq
if wavelen < high_freq_wavelen:
new_freqs.append(freq)
elif wavelen > low_freq_wavelen:
new_freqs.append(freq / scale_factor)
else:
assert low_freq_wavelen != high_freq_wavelen
smooth = (old_context_len / wavelen - low_freq_factor) / (
high_freq_factor - low_freq_factor
)
new_freqs.append((1 - smooth) * freq / scale_factor + smooth * freq)
return torch.tensor(new_freqs, dtype=freqs.dtype, device=freqs.device)
def precompute_freqs_cis(
dim: int, end: int, theta: float = 10000.0, use_scaled: bool = False
):
freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))
t = torch.arange(end, device=freqs.device, dtype=torch.float32)
if use_scaled:
freqs = apply_scaling(freqs)
freqs = torch.outer(t, freqs)
freqs_cis = torch.polar(torch.ones_like(freqs), freqs) # complex64
return freqs_cis
def reshape_for_broadcast(freqs_cis: torch.Tensor, x: torch.Tensor):
ndim = x.ndim
assert 0 <= 1 < ndim
assert freqs_cis.shape == (x.shape[1], x.shape[-1])
shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)]
return freqs_cis.view(*shape)
def apply_rotary_emb(
xq: torch.Tensor,
xk: torch.Tensor,
freqs_cis: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor]:
xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2))
xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2))
freqs_cis = reshape_for_broadcast(freqs_cis, xq_)
xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3)
xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3)
return xq_out.type_as(xq), xk_out.type_as(xk)
def repeat_kv(x: torch.Tensor, n_rep: int) -> torch.Tensor:
"""torch.repeat_interleave(x, dim=2, repeats=n_rep)"""
bs, slen, n_kv_heads, head_dim = x.shape
if n_rep == 1:
return x
return (
x[:, :, :, None, :]
.expand(bs, slen, n_kv_heads, n_rep, head_dim)
.reshape(bs, slen, n_kv_heads * n_rep, head_dim)
)
class Attention(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
self.n_kv_heads = args.n_heads if args.n_kv_heads is None else args.n_kv_heads
model_parallel_size = fs_init.get_model_parallel_world_size()
self.n_local_heads = args.n_heads // model_parallel_size
self.n_local_kv_heads = self.n_kv_heads // model_parallel_size
self.n_rep = self.n_local_heads // self.n_local_kv_heads
self.head_dim = args.dim // args.n_heads
self.wq = ColumnParallelLinear(
args.dim,
args.n_heads * self.head_dim,
bias=False,
gather_output=False,
init_method=lambda x: x,
)
self.wk = ColumnParallelLinear(
args.dim,
self.n_kv_heads * self.head_dim,
bias=False,
gather_output=False,
init_method=lambda x: x,
)
self.wv = ColumnParallelLinear(
args.dim,
self.n_kv_heads * self.head_dim,
bias=False,
gather_output=False,
init_method=lambda x: x,
)
self.wo = RowParallelLinear(
args.n_heads * self.head_dim,
args.dim,
bias=False,
input_is_parallel=True,
init_method=lambda x: x,
)
self.cache_k = torch.zeros(
(
args.max_batch_size,
args.max_seq_len,
self.n_local_kv_heads,
self.head_dim,
)
).cuda()
self.cache_v = torch.zeros(
(
args.max_batch_size,
args.max_seq_len,
self.n_local_kv_heads,
self.head_dim,
)
).cuda()
def forward(
self,
x: torch.Tensor,
start_pos: int,
freqs_cis: torch.Tensor,
mask: Optional[torch.Tensor],
):
bsz, seqlen, _ = x.shape
xq, xk, xv = self.wq(x), self.wk(x), self.wv(x)
xq = xq.view(bsz, seqlen, self.n_local_heads, self.head_dim)
xk = xk.view(bsz, seqlen, self.n_local_kv_heads, self.head_dim)
xv = xv.view(bsz, seqlen, self.n_local_kv_heads, self.head_dim)
xq, xk = apply_rotary_emb(xq, xk, freqs_cis=freqs_cis)
self.cache_k = self.cache_k.to(xq)
self.cache_v = self.cache_v.to(xq)
self.cache_k[:bsz, start_pos : start_pos + seqlen] = xk
self.cache_v[:bsz, start_pos : start_pos + seqlen] = xv
keys = self.cache_k[:bsz, : start_pos + seqlen]
values = self.cache_v[:bsz, : start_pos + seqlen]
# repeat k/v heads if n_kv_heads < n_heads
keys = repeat_kv(
keys, self.n_rep
) # (bs, cache_len + seqlen, n_local_heads, head_dim)
values = repeat_kv(
values, self.n_rep
) # (bs, cache_len + seqlen, n_local_heads, head_dim)
xq = xq.transpose(1, 2) # (bs, n_local_heads, seqlen, head_dim)
keys = keys.transpose(1, 2) # (bs, n_local_heads, cache_len + seqlen, head_dim)
values = values.transpose(
1, 2
) # (bs, n_local_heads, cache_len + seqlen, head_dim)
scores = torch.matmul(xq, keys.transpose(2, 3)) / math.sqrt(self.head_dim)
if mask is not None:
scores = scores + mask # (bs, n_local_heads, seqlen, cache_len + seqlen)
scores = F.softmax(scores.float(), dim=-1).type_as(xq)
output = torch.matmul(scores, values) # (bs, n_local_heads, seqlen, head_dim)
output = output.transpose(1, 2).contiguous().view(bsz, seqlen, -1)
return self.wo(output)
class FeedForward(nn.Module):
def __init__(
self,
dim: int,
hidden_dim: int,
multiple_of: int,
ffn_dim_multiplier: Optional[float],
):
super().__init__()
hidden_dim = int(2 * hidden_dim / 3)
# custom dim factor multiplier
if ffn_dim_multiplier is not None:
hidden_dim = int(ffn_dim_multiplier * hidden_dim)
hidden_dim = multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of)
self.w1 = ColumnParallelLinear(
dim,
hidden_dim,
bias=False,
gather_output=False,
init_method=lambda x: x,
)
self.w3 = ColumnParallelLinear(
dim,
hidden_dim,
bias=False,
gather_output=False,
init_method=lambda x: x,
)
self.w2 = RowParallelLinear(
hidden_dim, dim, bias=False, input_is_parallel=True, init_method=lambda x: x
)
def forward(self, x):
out = ffn_swiglu(x, self.w1.weight, self.w3.weight, self.w2.weight)
return reduce_from_model_parallel_region(out)
class TransformerBlock(nn.Module):
def __init__(self, layer_id: int, args: ModelArgs):
super().__init__()
self.n_heads = args.n_heads
self.dim = args.dim
self.head_dim = args.dim // args.n_heads
self.attention = Attention(args)
self.feed_forward = FeedForward(
dim=args.dim,
hidden_dim=4 * args.dim,
multiple_of=args.multiple_of,
ffn_dim_multiplier=args.ffn_dim_multiplier,
)
self.layer_id = layer_id
self.attention_norm = RMSNorm(args.dim, eps=args.norm_eps)
self.ffn_norm = RMSNorm(args.dim, eps=args.norm_eps)
def forward(
self,
x: torch.Tensor,
start_pos: int,
freqs_cis: torch.Tensor,
mask: Optional[torch.Tensor],
):
h = x + self.attention(self.attention_norm(x), start_pos, freqs_cis, mask)
out = h + self.feed_forward(self.ffn_norm(h))
return out
class Transformer(nn.Module):
def __init__(self, params: ModelArgs):
super().__init__()
self.params = params
self.vocab_size = params.vocab_size
self.n_layers = params.n_layers
self.tok_embeddings = VocabParallelEmbedding(
params.vocab_size, params.dim, init_method=lambda x: x
)
self.layers = torch.nn.ModuleList()
for layer_id in range(params.n_layers):
self.layers.append(TransformerBlock(layer_id, params))
self.norm = RMSNorm(params.dim, eps=params.norm_eps)
self.output = ColumnParallelLinear(
params.dim, params.vocab_size, bias=False, init_method=lambda x: x
)
self.freqs_cis = precompute_freqs_cis(
params.dim // params.n_heads,
params.max_seq_len * 2,
params.rope_theta,
params.use_scaled_rope,
)
@torch.inference_mode()
def forward(self, tokens: torch.Tensor, start_pos: int):
_bsz, seqlen = tokens.shape
h = self.tok_embeddings(tokens)
self.freqs_cis = self.freqs_cis.to(h.device)
freqs_cis = self.freqs_cis[start_pos : start_pos + seqlen]
mask = None
if seqlen > 1:
mask = torch.full((seqlen, seqlen), float("-inf"), device=tokens.device)
mask = torch.triu(mask, diagonal=1)
# When performing key-value caching, we compute the attention scores
# only for the new sequence. Thus, the matrix of scores is of size
# (seqlen, cache_len + seqlen), and the only masked entries are (i, j) for
# j > cache_len + i, since row i corresponds to token cache_len + i.
mask = torch.hstack(
[torch.zeros((seqlen, start_pos), device=tokens.device), mask]
).type_as(h)
for layer in self.layers:
h = layer(h, start_pos, freqs_cis, mask)
h = self.norm(h)
output = self.output(h).float()
return output

View file

@ -0,0 +1,155 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# This software may be used and distributed in accordance with the terms of the Llama 3 Community License Agreement.
import json
import os
import shutil
import sys
from pathlib import Path
from typing import Optional
import fire
import torch
from fairscale.nn.model_parallel.initialize import (
get_model_parallel_rank,
initialize_model_parallel,
model_parallel_is_initialized,
)
from fp8.fp8_impls import FfnQuantizeMode, quantize_fp8
from llama.model import ModelArgs, Transformer, TransformerBlock
from llama.tokenizer import Tokenizer
from torch.nn.parameter import Parameter
def main(
ckpt_dir: str,
tokenizer_path: str,
quantized_ckpt_dir: str,
max_seq_len: Optional[int] = 512,
max_batch_size: Optional[int] = 4,
model_parallel_size: Optional[int] = None,
ffn_quantize_mode: Optional[FfnQuantizeMode] = FfnQuantizeMode.FP8_ROWWISE,
fp8_activation_scale_ub: Optional[float] = 1200.0,
seed: int = 1,
):
""" """
if not os.path.exists(quantized_ckpt_dir):
os.makedirs(quantized_ckpt_dir)
shutil.copy(
os.path.join(ckpt_dir, "params.json"),
os.path.join(quantized_ckpt_dir, "params.json"),
)
shutil.copy(
os.path.join(ckpt_dir, "tokenizer.model"),
os.path.join(quantized_ckpt_dir, "tokenizer.model"),
)
if not torch.distributed.is_initialized():
torch.distributed.init_process_group("nccl")
if not model_parallel_is_initialized():
if model_parallel_size is None:
model_parallel_size = int(os.environ.get("WORLD_SIZE", 1))
initialize_model_parallel(model_parallel_size)
local_rank = int(os.environ.get("LOCAL_RANK", 0))
torch.cuda.set_device(local_rank)
# seed must be the same in all processes
torch.manual_seed(seed)
if local_rank > 0:
sys.stdout = open(os.devnull, "w")
checkpoints = sorted(Path(ckpt_dir).glob("*.pth"))
assert len(checkpoints) > 0, f"no checkpoint files found in {ckpt_dir}"
assert model_parallel_size == len(
checkpoints
), f"Loading a checkpoint for MP={len(checkpoints)} but world size is {model_parallel_size}"
ckpt_path = checkpoints[get_model_parallel_rank()]
checkpoint = torch.load(ckpt_path, map_location="cpu", weights_only=True)
with open(Path(ckpt_dir) / "params.json", "r") as f:
params = json.loads(f.read())
model_args: ModelArgs = ModelArgs(
max_seq_len=max_seq_len,
max_batch_size=max_batch_size,
**params,
)
tokenizer = Tokenizer(model_path=tokenizer_path)
assert (
model_args.vocab_size == tokenizer.n_words
), f"model_args vocab = {model_args.vocab_size} but tokenizer vocab = {tokenizer.n_words}"
# load on CPU in bf16 so that fp8 conversion does not find an unexpected (fp32, e.g.) datatype
torch.set_default_tensor_type(torch.BFloat16Tensor)
model = Transformer(model_args)
model.load_state_dict(checkpoint, strict=False)
if torch.cuda.is_bf16_supported():
torch.set_default_tensor_type(torch.cuda.BFloat16Tensor)
else:
torch.set_default_tensor_type(torch.cuda.HalfTensor)
print(ckpt_path)
assert (
quantized_ckpt_dir is not None
), "QUantized checkpoint directory should not be None"
fp8_scales = {}
for block in model.layers:
if isinstance(block, TransformerBlock):
if block.layer_id == 0 or block.layer_id == (model.n_layers - 1):
continue
fp8_weight = quantize_fp8(
block.feed_forward.w1.weight,
fp8_activation_scale_ub,
ffn_quantize_mode,
output_device=torch.device("cpu"),
)
with torch.inference_mode():
block.feed_forward.w1.weight = Parameter(fp8_weight.weight)
fp8_scales[
f"{block.layer_id}_feed_forward.w1_{get_model_parallel_rank()}"
] = fp8_weight.scale
fp8_weight = quantize_fp8(
block.feed_forward.w3.weight,
fp8_activation_scale_ub,
ffn_quantize_mode,
output_device=torch.device("cpu"),
)
with torch.inference_mode():
block.feed_forward.w3.weight = Parameter(fp8_weight.weight)
fp8_scales[
f"{block.layer_id}_feed_forward.w3_{get_model_parallel_rank()}"
] = fp8_weight.scale
fp8_weight = quantize_fp8(
block.feed_forward.w2.weight,
fp8_activation_scale_ub,
ffn_quantize_mode,
output_device=torch.device("cpu"),
)
with torch.inference_mode():
block.feed_forward.w2.weight = Parameter(fp8_weight.weight)
fp8_scales[
f"{block.layer_id}_feed_forward.w2_{get_model_parallel_rank()}"
] = fp8_weight.scale
fp8_scales_path = os.path.join(
quantized_ckpt_dir, f"fp8_scales_{get_model_parallel_rank()}.pt"
)
torch.save(fp8_scales, fp8_scales_path)
ckpt_path = os.path.join(
quantized_ckpt_dir,
"consolidated.{:02d}.pth".format(get_model_parallel_rank()),
)
torch.save(model.state_dict(), ckpt_path)
if __name__ == "__main__":
fire.Fire(main)

View file

@ -0,0 +1,25 @@
#!/bin/bash
set -euo pipefail
set -x
cd $(git rev-parse --show-toplevel)
MASTER_HOST=$1
RUN_ID=$2
CKPT_DIR=$3
QUANT_CKPT_DIR=$4
TOKENIZER_PATH=$5
NNODES=$6
NPROC=$7
echo $MASTER_HOST, $RUN_ID, $CKPT_DIR, $QUANT_CKPT_DIR
NCCL_NET=Socket NCCL_SOCKET_IFNAME=eth TIKTOKEN_CACHE_DIR="" \
torchrun \
--nnodes=$NNODES --nproc_per_node=$NPROC \
--rdzv_id=$RUN_ID \
--rdzv_conf='timeout=120' \
--rdzv_backend=c10d \
--rdzv_endpoint="${MASTER_HOST}:29502" \
quantize_checkpoint.py $CKPT_DIR $TOKENIZER_PATH $QUANT_CKPT_DIR

View file

@ -0,0 +1,102 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# This software may be used and distributed in accordance with the terms of the Llama 3 Community License Agreement.
import unittest
import torch
from fp8_impls import attn_linear, ffn_swiglu_fp8_dynamic, quantize_fp8
from hypothesis import given, settings, strategies as st
from torch import Tensor
@unittest.skipIf(
not torch.cuda.is_available()
or torch.cuda.get_device_properties(torch.cuda.current_device()).major < 9,
"Skip when H100 is not available",
)
class FP8Tests(unittest.TestCase):
@settings(deadline=None)
@given(
D=st.sampled_from([4096, 8192]),
HD_L=st.sampled_from([1280, 2560]),
B=st.sampled_from([1, 2]),
T=st.sampled_from([2048, 4096]),
UB=st.sampled_from([1000, 10000]),
)
def test_fp8_ffn(
self,
D: int,
HD_L: int,
B: int,
T: int,
UB: float,
) -> None:
x = torch.randn(size=(B, T, D), dtype=torch.bfloat16, device="cuda") * 0.1
w13 = (
torch.randn(size=(2 * HD_L, D), dtype=torch.bfloat16, device="cuda") * 0.01
)
w2 = torch.randn(size=(D, HD_L), dtype=torch.bfloat16, device="cuda") * 0.1
x_q = quantize_fp8(x, UB)
w13_q = quantize_fp8(w13, UB)
w2_q = quantize_fp8(w2, UB)
def ref_ffn(x: Tensor, w13: Tensor, w2: Tensor) -> Tensor:
(B, T, D) = x.shape
(HD_L_2, D_) = w13.shape
assert D_ == D
HD_L = HD_L_2 // 2
y = x.view(B * T, D) @ w13.T
x1 = y[:, :HD_L]
x2 = y[:, HD_L:]
z = torch.nn.functional.silu(x1) * x2
return (z @ w2.T).view(B, T, D).to(torch.bfloat16)
v = ffn_swiglu_fp8_dynamic(x, w13_q, w2_q)
# Fake quant
x = x_q.weight.bfloat16() * x_q.scale
w13 = w13_q.weight.bfloat16() * w13_q.scale
w2 = w2_q.weight.bfloat16() * w2_q.scale
v_ref = ref_ffn(x, w13, w2)
torch.testing.assert_close(v_ref, v, atol=4.0e-3, rtol=4.0e-3)
@settings(deadline=None)
@given(
B_T=st.sampled_from([2048, 4096]),
D=st.sampled_from([128, 256]),
HD_L=st.sampled_from([256, 512]),
UB=st.sampled_from([1000, 10000]),
)
def test_fp8_attn_linear(self, B_T: int, D: int, HD_L: int, UB: int) -> None:
B_T = 4096
D = 256
HD_L = 512
UB = float(UB)
x = torch.randn(size=(B_T, D), dtype=torch.bfloat16, device="cuda") * 0.1
wqkv = torch.randn(size=(HD_L, D), dtype=torch.bfloat16, device="cuda") * 0.01
x_q = quantize_fp8(x, UB)
wqkv_q = quantize_fp8(wqkv, UB)
num_tokens = torch.tensor(B_T, dtype=torch.int64, device="cuda")
y = attn_linear(x, wqkv_q)
y_nt = attn_linear(x, wqkv_q, num_tokens=num_tokens)
# Fake quant
x = x_q.weight.bfloat16() * x_q.scale
wqkv = wqkv_q.weight.bfloat16() * wqkv_q.scale
y_ref = (x @ wqkv.T).to(torch.bfloat16)
torch.testing.assert_close(y_ref, y, atol=1.0e-3, rtol=1.0e-3)
torch.testing.assert_close(y_ref, y_nt, atol=1.0e-3, rtol=1.0e-3)
if __name__ == "__main__":
unittest.main()

View file

@ -0,0 +1,117 @@
import asyncio
import signal
import fire
from dotenv import load_dotenv
from fastapi import FastAPI, HTTPException, Request
from fastapi.responses import StreamingResponse
from omegaconf import OmegaConf
from toolchain.utils import get_config_dir, parse_config
from .api.config import ModelInferenceHydraConfig
from .api.endpoints import ChatCompletionRequest, ChatCompletionResponseStreamChunk
from .api_instance import get_inference_api_instance
load_dotenv()
GLOBAL_CONFIG = None
def get_config():
return GLOBAL_CONFIG
def handle_sigint(*args, **kwargs):
print("SIGINT or CTRL-C detected. Exiting gracefully", args)
loop = asyncio.get_event_loop()
for task in asyncio.all_tasks(loop):
task.cancel()
loop.stop()
app = FastAPI()
@app.on_event("startup")
async def startup():
global InferenceApiInstance
config = get_config()
hydra_config = ModelInferenceHydraConfig(
**OmegaConf.to_container(config["model_inference_config"], resolve=True)
)
model_inference_config = hydra_config.convert_to_model_inferene_config()
InferenceApiInstance = await get_inference_api_instance(
model_inference_config,
)
await InferenceApiInstance.initialize()
@app.on_event("shutdown")
async def shutdown():
global InferenceApiInstance
print("shutting down")
await InferenceApiInstance.shutdown()
# there's a single model parallel process running serving the model. for now,
# we don't support multiple concurrent requests to this process.
semaphore = asyncio.Semaphore(1)
@app.post(
"/inference/chat_completion", response_model=ChatCompletionResponseStreamChunk
)
def chat_completion(request: Request, exec_request: ChatCompletionRequest):
if semaphore.locked():
raise HTTPException(
status_code=429,
detail="Only a single concurrent request allowed right now.",
)
async def sse_generator(event_gen):
try:
async for event in event_gen:
yield f"data: {event.json()}\n\n"
await asyncio.sleep(0.01)
except asyncio.CancelledError:
print("Generator cancelled")
await event_gen.aclose()
finally:
semaphore.release()
async def event_gen():
async for event in InferenceApiInstance.chat_completion(exec_request):
yield event
return StreamingResponse(
sse_generator(event_gen()),
media_type="text/event-stream",
)
def main(config_path: str, port: int = 5000, disable_ipv6: bool = False):
global GLOBAL_CONFIG
config_dir = get_config_dir()
GLOBAL_CONFIG = parse_config(config_dir, config_path)
signal.signal(signal.SIGINT, handle_sigint)
import uvicorn
# FYI this does not do hot-reloads
listen_host = "::" if not disable_ipv6 else "0.0.0.0"
print(f"Listening on {listen_host}:{port}")
uvicorn.run(app, host=listen_host, port=port)
if __name__ == "__main__":
fire.Fire(main)