mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-10-04 04:04:14 +00:00
update inference config to take model and not model_dir
This commit is contained in:
parent
08c3802f45
commit
039861f1c7
9 changed files with 400 additions and 101 deletions
|
@ -75,11 +75,13 @@ safetensors files to avoid downloading duplicate weights.
|
|||
from huggingface_hub import snapshot_download
|
||||
from huggingface_hub.utils import GatedRepoError, RepositoryNotFoundError
|
||||
|
||||
from llama_toolchain.common.model_utils import model_local_dir
|
||||
|
||||
repo_id = model.huggingface_repo
|
||||
if repo_id is None:
|
||||
raise ValueError(f"No repo id found for model {model.descriptor()}")
|
||||
|
||||
output_dir = Path(DEFAULT_CHECKPOINT_DIR) / model.descriptor()
|
||||
output_dir = model_local_dir(model)
|
||||
os.makedirs(output_dir, exist_ok=True)
|
||||
try:
|
||||
true_output_dir = snapshot_download(
|
||||
|
@ -107,8 +109,9 @@ safetensors files to avoid downloading duplicate weights.
|
|||
|
||||
def _meta_download(self, model: "Model", meta_url: str):
|
||||
from llama_models.sku_list import llama_meta_net_info
|
||||
from llama_toolchain.common.model_utils import model_local_dir
|
||||
|
||||
output_dir = Path(DEFAULT_CHECKPOINT_DIR) / model.descriptor()
|
||||
output_dir = model_local_dir(model)
|
||||
os.makedirs(output_dir, exist_ok=True)
|
||||
|
||||
info = llama_meta_net_info(model)
|
||||
|
|
8
llama_toolchain/common/model_utils.py
Normal file
8
llama_toolchain/common/model_utils.py
Normal file
|
@ -0,0 +1,8 @@
|
|||
import os
|
||||
from llama_models.datatypes import Model
|
||||
|
||||
from .config_dirs import DEFAULT_CHECKPOINT_DIR
|
||||
|
||||
|
||||
def model_local_dir(model: Model) -> str:
|
||||
return os.path.join(DEFAULT_CHECKPOINT_DIR, model.descriptor())
|
|
@ -4,61 +4,17 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from enum import Enum
|
||||
from typing import Literal, Optional, Union
|
||||
from typing import Optional
|
||||
|
||||
from llama_models.llama3_1.api.datatypes import CheckpointQuantizationFormat
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
from pydantic import BaseModel
|
||||
from strong_typing.schema import json_schema_type
|
||||
from typing_extensions import Annotated
|
||||
|
||||
from llama_toolchain.inference.api import QuantizationConfig
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class CheckpointType(Enum):
|
||||
pytorch = "pytorch"
|
||||
huggingface = "huggingface"
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class PytorchCheckpoint(BaseModel):
|
||||
checkpoint_type: Literal[CheckpointType.pytorch.value] = (
|
||||
CheckpointType.pytorch.value
|
||||
)
|
||||
checkpoint_dir: str
|
||||
tokenizer_path: str
|
||||
model_parallel_size: int
|
||||
quantization_format: CheckpointQuantizationFormat = (
|
||||
CheckpointQuantizationFormat.bf16
|
||||
)
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class HuggingFaceCheckpoint(BaseModel):
|
||||
checkpoint_type: Literal[CheckpointType.huggingface.value] = (
|
||||
CheckpointType.huggingface.value
|
||||
)
|
||||
repo_id: str # or model_name ?
|
||||
model_parallel_size: int
|
||||
quantization_format: CheckpointQuantizationFormat = (
|
||||
CheckpointQuantizationFormat.bf16
|
||||
)
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class ModelCheckpointConfig(BaseModel):
|
||||
checkpoint: Annotated[
|
||||
Union[PytorchCheckpoint, HuggingFaceCheckpoint],
|
||||
Field(discriminator="checkpoint_type"),
|
||||
]
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class MetaReferenceImplConfig(BaseModel):
|
||||
model: str
|
||||
checkpoint_config: ModelCheckpointConfig
|
||||
quantization: Optional[QuantizationConfig] = None
|
||||
torch_seed: Optional[int] = None
|
||||
max_seq_len: int
|
||||
|
|
|
@ -27,11 +27,22 @@ from llama_models.llama3_1.api.chat_format import ChatFormat, ModelInput
|
|||
from llama_models.llama3_1.api.datatypes import Message
|
||||
from llama_models.llama3_1.api.model import Transformer
|
||||
from llama_models.llama3_1.api.tokenizer import Tokenizer
|
||||
from llama_models.sku_list import resolve_model
|
||||
from termcolor import cprint
|
||||
|
||||
from llama_toolchain.common.model_utils import model_local_dir
|
||||
from llama_toolchain.inference.api import QuantizationType
|
||||
|
||||
from .config import CheckpointType, MetaReferenceImplConfig
|
||||
from .config import MetaReferenceImplConfig
|
||||
|
||||
|
||||
def model_checkpoint_dir(model) -> str:
|
||||
checkpoint_dir = Path(model_local_dir(model))
|
||||
if not Path(checkpoint_dir / "consolidated.00.pth").exists():
|
||||
checkpoint_dir = checkpoint_dir / "original"
|
||||
|
||||
assert checkpoint_dir.exists(), f"Could not find checkpoint dir: {checkpoint_dir}"
|
||||
return str(checkpoint_dir)
|
||||
|
||||
|
||||
@dataclass
|
||||
|
@ -51,9 +62,7 @@ class Llama:
|
|||
This method initializes the distributed process group, sets the device to CUDA,
|
||||
and loads the pre-trained model and tokenizer.
|
||||
"""
|
||||
checkpoint = config.checkpoint_config.checkpoint
|
||||
if checkpoint.checkpoint_type != CheckpointType.pytorch.value:
|
||||
raise NotImplementedError("HuggingFace checkpoints not supported yet")
|
||||
model = resolve_model(config.model)
|
||||
|
||||
if (
|
||||
config.quantization
|
||||
|
@ -67,7 +76,7 @@ class Llama:
|
|||
if not torch.distributed.is_initialized():
|
||||
torch.distributed.init_process_group("nccl")
|
||||
|
||||
model_parallel_size = checkpoint.model_parallel_size
|
||||
model_parallel_size = model.hardware_requirements.gpu_count
|
||||
if not model_parallel_is_initialized():
|
||||
initialize_model_parallel(model_parallel_size)
|
||||
|
||||
|
@ -82,7 +91,8 @@ class Llama:
|
|||
sys.stdout = open(os.devnull, "w")
|
||||
|
||||
start_time = time.time()
|
||||
ckpt_dir = checkpoint.checkpoint_dir
|
||||
ckpt_dir = model_checkpoint_dir(model)
|
||||
|
||||
checkpoints = sorted(Path(ckpt_dir).glob("*.pth"))
|
||||
assert len(checkpoints) > 0, f"no checkpoint files found in {ckpt_dir}"
|
||||
assert model_parallel_size == len(
|
||||
|
@ -103,7 +113,9 @@ class Llama:
|
|||
max_batch_size=config.max_batch_size,
|
||||
**params,
|
||||
)
|
||||
tokenizer = Tokenizer(model_path=checkpoint.tokenizer_path)
|
||||
|
||||
tokenizer_path = os.path.join(ckpt_dir, "tokenizer.model")
|
||||
tokenizer = Tokenizer(model_path=tokenizer_path)
|
||||
|
||||
assert (
|
||||
model_args.vocab_size == tokenizer.n_words
|
||||
|
|
|
@ -4,6 +4,7 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import os
|
||||
from copy import deepcopy
|
||||
from dataclasses import dataclass
|
||||
from functools import partial
|
||||
|
@ -12,9 +13,10 @@ from typing import Generator, List, Optional
|
|||
from llama_models.llama3_1.api.chat_format import ChatFormat
|
||||
from llama_models.llama3_1.api.datatypes import Message
|
||||
from llama_models.llama3_1.api.tokenizer import Tokenizer
|
||||
from llama_models.sku_list import resolve_model
|
||||
|
||||
from .config import MetaReferenceImplConfig
|
||||
from .generation import Llama
|
||||
from .generation import Llama, model_checkpoint_dir
|
||||
from .parallel_utils import ModelParallelProcessGroup
|
||||
|
||||
|
||||
|
@ -60,11 +62,12 @@ class LlamaModelParallelGenerator:
|
|||
|
||||
def __init__(self, config: MetaReferenceImplConfig):
|
||||
self.config = config
|
||||
|
||||
self.model = resolve_model(self.config.model)
|
||||
# this is a hack because Agent's loop uses this to tokenize and check if input is too long
|
||||
# while the tool-use loop is going
|
||||
checkpoint = self.config.checkpoint_config.checkpoint
|
||||
self.formatter = ChatFormat(Tokenizer(checkpoint.tokenizer_path))
|
||||
checkpoint_dir = model_checkpoint_dir(self.model)
|
||||
tokenizer_path = os.path.join(checkpoint_dir, "tokenizer.model")
|
||||
self.formatter = ChatFormat(Tokenizer(tokenizer_path))
|
||||
|
||||
def start(self):
|
||||
self.__enter__()
|
||||
|
@ -73,9 +76,8 @@ class LlamaModelParallelGenerator:
|
|||
self.__exit__(None, None, None)
|
||||
|
||||
def __enter__(self):
|
||||
checkpoint = self.config.checkpoint_config.checkpoint
|
||||
self.group = ModelParallelProcessGroup(
|
||||
checkpoint.model_parallel_size,
|
||||
self.model.hardware_requirements.gpu_count,
|
||||
init_model_cb=partial(init_model_cb, self.config),
|
||||
)
|
||||
self.group.start()
|
||||
|
|
|
@ -44,11 +44,13 @@ OLLAMA_SUPPORTED_SKUS = {
|
|||
}
|
||||
|
||||
|
||||
def get_provider_impl(config: OllamaImplConfig) -> Inference:
|
||||
async def get_provider_impl(config: OllamaImplConfig) -> Inference:
|
||||
assert isinstance(
|
||||
config, OllamaImplConfig
|
||||
), f"Unexpected config type: {type(config)}"
|
||||
return OllamaInference(config)
|
||||
impl = OllamaInference(config)
|
||||
await impl.initialize()
|
||||
return impl
|
||||
|
||||
|
||||
class OllamaInference(Inference):
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue