update import for quantization format from models

This commit is contained in:
Ashwin Bharambe 2024-07-21 23:56:04 -07:00
parent f9111652ef
commit 2e7978fa39
2 changed files with 10 additions and 11 deletions

View file

@ -8,6 +8,7 @@ from pydantic import BaseModel, Field
from typing_extensions import Annotated
from .datatypes import QuantizationConfig
from llama_models.llama3_1.api.datatypes import CheckpointQuantizationFormat
class ImplType(Enum):
@ -20,17 +21,6 @@ class CheckpointType(Enum):
huggingface = "huggingface"
# This enum represents the format in which weights are specified
# This does not necessarily always equal what quantization is desired
# at runtime since there can be on-the-fly conversions done
class CheckpointQuantizationFormat(Enum):
# default format
bf16 = "bf16"
# used for enabling fp8_rowwise inference, some weights are bf16
fp8_mixed = "fp8_mixed"
class PytorchCheckpoint(BaseModel):
checkpoint_type: Literal[CheckpointType.pytorch.value] = (
CheckpointType.pytorch.value

View file

@ -0,0 +1,9 @@
from typing import Protocol
from pyopenapi import webmethod
from pydantic import BaseModel
class Models(Protocol):
...