make inference server load checkpoints for fp8 inference

- introduce quantization related args for inference config
- also kill GeneratorArgs
This commit is contained in:
Ashwin Bharambe 2024-07-20 21:10:17 -07:00
parent 7d2c0b14b8
commit ad62e2e1f3
10 changed files with 249 additions and 155 deletions

View file

@ -18,6 +18,12 @@ from fp8.fp8_impls import ffn_swiglu
from torch import nn
@dataclass
class QuantizationArgs:
fp8_rowwise: bool = False
convert_from_bf16: bool = False
@dataclass
class ModelArgs:
dim: int = 4096
@ -31,6 +37,8 @@ class ModelArgs:
rope_theta: float = 500000
use_scaled_rope: bool = False
quantization: Optional[QuantizationArgs] = None
max_batch_size: int = 32
max_seq_len: int = 2048