Fix api.datatypes imports

This commit is contained in:
Ashwin Bharambe 2024-08-26 14:43:30 -07:00
parent fb78bdc5a9
commit fd1c7f0197
4 changed files with 4 additions and 4 deletions

View file

@ -12,7 +12,7 @@ from llama_models.schema_utils import webmethod
from pydantic import BaseModel from pydantic import BaseModel
from llama_models.llama3.api.datatypes import * # noqa: F403 from llama_models.llama3.api.datatypes import * # noqa: F403
from llama_toolchain.dataset.api.datatypes import * # noqa: F403 from llama_toolchain.dataset.api import * # noqa: F403
from llama_toolchain.common.training_types import * # noqa: F403 from llama_toolchain.common.training_types import * # noqa: F403

View file

@ -14,12 +14,12 @@ import torch
from fairscale.nn.model_parallel.mappings import reduce_from_model_parallel_region from fairscale.nn.model_parallel.mappings import reduce_from_model_parallel_region
from llama_models.llama3.api.model import Transformer, TransformerBlock from llama_models.llama3.api.model import Transformer, TransformerBlock
from llama_toolchain.inference.api import QuantizationType
from llama_toolchain.inference.api.config import ( from llama_toolchain.inference.api.config import (
CheckpointQuantizationFormat, CheckpointQuantizationFormat,
MetaReferenceImplConfig, MetaReferenceImplConfig,
) )
from llama_toolchain.inference.api.datatypes import QuantizationType
from termcolor import cprint from termcolor import cprint
from torch import Tensor from torch import Tensor

View file

@ -14,7 +14,7 @@ from llama_models.schema_utils import json_schema_type, webmethod
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
from llama_models.llama3.api.datatypes import * # noqa: F403 from llama_models.llama3.api.datatypes import * # noqa: F403
from llama_toolchain.dataset.api.datatypes import * # noqa: F403 from llama_toolchain.dataset.api import * # noqa: F403
from llama_toolchain.common.training_types import * # noqa: F403 from llama_toolchain.common.training_types import * # noqa: F403

View file

@ -13,7 +13,7 @@ from llama_models.schema_utils import json_schema_type, webmethod
from pydantic import BaseModel from pydantic import BaseModel
from llama_models.llama3.api.datatypes import * # noqa: F403 from llama_models.llama3.api.datatypes import * # noqa: F403
from llama_toolchain.reward_scoring.api.datatypes import * # noqa: F403 from llama_toolchain.reward_scoring.api import * # noqa: F403
class FilteringFunction(Enum): class FilteringFunction(Enum):