diff --git a/llama_toolchain/inference/api/config.py b/llama_toolchain/inference/api/config.py index 5994e805b..8c2f160f5 100644 --- a/llama_toolchain/inference/api/config.py +++ b/llama_toolchain/inference/api/config.py @@ -8,6 +8,7 @@ from pydantic import BaseModel, Field from typing_extensions import Annotated from .datatypes import QuantizationConfig +from llama_models.llama3_1.api.datatypes import CheckpointQuantizationFormat class ImplType(Enum): @@ -20,17 +21,6 @@ class CheckpointType(Enum): huggingface = "huggingface" -# This enum represents the format in which weights are specified -# This does not necessarily always equal what quantization is desired -# at runtime since there can be on-the-fly conversions done -class CheckpointQuantizationFormat(Enum): - # default format - bf16 = "bf16" - - # used for enabling fp8_rowwise inference, some weights are bf16 - fp8_mixed = "fp8_mixed" - - class PytorchCheckpoint(BaseModel): checkpoint_type: Literal[CheckpointType.pytorch.value] = ( CheckpointType.pytorch.value diff --git a/llama_toolchain/models/api/endpoints.py b/llama_toolchain/models/api/endpoints.py new file mode 100644 index 000000000..432dc391e --- /dev/null +++ b/llama_toolchain/models/api/endpoints.py @@ -0,0 +1,9 @@ +from typing import Protocol + +from pyopenapi import webmethod + +from pydantic import BaseModel + + +class Models(Protocol): + ...