make inference server load checkpoints for fp8 inference

- introduce quantization related args for inference config
- also kill GeneratorArgs
This commit is contained in:
Ashwin Bharambe 2024-07-20 21:10:17 -07:00
parent 7d2c0b14b8
commit ad62e2e1f3
10 changed files with 249 additions and 155 deletions

View file

@ -2,10 +2,15 @@ from typing import AsyncGenerator
from models.llama3_1.api.datatypes import StopReason
from .api.config import CheckpointType, GeneratorArgs, InlineImplConfig
from .api.config import (
CheckpointQuantizationFormat,
CheckpointType,
InlineImplConfig,
)
from .api.datatypes import (
ChatCompletionResponseEvent,
ChatCompletionResponseEventType,
QuantizationConfig,
ToolCallDelta,
ToolCallParseStatus,
)
@ -18,33 +23,13 @@ from .api.endpoints import (
from .model_parallel import LlamaModelParallelGenerator
def generator_args_from_config(config: InlineImplConfig) -> GeneratorArgs:
if (
config.checkpoint_config.checkpoint.checkpoint_type
== CheckpointType.pytorch.value
):
pt_checkpoint = config.checkpoint_config.checkpoint
return GeneratorArgs(
ckpt_dir=pt_checkpoint.checkpoint_dir,
tokenizer_path=pt_checkpoint.tokenizer_path,
model_parallel_size=pt_checkpoint.model_parallel_size,
max_seq_len=config.max_seq_len,
max_batch_size=config.max_batch_size,
)
else:
raise NotImplementedError("HF Checkpoint not supported yet")
class ModelInferenceImpl(ModelInference):
def __init__(self, config: InlineImplConfig) -> None:
self.config = config
async def initialize(self) -> None:
generator_args = generator_args_from_config(self.config)
self.generator = LlamaModelParallelGenerator(
args=generator_args,
)
self.generator = LlamaModelParallelGenerator(self.config)
self.generator.start()
async def shutdown(self) -> None: