chore(api): add mypy coverage to meta_llama3_generation

Signed-off-by: Mustafa Elbehery <melbeher@redhat.com>
This commit is contained in:
Mustafa Elbehery 2025-07-08 23:27:10 +02:00
parent 81ebaf6e9a
commit 6a11fdb57b
2 changed files with 64 additions and 34 deletions

View file

@ -17,14 +17,15 @@ import sys
import time import time
from collections.abc import Callable, Generator from collections.abc import Callable, Generator
from pathlib import Path from pathlib import Path
from typing import TYPE_CHECKING
import torch import torch # type: ignore
import torch.nn.functional as F import torch.nn.functional as F # type: ignore
from fairscale.nn.model_parallel.initialize import ( from fairscale.nn.model_parallel.initialize import ( # type: ignore
initialize_model_parallel, initialize_model_parallel,
model_parallel_is_initialized, model_parallel_is_initialized,
) )
from termcolor import cprint from termcolor import cprint # type: ignore
from ..checkpoint import maybe_reshard_state_dict from ..checkpoint import maybe_reshard_state_dict
from ..datatypes import GenerationResult, QuantizationMode, RawContent, RawMessage, ToolPromptFormat from ..datatypes import GenerationResult, QuantizationMode, RawContent, RawMessage, ToolPromptFormat
@ -34,6 +35,16 @@ from .model import Transformer
from .multimodal.model import CrossAttentionTransformer from .multimodal.model import CrossAttentionTransformer
from .tokenizer import Tokenizer from .tokenizer import Tokenizer
if TYPE_CHECKING:
from .quantization.loader import convert_to_quantized_model
else:
# Import at runtime to avoid circular dependencies
def _get_convert_to_quantized_model():
from .quantization.loader import convert_to_quantized_model
return convert_to_quantized_model
convert_to_quantized_model = _get_convert_to_quantized_model()
class Llama3: class Llama3:
@staticmethod @staticmethod
@ -44,19 +55,19 @@ class Llama3:
world_size: int | None = None, world_size: int | None = None,
quantization_mode: QuantizationMode | None = None, quantization_mode: QuantizationMode | None = None,
seed: int = 1, seed: int = 1,
device: str = "cuda", device: str | torch.device = "cuda",
): ):
device = torch.device(device) device_obj = torch.device(device)
if ( if (
device.type == "cuda" device_obj.type == "cuda"
and not torch.cuda.is_available() and not torch.cuda.is_available()
or device.type == "xpu" or device_obj.type == "xpu"
and not torch.xpu.is_available() and not torch.xpu.is_available()
): ):
raise RuntimeError(f"PyTorch backend for {device.type} device type is not available") raise RuntimeError(f"PyTorch backend for {device_obj.type} device type is not available")
if not torch.distributed.is_initialized(): if not torch.distributed.is_initialized():
if device.type == "cuda": if device_obj.type == "cuda":
torch.distributed.init_process_group("nccl") torch.distributed.init_process_group("nccl")
else: else:
torch.distributed.init_process_group("gloo") torch.distributed.init_process_group("gloo")
@ -67,9 +78,9 @@ class Llama3:
initialize_model_parallel(world_size) initialize_model_parallel(world_size)
local_rank = int(os.environ.get("LOCAL_RANK", 0)) local_rank = int(os.environ.get("LOCAL_RANK", 0))
if device.type == "cuda": if device_obj.type == "cuda":
torch.cuda.set_device(local_rank) torch.cuda.set_device(local_rank)
elif device.type == "xpu": elif device_obj.type == "xpu":
torch.xpu.set_device(local_rank) torch.xpu.set_device(local_rank)
torch.manual_seed(seed) torch.manual_seed(seed)
@ -102,29 +113,27 @@ class Llama3:
def build_model(): def build_model():
if model_args.vision_chunk_size > 0: if model_args.vision_chunk_size > 0:
model = CrossAttentionTransformer(model_args) model = CrossAttentionTransformer(model_args)
model.setup_cache(model_args.max_batch_size, device=device, dtype=torch.get_default_dtype()) model.setup_cache(model_args.max_batch_size, device=device_obj, dtype=torch.get_default_dtype())
else: else:
model = Transformer(model_args) model = Transformer(model_args)
return model return model
if quantization_mode == QuantizationMode.fp8_mixed or quantization_mode == QuantizationMode.int4_mixed: if quantization_mode == QuantizationMode.fp8_mixed or quantization_mode == QuantizationMode.int4_mixed:
from .quantization.loader import convert_to_quantized_model
torch.set_default_tensor_type(torch.BFloat16Tensor) torch.set_default_tensor_type(torch.BFloat16Tensor)
model = build_model() model = build_model()
print("Loading state dict...") print("Loading state dict...")
model.load_state_dict(state_dict, strict=False) model.load_state_dict(state_dict, strict=False)
print("Done...") print("Done...")
model = convert_to_quantized_model(model, ckpt_dir, quantization_mode, device=device) model = convert_to_quantized_model(model, ckpt_dir, quantization_mode, device=device_obj)
torch.set_default_device(device) torch.set_default_device(device_obj)
else: else:
print(f"Setting default device to {device}") print(f"Setting default device to {device_obj}")
if device.type == "cuda": if device_obj.type == "cuda":
if torch.cuda.is_bf16_supported(): if torch.cuda.is_bf16_supported():
torch.set_default_tensor_type(torch.cuda.BFloat16Tensor) torch.set_default_tensor_type(torch.cuda.BFloat16Tensor)
else: else:
torch.set_default_tensor_type(torch.cuda.Float16Tensor) torch.set_default_tensor_type(torch.cuda.Float16Tensor)
elif device.type == "xpu": elif device_obj.type == "xpu":
if torch.xpu.is_bf16_supported(): if torch.xpu.is_bf16_supported():
torch.set_default_tensor_type(torch.xpu.BFloat16Tensor) torch.set_default_tensor_type(torch.xpu.BFloat16Tensor)
else: else:
@ -133,7 +142,7 @@ class Llama3:
model = build_model() model = build_model()
print("Loading state dict...") print("Loading state dict...")
model.load_state_dict(state_dict, strict=True) model.load_state_dict(state_dict, strict=True)
model.to(device) model.to(device_obj)
print("Done...") print("Done...")
print(f"Loaded in {time.time() - start_time:.2f} seconds") print(f"Loaded in {time.time() - start_time:.2f} seconds")
@ -212,6 +221,11 @@ class Llama3:
total_len=total_len, total_len=total_len,
device=tokens.device, device=tokens.device,
) )
else:
# Define dummy values for non-vision models to satisfy mypy
xattn_caches = torch.tensor([])
cross_attention_masks = torch.tensor([])
full_text_row_masked_out_mask = torch.tensor([])
eos_reached = torch.tensor([False] * bsz) eos_reached = torch.tensor([False] * bsz)
input_text_mask = tokens != pad_id input_text_mask = tokens != pad_id
@ -240,16 +254,33 @@ class Llama3:
if is_vision: if is_vision:
position_ids = torch.arange(prev_pos, cur_pos, dtype=torch.long) position_ids = torch.arange(prev_pos, cur_pos, dtype=torch.long)
text_only_inference = all(inp.vision is None for inp in llm_inputs) text_only_inference = all(inp.vision is None for inp in llm_inputs)
# Type narrowing for mypy
if isinstance(self.model, CrossAttentionTransformer):
logits = self.model.forward( logits = self.model.forward(
position_ids, position_ids=position_ids,
tokens, tokens=tokens,
cross_attention_masks, cross_attention_masks=cross_attention_masks,
full_text_row_masked_out_mask, full_text_row_masked_out_mask=full_text_row_masked_out_mask,
xattn_caches, xattn_caches=xattn_caches,
text_only_inference, text_only_inference=text_only_inference,
) )
else: else:
# This should not happen when is_vision=True, but for type safety
logits = self.model.forward(tokens[:, prev_pos:cur_pos], prev_pos) logits = self.model.forward(tokens[:, prev_pos:cur_pos], prev_pos)
else:
# Type narrowing for mypy
if isinstance(self.model, Transformer):
logits = self.model.forward(tokens[:, prev_pos:cur_pos], prev_pos)
else:
# This should not happen when is_vision=False, but for type safety
logits = self.model.forward(
position_ids=torch.arange(prev_pos, cur_pos, dtype=torch.long),
tokens=tokens,
cross_attention_masks=cross_attention_masks,
full_text_row_masked_out_mask=full_text_row_masked_out_mask,
xattn_caches=xattn_caches,
text_only_inference=False,
)
if logits_processor is not None: if logits_processor is not None:
logits = logits_processor(tokens[:, :cur_pos], logits) logits = logits_processor(tokens[:, :cur_pos], logits)
@ -316,7 +347,7 @@ class Llama3:
) -> Generator[list[GenerationResult], None, None]: ) -> Generator[list[GenerationResult], None, None]:
model_inputs = [self.formatter.encode_content(c) for c in contents] model_inputs = [self.formatter.encode_content(c) for c in contents]
for result in self.generate( for result in self.generate(
model_inputs=model_inputs, llm_inputs=model_inputs,
temperature=temperature, temperature=temperature,
top_p=top_p, top_p=top_p,
max_gen_len=max_gen_len, max_gen_len=max_gen_len,
@ -339,7 +370,7 @@ class Llama3:
) -> Generator[list[GenerationResult], None, None]: ) -> Generator[list[GenerationResult], None, None]:
model_inputs = [self.formatter.encode_dialog_prompt(messages) for messages in messages_batch] model_inputs = [self.formatter.encode_dialog_prompt(messages) for messages in messages_batch]
for result in self.generate( for result in self.generate(
model_inputs=model_inputs, llm_inputs=model_inputs,
temperature=temperature, temperature=temperature,
top_p=top_p, top_p=top_p,
max_gen_len=max_gen_len, max_gen_len=max_gen_len,

View file

@ -248,7 +248,6 @@ exclude = [
"^llama_stack/providers/inline/datasetio/localfs/", "^llama_stack/providers/inline/datasetio/localfs/",
"^llama_stack/providers/inline/eval/meta_reference/eval\\.py$", "^llama_stack/providers/inline/eval/meta_reference/eval\\.py$",
"^llama_stack/providers/inline/inference/meta_reference/inference\\.py$", "^llama_stack/providers/inline/inference/meta_reference/inference\\.py$",
"^llama_stack/models/llama/llama3/generation\\.py$",
"^llama_stack/models/llama/llama3/multimodal/model\\.py$", "^llama_stack/models/llama/llama3/multimodal/model\\.py$",
"^llama_stack/models/llama/llama4/", "^llama_stack/models/llama/llama4/",
"^llama_stack/providers/inline/inference/meta_reference/parallel_utils\\.py$", "^llama_stack/providers/inline/inference/meta_reference/parallel_utils\\.py$",