Support for Llama3.2 models and Swift SDK (#98)

This commit is contained in:
Ashwin Bharambe 2024-09-25 10:29:58 -07:00 committed by GitHub
parent 95abbf576b
commit 56aed59eb4
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
56 changed files with 3745 additions and 630 deletions

View file

@ -24,21 +24,31 @@ from fairscale.nn.model_parallel.initialize import (
)
from llama_models.llama3.api.args import ModelArgs
from llama_models.llama3.api.chat_format import ChatFormat, ModelInput
from llama_models.llama3.api.datatypes import Message, ToolPromptFormat
from llama_models.llama3.api.datatypes import (
InterleavedTextMedia,
Message,
ToolPromptFormat,
)
from llama_models.llama3.api.tokenizer import Tokenizer
from llama_models.llama3.reference_impl.model import Transformer
from llama_models.llama3.reference_impl.multimodal.model import (
CrossAttentionTransformer,
)
from llama_models.sku_list import resolve_model
from termcolor import cprint
from llama_stack.apis.inference import QuantizationType
from llama_stack.distribution.utils.model_utils import model_local_dir
from termcolor import cprint
from .config import MetaReferenceImplConfig
def model_checkpoint_dir(model) -> str:
checkpoint_dir = Path(model_local_dir(model.descriptor()))
if not Path(checkpoint_dir / "consolidated.00.pth").exists():
paths = [Path(checkpoint_dir / f"consolidated.{ext}") for ext in ["pth", "00.pth"]]
if not any(p.exists() for p in paths):
checkpoint_dir = checkpoint_dir / "original"
assert checkpoint_dir.exists(), (
@ -134,7 +144,11 @@ class Llama:
# load on CPU in bf16 so that fp8 conversion does not find an
# unexpected (fp32, e.g.) datatype
torch.set_default_tensor_type(torch.BFloat16Tensor)
model = Transformer(model_args)
if model_args.vision_chunk_size > 0:
model = CrossAttentionTransformer(model_args)
model.setup_cache(model_args.max_batch_size, torch.bfloat16)
else:
model = Transformer(model_args)
model.load_state_dict(state_dict, strict=False)
model = convert_to_quantized_model(model, config)
else:
@ -142,7 +156,11 @@ class Llama:
torch.set_default_tensor_type(torch.cuda.BFloat16Tensor)
else:
torch.set_default_tensor_type(torch.cuda.HalfTensor)
model = Transformer(model_args)
if model_args.vision_chunk_size > 0:
model = CrossAttentionTransformer(model_args)
model.setup_cache(model_args.max_batch_size, torch.bfloat16)
else:
model = Transformer(model_args)
model.load_state_dict(state_dict, strict=False)
print(f"Loaded in {time.time() - start_time:.2f} seconds")
@ -167,7 +185,11 @@ class Llama:
) -> Generator:
params = self.model.params
# cprint("Input to model -> " + self.tokenizer.decode(model_input.tokens), "red")
# input_tokens = [
# self.formatter.vision_token if t == 128256 else t
# for t in model_input.tokens
# ]
# cprint("Input to model -> " + self.tokenizer.decode(input_tokens), "red")
prompt_tokens = [model_input.tokens]
bsz = 1
@ -183,6 +205,21 @@ class Llama:
return
total_len = min(max_gen_len + max_prompt_len, params.max_seq_len)
is_vision = isinstance(self.model, CrossAttentionTransformer)
if is_vision:
images = model_input.vision.images if model_input.vision is not None else []
mask = model_input.vision.mask if model_input.vision is not None else []
# the method works for bsz > 1 so add a batch dimension
xattn_caches, cross_attention_masks, full_text_row_masked_out_mask = (
self.model.compute_vision_tokens_masks(
batch_images=[images],
batch_masks=[mask],
total_len=total_len,
)
)
pad_id = self.tokenizer.pad_id
tokens = torch.full((bsz, total_len), pad_id, dtype=torch.long, device="cuda")
for k, t in enumerate(prompt_tokens):
@ -206,7 +243,19 @@ class Llama:
stop_tokens = torch.tensor(self.tokenizer.stop_tokens)
for cur_pos in range(min_prompt_len, total_len):
logits = self.model.forward(tokens[:, prev_pos:cur_pos], prev_pos)
if is_vision:
position_ids = torch.arange(
prev_pos, cur_pos, dtype=torch.long, device="cuda"
)
logits = self.model.forward(
position_ids,
tokens,
cross_attention_masks,
full_text_row_masked_out_mask,
xattn_caches,
)
else:
logits = self.model.forward(tokens[:, prev_pos:cur_pos], prev_pos)
if temperature > 0:
probs = torch.softmax(logits[:, -1] / temperature, dim=-1)
@ -222,6 +271,18 @@ class Llama:
tokens[:, cur_pos] = next_token
target = tokens[:, prev_pos + 1 : cur_pos + 1]
if is_vision:
# the logits space (num_classes) is designed to never contain a media_token
# however our input token stream does contain them. we need to nuke them here
# or else the CUDA kernels will crash with an illegal memory access
vision_tokens = [self.tokenizer.special_tokens["<|image|>"], 128256]
masks = [target.eq(t) for t in vision_tokens]
if len(masks) > 1:
mask = torch.logical_or(*masks)
else:
mask = masks[0]
target[mask] = 0
if logprobs:
token_logprobs[:, prev_pos + 1 : cur_pos + 1] = -F.cross_entropy(
input=logits.transpose(1, 2),
@ -248,7 +309,7 @@ class Llama:
def text_completion(
self,
prompt: str,
content: InterleavedTextMedia,
temperature: float = 0.6,
top_p: float = 0.9,
max_gen_len: Optional[int] = None,
@ -262,10 +323,10 @@ class Llama:
):
max_gen_len = self.model.params.max_seq_len - 1
prompt_tokens = self.tokenizer.encode(prompt, bos=True, eos=False)
model_input = self.formatter.encode_content(content)
yield from self.generate(
model_input=ModelInput(tokens=prompt_tokens),
model_input=model_input,
max_gen_len=max_gen_len,
temperature=temperature,
top_p=top_p,