diff --git a/llama_stack/providers/impls/meta_reference/inference/quantization/loader.py b/llama_stack/providers/impls/meta_reference/inference/quantization/loader.py index 9c5182ead..1df86cb84 100644 --- a/llama_stack/providers/impls/meta_reference/inference/quantization/loader.py +++ b/llama_stack/providers/impls/meta_reference/inference/quantization/loader.py @@ -13,15 +13,15 @@ from typing import Optional import torch from fairscale.nn.model_parallel.mappings import reduce_from_model_parallel_region -from llama_models.llama3.api.model import Transformer, TransformerBlock +from llama_models.datatypes import CheckpointQuantizationFormat +from llama_models.llama3.reference_impl.model import Transformer, TransformerBlock from termcolor import cprint from torch import Tensor from llama_stack.apis.inference import QuantizationType -from llama_stack.apis.inference.config import ( - CheckpointQuantizationFormat, +from llama_stack.providers.impls.meta_reference.inference.config import ( MetaReferenceImplConfig, )