From fd1c7f0197b12ba7fe438bbd15b5ae2c17a5b256 Mon Sep 17 00:00:00 2001 From: Ashwin Bharambe Date: Mon, 26 Aug 2024 14:43:30 -0700 Subject: [PATCH] Fix api.datatypes imports --- llama_toolchain/evaluations/api/api.py | 2 +- llama_toolchain/inference/quantization/loader.py | 2 +- llama_toolchain/post_training/api/api.py | 2 +- llama_toolchain/synthetic_data_generation/api/api.py | 2 +- 4 files changed, 4 insertions(+), 4 deletions(-) diff --git a/llama_toolchain/evaluations/api/api.py b/llama_toolchain/evaluations/api/api.py index 3e03fe12e..b8f3fa825 100644 --- a/llama_toolchain/evaluations/api/api.py +++ b/llama_toolchain/evaluations/api/api.py @@ -12,7 +12,7 @@ from llama_models.schema_utils import webmethod from pydantic import BaseModel from llama_models.llama3.api.datatypes import * # noqa: F403 -from llama_toolchain.dataset.api.datatypes import * # noqa: F403 +from llama_toolchain.dataset.api import * # noqa: F403 from llama_toolchain.common.training_types import * # noqa: F403 diff --git a/llama_toolchain/inference/quantization/loader.py b/llama_toolchain/inference/quantization/loader.py index 3645344aa..54827dce9 100644 --- a/llama_toolchain/inference/quantization/loader.py +++ b/llama_toolchain/inference/quantization/loader.py @@ -14,12 +14,12 @@ import torch from fairscale.nn.model_parallel.mappings import reduce_from_model_parallel_region from llama_models.llama3.api.model import Transformer, TransformerBlock +from llama_toolchain.inference.api import QuantizationType from llama_toolchain.inference.api.config import ( CheckpointQuantizationFormat, MetaReferenceImplConfig, ) -from llama_toolchain.inference.api.datatypes import QuantizationType from termcolor import cprint from torch import Tensor diff --git a/llama_toolchain/post_training/api/api.py b/llama_toolchain/post_training/api/api.py index ce7dcd65c..447a729fb 100644 --- a/llama_toolchain/post_training/api/api.py +++ b/llama_toolchain/post_training/api/api.py @@ -14,7 +14,7 @@ from llama_models.schema_utils import json_schema_type, webmethod from pydantic import BaseModel, Field from llama_models.llama3.api.datatypes import * # noqa: F403 -from llama_toolchain.dataset.api.datatypes import * # noqa: F403 +from llama_toolchain.dataset.api import * # noqa: F403 from llama_toolchain.common.training_types import * # noqa: F403 diff --git a/llama_toolchain/synthetic_data_generation/api/api.py b/llama_toolchain/synthetic_data_generation/api/api.py index 4d82553a3..44b8327a9 100644 --- a/llama_toolchain/synthetic_data_generation/api/api.py +++ b/llama_toolchain/synthetic_data_generation/api/api.py @@ -13,7 +13,7 @@ from llama_models.schema_utils import json_schema_type, webmethod from pydantic import BaseModel from llama_models.llama3.api.datatypes import * # noqa: F403 -from llama_toolchain.reward_scoring.api.datatypes import * # noqa: F403 +from llama_toolchain.reward_scoring.api import * # noqa: F403 class FilteringFunction(Enum):