forked from phoenix-oss/llama-stack-mirror
Support for Llama3.2 models and Swift SDK (#98)
This commit is contained in:
parent
95abbf576b
commit
56aed59eb4
56 changed files with 3745 additions and 630 deletions
|
@ -16,7 +16,7 @@ from pydantic import BaseModel, Field, field_validator
|
|||
|
||||
class MetaReferenceImplConfig(BaseModel):
|
||||
model: str = Field(
|
||||
default="Meta-Llama3.1-8B-Instruct",
|
||||
default="Llama3.1-8B-Instruct",
|
||||
description="Model descriptor from `llama model list`",
|
||||
)
|
||||
quantization: Optional[QuantizationConfig] = None
|
||||
|
@ -30,7 +30,7 @@ class MetaReferenceImplConfig(BaseModel):
|
|||
permitted_models = [
|
||||
m.descriptor()
|
||||
for m in all_registered_models()
|
||||
if m.model_family == ModelFamily.llama3_1
|
||||
if m.model_family in {ModelFamily.llama3_1, ModelFamily.llama3_2}
|
||||
or m.core_model_id == CoreModelId.llama_guard_3_8b
|
||||
]
|
||||
if model not in permitted_models:
|
||||
|
@ -42,14 +42,9 @@ class MetaReferenceImplConfig(BaseModel):
|
|||
|
||||
@property
|
||||
def model_parallel_size(self) -> int:
|
||||
# HUGE HACK ALERT: this will be fixed when we move inference configuration
|
||||
# HACK ALERT: this will be fixed when we move inference configuration
|
||||
# to ModelsRegistry and we can explicitly ask for `model_parallel_size`
|
||||
# as configuration there
|
||||
gpu_count = 1
|
||||
resolved = resolve_model(self.model)
|
||||
assert resolved is not None
|
||||
descriptor = resolved.descriptor().lower()
|
||||
if "-70b" in descriptor or "-405b" in descriptor:
|
||||
gpu_count = 8
|
||||
|
||||
return gpu_count
|
||||
return resolved.pth_file_count
|
||||
|
|
|
@ -24,21 +24,31 @@ from fairscale.nn.model_parallel.initialize import (
|
|||
)
|
||||
from llama_models.llama3.api.args import ModelArgs
|
||||
from llama_models.llama3.api.chat_format import ChatFormat, ModelInput
|
||||
from llama_models.llama3.api.datatypes import Message, ToolPromptFormat
|
||||
from llama_models.llama3.api.datatypes import (
|
||||
InterleavedTextMedia,
|
||||
Message,
|
||||
ToolPromptFormat,
|
||||
)
|
||||
from llama_models.llama3.api.tokenizer import Tokenizer
|
||||
from llama_models.llama3.reference_impl.model import Transformer
|
||||
from llama_models.llama3.reference_impl.multimodal.model import (
|
||||
CrossAttentionTransformer,
|
||||
)
|
||||
from llama_models.sku_list import resolve_model
|
||||
from termcolor import cprint
|
||||
|
||||
from llama_stack.apis.inference import QuantizationType
|
||||
|
||||
from llama_stack.distribution.utils.model_utils import model_local_dir
|
||||
from termcolor import cprint
|
||||
|
||||
from .config import MetaReferenceImplConfig
|
||||
|
||||
|
||||
def model_checkpoint_dir(model) -> str:
|
||||
checkpoint_dir = Path(model_local_dir(model.descriptor()))
|
||||
if not Path(checkpoint_dir / "consolidated.00.pth").exists():
|
||||
|
||||
paths = [Path(checkpoint_dir / f"consolidated.{ext}") for ext in ["pth", "00.pth"]]
|
||||
if not any(p.exists() for p in paths):
|
||||
checkpoint_dir = checkpoint_dir / "original"
|
||||
|
||||
assert checkpoint_dir.exists(), (
|
||||
|
@ -134,7 +144,11 @@ class Llama:
|
|||
# load on CPU in bf16 so that fp8 conversion does not find an
|
||||
# unexpected (fp32, e.g.) datatype
|
||||
torch.set_default_tensor_type(torch.BFloat16Tensor)
|
||||
model = Transformer(model_args)
|
||||
if model_args.vision_chunk_size > 0:
|
||||
model = CrossAttentionTransformer(model_args)
|
||||
model.setup_cache(model_args.max_batch_size, torch.bfloat16)
|
||||
else:
|
||||
model = Transformer(model_args)
|
||||
model.load_state_dict(state_dict, strict=False)
|
||||
model = convert_to_quantized_model(model, config)
|
||||
else:
|
||||
|
@ -142,7 +156,11 @@ class Llama:
|
|||
torch.set_default_tensor_type(torch.cuda.BFloat16Tensor)
|
||||
else:
|
||||
torch.set_default_tensor_type(torch.cuda.HalfTensor)
|
||||
model = Transformer(model_args)
|
||||
if model_args.vision_chunk_size > 0:
|
||||
model = CrossAttentionTransformer(model_args)
|
||||
model.setup_cache(model_args.max_batch_size, torch.bfloat16)
|
||||
else:
|
||||
model = Transformer(model_args)
|
||||
model.load_state_dict(state_dict, strict=False)
|
||||
|
||||
print(f"Loaded in {time.time() - start_time:.2f} seconds")
|
||||
|
@ -167,7 +185,11 @@ class Llama:
|
|||
) -> Generator:
|
||||
params = self.model.params
|
||||
|
||||
# cprint("Input to model -> " + self.tokenizer.decode(model_input.tokens), "red")
|
||||
# input_tokens = [
|
||||
# self.formatter.vision_token if t == 128256 else t
|
||||
# for t in model_input.tokens
|
||||
# ]
|
||||
# cprint("Input to model -> " + self.tokenizer.decode(input_tokens), "red")
|
||||
prompt_tokens = [model_input.tokens]
|
||||
|
||||
bsz = 1
|
||||
|
@ -183,6 +205,21 @@ class Llama:
|
|||
return
|
||||
|
||||
total_len = min(max_gen_len + max_prompt_len, params.max_seq_len)
|
||||
|
||||
is_vision = isinstance(self.model, CrossAttentionTransformer)
|
||||
if is_vision:
|
||||
images = model_input.vision.images if model_input.vision is not None else []
|
||||
mask = model_input.vision.mask if model_input.vision is not None else []
|
||||
|
||||
# the method works for bsz > 1 so add a batch dimension
|
||||
xattn_caches, cross_attention_masks, full_text_row_masked_out_mask = (
|
||||
self.model.compute_vision_tokens_masks(
|
||||
batch_images=[images],
|
||||
batch_masks=[mask],
|
||||
total_len=total_len,
|
||||
)
|
||||
)
|
||||
|
||||
pad_id = self.tokenizer.pad_id
|
||||
tokens = torch.full((bsz, total_len), pad_id, dtype=torch.long, device="cuda")
|
||||
for k, t in enumerate(prompt_tokens):
|
||||
|
@ -206,7 +243,19 @@ class Llama:
|
|||
stop_tokens = torch.tensor(self.tokenizer.stop_tokens)
|
||||
|
||||
for cur_pos in range(min_prompt_len, total_len):
|
||||
logits = self.model.forward(tokens[:, prev_pos:cur_pos], prev_pos)
|
||||
if is_vision:
|
||||
position_ids = torch.arange(
|
||||
prev_pos, cur_pos, dtype=torch.long, device="cuda"
|
||||
)
|
||||
logits = self.model.forward(
|
||||
position_ids,
|
||||
tokens,
|
||||
cross_attention_masks,
|
||||
full_text_row_masked_out_mask,
|
||||
xattn_caches,
|
||||
)
|
||||
else:
|
||||
logits = self.model.forward(tokens[:, prev_pos:cur_pos], prev_pos)
|
||||
|
||||
if temperature > 0:
|
||||
probs = torch.softmax(logits[:, -1] / temperature, dim=-1)
|
||||
|
@ -222,6 +271,18 @@ class Llama:
|
|||
tokens[:, cur_pos] = next_token
|
||||
|
||||
target = tokens[:, prev_pos + 1 : cur_pos + 1]
|
||||
if is_vision:
|
||||
# the logits space (num_classes) is designed to never contain a media_token
|
||||
# however our input token stream does contain them. we need to nuke them here
|
||||
# or else the CUDA kernels will crash with an illegal memory access
|
||||
vision_tokens = [self.tokenizer.special_tokens["<|image|>"], 128256]
|
||||
masks = [target.eq(t) for t in vision_tokens]
|
||||
if len(masks) > 1:
|
||||
mask = torch.logical_or(*masks)
|
||||
else:
|
||||
mask = masks[0]
|
||||
target[mask] = 0
|
||||
|
||||
if logprobs:
|
||||
token_logprobs[:, prev_pos + 1 : cur_pos + 1] = -F.cross_entropy(
|
||||
input=logits.transpose(1, 2),
|
||||
|
@ -248,7 +309,7 @@ class Llama:
|
|||
|
||||
def text_completion(
|
||||
self,
|
||||
prompt: str,
|
||||
content: InterleavedTextMedia,
|
||||
temperature: float = 0.6,
|
||||
top_p: float = 0.9,
|
||||
max_gen_len: Optional[int] = None,
|
||||
|
@ -262,10 +323,10 @@ class Llama:
|
|||
):
|
||||
max_gen_len = self.model.params.max_seq_len - 1
|
||||
|
||||
prompt_tokens = self.tokenizer.encode(prompt, bos=True, eos=False)
|
||||
model_input = self.formatter.encode_content(content)
|
||||
|
||||
yield from self.generate(
|
||||
model_input=ModelInput(tokens=prompt_tokens),
|
||||
model_input=model_input,
|
||||
max_gen_len=max_gen_len,
|
||||
temperature=temperature,
|
||||
top_p=top_p,
|
||||
|
|
|
@ -21,7 +21,9 @@ from llama_stack.apis.inference import (
|
|||
ToolCallDelta,
|
||||
ToolCallParseStatus,
|
||||
)
|
||||
from llama_stack.providers.utils.inference.prepare_messages import prepare_messages
|
||||
from llama_stack.providers.utils.inference.augment_messages import (
|
||||
augment_messages_for_tools,
|
||||
)
|
||||
|
||||
from .config import MetaReferenceImplConfig
|
||||
from .model_parallel import LlamaModelParallelGenerator
|
||||
|
@ -57,7 +59,7 @@ class MetaReferenceInferenceImpl(Inference):
|
|||
model: str,
|
||||
messages: List[Message],
|
||||
sampling_params: Optional[SamplingParams] = SamplingParams(),
|
||||
tools: Optional[List[ToolDefinition]] = [],
|
||||
tools: Optional[List[ToolDefinition]] = None,
|
||||
tool_choice: Optional[ToolChoice] = ToolChoice.auto,
|
||||
tool_prompt_format: Optional[ToolPromptFormat] = ToolPromptFormat.json,
|
||||
stream: Optional[bool] = False,
|
||||
|
@ -70,14 +72,14 @@ class MetaReferenceInferenceImpl(Inference):
|
|||
model=model,
|
||||
messages=messages,
|
||||
sampling_params=sampling_params,
|
||||
tools=tools,
|
||||
tools=tools or [],
|
||||
tool_choice=tool_choice,
|
||||
tool_prompt_format=tool_prompt_format,
|
||||
stream=stream,
|
||||
logprobs=logprobs,
|
||||
)
|
||||
|
||||
messages = prepare_messages(request)
|
||||
messages = augment_messages_for_tools(request)
|
||||
model = resolve_model(request.model)
|
||||
if model is None:
|
||||
raise RuntimeError(
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue