mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-06-28 02:53:30 +00:00
Support for Llama3.2 models and Swift SDK (#98)
This commit is contained in:
parent
95abbf576b
commit
56aed59eb4
56 changed files with 3745 additions and 630 deletions
|
@ -24,21 +24,31 @@ from fairscale.nn.model_parallel.initialize import (
|
|||
)
|
||||
from llama_models.llama3.api.args import ModelArgs
|
||||
from llama_models.llama3.api.chat_format import ChatFormat, ModelInput
|
||||
from llama_models.llama3.api.datatypes import Message, ToolPromptFormat
|
||||
from llama_models.llama3.api.datatypes import (
|
||||
InterleavedTextMedia,
|
||||
Message,
|
||||
ToolPromptFormat,
|
||||
)
|
||||
from llama_models.llama3.api.tokenizer import Tokenizer
|
||||
from llama_models.llama3.reference_impl.model import Transformer
|
||||
from llama_models.llama3.reference_impl.multimodal.model import (
|
||||
CrossAttentionTransformer,
|
||||
)
|
||||
from llama_models.sku_list import resolve_model
|
||||
from termcolor import cprint
|
||||
|
||||
from llama_stack.apis.inference import QuantizationType
|
||||
|
||||
from llama_stack.distribution.utils.model_utils import model_local_dir
|
||||
from termcolor import cprint
|
||||
|
||||
from .config import MetaReferenceImplConfig
|
||||
|
||||
|
||||
def model_checkpoint_dir(model) -> str:
|
||||
checkpoint_dir = Path(model_local_dir(model.descriptor()))
|
||||
if not Path(checkpoint_dir / "consolidated.00.pth").exists():
|
||||
|
||||
paths = [Path(checkpoint_dir / f"consolidated.{ext}") for ext in ["pth", "00.pth"]]
|
||||
if not any(p.exists() for p in paths):
|
||||
checkpoint_dir = checkpoint_dir / "original"
|
||||
|
||||
assert checkpoint_dir.exists(), (
|
||||
|
@ -134,7 +144,11 @@ class Llama:
|
|||
# load on CPU in bf16 so that fp8 conversion does not find an
|
||||
# unexpected (fp32, e.g.) datatype
|
||||
torch.set_default_tensor_type(torch.BFloat16Tensor)
|
||||
model = Transformer(model_args)
|
||||
if model_args.vision_chunk_size > 0:
|
||||
model = CrossAttentionTransformer(model_args)
|
||||
model.setup_cache(model_args.max_batch_size, torch.bfloat16)
|
||||
else:
|
||||
model = Transformer(model_args)
|
||||
model.load_state_dict(state_dict, strict=False)
|
||||
model = convert_to_quantized_model(model, config)
|
||||
else:
|
||||
|
@ -142,7 +156,11 @@ class Llama:
|
|||
torch.set_default_tensor_type(torch.cuda.BFloat16Tensor)
|
||||
else:
|
||||
torch.set_default_tensor_type(torch.cuda.HalfTensor)
|
||||
model = Transformer(model_args)
|
||||
if model_args.vision_chunk_size > 0:
|
||||
model = CrossAttentionTransformer(model_args)
|
||||
model.setup_cache(model_args.max_batch_size, torch.bfloat16)
|
||||
else:
|
||||
model = Transformer(model_args)
|
||||
model.load_state_dict(state_dict, strict=False)
|
||||
|
||||
print(f"Loaded in {time.time() - start_time:.2f} seconds")
|
||||
|
@ -167,7 +185,11 @@ class Llama:
|
|||
) -> Generator:
|
||||
params = self.model.params
|
||||
|
||||
# cprint("Input to model -> " + self.tokenizer.decode(model_input.tokens), "red")
|
||||
# input_tokens = [
|
||||
# self.formatter.vision_token if t == 128256 else t
|
||||
# for t in model_input.tokens
|
||||
# ]
|
||||
# cprint("Input to model -> " + self.tokenizer.decode(input_tokens), "red")
|
||||
prompt_tokens = [model_input.tokens]
|
||||
|
||||
bsz = 1
|
||||
|
@ -183,6 +205,21 @@ class Llama:
|
|||
return
|
||||
|
||||
total_len = min(max_gen_len + max_prompt_len, params.max_seq_len)
|
||||
|
||||
is_vision = isinstance(self.model, CrossAttentionTransformer)
|
||||
if is_vision:
|
||||
images = model_input.vision.images if model_input.vision is not None else []
|
||||
mask = model_input.vision.mask if model_input.vision is not None else []
|
||||
|
||||
# the method works for bsz > 1 so add a batch dimension
|
||||
xattn_caches, cross_attention_masks, full_text_row_masked_out_mask = (
|
||||
self.model.compute_vision_tokens_masks(
|
||||
batch_images=[images],
|
||||
batch_masks=[mask],
|
||||
total_len=total_len,
|
||||
)
|
||||
)
|
||||
|
||||
pad_id = self.tokenizer.pad_id
|
||||
tokens = torch.full((bsz, total_len), pad_id, dtype=torch.long, device="cuda")
|
||||
for k, t in enumerate(prompt_tokens):
|
||||
|
@ -206,7 +243,19 @@ class Llama:
|
|||
stop_tokens = torch.tensor(self.tokenizer.stop_tokens)
|
||||
|
||||
for cur_pos in range(min_prompt_len, total_len):
|
||||
logits = self.model.forward(tokens[:, prev_pos:cur_pos], prev_pos)
|
||||
if is_vision:
|
||||
position_ids = torch.arange(
|
||||
prev_pos, cur_pos, dtype=torch.long, device="cuda"
|
||||
)
|
||||
logits = self.model.forward(
|
||||
position_ids,
|
||||
tokens,
|
||||
cross_attention_masks,
|
||||
full_text_row_masked_out_mask,
|
||||
xattn_caches,
|
||||
)
|
||||
else:
|
||||
logits = self.model.forward(tokens[:, prev_pos:cur_pos], prev_pos)
|
||||
|
||||
if temperature > 0:
|
||||
probs = torch.softmax(logits[:, -1] / temperature, dim=-1)
|
||||
|
@ -222,6 +271,18 @@ class Llama:
|
|||
tokens[:, cur_pos] = next_token
|
||||
|
||||
target = tokens[:, prev_pos + 1 : cur_pos + 1]
|
||||
if is_vision:
|
||||
# the logits space (num_classes) is designed to never contain a media_token
|
||||
# however our input token stream does contain them. we need to nuke them here
|
||||
# or else the CUDA kernels will crash with an illegal memory access
|
||||
vision_tokens = [self.tokenizer.special_tokens["<|image|>"], 128256]
|
||||
masks = [target.eq(t) for t in vision_tokens]
|
||||
if len(masks) > 1:
|
||||
mask = torch.logical_or(*masks)
|
||||
else:
|
||||
mask = masks[0]
|
||||
target[mask] = 0
|
||||
|
||||
if logprobs:
|
||||
token_logprobs[:, prev_pos + 1 : cur_pos + 1] = -F.cross_entropy(
|
||||
input=logits.transpose(1, 2),
|
||||
|
@ -248,7 +309,7 @@ class Llama:
|
|||
|
||||
def text_completion(
|
||||
self,
|
||||
prompt: str,
|
||||
content: InterleavedTextMedia,
|
||||
temperature: float = 0.6,
|
||||
top_p: float = 0.9,
|
||||
max_gen_len: Optional[int] = None,
|
||||
|
@ -262,10 +323,10 @@ class Llama:
|
|||
):
|
||||
max_gen_len = self.model.params.max_seq_len - 1
|
||||
|
||||
prompt_tokens = self.tokenizer.encode(prompt, bos=True, eos=False)
|
||||
model_input = self.formatter.encode_content(content)
|
||||
|
||||
yield from self.generate(
|
||||
model_input=ModelInput(tokens=prompt_tokens),
|
||||
model_input=model_input,
|
||||
max_gen_len=max_gen_len,
|
||||
temperature=temperature,
|
||||
top_p=top_p,
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue