mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-10-04 12:07:34 +00:00
make inference server load checkpoints for fp8 inference
- introduce quantization related args for inference config - also kill GeneratorArgs
This commit is contained in:
parent
7d2c0b14b8
commit
ad62e2e1f3
10 changed files with 249 additions and 155 deletions
|
@ -6,7 +6,7 @@ from models.llama3_1.api.chat_format import ChatFormat
|
|||
from models.llama3_1.api.datatypes import Message
|
||||
from models.llama3_1.api.tokenizer import Tokenizer
|
||||
|
||||
from .api.config import GeneratorArgs
|
||||
from .api.config import InlineImplConfig
|
||||
from .generation import Llama
|
||||
from .parallel_utils import ModelParallelProcessGroup
|
||||
|
||||
|
@ -35,13 +35,8 @@ class ModelRunner:
|
|||
)
|
||||
|
||||
|
||||
def init_model_cb(args: GeneratorArgs):
|
||||
llama = Llama.build(
|
||||
args.ckpt_dir,
|
||||
args.tokenizer_path,
|
||||
args.max_seq_len,
|
||||
args.max_batch_size,
|
||||
)
|
||||
def init_model_cb(config: InlineImplConfig):
|
||||
llama = Llama.build(config)
|
||||
return ModelRunner(llama)
|
||||
|
||||
|
||||
|
@ -56,12 +51,13 @@ class LlamaModelParallelGenerator:
|
|||
clear at the callsite why we need to use a context manager.
|
||||
"""
|
||||
|
||||
def __init__(self, args: GeneratorArgs):
|
||||
self.args = args
|
||||
def __init__(self, config: InlineImplConfig):
|
||||
self.config = config
|
||||
|
||||
# this is a hack because Agent's loop uses this to tokenize and check if input is too long
|
||||
# while the tool-use loop is going
|
||||
self.formatter = ChatFormat(Tokenizer(self.args.tokenizer_path))
|
||||
checkpoint = self.config.checkpoint_config.checkpoint
|
||||
self.formatter = ChatFormat(Tokenizer(checkpoint.tokenizer_path))
|
||||
|
||||
def start(self):
|
||||
self.__enter__()
|
||||
|
@ -70,9 +66,10 @@ class LlamaModelParallelGenerator:
|
|||
self.__exit__(None, None, None)
|
||||
|
||||
def __enter__(self):
|
||||
checkpoint = self.config.checkpoint_config.checkpoint
|
||||
self.group = ModelParallelProcessGroup(
|
||||
self.args.model_parallel_size,
|
||||
init_model_cb=partial(init_model_cb, self.args),
|
||||
checkpoint.model_parallel_size,
|
||||
init_model_cb=partial(init_model_cb, self.config),
|
||||
)
|
||||
self.group.start()
|
||||
return self
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue