fix fp8 imports

This commit is contained in:
Ashwin Bharambe 2024-10-03 14:40:21 -07:00
parent 8d41e6caa9
commit f913b57397

View file

@ -13,15 +13,15 @@ from typing import Optional
import torch
from fairscale.nn.model_parallel.mappings import reduce_from_model_parallel_region
from llama_models.llama3.api.model import Transformer, TransformerBlock
from llama_models.datatypes import CheckpointQuantizationFormat
from llama_models.llama3.reference_impl.model import Transformer, TransformerBlock
from termcolor import cprint
from torch import Tensor
from llama_stack.apis.inference import QuantizationType
from llama_stack.apis.inference.config import (
CheckpointQuantizationFormat,
from llama_stack.providers.impls.meta_reference.inference.config import (
MetaReferenceImplConfig,
)