make inference server load checkpoints for fp8 inference

- introduce quantization related args for inference config
- also kill GeneratorArgs
This commit is contained in:
Ashwin Bharambe 2024-07-20 21:10:17 -07:00
parent 7d2c0b14b8
commit ad62e2e1f3
10 changed files with 249 additions and 155 deletions

View file

@ -7,14 +7,7 @@ from hydra.core.config_store import ConfigStore
from pydantic import BaseModel, Field
from typing_extensions import Annotated
@dataclass
class GeneratorArgs:
ckpt_dir: str
tokenizer_path: str
model_parallel_size: Optional[int] = None
max_seq_len: int = 2048
max_batch_size: int = 4
from .datatypes import QuantizationConfig
class ImplType(Enum):
@ -27,6 +20,17 @@ class CheckpointType(Enum):
huggingface = "huggingface"
# This enum represents the format in which weights are specified
# This does not necessarily always equal what quantization is desired
# at runtime since there can be on-the-fly conversions done
class CheckpointQuantizationFormat(Enum):
# default format
bf16 = "bf16"
# used for enabling fp8_rowwise inference, some weights are bf16
fp8_mixed = "fp8_mixed"
class PytorchCheckpoint(BaseModel):
checkpoint_type: Literal[CheckpointType.pytorch.value] = (
CheckpointType.pytorch.value
@ -34,6 +38,9 @@ class PytorchCheckpoint(BaseModel):
checkpoint_dir: str
tokenizer_path: str
model_parallel_size: int
quantization_format: CheckpointQuantizationFormat = (
CheckpointQuantizationFormat.bf16
)
class HuggingFaceCheckpoint(BaseModel):
@ -42,6 +49,9 @@ class HuggingFaceCheckpoint(BaseModel):
)
repo_id: str # or model_name ?
model_parallel_size: int
quantization_format: CheckpointQuantizationFormat = (
CheckpointQuantizationFormat.bf16
)
class ModelCheckpointConfig(BaseModel):
@ -51,10 +61,11 @@ class ModelCheckpointConfig(BaseModel):
]
# NOTE: this same config will be used when instantiating an inference server naturally
class InlineImplConfig(BaseModel):
impl_type: Literal[ImplType.inline.value] = ImplType.inline.value
checkpoint_config: ModelCheckpointConfig
quantization: Optional[QuantizationConfig] = None
torch_seed: Optional[int] = None
max_seq_len: int
max_batch_size: int = 1
@ -86,6 +97,7 @@ class InlineImplHydraConfig:
model_parallel_size: int
max_seq_len: int
max_batch_size: int = 1
quantization: Optional[QuantizationConfig] = None
# TODO: huggingface checkpoint required args
def convert_to_inline_impl_config(self):
@ -99,6 +111,7 @@ class InlineImplHydraConfig:
model_parallel_size=self.model_parallel_size,
)
),
quantization=self.quantization,
max_seq_len=self.max_seq_len,
max_batch_size=self.max_batch_size,
)

View file

@ -21,19 +21,19 @@ class QuantizationType(Enum):
@json_schema_type
class Fp8QuantizationConfig(BaseModel):
quantization_type: Literal[QuantizationType.fp8.value] = QuantizationType.fp8.value
type: Literal[QuantizationType.fp8.value] = QuantizationType.fp8.value
@json_schema_type
class Bf16QuantizationConfig(BaseModel):
quantization_type: Literal[QuantizationType.bf16.value] = (
type: Literal[QuantizationType.bf16.value] = (
QuantizationType.bf16.value
)
QuantizationConfig = Annotated[
Union[Bf16QuantizationConfig, Fp8QuantizationConfig],
Field(discriminator="quantization_type"),
Field(discriminator="type"),
]