mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-07-29 07:14:20 +00:00
chore(api): add mypy coverage to meta_llama3_generation
Signed-off-by: Mustafa Elbehery <melbeher@redhat.com>
This commit is contained in:
parent
81ebaf6e9a
commit
6a11fdb57b
2 changed files with 64 additions and 34 deletions
|
@ -17,14 +17,15 @@ import sys
|
|||
import time
|
||||
from collections.abc import Callable, Generator
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from fairscale.nn.model_parallel.initialize import (
|
||||
import torch # type: ignore
|
||||
import torch.nn.functional as F # type: ignore
|
||||
from fairscale.nn.model_parallel.initialize import ( # type: ignore
|
||||
initialize_model_parallel,
|
||||
model_parallel_is_initialized,
|
||||
)
|
||||
from termcolor import cprint
|
||||
from termcolor import cprint # type: ignore
|
||||
|
||||
from ..checkpoint import maybe_reshard_state_dict
|
||||
from ..datatypes import GenerationResult, QuantizationMode, RawContent, RawMessage, ToolPromptFormat
|
||||
|
@ -34,6 +35,16 @@ from .model import Transformer
|
|||
from .multimodal.model import CrossAttentionTransformer
|
||||
from .tokenizer import Tokenizer
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .quantization.loader import convert_to_quantized_model
|
||||
else:
|
||||
# Import at runtime to avoid circular dependencies
|
||||
def _get_convert_to_quantized_model():
|
||||
from .quantization.loader import convert_to_quantized_model
|
||||
return convert_to_quantized_model
|
||||
|
||||
convert_to_quantized_model = _get_convert_to_quantized_model()
|
||||
|
||||
|
||||
class Llama3:
|
||||
@staticmethod
|
||||
|
@ -44,19 +55,19 @@ class Llama3:
|
|||
world_size: int | None = None,
|
||||
quantization_mode: QuantizationMode | None = None,
|
||||
seed: int = 1,
|
||||
device: str = "cuda",
|
||||
device: str | torch.device = "cuda",
|
||||
):
|
||||
device = torch.device(device)
|
||||
device_obj = torch.device(device)
|
||||
if (
|
||||
device.type == "cuda"
|
||||
device_obj.type == "cuda"
|
||||
and not torch.cuda.is_available()
|
||||
or device.type == "xpu"
|
||||
or device_obj.type == "xpu"
|
||||
and not torch.xpu.is_available()
|
||||
):
|
||||
raise RuntimeError(f"PyTorch backend for {device.type} device type is not available")
|
||||
raise RuntimeError(f"PyTorch backend for {device_obj.type} device type is not available")
|
||||
|
||||
if not torch.distributed.is_initialized():
|
||||
if device.type == "cuda":
|
||||
if device_obj.type == "cuda":
|
||||
torch.distributed.init_process_group("nccl")
|
||||
else:
|
||||
torch.distributed.init_process_group("gloo")
|
||||
|
@ -67,9 +78,9 @@ class Llama3:
|
|||
initialize_model_parallel(world_size)
|
||||
|
||||
local_rank = int(os.environ.get("LOCAL_RANK", 0))
|
||||
if device.type == "cuda":
|
||||
if device_obj.type == "cuda":
|
||||
torch.cuda.set_device(local_rank)
|
||||
elif device.type == "xpu":
|
||||
elif device_obj.type == "xpu":
|
||||
torch.xpu.set_device(local_rank)
|
||||
|
||||
torch.manual_seed(seed)
|
||||
|
@ -102,29 +113,27 @@ class Llama3:
|
|||
def build_model():
|
||||
if model_args.vision_chunk_size > 0:
|
||||
model = CrossAttentionTransformer(model_args)
|
||||
model.setup_cache(model_args.max_batch_size, device=device, dtype=torch.get_default_dtype())
|
||||
model.setup_cache(model_args.max_batch_size, device=device_obj, dtype=torch.get_default_dtype())
|
||||
else:
|
||||
model = Transformer(model_args)
|
||||
return model
|
||||
|
||||
if quantization_mode == QuantizationMode.fp8_mixed or quantization_mode == QuantizationMode.int4_mixed:
|
||||
from .quantization.loader import convert_to_quantized_model
|
||||
|
||||
torch.set_default_tensor_type(torch.BFloat16Tensor)
|
||||
model = build_model()
|
||||
print("Loading state dict...")
|
||||
model.load_state_dict(state_dict, strict=False)
|
||||
print("Done...")
|
||||
model = convert_to_quantized_model(model, ckpt_dir, quantization_mode, device=device)
|
||||
torch.set_default_device(device)
|
||||
model = convert_to_quantized_model(model, ckpt_dir, quantization_mode, device=device_obj)
|
||||
torch.set_default_device(device_obj)
|
||||
else:
|
||||
print(f"Setting default device to {device}")
|
||||
if device.type == "cuda":
|
||||
print(f"Setting default device to {device_obj}")
|
||||
if device_obj.type == "cuda":
|
||||
if torch.cuda.is_bf16_supported():
|
||||
torch.set_default_tensor_type(torch.cuda.BFloat16Tensor)
|
||||
else:
|
||||
torch.set_default_tensor_type(torch.cuda.Float16Tensor)
|
||||
elif device.type == "xpu":
|
||||
elif device_obj.type == "xpu":
|
||||
if torch.xpu.is_bf16_supported():
|
||||
torch.set_default_tensor_type(torch.xpu.BFloat16Tensor)
|
||||
else:
|
||||
|
@ -133,7 +142,7 @@ class Llama3:
|
|||
model = build_model()
|
||||
print("Loading state dict...")
|
||||
model.load_state_dict(state_dict, strict=True)
|
||||
model.to(device)
|
||||
model.to(device_obj)
|
||||
print("Done...")
|
||||
|
||||
print(f"Loaded in {time.time() - start_time:.2f} seconds")
|
||||
|
@ -212,6 +221,11 @@ class Llama3:
|
|||
total_len=total_len,
|
||||
device=tokens.device,
|
||||
)
|
||||
else:
|
||||
# Define dummy values for non-vision models to satisfy mypy
|
||||
xattn_caches = torch.tensor([])
|
||||
cross_attention_masks = torch.tensor([])
|
||||
full_text_row_masked_out_mask = torch.tensor([])
|
||||
|
||||
eos_reached = torch.tensor([False] * bsz)
|
||||
input_text_mask = tokens != pad_id
|
||||
|
@ -240,16 +254,33 @@ class Llama3:
|
|||
if is_vision:
|
||||
position_ids = torch.arange(prev_pos, cur_pos, dtype=torch.long)
|
||||
text_only_inference = all(inp.vision is None for inp in llm_inputs)
|
||||
# Type narrowing for mypy
|
||||
if isinstance(self.model, CrossAttentionTransformer):
|
||||
logits = self.model.forward(
|
||||
position_ids,
|
||||
tokens,
|
||||
cross_attention_masks,
|
||||
full_text_row_masked_out_mask,
|
||||
xattn_caches,
|
||||
text_only_inference,
|
||||
position_ids=position_ids,
|
||||
tokens=tokens,
|
||||
cross_attention_masks=cross_attention_masks,
|
||||
full_text_row_masked_out_mask=full_text_row_masked_out_mask,
|
||||
xattn_caches=xattn_caches,
|
||||
text_only_inference=text_only_inference,
|
||||
)
|
||||
else:
|
||||
# This should not happen when is_vision=True, but for type safety
|
||||
logits = self.model.forward(tokens[:, prev_pos:cur_pos], prev_pos)
|
||||
else:
|
||||
# Type narrowing for mypy
|
||||
if isinstance(self.model, Transformer):
|
||||
logits = self.model.forward(tokens[:, prev_pos:cur_pos], prev_pos)
|
||||
else:
|
||||
# This should not happen when is_vision=False, but for type safety
|
||||
logits = self.model.forward(
|
||||
position_ids=torch.arange(prev_pos, cur_pos, dtype=torch.long),
|
||||
tokens=tokens,
|
||||
cross_attention_masks=cross_attention_masks,
|
||||
full_text_row_masked_out_mask=full_text_row_masked_out_mask,
|
||||
xattn_caches=xattn_caches,
|
||||
text_only_inference=False,
|
||||
)
|
||||
|
||||
if logits_processor is not None:
|
||||
logits = logits_processor(tokens[:, :cur_pos], logits)
|
||||
|
@ -316,7 +347,7 @@ class Llama3:
|
|||
) -> Generator[list[GenerationResult], None, None]:
|
||||
model_inputs = [self.formatter.encode_content(c) for c in contents]
|
||||
for result in self.generate(
|
||||
model_inputs=model_inputs,
|
||||
llm_inputs=model_inputs,
|
||||
temperature=temperature,
|
||||
top_p=top_p,
|
||||
max_gen_len=max_gen_len,
|
||||
|
@ -339,7 +370,7 @@ class Llama3:
|
|||
) -> Generator[list[GenerationResult], None, None]:
|
||||
model_inputs = [self.formatter.encode_dialog_prompt(messages) for messages in messages_batch]
|
||||
for result in self.generate(
|
||||
model_inputs=model_inputs,
|
||||
llm_inputs=model_inputs,
|
||||
temperature=temperature,
|
||||
top_p=top_p,
|
||||
max_gen_len=max_gen_len,
|
||||
|
|
|
@ -248,7 +248,6 @@ exclude = [
|
|||
"^llama_stack/providers/inline/datasetio/localfs/",
|
||||
"^llama_stack/providers/inline/eval/meta_reference/eval\\.py$",
|
||||
"^llama_stack/providers/inline/inference/meta_reference/inference\\.py$",
|
||||
"^llama_stack/models/llama/llama3/generation\\.py$",
|
||||
"^llama_stack/models/llama/llama3/multimodal/model\\.py$",
|
||||
"^llama_stack/models/llama/llama4/",
|
||||
"^llama_stack/providers/inline/inference/meta_reference/parallel_utils\\.py$",
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue