mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-07-29 15:23:51 +00:00
chore(api): add mypy coverage to meta_llama3_generation
Signed-off-by: Mustafa Elbehery <melbeher@redhat.com>
This commit is contained in:
parent
81ebaf6e9a
commit
6a11fdb57b
2 changed files with 64 additions and 34 deletions
|
@ -17,14 +17,15 @@ import sys
|
||||||
import time
|
import time
|
||||||
from collections.abc import Callable, Generator
|
from collections.abc import Callable, Generator
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
from typing import TYPE_CHECKING
|
||||||
|
|
||||||
import torch
|
import torch # type: ignore
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F # type: ignore
|
||||||
from fairscale.nn.model_parallel.initialize import (
|
from fairscale.nn.model_parallel.initialize import ( # type: ignore
|
||||||
initialize_model_parallel,
|
initialize_model_parallel,
|
||||||
model_parallel_is_initialized,
|
model_parallel_is_initialized,
|
||||||
)
|
)
|
||||||
from termcolor import cprint
|
from termcolor import cprint # type: ignore
|
||||||
|
|
||||||
from ..checkpoint import maybe_reshard_state_dict
|
from ..checkpoint import maybe_reshard_state_dict
|
||||||
from ..datatypes import GenerationResult, QuantizationMode, RawContent, RawMessage, ToolPromptFormat
|
from ..datatypes import GenerationResult, QuantizationMode, RawContent, RawMessage, ToolPromptFormat
|
||||||
|
@ -34,6 +35,16 @@ from .model import Transformer
|
||||||
from .multimodal.model import CrossAttentionTransformer
|
from .multimodal.model import CrossAttentionTransformer
|
||||||
from .tokenizer import Tokenizer
|
from .tokenizer import Tokenizer
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from .quantization.loader import convert_to_quantized_model
|
||||||
|
else:
|
||||||
|
# Import at runtime to avoid circular dependencies
|
||||||
|
def _get_convert_to_quantized_model():
|
||||||
|
from .quantization.loader import convert_to_quantized_model
|
||||||
|
return convert_to_quantized_model
|
||||||
|
|
||||||
|
convert_to_quantized_model = _get_convert_to_quantized_model()
|
||||||
|
|
||||||
|
|
||||||
class Llama3:
|
class Llama3:
|
||||||
@staticmethod
|
@staticmethod
|
||||||
|
@ -44,19 +55,19 @@ class Llama3:
|
||||||
world_size: int | None = None,
|
world_size: int | None = None,
|
||||||
quantization_mode: QuantizationMode | None = None,
|
quantization_mode: QuantizationMode | None = None,
|
||||||
seed: int = 1,
|
seed: int = 1,
|
||||||
device: str = "cuda",
|
device: str | torch.device = "cuda",
|
||||||
):
|
):
|
||||||
device = torch.device(device)
|
device_obj = torch.device(device)
|
||||||
if (
|
if (
|
||||||
device.type == "cuda"
|
device_obj.type == "cuda"
|
||||||
and not torch.cuda.is_available()
|
and not torch.cuda.is_available()
|
||||||
or device.type == "xpu"
|
or device_obj.type == "xpu"
|
||||||
and not torch.xpu.is_available()
|
and not torch.xpu.is_available()
|
||||||
):
|
):
|
||||||
raise RuntimeError(f"PyTorch backend for {device.type} device type is not available")
|
raise RuntimeError(f"PyTorch backend for {device_obj.type} device type is not available")
|
||||||
|
|
||||||
if not torch.distributed.is_initialized():
|
if not torch.distributed.is_initialized():
|
||||||
if device.type == "cuda":
|
if device_obj.type == "cuda":
|
||||||
torch.distributed.init_process_group("nccl")
|
torch.distributed.init_process_group("nccl")
|
||||||
else:
|
else:
|
||||||
torch.distributed.init_process_group("gloo")
|
torch.distributed.init_process_group("gloo")
|
||||||
|
@ -67,9 +78,9 @@ class Llama3:
|
||||||
initialize_model_parallel(world_size)
|
initialize_model_parallel(world_size)
|
||||||
|
|
||||||
local_rank = int(os.environ.get("LOCAL_RANK", 0))
|
local_rank = int(os.environ.get("LOCAL_RANK", 0))
|
||||||
if device.type == "cuda":
|
if device_obj.type == "cuda":
|
||||||
torch.cuda.set_device(local_rank)
|
torch.cuda.set_device(local_rank)
|
||||||
elif device.type == "xpu":
|
elif device_obj.type == "xpu":
|
||||||
torch.xpu.set_device(local_rank)
|
torch.xpu.set_device(local_rank)
|
||||||
|
|
||||||
torch.manual_seed(seed)
|
torch.manual_seed(seed)
|
||||||
|
@ -102,29 +113,27 @@ class Llama3:
|
||||||
def build_model():
|
def build_model():
|
||||||
if model_args.vision_chunk_size > 0:
|
if model_args.vision_chunk_size > 0:
|
||||||
model = CrossAttentionTransformer(model_args)
|
model = CrossAttentionTransformer(model_args)
|
||||||
model.setup_cache(model_args.max_batch_size, device=device, dtype=torch.get_default_dtype())
|
model.setup_cache(model_args.max_batch_size, device=device_obj, dtype=torch.get_default_dtype())
|
||||||
else:
|
else:
|
||||||
model = Transformer(model_args)
|
model = Transformer(model_args)
|
||||||
return model
|
return model
|
||||||
|
|
||||||
if quantization_mode == QuantizationMode.fp8_mixed or quantization_mode == QuantizationMode.int4_mixed:
|
if quantization_mode == QuantizationMode.fp8_mixed or quantization_mode == QuantizationMode.int4_mixed:
|
||||||
from .quantization.loader import convert_to_quantized_model
|
|
||||||
|
|
||||||
torch.set_default_tensor_type(torch.BFloat16Tensor)
|
torch.set_default_tensor_type(torch.BFloat16Tensor)
|
||||||
model = build_model()
|
model = build_model()
|
||||||
print("Loading state dict...")
|
print("Loading state dict...")
|
||||||
model.load_state_dict(state_dict, strict=False)
|
model.load_state_dict(state_dict, strict=False)
|
||||||
print("Done...")
|
print("Done...")
|
||||||
model = convert_to_quantized_model(model, ckpt_dir, quantization_mode, device=device)
|
model = convert_to_quantized_model(model, ckpt_dir, quantization_mode, device=device_obj)
|
||||||
torch.set_default_device(device)
|
torch.set_default_device(device_obj)
|
||||||
else:
|
else:
|
||||||
print(f"Setting default device to {device}")
|
print(f"Setting default device to {device_obj}")
|
||||||
if device.type == "cuda":
|
if device_obj.type == "cuda":
|
||||||
if torch.cuda.is_bf16_supported():
|
if torch.cuda.is_bf16_supported():
|
||||||
torch.set_default_tensor_type(torch.cuda.BFloat16Tensor)
|
torch.set_default_tensor_type(torch.cuda.BFloat16Tensor)
|
||||||
else:
|
else:
|
||||||
torch.set_default_tensor_type(torch.cuda.Float16Tensor)
|
torch.set_default_tensor_type(torch.cuda.Float16Tensor)
|
||||||
elif device.type == "xpu":
|
elif device_obj.type == "xpu":
|
||||||
if torch.xpu.is_bf16_supported():
|
if torch.xpu.is_bf16_supported():
|
||||||
torch.set_default_tensor_type(torch.xpu.BFloat16Tensor)
|
torch.set_default_tensor_type(torch.xpu.BFloat16Tensor)
|
||||||
else:
|
else:
|
||||||
|
@ -133,7 +142,7 @@ class Llama3:
|
||||||
model = build_model()
|
model = build_model()
|
||||||
print("Loading state dict...")
|
print("Loading state dict...")
|
||||||
model.load_state_dict(state_dict, strict=True)
|
model.load_state_dict(state_dict, strict=True)
|
||||||
model.to(device)
|
model.to(device_obj)
|
||||||
print("Done...")
|
print("Done...")
|
||||||
|
|
||||||
print(f"Loaded in {time.time() - start_time:.2f} seconds")
|
print(f"Loaded in {time.time() - start_time:.2f} seconds")
|
||||||
|
@ -212,6 +221,11 @@ class Llama3:
|
||||||
total_len=total_len,
|
total_len=total_len,
|
||||||
device=tokens.device,
|
device=tokens.device,
|
||||||
)
|
)
|
||||||
|
else:
|
||||||
|
# Define dummy values for non-vision models to satisfy mypy
|
||||||
|
xattn_caches = torch.tensor([])
|
||||||
|
cross_attention_masks = torch.tensor([])
|
||||||
|
full_text_row_masked_out_mask = torch.tensor([])
|
||||||
|
|
||||||
eos_reached = torch.tensor([False] * bsz)
|
eos_reached = torch.tensor([False] * bsz)
|
||||||
input_text_mask = tokens != pad_id
|
input_text_mask = tokens != pad_id
|
||||||
|
@ -240,16 +254,33 @@ class Llama3:
|
||||||
if is_vision:
|
if is_vision:
|
||||||
position_ids = torch.arange(prev_pos, cur_pos, dtype=torch.long)
|
position_ids = torch.arange(prev_pos, cur_pos, dtype=torch.long)
|
||||||
text_only_inference = all(inp.vision is None for inp in llm_inputs)
|
text_only_inference = all(inp.vision is None for inp in llm_inputs)
|
||||||
|
# Type narrowing for mypy
|
||||||
|
if isinstance(self.model, CrossAttentionTransformer):
|
||||||
logits = self.model.forward(
|
logits = self.model.forward(
|
||||||
position_ids,
|
position_ids=position_ids,
|
||||||
tokens,
|
tokens=tokens,
|
||||||
cross_attention_masks,
|
cross_attention_masks=cross_attention_masks,
|
||||||
full_text_row_masked_out_mask,
|
full_text_row_masked_out_mask=full_text_row_masked_out_mask,
|
||||||
xattn_caches,
|
xattn_caches=xattn_caches,
|
||||||
text_only_inference,
|
text_only_inference=text_only_inference,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
|
# This should not happen when is_vision=True, but for type safety
|
||||||
logits = self.model.forward(tokens[:, prev_pos:cur_pos], prev_pos)
|
logits = self.model.forward(tokens[:, prev_pos:cur_pos], prev_pos)
|
||||||
|
else:
|
||||||
|
# Type narrowing for mypy
|
||||||
|
if isinstance(self.model, Transformer):
|
||||||
|
logits = self.model.forward(tokens[:, prev_pos:cur_pos], prev_pos)
|
||||||
|
else:
|
||||||
|
# This should not happen when is_vision=False, but for type safety
|
||||||
|
logits = self.model.forward(
|
||||||
|
position_ids=torch.arange(prev_pos, cur_pos, dtype=torch.long),
|
||||||
|
tokens=tokens,
|
||||||
|
cross_attention_masks=cross_attention_masks,
|
||||||
|
full_text_row_masked_out_mask=full_text_row_masked_out_mask,
|
||||||
|
xattn_caches=xattn_caches,
|
||||||
|
text_only_inference=False,
|
||||||
|
)
|
||||||
|
|
||||||
if logits_processor is not None:
|
if logits_processor is not None:
|
||||||
logits = logits_processor(tokens[:, :cur_pos], logits)
|
logits = logits_processor(tokens[:, :cur_pos], logits)
|
||||||
|
@ -316,7 +347,7 @@ class Llama3:
|
||||||
) -> Generator[list[GenerationResult], None, None]:
|
) -> Generator[list[GenerationResult], None, None]:
|
||||||
model_inputs = [self.formatter.encode_content(c) for c in contents]
|
model_inputs = [self.formatter.encode_content(c) for c in contents]
|
||||||
for result in self.generate(
|
for result in self.generate(
|
||||||
model_inputs=model_inputs,
|
llm_inputs=model_inputs,
|
||||||
temperature=temperature,
|
temperature=temperature,
|
||||||
top_p=top_p,
|
top_p=top_p,
|
||||||
max_gen_len=max_gen_len,
|
max_gen_len=max_gen_len,
|
||||||
|
@ -339,7 +370,7 @@ class Llama3:
|
||||||
) -> Generator[list[GenerationResult], None, None]:
|
) -> Generator[list[GenerationResult], None, None]:
|
||||||
model_inputs = [self.formatter.encode_dialog_prompt(messages) for messages in messages_batch]
|
model_inputs = [self.formatter.encode_dialog_prompt(messages) for messages in messages_batch]
|
||||||
for result in self.generate(
|
for result in self.generate(
|
||||||
model_inputs=model_inputs,
|
llm_inputs=model_inputs,
|
||||||
temperature=temperature,
|
temperature=temperature,
|
||||||
top_p=top_p,
|
top_p=top_p,
|
||||||
max_gen_len=max_gen_len,
|
max_gen_len=max_gen_len,
|
||||||
|
|
|
@ -248,7 +248,6 @@ exclude = [
|
||||||
"^llama_stack/providers/inline/datasetio/localfs/",
|
"^llama_stack/providers/inline/datasetio/localfs/",
|
||||||
"^llama_stack/providers/inline/eval/meta_reference/eval\\.py$",
|
"^llama_stack/providers/inline/eval/meta_reference/eval\\.py$",
|
||||||
"^llama_stack/providers/inline/inference/meta_reference/inference\\.py$",
|
"^llama_stack/providers/inline/inference/meta_reference/inference\\.py$",
|
||||||
"^llama_stack/models/llama/llama3/generation\\.py$",
|
|
||||||
"^llama_stack/models/llama/llama3/multimodal/model\\.py$",
|
"^llama_stack/models/llama/llama3/multimodal/model\\.py$",
|
||||||
"^llama_stack/models/llama/llama4/",
|
"^llama_stack/models/llama/llama4/",
|
||||||
"^llama_stack/providers/inline/inference/meta_reference/parallel_utils\\.py$",
|
"^llama_stack/providers/inline/inference/meta_reference/parallel_utils\\.py$",
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue