diff --git a/llama_toolchain/evaluations/api/api.py b/llama_toolchain/evaluations/api/api.py index 3e03fe12e..b8f3fa825 100644 --- a/llama_toolchain/evaluations/api/api.py +++ b/llama_toolchain/evaluations/api/api.py @@ -12,7 +12,7 @@ from llama_models.schema_utils import webmethod from pydantic import BaseModel from llama_models.llama3.api.datatypes import * # noqa: F403 -from llama_toolchain.dataset.api.datatypes import * # noqa: F403 +from llama_toolchain.dataset.api import * # noqa: F403 from llama_toolchain.common.training_types import * # noqa: F403 diff --git a/llama_toolchain/inference/quantization/loader.py b/llama_toolchain/inference/quantization/loader.py index 3645344aa..54827dce9 100644 --- a/llama_toolchain/inference/quantization/loader.py +++ b/llama_toolchain/inference/quantization/loader.py @@ -14,12 +14,12 @@ import torch from fairscale.nn.model_parallel.mappings import reduce_from_model_parallel_region from llama_models.llama3.api.model import Transformer, TransformerBlock +from llama_toolchain.inference.api import QuantizationType from llama_toolchain.inference.api.config import ( CheckpointQuantizationFormat, MetaReferenceImplConfig, ) -from llama_toolchain.inference.api.datatypes import QuantizationType from termcolor import cprint from torch import Tensor diff --git a/llama_toolchain/post_training/api/api.py b/llama_toolchain/post_training/api/api.py index ce7dcd65c..447a729fb 100644 --- a/llama_toolchain/post_training/api/api.py +++ b/llama_toolchain/post_training/api/api.py @@ -14,7 +14,7 @@ from llama_models.schema_utils import json_schema_type, webmethod from pydantic import BaseModel, Field from llama_models.llama3.api.datatypes import * # noqa: F403 -from llama_toolchain.dataset.api.datatypes import * # noqa: F403 +from llama_toolchain.dataset.api import * # noqa: F403 from llama_toolchain.common.training_types import * # noqa: F403 diff --git a/llama_toolchain/synthetic_data_generation/api/api.py b/llama_toolchain/synthetic_data_generation/api/api.py index 4d82553a3..44b8327a9 100644 --- a/llama_toolchain/synthetic_data_generation/api/api.py +++ b/llama_toolchain/synthetic_data_generation/api/api.py @@ -13,7 +13,7 @@ from llama_models.schema_utils import json_schema_type, webmethod from pydantic import BaseModel from llama_models.llama3.api.datatypes import * # noqa: F403 -from llama_toolchain.reward_scoring.api.datatypes import * # noqa: F403 +from llama_toolchain.reward_scoring.api import * # noqa: F403 class FilteringFunction(Enum):