mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-10-04 12:07:34 +00:00
update inference config to take model and not model_dir
This commit is contained in:
parent
08c3802f45
commit
039861f1c7
9 changed files with 400 additions and 101 deletions
|
@ -4,6 +4,7 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import os
|
||||
from copy import deepcopy
|
||||
from dataclasses import dataclass
|
||||
from functools import partial
|
||||
|
@ -12,9 +13,10 @@ from typing import Generator, List, Optional
|
|||
from llama_models.llama3_1.api.chat_format import ChatFormat
|
||||
from llama_models.llama3_1.api.datatypes import Message
|
||||
from llama_models.llama3_1.api.tokenizer import Tokenizer
|
||||
from llama_models.sku_list import resolve_model
|
||||
|
||||
from .config import MetaReferenceImplConfig
|
||||
from .generation import Llama
|
||||
from .generation import Llama, model_checkpoint_dir
|
||||
from .parallel_utils import ModelParallelProcessGroup
|
||||
|
||||
|
||||
|
@ -60,11 +62,12 @@ class LlamaModelParallelGenerator:
|
|||
|
||||
def __init__(self, config: MetaReferenceImplConfig):
|
||||
self.config = config
|
||||
|
||||
self.model = resolve_model(self.config.model)
|
||||
# this is a hack because Agent's loop uses this to tokenize and check if input is too long
|
||||
# while the tool-use loop is going
|
||||
checkpoint = self.config.checkpoint_config.checkpoint
|
||||
self.formatter = ChatFormat(Tokenizer(checkpoint.tokenizer_path))
|
||||
checkpoint_dir = model_checkpoint_dir(self.model)
|
||||
tokenizer_path = os.path.join(checkpoint_dir, "tokenizer.model")
|
||||
self.formatter = ChatFormat(Tokenizer(tokenizer_path))
|
||||
|
||||
def start(self):
|
||||
self.__enter__()
|
||||
|
@ -73,9 +76,8 @@ class LlamaModelParallelGenerator:
|
|||
self.__exit__(None, None, None)
|
||||
|
||||
def __enter__(self):
|
||||
checkpoint = self.config.checkpoint_config.checkpoint
|
||||
self.group = ModelParallelProcessGroup(
|
||||
checkpoint.model_parallel_size,
|
||||
self.model.hardware_requirements.gpu_count,
|
||||
init_model_cb=partial(init_model_cb, self.config),
|
||||
)
|
||||
self.group.start()
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue