forked from phoenix-oss/llama-stack-mirror
API Updates (#73)
* API Keys passed from Client instead of distro configuration * delete distribution registry * Rename the "package" word away * Introduce a "Router" layer for providers Some providers need to be factorized and considered as thin routing layers on top of other providers. Consider two examples: - The inference API should be a routing layer over inference providers, routed using the "model" key - The memory banks API is another instance where various memory bank types will be provided by independent providers (e.g., a vector store is served by Chroma while a keyvalue memory can be served by Redis or PGVector) This commit introduces a generalized routing layer for this purpose. * update `apis_to_serve` * llama_toolchain -> llama_stack * Codemod from llama_toolchain -> llama_stack - added providers/registry - cleaned up api/ subdirectories and moved impls away - restructured api/api.py - from llama_stack.apis.<api> import foo should work now - update imports to do llama_stack.apis.<api> - update many other imports - added __init__, fixed some registry imports - updated registry imports - create_agentic_system -> create_agent - AgenticSystem -> Agent * Moved some stuff out of common/; re-generated OpenAPI spec * llama-toolchain -> llama-stack (hyphens) * add control plane API * add redis adapter + sqlite provider * move core -> distribution * Some more toolchain -> stack changes * small naming shenanigans * Removing custom tool and agent utilities and moving them client side * Move control plane to distribution server for now * Remove control plane from API list * no codeshield dependency randomly plzzzzz * Add "fire" as a dependency * add back event loggers * stack configure fixes * use brave instead of bing in the example client * add init file so it gets packaged * add init files so it gets packaged * Update MANIFEST * bug fix --------- Co-authored-by: Hardik Shah <hjshah@fb.com> Co-authored-by: Xi Yan <xiyan@meta.com> Co-authored-by: Ashwin Bharambe <ashwin@meta.com>
This commit is contained in:
parent
f294eac5f5
commit
9487ad8294
213 changed files with 1725 additions and 1204 deletions
|
@ -0,0 +1,19 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from .config import MetaReferenceImplConfig # noqa
|
||||
|
||||
|
||||
async def get_provider_impl(config: MetaReferenceImplConfig, _deps):
|
||||
from .inference import MetaReferenceInferenceImpl
|
||||
|
||||
assert isinstance(
|
||||
config, MetaReferenceImplConfig
|
||||
), f"Unexpected config type: {type(config)}"
|
||||
|
||||
impl = MetaReferenceInferenceImpl(config)
|
||||
await impl.initialize()
|
||||
return impl
|
|
@ -0,0 +1,57 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import Optional
|
||||
|
||||
from llama_models.datatypes import ModelFamily
|
||||
|
||||
from llama_models.schema_utils import json_schema_type
|
||||
from llama_models.sku_list import all_registered_models, resolve_model
|
||||
|
||||
from llama_stack.apis.inference import QuantizationConfig
|
||||
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class MetaReferenceImplConfig(BaseModel):
|
||||
model: str = Field(
|
||||
default="Meta-Llama3.1-8B-Instruct",
|
||||
description="Model descriptor from `llama model list`",
|
||||
)
|
||||
quantization: Optional[QuantizationConfig] = None
|
||||
torch_seed: Optional[int] = None
|
||||
max_seq_len: int
|
||||
max_batch_size: int = 1
|
||||
|
||||
@field_validator("model")
|
||||
@classmethod
|
||||
def validate_model(cls, model: str) -> str:
|
||||
permitted_models = [
|
||||
m.descriptor()
|
||||
for m in all_registered_models()
|
||||
if m.model_family == ModelFamily.llama3_1
|
||||
]
|
||||
if model not in permitted_models:
|
||||
model_list = "\n\t".join(permitted_models)
|
||||
raise ValueError(
|
||||
f"Unknown model: `{model}`. Choose from [\n\t{model_list}\n]"
|
||||
)
|
||||
return model
|
||||
|
||||
@property
|
||||
def model_parallel_size(self) -> int:
|
||||
# HUGE HACK ALERT: this will be fixed when we move inference configuration
|
||||
# to ModelsRegistry and we can explicitly ask for `model_parallel_size`
|
||||
# as configuration there
|
||||
gpu_count = 1
|
||||
resolved = resolve_model(self.model)
|
||||
assert resolved is not None
|
||||
descriptor = resolved.descriptor().lower()
|
||||
if "-70b" in descriptor or "-405b" in descriptor:
|
||||
gpu_count = 8
|
||||
|
||||
return gpu_count
|
|
@ -0,0 +1,327 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# This software may be used and distributed in accordance with the terms of the Llama 3 Community License Agreement.
|
||||
|
||||
import json
|
||||
import os
|
||||
import sys
|
||||
import time
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
from typing import Generator, List, Optional
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from fairscale.nn.model_parallel.initialize import (
|
||||
get_model_parallel_rank,
|
||||
initialize_model_parallel,
|
||||
model_parallel_is_initialized,
|
||||
)
|
||||
from llama_models.llama3.api.args import ModelArgs
|
||||
from llama_models.llama3.api.chat_format import ChatFormat, ModelInput
|
||||
from llama_models.llama3.api.datatypes import Message, ToolPromptFormat
|
||||
from llama_models.llama3.api.tokenizer import Tokenizer
|
||||
from llama_models.llama3.reference_impl.model import Transformer
|
||||
from llama_models.sku_list import resolve_model
|
||||
from llama_stack.apis.inference import QuantizationType
|
||||
|
||||
from llama_stack.distribution.utils.model_utils import model_local_dir
|
||||
from termcolor import cprint
|
||||
|
||||
from .config import MetaReferenceImplConfig
|
||||
|
||||
|
||||
def model_checkpoint_dir(model) -> str:
|
||||
checkpoint_dir = Path(model_local_dir(model.descriptor()))
|
||||
if not Path(checkpoint_dir / "consolidated.00.pth").exists():
|
||||
checkpoint_dir = checkpoint_dir / "original"
|
||||
|
||||
assert checkpoint_dir.exists(), (
|
||||
f"Could not find checkpoint dir: {checkpoint_dir}."
|
||||
f"Please download model using `llama download {model.descriptor()}`"
|
||||
)
|
||||
return str(checkpoint_dir)
|
||||
|
||||
|
||||
@dataclass
|
||||
class TokenResult:
|
||||
token: int
|
||||
text: str
|
||||
logprobs: Optional[List[float]] = None
|
||||
|
||||
|
||||
class Llama:
|
||||
@staticmethod
|
||||
def build(config: MetaReferenceImplConfig):
|
||||
"""
|
||||
Build a Llama instance by initializing and loading a model checkpoint.
|
||||
|
||||
Note:
|
||||
This method initializes the distributed process group, sets the device to CUDA,
|
||||
and loads the pre-trained model and tokenizer.
|
||||
"""
|
||||
model = resolve_model(config.model)
|
||||
|
||||
if (
|
||||
config.quantization
|
||||
and config.quantization.type == QuantizationType.fp8.value
|
||||
):
|
||||
from .quantization.loader import is_fbgemm_available
|
||||
|
||||
if not is_fbgemm_available():
|
||||
raise ImportError("fbgemm-gpu is required for FP8 quantization")
|
||||
|
||||
if not torch.distributed.is_initialized():
|
||||
torch.distributed.init_process_group("nccl")
|
||||
|
||||
model_parallel_size = config.model_parallel_size
|
||||
|
||||
if not model_parallel_is_initialized():
|
||||
initialize_model_parallel(model_parallel_size)
|
||||
|
||||
local_rank = int(os.environ.get("LOCAL_RANK", 0))
|
||||
torch.cuda.set_device(local_rank)
|
||||
|
||||
# seed must be the same in all processes
|
||||
if config.torch_seed is not None:
|
||||
torch.manual_seed(config.torch_seed)
|
||||
|
||||
if local_rank > 0:
|
||||
sys.stdout = open(os.devnull, "w")
|
||||
|
||||
start_time = time.time()
|
||||
ckpt_dir = model_checkpoint_dir(model)
|
||||
|
||||
checkpoints = sorted(Path(ckpt_dir).glob("*.pth"))
|
||||
assert len(checkpoints) > 0, f"no checkpoint files found in {ckpt_dir}"
|
||||
assert model_parallel_size == len(
|
||||
checkpoints
|
||||
), f"Loading a checkpoint for MP={len(checkpoints)} but world size is {model_parallel_size}"
|
||||
ckpt_path = checkpoints[get_model_parallel_rank()]
|
||||
state_dict = torch.load(ckpt_path, map_location="cpu", weights_only=True)
|
||||
with open(Path(ckpt_dir) / "params.json", "r") as f:
|
||||
params = json.loads(f.read())
|
||||
|
||||
if "model" in params:
|
||||
params = params["model"]
|
||||
|
||||
model_args: ModelArgs = ModelArgs(
|
||||
max_seq_len=config.max_seq_len,
|
||||
max_batch_size=config.max_batch_size,
|
||||
**params,
|
||||
)
|
||||
|
||||
tokenizer_path = os.path.join(ckpt_dir, "tokenizer.model")
|
||||
tokenizer = Tokenizer(model_path=tokenizer_path)
|
||||
|
||||
assert (
|
||||
model_args.vocab_size == tokenizer.n_words
|
||||
), f"model_args vocab = {model_args.vocab_size} but tokenizer vocab = {tokenizer.n_words}"
|
||||
|
||||
fp8 = (
|
||||
config.quantization
|
||||
and config.quantization.type == QuantizationType.fp8.value
|
||||
)
|
||||
|
||||
if fp8:
|
||||
from .quantization.loader import convert_to_quantized_model
|
||||
|
||||
# load on CPU in bf16 so that fp8 conversion does not find an
|
||||
# unexpected (fp32, e.g.) datatype
|
||||
torch.set_default_tensor_type(torch.BFloat16Tensor)
|
||||
model = Transformer(model_args)
|
||||
model.load_state_dict(state_dict, strict=False)
|
||||
model = convert_to_quantized_model(model, config)
|
||||
else:
|
||||
if torch.cuda.is_bf16_supported():
|
||||
torch.set_default_tensor_type(torch.cuda.BFloat16Tensor)
|
||||
else:
|
||||
torch.set_default_tensor_type(torch.cuda.HalfTensor)
|
||||
model = Transformer(model_args)
|
||||
model.load_state_dict(state_dict, strict=False)
|
||||
|
||||
print(f"Loaded in {time.time() - start_time:.2f} seconds")
|
||||
return Llama(model, tokenizer, model_args)
|
||||
|
||||
def __init__(self, model: Transformer, tokenizer: Tokenizer, args: ModelArgs):
|
||||
self.args = args
|
||||
self.model = model
|
||||
self.tokenizer = tokenizer
|
||||
self.formatter = ChatFormat(tokenizer)
|
||||
|
||||
@torch.inference_mode()
|
||||
def generate(
|
||||
self,
|
||||
model_input: ModelInput,
|
||||
max_gen_len: int,
|
||||
temperature: float = 0.6,
|
||||
top_p: float = 0.9,
|
||||
logprobs: bool = False,
|
||||
echo: bool = False,
|
||||
include_stop_token: bool = False,
|
||||
) -> Generator:
|
||||
params = self.model.params
|
||||
|
||||
# cprint("Input to model -> " + self.tokenizer.decode(model_input.tokens), "red")
|
||||
prompt_tokens = [model_input.tokens]
|
||||
|
||||
bsz = 1
|
||||
assert bsz <= params.max_batch_size, (bsz, params.max_batch_size)
|
||||
|
||||
min_prompt_len = min(len(t) for t in prompt_tokens)
|
||||
max_prompt_len = max(len(t) for t in prompt_tokens)
|
||||
|
||||
if max_prompt_len >= params.max_seq_len:
|
||||
cprint(
|
||||
f"Out of token budget {max_prompt_len} vs {params.max_seq_len}", "red"
|
||||
)
|
||||
return
|
||||
|
||||
total_len = min(max_gen_len + max_prompt_len, params.max_seq_len)
|
||||
pad_id = self.tokenizer.pad_id
|
||||
tokens = torch.full((bsz, total_len), pad_id, dtype=torch.long, device="cuda")
|
||||
for k, t in enumerate(prompt_tokens):
|
||||
tokens[k, : len(t)] = torch.tensor(t, dtype=torch.long, device="cuda")
|
||||
if logprobs:
|
||||
token_logprobs = torch.zeros_like(tokens, dtype=torch.float)
|
||||
|
||||
prev_pos = 0
|
||||
eos_reached = torch.tensor([False] * bsz, device="cuda")
|
||||
input_text_mask = tokens != pad_id
|
||||
if min_prompt_len == total_len:
|
||||
# TODO(ashwin): unify this branch with the one below and figure out multimodal crap
|
||||
logits = self.model.forward(tokens, prev_pos)
|
||||
token_logprobs = -F.cross_entropy(
|
||||
input=logits.transpose(1, 2),
|
||||
target=tokens,
|
||||
reduction="none",
|
||||
ignore_index=pad_id,
|
||||
)
|
||||
|
||||
stop_tokens = torch.tensor(self.tokenizer.stop_tokens)
|
||||
|
||||
for cur_pos in range(min_prompt_len, total_len):
|
||||
logits = self.model.forward(tokens[:, prev_pos:cur_pos], prev_pos)
|
||||
|
||||
if temperature > 0:
|
||||
probs = torch.softmax(logits[:, -1] / temperature, dim=-1)
|
||||
next_token = sample_top_p(probs, top_p)
|
||||
else:
|
||||
next_token = torch.argmax(logits[:, -1], dim=-1)
|
||||
|
||||
next_token = next_token.reshape(-1)
|
||||
# only replace token if prompt has already been generated
|
||||
next_token = torch.where(
|
||||
input_text_mask[:, cur_pos], tokens[:, cur_pos], next_token
|
||||
)
|
||||
tokens[:, cur_pos] = next_token
|
||||
|
||||
target = tokens[:, prev_pos + 1 : cur_pos + 1]
|
||||
if logprobs:
|
||||
token_logprobs[:, prev_pos + 1 : cur_pos + 1] = -F.cross_entropy(
|
||||
input=logits.transpose(1, 2),
|
||||
target=tokens[:, prev_pos + 1 : cur_pos + 1],
|
||||
reduction="none",
|
||||
ignore_index=pad_id,
|
||||
)
|
||||
eos_reached |= (~input_text_mask[:, cur_pos]) & (
|
||||
torch.isin(next_token, stop_tokens)
|
||||
)
|
||||
yield TokenResult(
|
||||
token=next_token[0].item(),
|
||||
text=self.tokenizer.decode(next_token.tolist()),
|
||||
logprobs=(
|
||||
token_logprobs[:, prev_pos + 1 : cur_pos + 1][0].tolist()
|
||||
if logprobs
|
||||
else None
|
||||
),
|
||||
)
|
||||
|
||||
prev_pos = cur_pos
|
||||
if all(eos_reached):
|
||||
break
|
||||
|
||||
def text_completion(
|
||||
self,
|
||||
prompt: str,
|
||||
temperature: float = 0.6,
|
||||
top_p: float = 0.9,
|
||||
max_gen_len: Optional[int] = None,
|
||||
logprobs: bool = False,
|
||||
echo: bool = False,
|
||||
) -> Generator:
|
||||
if (
|
||||
max_gen_len is None
|
||||
or max_gen_len == 0
|
||||
or max_gen_len >= self.model.params.max_seq_len
|
||||
):
|
||||
max_gen_len = self.model.params.max_seq_len - 1
|
||||
|
||||
prompt_tokens = self.tokenizer.encode(prompt, bos=True, eos=False)
|
||||
|
||||
yield from self.generate(
|
||||
model_input=ModelInput(tokens=prompt_tokens),
|
||||
max_gen_len=max_gen_len,
|
||||
temperature=temperature,
|
||||
top_p=top_p,
|
||||
logprobs=logprobs,
|
||||
echo=echo,
|
||||
)
|
||||
|
||||
def chat_completion(
|
||||
self,
|
||||
messages: List[Message],
|
||||
temperature: float = 0.6,
|
||||
top_p: float = 0.9,
|
||||
max_gen_len: Optional[int] = None,
|
||||
logprobs: bool = False,
|
||||
tool_prompt_format: ToolPromptFormat = ToolPromptFormat.json,
|
||||
) -> Generator:
|
||||
if (
|
||||
max_gen_len is None
|
||||
or max_gen_len == 0
|
||||
or max_gen_len >= self.model.params.max_seq_len
|
||||
):
|
||||
max_gen_len = self.model.params.max_seq_len - 1
|
||||
|
||||
yield from self.generate(
|
||||
model_input=self.formatter.encode_dialog_prompt(
|
||||
messages,
|
||||
tool_prompt_format,
|
||||
),
|
||||
max_gen_len=max_gen_len,
|
||||
temperature=temperature,
|
||||
top_p=top_p,
|
||||
logprobs=logprobs,
|
||||
include_stop_token=True,
|
||||
)
|
||||
|
||||
|
||||
def sample_top_p(probs, p):
|
||||
"""
|
||||
Perform top-p (nucleus) sampling on a probability distribution.
|
||||
|
||||
Args:
|
||||
probs (torch.Tensor): Probability distribution tensor.
|
||||
p (float): Probability threshold for top-p sampling.
|
||||
|
||||
Returns:
|
||||
torch.Tensor: Sampled token indices.
|
||||
|
||||
Note:
|
||||
Top-p sampling selects the smallest set of tokens whose cumulative probability mass
|
||||
exceeds the threshold p. The distribution is renormalized based on the selected tokens.
|
||||
"""
|
||||
probs_sort, probs_idx = torch.sort(probs, dim=-1, descending=True)
|
||||
probs_sum = torch.cumsum(probs_sort, dim=-1)
|
||||
mask = probs_sum - probs_sort > p
|
||||
probs_sort[mask] = 0.0
|
||||
probs_sort.div_(probs_sort.sum(dim=-1, keepdim=True))
|
||||
next_token = torch.multinomial(probs_sort, num_samples=1)
|
||||
next_token = torch.gather(probs_idx, -1, next_token)
|
||||
return next_token
|
|
@ -0,0 +1,215 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import asyncio
|
||||
|
||||
from typing import AsyncIterator, Union
|
||||
|
||||
from llama_models.llama3.api.datatypes import StopReason
|
||||
from llama_models.sku_list import resolve_model
|
||||
|
||||
from llama_stack.apis.inference import (
|
||||
ChatCompletionRequest,
|
||||
ChatCompletionResponse,
|
||||
ChatCompletionResponseEvent,
|
||||
ChatCompletionResponseEventType,
|
||||
ChatCompletionResponseStreamChunk,
|
||||
Inference,
|
||||
ToolCallDelta,
|
||||
ToolCallParseStatus,
|
||||
)
|
||||
from llama_stack.providers.utils.inference.prepare_messages import prepare_messages
|
||||
|
||||
from .config import MetaReferenceImplConfig
|
||||
from .model_parallel import LlamaModelParallelGenerator
|
||||
|
||||
from llama_models.llama3.api.datatypes import * # noqa: F403
|
||||
from llama_stack.apis.inference import * # noqa: F403
|
||||
|
||||
# there's a single model parallel process running serving the model. for now,
|
||||
# we don't support multiple concurrent requests to this process.
|
||||
SEMAPHORE = asyncio.Semaphore(1)
|
||||
|
||||
|
||||
class MetaReferenceInferenceImpl(Inference):
|
||||
def __init__(self, config: MetaReferenceImplConfig) -> None:
|
||||
self.config = config
|
||||
model = resolve_model(config.model)
|
||||
if model is None:
|
||||
raise RuntimeError(f"Unknown model: {config.model}, Run `llama model list`")
|
||||
self.model = model
|
||||
# verify that the checkpoint actually is for this model lol
|
||||
|
||||
async def initialize(self) -> None:
|
||||
self.generator = LlamaModelParallelGenerator(self.config)
|
||||
self.generator.start()
|
||||
|
||||
async def shutdown(self) -> None:
|
||||
self.generator.stop()
|
||||
|
||||
# hm, when stream=False, we should not be doing SSE :/ which is what the
|
||||
# top-level server is going to do. make the typing more specific here
|
||||
async def chat_completion(
|
||||
self,
|
||||
model: str,
|
||||
messages: List[Message],
|
||||
sampling_params: Optional[SamplingParams] = SamplingParams(),
|
||||
tools: Optional[List[ToolDefinition]] = None,
|
||||
tool_choice: Optional[ToolChoice] = ToolChoice.auto,
|
||||
tool_prompt_format: Optional[ToolPromptFormat] = ToolPromptFormat.json,
|
||||
stream: Optional[bool] = False,
|
||||
logprobs: Optional[LogProbConfig] = None,
|
||||
) -> AsyncIterator[
|
||||
Union[ChatCompletionResponseStreamChunk, ChatCompletionResponse]
|
||||
]:
|
||||
# wrapper request to make it easier to pass around (internal only, not exposed to API)
|
||||
request = ChatCompletionRequest(
|
||||
model=model,
|
||||
messages=messages,
|
||||
sampling_params=sampling_params,
|
||||
tools=tools or [],
|
||||
tool_choice=tool_choice,
|
||||
tool_prompt_format=tool_prompt_format,
|
||||
stream=stream,
|
||||
logprobs=logprobs,
|
||||
)
|
||||
|
||||
messages = prepare_messages(request)
|
||||
model = resolve_model(request.model)
|
||||
if model is None:
|
||||
raise RuntimeError(
|
||||
f"Unknown model: {request.model}, Run `llama model list`"
|
||||
)
|
||||
elif model.descriptor() != self.model.descriptor():
|
||||
raise RuntimeError(
|
||||
f"Model mismatch: {request.model} != {self.model.descriptor()}"
|
||||
)
|
||||
|
||||
if SEMAPHORE.locked():
|
||||
raise RuntimeError("Only one concurrent request is supported")
|
||||
|
||||
async with SEMAPHORE:
|
||||
if request.stream:
|
||||
yield ChatCompletionResponseStreamChunk(
|
||||
event=ChatCompletionResponseEvent(
|
||||
event_type=ChatCompletionResponseEventType.start,
|
||||
delta="",
|
||||
)
|
||||
)
|
||||
|
||||
tokens = []
|
||||
logprobs = []
|
||||
|
||||
stop_reason = None
|
||||
|
||||
buffer = ""
|
||||
ipython = False
|
||||
|
||||
for token_result in self.generator.chat_completion(
|
||||
messages=messages,
|
||||
temperature=request.sampling_params.temperature,
|
||||
top_p=request.sampling_params.top_p,
|
||||
max_gen_len=request.sampling_params.max_tokens,
|
||||
logprobs=request.logprobs,
|
||||
tool_prompt_format=request.tool_prompt_format,
|
||||
):
|
||||
buffer += token_result.text
|
||||
tokens.append(token_result.token)
|
||||
|
||||
if not ipython and buffer.startswith("<|python_tag|>"):
|
||||
ipython = True
|
||||
yield ChatCompletionResponseStreamChunk(
|
||||
event=ChatCompletionResponseEvent(
|
||||
event_type=ChatCompletionResponseEventType.progress,
|
||||
delta=ToolCallDelta(
|
||||
content="",
|
||||
parse_status=ToolCallParseStatus.started,
|
||||
),
|
||||
)
|
||||
)
|
||||
buffer = buffer[len("<|python_tag|>") :]
|
||||
continue
|
||||
|
||||
if not request.stream:
|
||||
if request.logprobs:
|
||||
logprobs.append(token_result.logprob)
|
||||
|
||||
continue
|
||||
|
||||
if token_result.text == "<|eot_id|>":
|
||||
stop_reason = StopReason.end_of_turn
|
||||
text = ""
|
||||
elif token_result.text == "<|eom_id|>":
|
||||
stop_reason = StopReason.end_of_message
|
||||
text = ""
|
||||
else:
|
||||
text = token_result.text
|
||||
|
||||
if ipython:
|
||||
delta = ToolCallDelta(
|
||||
content=text,
|
||||
parse_status=ToolCallParseStatus.in_progress,
|
||||
)
|
||||
else:
|
||||
delta = text
|
||||
|
||||
if stop_reason is None:
|
||||
yield ChatCompletionResponseStreamChunk(
|
||||
event=ChatCompletionResponseEvent(
|
||||
event_type=ChatCompletionResponseEventType.progress,
|
||||
delta=delta,
|
||||
stop_reason=stop_reason,
|
||||
)
|
||||
)
|
||||
|
||||
if stop_reason is None:
|
||||
stop_reason = StopReason.out_of_tokens
|
||||
|
||||
# TODO(ashwin): parse tool calls separately here and report errors?
|
||||
# if someone breaks the iteration before coming here we are toast
|
||||
message = self.generator.formatter.decode_assistant_message(
|
||||
tokens, stop_reason
|
||||
)
|
||||
if request.stream:
|
||||
parsed_tool_calls = len(message.tool_calls) > 0
|
||||
if ipython and not parsed_tool_calls:
|
||||
yield ChatCompletionResponseStreamChunk(
|
||||
event=ChatCompletionResponseEvent(
|
||||
event_type=ChatCompletionResponseEventType.progress,
|
||||
delta=ToolCallDelta(
|
||||
content="",
|
||||
parse_status=ToolCallParseStatus.failure,
|
||||
),
|
||||
stop_reason=stop_reason,
|
||||
)
|
||||
)
|
||||
|
||||
for tool_call in message.tool_calls:
|
||||
yield ChatCompletionResponseStreamChunk(
|
||||
event=ChatCompletionResponseEvent(
|
||||
event_type=ChatCompletionResponseEventType.progress,
|
||||
delta=ToolCallDelta(
|
||||
content=tool_call,
|
||||
parse_status=ToolCallParseStatus.success,
|
||||
),
|
||||
stop_reason=stop_reason,
|
||||
)
|
||||
)
|
||||
|
||||
yield ChatCompletionResponseStreamChunk(
|
||||
event=ChatCompletionResponseEvent(
|
||||
event_type=ChatCompletionResponseEventType.complete,
|
||||
delta="",
|
||||
stop_reason=stop_reason,
|
||||
)
|
||||
)
|
||||
|
||||
# TODO(ashwin): what else do we need to send out here when everything finishes?
|
||||
else:
|
||||
yield ChatCompletionResponse(
|
||||
completion_message=message,
|
||||
logprobs=logprobs if request.logprobs else None,
|
||||
)
|
|
@ -0,0 +1,110 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import os
|
||||
from copy import deepcopy
|
||||
from dataclasses import dataclass
|
||||
from functools import partial
|
||||
from typing import Generator, List, Optional
|
||||
|
||||
from llama_models.llama3.api.chat_format import ChatFormat
|
||||
from llama_models.llama3.api.datatypes import Message, ToolPromptFormat
|
||||
from llama_models.llama3.api.tokenizer import Tokenizer
|
||||
from llama_models.sku_list import resolve_model
|
||||
|
||||
from .config import MetaReferenceImplConfig
|
||||
from .generation import Llama, model_checkpoint_dir
|
||||
from .parallel_utils import ModelParallelProcessGroup
|
||||
|
||||
|
||||
@dataclass
|
||||
class InferenceArgs:
|
||||
messages: List[Message]
|
||||
temperature: float
|
||||
top_p: float
|
||||
max_gen_len: int
|
||||
logprobs: bool
|
||||
tool_prompt_format: ToolPromptFormat
|
||||
|
||||
|
||||
class ModelRunner:
|
||||
def __init__(self, llama):
|
||||
self.llama = llama
|
||||
|
||||
# the `task` object is the same that is sent to `ModelParallelProcessGroup.run_inference()`
|
||||
def __call__(self, task: InferenceArgs):
|
||||
return self.llama.chat_completion(
|
||||
task.messages,
|
||||
task.temperature,
|
||||
task.top_p,
|
||||
task.max_gen_len,
|
||||
task.logprobs,
|
||||
task.tool_prompt_format,
|
||||
)
|
||||
|
||||
|
||||
def init_model_cb(config: MetaReferenceImplConfig):
|
||||
llama = Llama.build(config)
|
||||
return ModelRunner(llama)
|
||||
|
||||
|
||||
class LlamaModelParallelGenerator:
|
||||
"""
|
||||
This abstraction exists so
|
||||
- we can run model parallel code without needing to run the CLIs via torchrun
|
||||
- this also enables use model parallel code within a notebook context.
|
||||
|
||||
A Context Manager is used to ensure that the model parallel process is started and stopped
|
||||
correctly. This does make the ergonomics a little awkward, because it isn't immediately
|
||||
clear at the callsite why we need to use a context manager.
|
||||
"""
|
||||
|
||||
def __init__(self, config: MetaReferenceImplConfig):
|
||||
self.config = config
|
||||
self.model = resolve_model(self.config.model)
|
||||
# this is a hack because Agent's loop uses this to tokenize and check if input is too long
|
||||
# while the tool-use loop is going
|
||||
checkpoint_dir = model_checkpoint_dir(self.model)
|
||||
tokenizer_path = os.path.join(checkpoint_dir, "tokenizer.model")
|
||||
self.formatter = ChatFormat(Tokenizer(tokenizer_path))
|
||||
|
||||
def start(self):
|
||||
self.__enter__()
|
||||
|
||||
def stop(self):
|
||||
self.__exit__(None, None, None)
|
||||
|
||||
def __enter__(self):
|
||||
self.group = ModelParallelProcessGroup(
|
||||
self.config.model_parallel_size,
|
||||
init_model_cb=partial(init_model_cb, self.config),
|
||||
)
|
||||
self.group.start()
|
||||
return self
|
||||
|
||||
def __exit__(self, exc_type, exc_value, exc_traceback):
|
||||
self.group.stop()
|
||||
|
||||
def chat_completion(
|
||||
self,
|
||||
messages: List[Message],
|
||||
temperature: float = 0.6,
|
||||
top_p: float = 0.9,
|
||||
max_gen_len: Optional[int] = None,
|
||||
logprobs: bool = False,
|
||||
tool_prompt_format: ToolPromptFormat = ToolPromptFormat.json,
|
||||
) -> Generator:
|
||||
req_obj = InferenceArgs(
|
||||
messages=deepcopy(messages),
|
||||
temperature=temperature,
|
||||
top_p=top_p,
|
||||
max_gen_len=max_gen_len,
|
||||
logprobs=logprobs,
|
||||
tool_prompt_format=tool_prompt_format,
|
||||
)
|
||||
|
||||
gen = self.group.run_inference(req_obj)
|
||||
yield from gen
|
|
@ -0,0 +1,265 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import multiprocessing
|
||||
import os
|
||||
import pickle
|
||||
import tempfile
|
||||
import time
|
||||
import uuid
|
||||
|
||||
from typing import Callable, Generator
|
||||
|
||||
import torch
|
||||
|
||||
import zmq
|
||||
|
||||
from fairscale.nn.model_parallel.initialize import (
|
||||
get_model_parallel_group,
|
||||
get_model_parallel_rank,
|
||||
get_model_parallel_src_rank,
|
||||
)
|
||||
|
||||
from torch.distributed.launcher.api import elastic_launch, LaunchConfig
|
||||
|
||||
|
||||
_END_SENTINEL = "__end_sentinel__"
|
||||
_CANCEL_SENTINEL = "__cancel_sentinel__"
|
||||
|
||||
|
||||
def mp_rank_0() -> bool:
|
||||
return get_model_parallel_rank() == 0
|
||||
|
||||
|
||||
def retrieve_requests(reply_socket_url: str):
|
||||
if mp_rank_0():
|
||||
context = zmq.Context()
|
||||
reply_socket = context.socket(zmq.ROUTER)
|
||||
reply_socket.connect(reply_socket_url)
|
||||
|
||||
while True:
|
||||
client_id, obj = maybe_get_work(reply_socket)
|
||||
if obj is None:
|
||||
time.sleep(0.01)
|
||||
continue
|
||||
|
||||
reply_socket.send_multipart([client_id, pickle.dumps("YES READY")])
|
||||
break
|
||||
|
||||
def send_obj(obj):
|
||||
reply_socket.send_multipart([client_id, pickle.dumps(obj)])
|
||||
|
||||
while True:
|
||||
tasks = [None]
|
||||
if mp_rank_0():
|
||||
client_id, task = maybe_get_work(reply_socket)
|
||||
# there is still an unknown unclean GeneratorExit happening resulting in a
|
||||
# cancel sentinel getting queued _after_ we have finished sending everything :/
|
||||
# kind of a hack this is :/
|
||||
if task != _CANCEL_SENTINEL:
|
||||
tasks = [task]
|
||||
|
||||
torch.distributed.broadcast_object_list(
|
||||
tasks,
|
||||
src=get_model_parallel_src_rank(),
|
||||
group=get_model_parallel_group(),
|
||||
)
|
||||
|
||||
task = tasks[0]
|
||||
if task is None:
|
||||
time.sleep(0.1)
|
||||
else:
|
||||
try:
|
||||
out = yield task
|
||||
if out is None:
|
||||
break
|
||||
|
||||
for obj in out:
|
||||
updates = [None]
|
||||
if mp_rank_0():
|
||||
_, update = maybe_get_work(reply_socket)
|
||||
if update == _CANCEL_SENTINEL:
|
||||
updates = [update]
|
||||
else:
|
||||
# only send the update if it's not cancelled otherwise the object sits in the socket
|
||||
# and gets pulled in the next request lol
|
||||
send_obj(obj)
|
||||
|
||||
torch.distributed.broadcast_object_list(
|
||||
updates,
|
||||
src=get_model_parallel_src_rank(),
|
||||
group=get_model_parallel_group(),
|
||||
)
|
||||
if updates[0] == _CANCEL_SENTINEL:
|
||||
print("quitting generation loop because request was cancelled")
|
||||
break
|
||||
|
||||
if mp_rank_0():
|
||||
send_obj(_END_SENTINEL)
|
||||
except Exception as e:
|
||||
print(f"[debug] got exception {e}")
|
||||
import traceback
|
||||
|
||||
traceback.print_exc()
|
||||
if mp_rank_0():
|
||||
send_obj(e)
|
||||
|
||||
if mp_rank_0():
|
||||
send_obj("DONE")
|
||||
|
||||
|
||||
def maybe_get_work(sock: zmq.Socket):
|
||||
message = None
|
||||
client_id = None
|
||||
try:
|
||||
client_id, obj = sock.recv_multipart(zmq.NOBLOCK)
|
||||
message = pickle.loads(obj)
|
||||
except zmq.ZMQError as e:
|
||||
if e.errno != zmq.EAGAIN:
|
||||
raise e
|
||||
|
||||
return client_id, message
|
||||
|
||||
|
||||
def worker_process_entrypoint(
|
||||
reply_socket_url: str,
|
||||
init_model_cb: Callable,
|
||||
) -> None:
|
||||
model = init_model_cb()
|
||||
torch.distributed.barrier()
|
||||
time.sleep(1)
|
||||
|
||||
# run the requests co-routine which retrieves requests from the socket
|
||||
# and sends responses (we provide) back to the caller
|
||||
req_gen = retrieve_requests(reply_socket_url)
|
||||
result = None
|
||||
while True:
|
||||
try:
|
||||
task = req_gen.send(result)
|
||||
if isinstance(task, str) and task == _END_SENTINEL:
|
||||
break
|
||||
|
||||
result = model(task)
|
||||
except StopIteration:
|
||||
break
|
||||
|
||||
print("[debug] worker process done")
|
||||
|
||||
|
||||
def launch_dist_group(
|
||||
reply_socket_url: str,
|
||||
model_parallel_size: int,
|
||||
init_model_cb: Callable,
|
||||
**kwargs,
|
||||
) -> None:
|
||||
id = uuid.uuid4().hex
|
||||
dist_url = f"file:///tmp/llama3_{id}_{time.time()}"
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
# TODO: track workers and if they terminate, tell parent process about it so cleanup can happen
|
||||
launch_config = LaunchConfig(
|
||||
max_nodes=1,
|
||||
min_nodes=1,
|
||||
nproc_per_node=model_parallel_size,
|
||||
start_method="fork",
|
||||
rdzv_backend="c10d",
|
||||
rdzv_endpoint=os.path.join(tmpdir, "rdzv"),
|
||||
rdzv_configs={"store_type": "file", "timeout": 90},
|
||||
max_restarts=0,
|
||||
monitor_interval=1,
|
||||
run_id=str(uuid.uuid4()),
|
||||
)
|
||||
elastic_launch(launch_config, entrypoint=worker_process_entrypoint)(
|
||||
reply_socket_url,
|
||||
init_model_cb,
|
||||
)
|
||||
|
||||
|
||||
def start_model_parallel_process(
|
||||
model_parallel_size: int,
|
||||
init_model_cb: Callable,
|
||||
**kwargs,
|
||||
):
|
||||
context = zmq.Context()
|
||||
request_socket = context.socket(zmq.DEALER)
|
||||
|
||||
# Binding the request socket to a random port
|
||||
request_socket.bind("tcp://127.0.0.1:0")
|
||||
|
||||
main_process_url = request_socket.getsockopt_string(zmq.LAST_ENDPOINT)
|
||||
|
||||
ctx = multiprocessing.get_context("fork")
|
||||
process = ctx.Process(
|
||||
target=launch_dist_group,
|
||||
args=(
|
||||
main_process_url,
|
||||
model_parallel_size,
|
||||
init_model_cb,
|
||||
),
|
||||
kwargs=kwargs,
|
||||
)
|
||||
process.start()
|
||||
|
||||
# wait until the model is loaded; rank 0 will send a message to indicate it's ready
|
||||
|
||||
request_socket.send_pyobj("READY?")
|
||||
response = request_socket.recv_pyobj()
|
||||
print(f"Finished model load {response}")
|
||||
|
||||
return request_socket, process
|
||||
|
||||
|
||||
class ModelParallelProcessGroup:
|
||||
def __init__(
|
||||
self,
|
||||
model_parallel_size: int,
|
||||
init_model_cb: Callable,
|
||||
**kwargs,
|
||||
):
|
||||
self.model_parallel_size = model_parallel_size
|
||||
self.init_model_cb = init_model_cb
|
||||
self.started = False
|
||||
self.running = False
|
||||
|
||||
def start(self):
|
||||
assert not self.started, "process group already started"
|
||||
self.request_socket, self.process = start_model_parallel_process(
|
||||
self.model_parallel_size,
|
||||
self.init_model_cb,
|
||||
)
|
||||
self.started = True
|
||||
|
||||
def stop(self):
|
||||
assert self.started, "process group not started"
|
||||
if self.process.is_alive():
|
||||
self.request_socket.send_pyobj(_END_SENTINEL, zmq.NOBLOCK)
|
||||
self.process.join()
|
||||
self.started = False
|
||||
|
||||
def run_inference(self, request) -> Generator:
|
||||
assert not self.running, "inference already running"
|
||||
|
||||
self.running = True
|
||||
self.request_socket.send_pyobj(request)
|
||||
try:
|
||||
while True:
|
||||
obj = self.request_socket.recv_pyobj()
|
||||
if obj == _END_SENTINEL:
|
||||
break
|
||||
|
||||
if isinstance(obj, Exception):
|
||||
print(f"[debug] got exception {obj}")
|
||||
raise obj
|
||||
|
||||
yield obj
|
||||
except GeneratorExit as e:
|
||||
self.request_socket.send_pyobj(_CANCEL_SENTINEL)
|
||||
while True:
|
||||
obj = self.request_socket.recv_pyobj()
|
||||
if obj == _END_SENTINEL:
|
||||
break
|
||||
finally:
|
||||
self.running = False
|
|
@ -0,0 +1,5 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
|
@ -0,0 +1,184 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# This software may be used and distributed in accordance with the terms of the Llama 3 Community License Agreement.
|
||||
|
||||
import collections
|
||||
from typing import Optional, Type
|
||||
|
||||
try:
|
||||
import fbgemm_gpu.experimental.gen_ai # noqa: F401
|
||||
|
||||
print("Using efficient FP8 operators in FBGEMM.")
|
||||
except ImportError:
|
||||
print("No efficient FP8 operators. Please install FBGEMM in fp8_requirements.txt.")
|
||||
raise
|
||||
|
||||
import torch
|
||||
from torch import nn, Tensor
|
||||
|
||||
|
||||
class Fp8ScaledWeights:
|
||||
# TODO: Ugly trick so torch allows us to replace parameters
|
||||
# with our custom Fp8Weights instance. Do this properly.
|
||||
@property
|
||||
def __class__(self) -> Type[nn.parameter.Parameter]:
|
||||
return nn.Parameter
|
||||
|
||||
@property
|
||||
def grad_fn(self) -> None:
|
||||
return None
|
||||
|
||||
|
||||
# pyre-fixme[4]: Attribute annotation cannot be `Any`.
|
||||
# pyre-fixme[2]: Parameter annotation cannot be `Any`.
|
||||
class Fp8RowwiseWeights(
|
||||
Fp8ScaledWeights,
|
||||
collections.namedtuple(
|
||||
"Fp8RowwiseWeights",
|
||||
["weight", "scale", "shape", "activation_scale_ub"],
|
||||
),
|
||||
):
|
||||
pass
|
||||
|
||||
|
||||
def ffn_swiglu(
|
||||
x: Tensor,
|
||||
w1: Fp8RowwiseWeights,
|
||||
w3: Fp8RowwiseWeights,
|
||||
w2: Fp8RowwiseWeights,
|
||||
num_tokens: Optional[Tensor] = None,
|
||||
is_memory_bounded: bool = False,
|
||||
) -> Tensor:
|
||||
if (
|
||||
isinstance(w1, Fp8ScaledWeights)
|
||||
and isinstance(w3, Fp8ScaledWeights)
|
||||
and isinstance(w2, Fp8ScaledWeights)
|
||||
):
|
||||
return ffn_swiglu_fp8_dynamic(
|
||||
x, w1, w3, w2, w1.activation_scale_ub, num_tokens, is_memory_bounded
|
||||
)
|
||||
|
||||
(B, T, D) = x.shape # noqa: N806
|
||||
(HD_L, D_) = w1.shape # noqa: N806
|
||||
assert D_ == D
|
||||
|
||||
assert isinstance(w1, Tensor)
|
||||
assert isinstance(w3, Tensor)
|
||||
x1 = x.view(B * T, D) @ w1.T
|
||||
x2 = x.view(B * T, D) @ w3.T
|
||||
z = torch.nn.functional.silu(x1) * x2
|
||||
del x1, x2
|
||||
assert isinstance(w2, Tensor)
|
||||
return (z @ w2.T).view(B, T, D)
|
||||
|
||||
|
||||
@torch.inference_mode()
|
||||
def quantize_fp8(
|
||||
w: Tensor,
|
||||
fp8_activation_scale_ub: float,
|
||||
output_device: Optional[torch.device] = None,
|
||||
) -> Fp8RowwiseWeights:
|
||||
"""Quantize [n, k] weight tensor.
|
||||
|
||||
Args:
|
||||
w (Tensor): [n, k] input high precision tensor to quantize.
|
||||
fp8_activation_scale_ub (float): Upper bound for activation max.
|
||||
"""
|
||||
activation_scale_ub = torch.tensor(
|
||||
[fp8_activation_scale_ub],
|
||||
dtype=torch.float,
|
||||
device="cuda",
|
||||
)
|
||||
wq, w_scale = torch.ops.fbgemm.quantize_fp8_per_row(w)
|
||||
del w
|
||||
return Fp8RowwiseWeights(
|
||||
weight=wq,
|
||||
scale=w_scale,
|
||||
shape=wq.shape,
|
||||
activation_scale_ub=activation_scale_ub,
|
||||
)
|
||||
|
||||
|
||||
@torch.inference_mode()
|
||||
def load_fp8(
|
||||
w: Tensor,
|
||||
w_scale: Tensor,
|
||||
fp8_activation_scale_ub: float,
|
||||
) -> Fp8RowwiseWeights:
|
||||
"""Load FP8 [n, k] weight tensor.
|
||||
|
||||
Args:
|
||||
w (Tensor): [n, k] input FP8.
|
||||
fp8_activation_scale_ub (float): Upper bound for activation max.
|
||||
"""
|
||||
activation_scale_ub = torch.tensor(
|
||||
[fp8_activation_scale_ub],
|
||||
dtype=torch.float,
|
||||
device="cuda",
|
||||
)
|
||||
return Fp8RowwiseWeights(
|
||||
weight=w.to(torch.float8_e4m3fn).to(device="cuda"),
|
||||
scale=w_scale.to(device="cuda"),
|
||||
shape=w.shape,
|
||||
activation_scale_ub=activation_scale_ub,
|
||||
)
|
||||
|
||||
|
||||
def fc_fp8_dynamic(
|
||||
x: Tensor,
|
||||
w: Fp8RowwiseWeights,
|
||||
activation_scale_ub: Optional[Tensor] = None,
|
||||
num_tokens: Optional[Tensor] = None,
|
||||
is_memory_bounded: bool = False,
|
||||
) -> Tensor:
|
||||
"""
|
||||
Single w8a8 fc layer with dynamic row-wise scaling.
|
||||
"""
|
||||
if isinstance(w, Fp8RowwiseWeights):
|
||||
xq, x_scale = torch.ops.fbgemm.quantize_fp8_per_row(
|
||||
x, num_tokens, activation_scale_ub
|
||||
)
|
||||
y = torch.ops.fbgemm.f8f8bf16_rowwise(
|
||||
xq, w.weight, x_scale, w.scale, use_fast_accum=True
|
||||
)
|
||||
del xq
|
||||
return y
|
||||
|
||||
|
||||
def ffn_swiglu_fp8_dynamic(
|
||||
x: Tensor,
|
||||
w1: Fp8RowwiseWeights,
|
||||
w3: Fp8RowwiseWeights,
|
||||
w2: Fp8RowwiseWeights,
|
||||
activation_scale_ub: Optional[Tensor] = None,
|
||||
num_tokens: Optional[Tensor] = None,
|
||||
is_memory_bounded: bool = False,
|
||||
) -> Tensor:
|
||||
(B, T, D) = x.shape # noqa: N806
|
||||
HD_L = w1.shape[0] # noqa: N806
|
||||
assert HD_L == w3.shape[0]
|
||||
x1 = fc_fp8_dynamic(
|
||||
x.view(B * T, D),
|
||||
w1,
|
||||
activation_scale_ub,
|
||||
num_tokens,
|
||||
is_memory_bounded,
|
||||
)
|
||||
x2 = fc_fp8_dynamic(
|
||||
x.view(B * T, D),
|
||||
w3,
|
||||
activation_scale_ub,
|
||||
num_tokens,
|
||||
is_memory_bounded,
|
||||
)
|
||||
z = torch.nn.functional.silu(x1) * x2
|
||||
del x1, x2
|
||||
|
||||
z_ = fc_fp8_dynamic(z, w2, activation_scale_ub, num_tokens, is_memory_bounded)
|
||||
|
||||
return z_.view(B, T, D)
|
|
@ -0,0 +1,105 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# This software may be used and distributed in accordance with the terms of the Llama 3 Community License Agreement.
|
||||
|
||||
import os
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
|
||||
from fairscale.nn.model_parallel.mappings import reduce_from_model_parallel_region
|
||||
from llama_models.llama3.api.model import Transformer, TransformerBlock
|
||||
from llama_stack.apis.inference import QuantizationType
|
||||
|
||||
from llama_stack.apis.inference.config import (
|
||||
CheckpointQuantizationFormat,
|
||||
MetaReferenceImplConfig,
|
||||
)
|
||||
|
||||
from termcolor import cprint
|
||||
from torch import Tensor
|
||||
|
||||
|
||||
def is_fbgemm_available() -> bool:
|
||||
try:
|
||||
import fbgemm_gpu.experimental.gen_ai # noqa: F401
|
||||
|
||||
return True
|
||||
except ImportError:
|
||||
return False
|
||||
|
||||
|
||||
def swiglu_wrapper(
|
||||
self,
|
||||
x: Tensor,
|
||||
):
|
||||
from .fp8_impls import ffn_swiglu
|
||||
|
||||
out = ffn_swiglu(x, self.w1.weight, self.w3.weight, self.w2.weight)
|
||||
return reduce_from_model_parallel_region(out)
|
||||
|
||||
|
||||
def convert_to_quantized_model(
|
||||
model: Transformer,
|
||||
config: MetaReferenceImplConfig,
|
||||
fp8_activation_scale_ub: Optional[float] = 1200.0,
|
||||
) -> Transformer:
|
||||
if config.quantization.type == QuantizationType.bf16.value:
|
||||
return model
|
||||
|
||||
elif config.quantization.type != QuantizationType.fp8.value:
|
||||
raise ValueError("Only FP8 quantization is supported")
|
||||
|
||||
from .fp8_impls import Fp8ScaledWeights, load_fp8, quantize_fp8
|
||||
|
||||
checkpoint = config.checkpoint_config.checkpoint
|
||||
# Move weights to GPU with quantization
|
||||
if checkpoint.quantization_format == CheckpointQuantizationFormat.fp8_mixed.value:
|
||||
cprint("Loading fp8 scales...", "yellow")
|
||||
fp8_scales_path = os.path.join(
|
||||
checkpoint.checkpoint_dir, f"fp8_scales_{get_model_parallel_rank()}.pt"
|
||||
)
|
||||
assert os.path.isfile(
|
||||
fp8_scales_path
|
||||
), f"fp8_scales_path not found for rank {get_model_parallel_rank()}"
|
||||
fp8_scales = torch.load(fp8_scales_path, weights_only=True)
|
||||
|
||||
for block in model.layers:
|
||||
if isinstance(block, TransformerBlock):
|
||||
if block.layer_id == 0 or block.layer_id == (model.n_layers - 1):
|
||||
continue
|
||||
|
||||
block.feed_forward.forward = swiglu_wrapper.__get__(block.feed_forward)
|
||||
for key in ("w1", "w3", "w2"):
|
||||
param = getattr(block.feed_forward, key)
|
||||
param.weight = load_fp8(
|
||||
param.weight,
|
||||
fp8_scales[
|
||||
f"{block.layer_id}_feed_forward.{key}_{get_model_parallel_rank()}"
|
||||
],
|
||||
fp8_activation_scale_ub,
|
||||
)
|
||||
else:
|
||||
cprint("Quantizing fp8 weights from bf16...", "yellow")
|
||||
for block in model.layers:
|
||||
if isinstance(block, TransformerBlock):
|
||||
if block.layer_id == 0 or block.layer_id == (model.n_layers - 1):
|
||||
continue
|
||||
block.feed_forward.forward = swiglu_wrapper.__get__(block.feed_forward)
|
||||
for key in ("w1", "w3", "w2"):
|
||||
param = getattr(block.feed_forward, key)
|
||||
param.weight = quantize_fp8(
|
||||
param.weight,
|
||||
fp8_activation_scale_ub,
|
||||
output_device=torch.device("cuda"),
|
||||
)
|
||||
|
||||
for _, parameter in model.named_parameters():
|
||||
if not isinstance(parameter, Fp8ScaledWeights):
|
||||
parameter.data = parameter.to(device="cuda")
|
||||
return model
|
|
@ -0,0 +1,5 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
|
@ -0,0 +1,30 @@
|
|||
#!/bin/bash
|
||||
|
||||
if [[ $# -ne 1 ]]; then
|
||||
echo "Error: Please provide the name of CONDA environment you wish to create"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
ENV_NAME=$1
|
||||
|
||||
set -eu
|
||||
eval "$(conda shell.bash hook)"
|
||||
|
||||
echo "Will build env (or overwrite) named '$ENV_NAME'"
|
||||
|
||||
set -x
|
||||
|
||||
run_build() {
|
||||
# Set up the conda environment
|
||||
yes | conda remove --name $ENV_NAME --all
|
||||
yes | conda create -n $ENV_NAME python=3.10
|
||||
conda activate $ENV_NAME
|
||||
|
||||
# PT nightly
|
||||
pip install --pre torch --index-url https://download.pytorch.org/whl/nightly/cu121
|
||||
|
||||
# install dependencies for `llama-agentic-system`
|
||||
pip install -r fp8_requirements.txt
|
||||
}
|
||||
|
||||
run_build
|
|
@ -0,0 +1,161 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# This software may be used and distributed in accordance with the terms of the Llama 3 Community License Agreement.
|
||||
|
||||
import json
|
||||
import os
|
||||
import shutil
|
||||
import sys
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
import fire
|
||||
|
||||
import torch
|
||||
from fairscale.nn.model_parallel.initialize import (
|
||||
get_model_parallel_rank,
|
||||
initialize_model_parallel,
|
||||
model_parallel_is_initialized,
|
||||
)
|
||||
from fp8.fp8_impls import FfnQuantizeMode, quantize_fp8
|
||||
|
||||
from llama.model import ModelArgs, Transformer, TransformerBlock
|
||||
from llama.tokenizer import Tokenizer
|
||||
from torch.nn.parameter import Parameter
|
||||
|
||||
|
||||
def main(
|
||||
ckpt_dir: str,
|
||||
tokenizer_path: str,
|
||||
quantized_ckpt_dir: str,
|
||||
max_seq_len: Optional[int] = 512,
|
||||
max_batch_size: Optional[int] = 4,
|
||||
model_parallel_size: Optional[int] = None,
|
||||
ffn_quantize_mode: Optional[FfnQuantizeMode] = FfnQuantizeMode.FP8_ROWWISE,
|
||||
fp8_activation_scale_ub: Optional[float] = 1200.0,
|
||||
seed: int = 1,
|
||||
):
|
||||
""" """
|
||||
if not os.path.exists(quantized_ckpt_dir):
|
||||
os.makedirs(quantized_ckpt_dir)
|
||||
shutil.copy(
|
||||
os.path.join(ckpt_dir, "params.json"),
|
||||
os.path.join(quantized_ckpt_dir, "params.json"),
|
||||
)
|
||||
shutil.copy(
|
||||
os.path.join(ckpt_dir, "tokenizer.model"),
|
||||
os.path.join(quantized_ckpt_dir, "tokenizer.model"),
|
||||
)
|
||||
|
||||
if not torch.distributed.is_initialized():
|
||||
torch.distributed.init_process_group("nccl")
|
||||
if not model_parallel_is_initialized():
|
||||
if model_parallel_size is None:
|
||||
model_parallel_size = int(os.environ.get("WORLD_SIZE", 1))
|
||||
initialize_model_parallel(model_parallel_size)
|
||||
|
||||
local_rank = int(os.environ.get("LOCAL_RANK", 0))
|
||||
torch.cuda.set_device(local_rank)
|
||||
|
||||
# seed must be the same in all processes
|
||||
torch.manual_seed(seed)
|
||||
|
||||
if local_rank > 0:
|
||||
sys.stdout = open(os.devnull, "w")
|
||||
|
||||
checkpoints = sorted(Path(ckpt_dir).glob("*.pth"))
|
||||
assert len(checkpoints) > 0, f"no checkpoint files found in {ckpt_dir}"
|
||||
assert model_parallel_size == len(
|
||||
checkpoints
|
||||
), f"Loading a checkpoint for MP={len(checkpoints)} but world size is {model_parallel_size}"
|
||||
ckpt_path = checkpoints[get_model_parallel_rank()]
|
||||
checkpoint = torch.load(ckpt_path, map_location="cpu", weights_only=True)
|
||||
with open(Path(ckpt_dir) / "params.json", "r") as f:
|
||||
params = json.loads(f.read())
|
||||
|
||||
model_args: ModelArgs = ModelArgs(
|
||||
max_seq_len=max_seq_len,
|
||||
max_batch_size=max_batch_size,
|
||||
**params,
|
||||
)
|
||||
tokenizer = Tokenizer(model_path=tokenizer_path)
|
||||
assert (
|
||||
model_args.vocab_size == tokenizer.n_words
|
||||
), f"model_args vocab = {model_args.vocab_size} but tokenizer vocab = {tokenizer.n_words}"
|
||||
|
||||
# load on CPU in bf16 so that fp8 conversion does not find an unexpected (fp32, e.g.) datatype
|
||||
torch.set_default_tensor_type(torch.BFloat16Tensor)
|
||||
|
||||
model = Transformer(model_args)
|
||||
model.load_state_dict(checkpoint, strict=False)
|
||||
|
||||
if torch.cuda.is_bf16_supported():
|
||||
torch.set_default_tensor_type(torch.cuda.BFloat16Tensor)
|
||||
else:
|
||||
torch.set_default_tensor_type(torch.cuda.HalfTensor)
|
||||
|
||||
print(ckpt_path)
|
||||
assert (
|
||||
quantized_ckpt_dir is not None
|
||||
), "QUantized checkpoint directory should not be None"
|
||||
fp8_scales = {}
|
||||
for block in model.layers:
|
||||
if isinstance(block, TransformerBlock):
|
||||
if block.layer_id == 0 or block.layer_id == (model.n_layers - 1):
|
||||
continue
|
||||
|
||||
fp8_weight = quantize_fp8(
|
||||
block.feed_forward.w1.weight,
|
||||
fp8_activation_scale_ub,
|
||||
ffn_quantize_mode,
|
||||
output_device=torch.device("cpu"),
|
||||
)
|
||||
with torch.inference_mode():
|
||||
block.feed_forward.w1.weight = Parameter(fp8_weight.weight)
|
||||
fp8_scales[
|
||||
f"{block.layer_id}_feed_forward.w1_{get_model_parallel_rank()}"
|
||||
] = fp8_weight.scale
|
||||
|
||||
fp8_weight = quantize_fp8(
|
||||
block.feed_forward.w3.weight,
|
||||
fp8_activation_scale_ub,
|
||||
ffn_quantize_mode,
|
||||
output_device=torch.device("cpu"),
|
||||
)
|
||||
with torch.inference_mode():
|
||||
block.feed_forward.w3.weight = Parameter(fp8_weight.weight)
|
||||
fp8_scales[
|
||||
f"{block.layer_id}_feed_forward.w3_{get_model_parallel_rank()}"
|
||||
] = fp8_weight.scale
|
||||
|
||||
fp8_weight = quantize_fp8(
|
||||
block.feed_forward.w2.weight,
|
||||
fp8_activation_scale_ub,
|
||||
ffn_quantize_mode,
|
||||
output_device=torch.device("cpu"),
|
||||
)
|
||||
with torch.inference_mode():
|
||||
block.feed_forward.w2.weight = Parameter(fp8_weight.weight)
|
||||
fp8_scales[
|
||||
f"{block.layer_id}_feed_forward.w2_{get_model_parallel_rank()}"
|
||||
] = fp8_weight.scale
|
||||
|
||||
fp8_scales_path = os.path.join(
|
||||
quantized_ckpt_dir, f"fp8_scales_{get_model_parallel_rank()}.pt"
|
||||
)
|
||||
torch.save(fp8_scales, fp8_scales_path)
|
||||
|
||||
ckpt_path = os.path.join(
|
||||
quantized_ckpt_dir,
|
||||
"consolidated.{:02d}.pth".format(get_model_parallel_rank()),
|
||||
)
|
||||
torch.save(model.state_dict(), ckpt_path)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
fire.Fire(main)
|
|
@ -0,0 +1,31 @@
|
|||
#!/bin/bash
|
||||
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
set -euo pipefail
|
||||
set -x
|
||||
|
||||
cd $(git rev-parse --show-toplevel)
|
||||
|
||||
MASTER_HOST=$1
|
||||
RUN_ID=$2
|
||||
CKPT_DIR=$3
|
||||
QUANT_CKPT_DIR=$4
|
||||
TOKENIZER_PATH=$5
|
||||
NNODES=$6
|
||||
NPROC=$7
|
||||
|
||||
echo $MASTER_HOST, $RUN_ID, $CKPT_DIR, $QUANT_CKPT_DIR
|
||||
|
||||
NCCL_NET=Socket NCCL_SOCKET_IFNAME=eth TIKTOKEN_CACHE_DIR="" \
|
||||
torchrun \
|
||||
--nnodes=$NNODES --nproc_per_node=$NPROC \
|
||||
--rdzv_id=$RUN_ID \
|
||||
--rdzv_conf='timeout=120' \
|
||||
--rdzv_backend=c10d \
|
||||
--rdzv_endpoint="${MASTER_HOST}:29502" \
|
||||
quantize_checkpoint.py $CKPT_DIR $TOKENIZER_PATH $QUANT_CKPT_DIR
|
|
@ -0,0 +1,76 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# This software may be used and distributed in accordance with the terms of the Llama 3 Community License Agreement.
|
||||
|
||||
import unittest
|
||||
|
||||
import torch
|
||||
|
||||
from fp8_impls import ffn_swiglu_fp8_dynamic, FfnQuantizeMode, quantize_fp8
|
||||
from hypothesis import given, settings, strategies as st
|
||||
from torch import Tensor
|
||||
|
||||
|
||||
@unittest.skipIf(
|
||||
not torch.cuda.is_available()
|
||||
or torch.cuda.get_device_properties(torch.cuda.current_device()).major < 9,
|
||||
"Skip when H100 is not available",
|
||||
)
|
||||
class FP8Tests(unittest.TestCase):
|
||||
@settings(deadline=None)
|
||||
@given(
|
||||
D=st.sampled_from([4096, 8192]),
|
||||
HD_L=st.sampled_from([1280, 2560]),
|
||||
B=st.sampled_from([1, 2]),
|
||||
T=st.sampled_from([2048, 4096]),
|
||||
UB=st.sampled_from([1000, 10000]),
|
||||
)
|
||||
def test_fp8_ffn(
|
||||
self,
|
||||
D: int, # noqa
|
||||
HD_L: int,
|
||||
B: int,
|
||||
T: int,
|
||||
UB: float,
|
||||
) -> None:
|
||||
x = torch.randn(size=(B, T, D), dtype=torch.bfloat16, device="cuda") * 0.1
|
||||
w1 = torch.randn(size=(HD_L, D), dtype=torch.bfloat16, device="cuda") * 0.01
|
||||
w3 = torch.randn(size=(HD_L, D), dtype=torch.bfloat16, device="cuda") * 0.01
|
||||
w2 = torch.randn(size=(D, HD_L), dtype=torch.bfloat16, device="cuda") * 0.1
|
||||
|
||||
x_q = quantize_fp8(x, UB, mode=FfnQuantizeMode.FP8_ROWWISE)
|
||||
w1_q = quantize_fp8(w1, UB, mode=FfnQuantizeMode.FP8_ROWWISE)
|
||||
w3_q = quantize_fp8(w3, UB, mode=FfnQuantizeMode.FP8_ROWWISE)
|
||||
w2_q = quantize_fp8(w2, UB, mode=FfnQuantizeMode.FP8_ROWWISE)
|
||||
|
||||
def ref_ffn(x: Tensor, w1: Tensor, w3: Tensor, w2: Tensor) -> Tensor:
|
||||
(B, T, D) = x.shape # noqa: N806
|
||||
(HD_L, D_) = w1.shape # noqa: N806
|
||||
assert D_ == D
|
||||
|
||||
x1 = x.view(B * T, D) @ w1.T
|
||||
x2 = x.view(B * T, D) @ w3.T
|
||||
|
||||
z = torch.nn.functional.silu(x1) * x2
|
||||
return (z @ w2.T).view(B, T, D).to(torch.bfloat16)
|
||||
|
||||
v = ffn_swiglu_fp8_dynamic(x, w1_q, w3_q, w2_q)
|
||||
|
||||
# Fake quant
|
||||
x = x_q.weight.bfloat16() * x_q.scale.unsqueeze(-1)
|
||||
w1 = w1_q.weight.bfloat16() * w1_q.scale.unsqueeze(-1)
|
||||
w3 = w3_q.weight.bfloat16() * w3_q.scale.unsqueeze(-1)
|
||||
w2 = w2_q.weight.bfloat16() * w2_q.scale.unsqueeze(-1)
|
||||
|
||||
v_ref = ref_ffn(x, w1, w3, w2)
|
||||
|
||||
torch.testing.assert_close(v_ref, v, atol=4.0e-3, rtol=4.0e-3)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
Loading…
Add table
Add a link
Reference in a new issue