update inference config to take model and not model_dir

This commit is contained in:
Hardik Shah 2024-08-06 15:02:41 -07:00
parent 08c3802f45
commit 039861f1c7
9 changed files with 400 additions and 101 deletions

View file

@ -4,61 +4,17 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from enum import Enum
from typing import Literal, Optional, Union
from typing import Optional
from llama_models.llama3_1.api.datatypes import CheckpointQuantizationFormat
from pydantic import BaseModel, Field
from pydantic import BaseModel
from strong_typing.schema import json_schema_type
from typing_extensions import Annotated
from llama_toolchain.inference.api import QuantizationConfig
@json_schema_type
class CheckpointType(Enum):
pytorch = "pytorch"
huggingface = "huggingface"
@json_schema_type
class PytorchCheckpoint(BaseModel):
checkpoint_type: Literal[CheckpointType.pytorch.value] = (
CheckpointType.pytorch.value
)
checkpoint_dir: str
tokenizer_path: str
model_parallel_size: int
quantization_format: CheckpointQuantizationFormat = (
CheckpointQuantizationFormat.bf16
)
@json_schema_type
class HuggingFaceCheckpoint(BaseModel):
checkpoint_type: Literal[CheckpointType.huggingface.value] = (
CheckpointType.huggingface.value
)
repo_id: str # or model_name ?
model_parallel_size: int
quantization_format: CheckpointQuantizationFormat = (
CheckpointQuantizationFormat.bf16
)
@json_schema_type
class ModelCheckpointConfig(BaseModel):
checkpoint: Annotated[
Union[PytorchCheckpoint, HuggingFaceCheckpoint],
Field(discriminator="checkpoint_type"),
]
@json_schema_type
class MetaReferenceImplConfig(BaseModel):
model: str
checkpoint_config: ModelCheckpointConfig
quantization: Optional[QuantizationConfig] = None
torch_seed: Optional[int] = None
max_seq_len: int