mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-10-04 04:04:14 +00:00
make inference server load checkpoints for fp8 inference
- introduce quantization related args for inference config - also kill GeneratorArgs
This commit is contained in:
parent
7d2c0b14b8
commit
ad62e2e1f3
10 changed files with 249 additions and 155 deletions
|
@ -7,14 +7,7 @@ from hydra.core.config_store import ConfigStore
|
|||
from pydantic import BaseModel, Field
|
||||
from typing_extensions import Annotated
|
||||
|
||||
|
||||
@dataclass
|
||||
class GeneratorArgs:
|
||||
ckpt_dir: str
|
||||
tokenizer_path: str
|
||||
model_parallel_size: Optional[int] = None
|
||||
max_seq_len: int = 2048
|
||||
max_batch_size: int = 4
|
||||
from .datatypes import QuantizationConfig
|
||||
|
||||
|
||||
class ImplType(Enum):
|
||||
|
@ -27,6 +20,17 @@ class CheckpointType(Enum):
|
|||
huggingface = "huggingface"
|
||||
|
||||
|
||||
# This enum represents the format in which weights are specified
|
||||
# This does not necessarily always equal what quantization is desired
|
||||
# at runtime since there can be on-the-fly conversions done
|
||||
class CheckpointQuantizationFormat(Enum):
|
||||
# default format
|
||||
bf16 = "bf16"
|
||||
|
||||
# used for enabling fp8_rowwise inference, some weights are bf16
|
||||
fp8_mixed = "fp8_mixed"
|
||||
|
||||
|
||||
class PytorchCheckpoint(BaseModel):
|
||||
checkpoint_type: Literal[CheckpointType.pytorch.value] = (
|
||||
CheckpointType.pytorch.value
|
||||
|
@ -34,6 +38,9 @@ class PytorchCheckpoint(BaseModel):
|
|||
checkpoint_dir: str
|
||||
tokenizer_path: str
|
||||
model_parallel_size: int
|
||||
quantization_format: CheckpointQuantizationFormat = (
|
||||
CheckpointQuantizationFormat.bf16
|
||||
)
|
||||
|
||||
|
||||
class HuggingFaceCheckpoint(BaseModel):
|
||||
|
@ -42,6 +49,9 @@ class HuggingFaceCheckpoint(BaseModel):
|
|||
)
|
||||
repo_id: str # or model_name ?
|
||||
model_parallel_size: int
|
||||
quantization_format: CheckpointQuantizationFormat = (
|
||||
CheckpointQuantizationFormat.bf16
|
||||
)
|
||||
|
||||
|
||||
class ModelCheckpointConfig(BaseModel):
|
||||
|
@ -51,10 +61,11 @@ class ModelCheckpointConfig(BaseModel):
|
|||
]
|
||||
|
||||
|
||||
# NOTE: this same config will be used when instantiating an inference server naturally
|
||||
class InlineImplConfig(BaseModel):
|
||||
impl_type: Literal[ImplType.inline.value] = ImplType.inline.value
|
||||
checkpoint_config: ModelCheckpointConfig
|
||||
quantization: Optional[QuantizationConfig] = None
|
||||
torch_seed: Optional[int] = None
|
||||
max_seq_len: int
|
||||
max_batch_size: int = 1
|
||||
|
||||
|
@ -86,6 +97,7 @@ class InlineImplHydraConfig:
|
|||
model_parallel_size: int
|
||||
max_seq_len: int
|
||||
max_batch_size: int = 1
|
||||
quantization: Optional[QuantizationConfig] = None
|
||||
# TODO: huggingface checkpoint required args
|
||||
|
||||
def convert_to_inline_impl_config(self):
|
||||
|
@ -99,6 +111,7 @@ class InlineImplHydraConfig:
|
|||
model_parallel_size=self.model_parallel_size,
|
||||
)
|
||||
),
|
||||
quantization=self.quantization,
|
||||
max_seq_len=self.max_seq_len,
|
||||
max_batch_size=self.max_batch_size,
|
||||
)
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue