From 76004eacb493a7e0ddf1b230861ac45c8a214a5b Mon Sep 17 00:00:00 2001 From: Ashwin Bharambe Date: Mon, 7 Apr 2025 12:57:58 -0700 Subject: [PATCH] rename quant types to use _mixed naming --- llama_stack/apis/inference/inference.py | 12 ++++++------ .../models/llama/llama4/quantization/loader.py | 2 +- .../inline/inference/meta_reference/generators.py | 8 ++++---- 3 files changed, 11 insertions(+), 11 deletions(-) diff --git a/llama_stack/apis/inference/inference.py b/llama_stack/apis/inference/inference.py index 216935ede..e59132e33 100644 --- a/llama_stack/apis/inference/inference.py +++ b/llama_stack/apis/inference/inference.py @@ -97,18 +97,18 @@ class QuantizationType(Enum): """Type of model quantization to run inference with. :cvar bf16: BFloat16 typically this means _no_ quantization - :cvar fp8: 8-bit floating point quantization - :cvar int4: 4-bit integer quantization + :cvar fp8_mixed: 8-bit floating point quantization with mixed precision + :cvar int4_mixed: 4-bit integer quantization with mixed precision """ bf16 = "bf16" - fp8 = "fp8" - int4 = "int4" + fp8_mixed = "fp8_mixed" + int4_mixed = "int4_mixed" @json_schema_type class Fp8QuantizationConfig(BaseModel): - type: Literal["fp8"] = "fp8" + type: Literal["fp8_mixed"] = "fp8_mixed" @json_schema_type @@ -124,7 +124,7 @@ class Int4QuantizationConfig(BaseModel): :param scheme: Quantization scheme to use. Defaults to "int4_weight_int8_dynamic_activation" """ - type: Literal["int4"] = "int4" + type: Literal["int4_mixed"] = "int4_mixed" scheme: Optional[str] = "int4_weight_int8_dynamic_activation" diff --git a/llama_stack/models/llama/llama4/quantization/loader.py b/llama_stack/models/llama/llama4/quantization/loader.py index f11d83c60..b50432896 100644 --- a/llama_stack/models/llama/llama4/quantization/loader.py +++ b/llama_stack/models/llama/llama4/quantization/loader.py @@ -91,7 +91,7 @@ def convert_to_quantized_model( log_status(f"Rank {rank}: Quantizing int4 weights from bf16") def apply_quantization(_, weight): - return quantize_int4(weight, output_device=torch.device("cuda")) + return quantize_int4(weight, fp8_activation_scale_ub, output_device=torch.device("cuda")) else: fp8_scales_path = os.path.join(checkpoint_dir, f"fp8_scales_{rank}.pt") diff --git a/llama_stack/providers/inline/inference/meta_reference/generators.py b/llama_stack/providers/inline/inference/meta_reference/generators.py index c2baed905..65bed4d8c 100644 --- a/llama_stack/providers/inline/inference/meta_reference/generators.py +++ b/llama_stack/providers/inline/inference/meta_reference/generators.py @@ -133,9 +133,9 @@ class Llama4Generator: ckpt_dir = model_checkpoint_dir(resolved_model.descriptor()) if config.quantization: - if config.quantization.type == "fp8": + if config.quantization.type == "fp8_mixed": quantization_mode = QuantizationMode.fp8_mixed - elif config.quantization.type == "int4": + elif config.quantization.type == "int4_mixed": quantization_mode = QuantizationMode.int4_mixed elif config.quantization.type == "bf16": quantization_mode = None @@ -226,9 +226,9 @@ class Llama3Generator: ckpt_dir = model_checkpoint_dir(resolved_model.descriptor()) if config.quantization: - if config.quantization.type == "fp8": + if config.quantization.type == "fp8_mixed": quantization_mode = QuantizationMode.fp8_mixed - elif config.quantization.type == "int4": + elif config.quantization.type == "int4_mixed": quantization_mode = QuantizationMode.int4_mixed elif config.quantization.type == "bf16": quantization_mode = None