diff --git a/llama_stack/models/llama/llama3/generation.py b/llama_stack/models/llama/llama3/generation.py index fe7be5ea9..4eb14ea6a 100644 --- a/llama_stack/models/llama/llama3/generation.py +++ b/llama_stack/models/llama/llama3/generation.py @@ -17,14 +17,15 @@ import sys import time from collections.abc import Callable, Generator from pathlib import Path +from typing import TYPE_CHECKING -import torch -import torch.nn.functional as F -from fairscale.nn.model_parallel.initialize import ( +import torch # type: ignore +import torch.nn.functional as F # type: ignore +from fairscale.nn.model_parallel.initialize import ( # type: ignore initialize_model_parallel, model_parallel_is_initialized, ) -from termcolor import cprint +from termcolor import cprint # type: ignore from ..checkpoint import maybe_reshard_state_dict from ..datatypes import GenerationResult, QuantizationMode, RawContent, RawMessage, ToolPromptFormat @@ -34,6 +35,16 @@ from .model import Transformer from .multimodal.model import CrossAttentionTransformer from .tokenizer import Tokenizer +if TYPE_CHECKING: + from .quantization.loader import convert_to_quantized_model +else: + # Import at runtime to avoid circular dependencies + def _get_convert_to_quantized_model(): + from .quantization.loader import convert_to_quantized_model + return convert_to_quantized_model + + convert_to_quantized_model = _get_convert_to_quantized_model() + class Llama3: @staticmethod @@ -44,19 +55,19 @@ class Llama3: world_size: int | None = None, quantization_mode: QuantizationMode | None = None, seed: int = 1, - device: str = "cuda", + device: str | torch.device = "cuda", ): - device = torch.device(device) + device_obj = torch.device(device) if ( - device.type == "cuda" + device_obj.type == "cuda" and not torch.cuda.is_available() - or device.type == "xpu" + or device_obj.type == "xpu" and not torch.xpu.is_available() ): - raise RuntimeError(f"PyTorch backend for {device.type} device type is not available") + raise RuntimeError(f"PyTorch backend for {device_obj.type} device type is not available") if not torch.distributed.is_initialized(): - if device.type == "cuda": + if device_obj.type == "cuda": torch.distributed.init_process_group("nccl") else: torch.distributed.init_process_group("gloo") @@ -67,9 +78,9 @@ class Llama3: initialize_model_parallel(world_size) local_rank = int(os.environ.get("LOCAL_RANK", 0)) - if device.type == "cuda": + if device_obj.type == "cuda": torch.cuda.set_device(local_rank) - elif device.type == "xpu": + elif device_obj.type == "xpu": torch.xpu.set_device(local_rank) torch.manual_seed(seed) @@ -102,29 +113,27 @@ class Llama3: def build_model(): if model_args.vision_chunk_size > 0: model = CrossAttentionTransformer(model_args) - model.setup_cache(model_args.max_batch_size, device=device, dtype=torch.get_default_dtype()) + model.setup_cache(model_args.max_batch_size, device=device_obj, dtype=torch.get_default_dtype()) else: model = Transformer(model_args) return model - if quantization_mode == QuantizationMode.fp8_mixed or quantization_mode == QuantizationMode.int4_mixed: - from .quantization.loader import convert_to_quantized_model - + if quantization_mode == QuantizationMode.fp8_mixed or quantization_mode == QuantizationMode.int4_mixed: torch.set_default_tensor_type(torch.BFloat16Tensor) model = build_model() print("Loading state dict...") model.load_state_dict(state_dict, strict=False) print("Done...") - model = convert_to_quantized_model(model, ckpt_dir, quantization_mode, device=device) - torch.set_default_device(device) + model = convert_to_quantized_model(model, ckpt_dir, quantization_mode, device=device_obj) + torch.set_default_device(device_obj) else: - print(f"Setting default device to {device}") - if device.type == "cuda": + print(f"Setting default device to {device_obj}") + if device_obj.type == "cuda": if torch.cuda.is_bf16_supported(): torch.set_default_tensor_type(torch.cuda.BFloat16Tensor) else: torch.set_default_tensor_type(torch.cuda.Float16Tensor) - elif device.type == "xpu": + elif device_obj.type == "xpu": if torch.xpu.is_bf16_supported(): torch.set_default_tensor_type(torch.xpu.BFloat16Tensor) else: @@ -133,7 +142,7 @@ class Llama3: model = build_model() print("Loading state dict...") model.load_state_dict(state_dict, strict=True) - model.to(device) + model.to(device_obj) print("Done...") print(f"Loaded in {time.time() - start_time:.2f} seconds") @@ -212,6 +221,11 @@ class Llama3: total_len=total_len, device=tokens.device, ) + else: + # Define dummy values for non-vision models to satisfy mypy + xattn_caches = torch.tensor([]) + cross_attention_masks = torch.tensor([]) + full_text_row_masked_out_mask = torch.tensor([]) eos_reached = torch.tensor([False] * bsz) input_text_mask = tokens != pad_id @@ -240,16 +254,33 @@ class Llama3: if is_vision: position_ids = torch.arange(prev_pos, cur_pos, dtype=torch.long) text_only_inference = all(inp.vision is None for inp in llm_inputs) - logits = self.model.forward( - position_ids, - tokens, - cross_attention_masks, - full_text_row_masked_out_mask, - xattn_caches, - text_only_inference, - ) + # Type narrowing for mypy + if isinstance(self.model, CrossAttentionTransformer): + logits = self.model.forward( + position_ids=position_ids, + tokens=tokens, + cross_attention_masks=cross_attention_masks, + full_text_row_masked_out_mask=full_text_row_masked_out_mask, + xattn_caches=xattn_caches, + text_only_inference=text_only_inference, + ) + else: + # This should not happen when is_vision=True, but for type safety + logits = self.model.forward(tokens[:, prev_pos:cur_pos], prev_pos) else: - logits = self.model.forward(tokens[:, prev_pos:cur_pos], prev_pos) + # Type narrowing for mypy + if isinstance(self.model, Transformer): + logits = self.model.forward(tokens[:, prev_pos:cur_pos], prev_pos) + else: + # This should not happen when is_vision=False, but for type safety + logits = self.model.forward( + position_ids=torch.arange(prev_pos, cur_pos, dtype=torch.long), + tokens=tokens, + cross_attention_masks=cross_attention_masks, + full_text_row_masked_out_mask=full_text_row_masked_out_mask, + xattn_caches=xattn_caches, + text_only_inference=False, + ) if logits_processor is not None: logits = logits_processor(tokens[:, :cur_pos], logits) @@ -316,7 +347,7 @@ class Llama3: ) -> Generator[list[GenerationResult], None, None]: model_inputs = [self.formatter.encode_content(c) for c in contents] for result in self.generate( - model_inputs=model_inputs, + llm_inputs=model_inputs, temperature=temperature, top_p=top_p, max_gen_len=max_gen_len, @@ -339,7 +370,7 @@ class Llama3: ) -> Generator[list[GenerationResult], None, None]: model_inputs = [self.formatter.encode_dialog_prompt(messages) for messages in messages_batch] for result in self.generate( - model_inputs=model_inputs, + llm_inputs=model_inputs, temperature=temperature, top_p=top_p, max_gen_len=max_gen_len, diff --git a/pyproject.toml b/pyproject.toml index 30598e5e3..9198cd0c6 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -248,7 +248,6 @@ exclude = [ "^llama_stack/providers/inline/datasetio/localfs/", "^llama_stack/providers/inline/eval/meta_reference/eval\\.py$", "^llama_stack/providers/inline/inference/meta_reference/inference\\.py$", - "^llama_stack/models/llama/llama3/generation\\.py$", "^llama_stack/models/llama/llama3/multimodal/model\\.py$", "^llama_stack/models/llama/llama4/", "^llama_stack/providers/inline/inference/meta_reference/parallel_utils\\.py$",