mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-10-04 12:07:34 +00:00
make inference server load checkpoints for fp8 inference
- introduce quantization related args for inference config - also kill GeneratorArgs
This commit is contained in:
parent
7d2c0b14b8
commit
ad62e2e1f3
10 changed files with 249 additions and 155 deletions
|
@ -2,10 +2,15 @@ from typing import AsyncGenerator
|
|||
|
||||
from models.llama3_1.api.datatypes import StopReason
|
||||
|
||||
from .api.config import CheckpointType, GeneratorArgs, InlineImplConfig
|
||||
from .api.config import (
|
||||
CheckpointQuantizationFormat,
|
||||
CheckpointType,
|
||||
InlineImplConfig,
|
||||
)
|
||||
from .api.datatypes import (
|
||||
ChatCompletionResponseEvent,
|
||||
ChatCompletionResponseEventType,
|
||||
QuantizationConfig,
|
||||
ToolCallDelta,
|
||||
ToolCallParseStatus,
|
||||
)
|
||||
|
@ -18,33 +23,13 @@ from .api.endpoints import (
|
|||
from .model_parallel import LlamaModelParallelGenerator
|
||||
|
||||
|
||||
def generator_args_from_config(config: InlineImplConfig) -> GeneratorArgs:
|
||||
if (
|
||||
config.checkpoint_config.checkpoint.checkpoint_type
|
||||
== CheckpointType.pytorch.value
|
||||
):
|
||||
pt_checkpoint = config.checkpoint_config.checkpoint
|
||||
return GeneratorArgs(
|
||||
ckpt_dir=pt_checkpoint.checkpoint_dir,
|
||||
tokenizer_path=pt_checkpoint.tokenizer_path,
|
||||
model_parallel_size=pt_checkpoint.model_parallel_size,
|
||||
max_seq_len=config.max_seq_len,
|
||||
max_batch_size=config.max_batch_size,
|
||||
)
|
||||
else:
|
||||
raise NotImplementedError("HF Checkpoint not supported yet")
|
||||
|
||||
|
||||
class ModelInferenceImpl(ModelInference):
|
||||
|
||||
def __init__(self, config: InlineImplConfig) -> None:
|
||||
self.config = config
|
||||
|
||||
async def initialize(self) -> None:
|
||||
generator_args = generator_args_from_config(self.config)
|
||||
self.generator = LlamaModelParallelGenerator(
|
||||
args=generator_args,
|
||||
)
|
||||
self.generator = LlamaModelParallelGenerator(self.config)
|
||||
self.generator.start()
|
||||
|
||||
async def shutdown(self) -> None:
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue