mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-07-29 15:23:51 +00:00
update import for quantization format from models
This commit is contained in:
parent
f9111652ef
commit
2e7978fa39
2 changed files with 10 additions and 11 deletions
|
@ -8,6 +8,7 @@ from pydantic import BaseModel, Field
|
||||||
from typing_extensions import Annotated
|
from typing_extensions import Annotated
|
||||||
|
|
||||||
from .datatypes import QuantizationConfig
|
from .datatypes import QuantizationConfig
|
||||||
|
from llama_models.llama3_1.api.datatypes import CheckpointQuantizationFormat
|
||||||
|
|
||||||
|
|
||||||
class ImplType(Enum):
|
class ImplType(Enum):
|
||||||
|
@ -20,17 +21,6 @@ class CheckpointType(Enum):
|
||||||
huggingface = "huggingface"
|
huggingface = "huggingface"
|
||||||
|
|
||||||
|
|
||||||
# This enum represents the format in which weights are specified
|
|
||||||
# This does not necessarily always equal what quantization is desired
|
|
||||||
# at runtime since there can be on-the-fly conversions done
|
|
||||||
class CheckpointQuantizationFormat(Enum):
|
|
||||||
# default format
|
|
||||||
bf16 = "bf16"
|
|
||||||
|
|
||||||
# used for enabling fp8_rowwise inference, some weights are bf16
|
|
||||||
fp8_mixed = "fp8_mixed"
|
|
||||||
|
|
||||||
|
|
||||||
class PytorchCheckpoint(BaseModel):
|
class PytorchCheckpoint(BaseModel):
|
||||||
checkpoint_type: Literal[CheckpointType.pytorch.value] = (
|
checkpoint_type: Literal[CheckpointType.pytorch.value] = (
|
||||||
CheckpointType.pytorch.value
|
CheckpointType.pytorch.value
|
||||||
|
|
9
llama_toolchain/models/api/endpoints.py
Normal file
9
llama_toolchain/models/api/endpoints.py
Normal file
|
@ -0,0 +1,9 @@
|
||||||
|
from typing import Protocol
|
||||||
|
|
||||||
|
from pyopenapi import webmethod
|
||||||
|
|
||||||
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
|
||||||
|
class Models(Protocol):
|
||||||
|
...
|
Loading…
Add table
Add a link
Reference in a new issue