diff --git a/llama_stack/providers/inline/inference/meta_reference/generation.py b/llama_stack/providers/inline/inference/meta_reference/generation.py index 9020a48fe..f76b5a448 100644 --- a/llama_stack/providers/inline/inference/meta_reference/generation.py +++ b/llama_stack/providers/inline/inference/meta_reference/generation.py @@ -39,12 +39,7 @@ from llama_stack.models.llama.datatypes import ( SamplingParams, TopPSamplingStrategy, ) -from llama_stack.models.llama.llama3.args import ModelArgs from llama_stack.models.llama.llama3.chat_format import ChatFormat, LLMInput -from llama_stack.models.llama.llama3.model import Transformer -from llama_stack.models.llama.llama3.multimodal.model import ( - CrossAttentionTransformer, -) from llama_stack.models.llama.llama3.tokenizer import Tokenizer from llama_stack.models.llama.sku_list import resolve_model from llama_stack.providers.utils.inference.prompt_adapter import ( @@ -53,6 +48,9 @@ from llama_stack.providers.utils.inference.prompt_adapter import ( ) from .config import MetaReferenceInferenceConfig, MetaReferenceQuantizedInferenceConfig +from .llama3.args import ModelArgs +from .llama3.model import Transformer +from .llama3.multimodal.model import CrossAttentionTransformer log = logging.getLogger(__name__) diff --git a/llama_stack/models/llama/llama3/args.py b/llama_stack/providers/inline/inference/meta_reference/llama3/args.py similarity index 100% rename from llama_stack/models/llama/llama3/args.py rename to llama_stack/providers/inline/inference/meta_reference/llama3/args.py diff --git a/llama_stack/models/llama/llama3/model.py b/llama_stack/providers/inline/inference/meta_reference/llama3/model.py similarity index 100% rename from llama_stack/models/llama/llama3/model.py rename to llama_stack/providers/inline/inference/meta_reference/llama3/model.py diff --git a/llama_stack/models/llama/llama3/multimodal/__init__.py b/llama_stack/providers/inline/inference/meta_reference/llama3/multimodal/__init__.py similarity index 100% rename from llama_stack/models/llama/llama3/multimodal/__init__.py rename to llama_stack/providers/inline/inference/meta_reference/llama3/multimodal/__init__.py diff --git a/llama_stack/models/llama/llama3/multimodal/encoder_utils.py b/llama_stack/providers/inline/inference/meta_reference/llama3/multimodal/encoder_utils.py similarity index 100% rename from llama_stack/models/llama/llama3/multimodal/encoder_utils.py rename to llama_stack/providers/inline/inference/meta_reference/llama3/multimodal/encoder_utils.py diff --git a/llama_stack/models/llama/llama3/multimodal/image_transform.py b/llama_stack/providers/inline/inference/meta_reference/llama3/multimodal/image_transform.py similarity index 100% rename from llama_stack/models/llama/llama3/multimodal/image_transform.py rename to llama_stack/providers/inline/inference/meta_reference/llama3/multimodal/image_transform.py diff --git a/llama_stack/models/llama/llama3/multimodal/model.py b/llama_stack/providers/inline/inference/meta_reference/llama3/multimodal/model.py similarity index 100% rename from llama_stack/models/llama/llama3/multimodal/model.py rename to llama_stack/providers/inline/inference/meta_reference/llama3/multimodal/model.py diff --git a/llama_stack/models/llama/llama3/multimodal/utils.py b/llama_stack/providers/inline/inference/meta_reference/llama3/multimodal/utils.py similarity index 100% rename from llama_stack/models/llama/llama3/multimodal/utils.py rename to llama_stack/providers/inline/inference/meta_reference/llama3/multimodal/utils.py diff --git a/llama_stack/providers/inline/inference/meta_reference/quantization/loader.py b/llama_stack/providers/inline/inference/meta_reference/quantization/loader.py index 70f7670aa..8a15f688a 100644 --- a/llama_stack/providers/inline/inference/meta_reference/quantization/loader.py +++ b/llama_stack/providers/inline/inference/meta_reference/quantization/loader.py @@ -20,10 +20,10 @@ from torchao.quantization.GPTQ import Int8DynActInt4WeightLinear from llama_stack.apis.inference import QuantizationType from llama_stack.models.llama.datatypes import CheckpointQuantizationFormat -from llama_stack.models.llama.llama3.args import ModelArgs -from llama_stack.models.llama.llama3.model import Transformer, TransformerBlock from llama_stack.models.llama.sku_list import resolve_model +from ...llama3.args import ModelArgs +from ...llama3.model import Transformer, TransformerBlock from ..config import MetaReferenceQuantizedInferenceConfig log = logging.getLogger(__name__) diff --git a/llama_stack/providers/inline/inference/meta_reference/quantization/scripts/quantize_checkpoint.py b/llama_stack/providers/inline/inference/meta_reference/quantization/scripts/quantize_checkpoint.py index cf112f19e..bb2a66682 100644 --- a/llama_stack/providers/inline/inference/meta_reference/quantization/scripts/quantize_checkpoint.py +++ b/llama_stack/providers/inline/inference/meta_reference/quantization/scripts/quantize_checkpoint.py @@ -24,9 +24,9 @@ from fairscale.nn.model_parallel.initialize import ( ) from torch.nn.parameter import Parameter -from llama_stack.models.llama.llama3.args import ModelArgs -from llama_stack.models.llama.llama3.model import Transformer, TransformerBlock from llama_stack.models.llama.llama3.tokenizer import Tokenizer +from llama_stack.providers.inline.inference.meta_reference.llama3.args import ModelArgs +from llama_stack.providers.inline.inference.meta_reference.llama3.model import Transformer, TransformerBlock from llama_stack.providers.inline.inference.meta_reference.quantization.fp8_impls import ( quantize_fp8, )