mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-10-04 12:07:34 +00:00
update inference config to take model and not model_dir
This commit is contained in:
parent
08c3802f45
commit
039861f1c7
9 changed files with 400 additions and 101 deletions
|
@ -4,61 +4,17 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from enum import Enum
|
||||
from typing import Literal, Optional, Union
|
||||
from typing import Optional
|
||||
|
||||
from llama_models.llama3_1.api.datatypes import CheckpointQuantizationFormat
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
from pydantic import BaseModel
|
||||
from strong_typing.schema import json_schema_type
|
||||
from typing_extensions import Annotated
|
||||
|
||||
from llama_toolchain.inference.api import QuantizationConfig
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class CheckpointType(Enum):
|
||||
pytorch = "pytorch"
|
||||
huggingface = "huggingface"
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class PytorchCheckpoint(BaseModel):
|
||||
checkpoint_type: Literal[CheckpointType.pytorch.value] = (
|
||||
CheckpointType.pytorch.value
|
||||
)
|
||||
checkpoint_dir: str
|
||||
tokenizer_path: str
|
||||
model_parallel_size: int
|
||||
quantization_format: CheckpointQuantizationFormat = (
|
||||
CheckpointQuantizationFormat.bf16
|
||||
)
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class HuggingFaceCheckpoint(BaseModel):
|
||||
checkpoint_type: Literal[CheckpointType.huggingface.value] = (
|
||||
CheckpointType.huggingface.value
|
||||
)
|
||||
repo_id: str # or model_name ?
|
||||
model_parallel_size: int
|
||||
quantization_format: CheckpointQuantizationFormat = (
|
||||
CheckpointQuantizationFormat.bf16
|
||||
)
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class ModelCheckpointConfig(BaseModel):
|
||||
checkpoint: Annotated[
|
||||
Union[PytorchCheckpoint, HuggingFaceCheckpoint],
|
||||
Field(discriminator="checkpoint_type"),
|
||||
]
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class MetaReferenceImplConfig(BaseModel):
|
||||
model: str
|
||||
checkpoint_config: ModelCheckpointConfig
|
||||
quantization: Optional[QuantizationConfig] = None
|
||||
torch_seed: Optional[int] = None
|
||||
max_seq_len: int
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue