From 725423c95cefa7da5326c927688f7ee6582fc7d7 Mon Sep 17 00:00:00 2001 From: Ashwin Bharambe Date: Mon, 3 Mar 2025 13:22:57 -0800 Subject: [PATCH] refactor: move llama3 impl to meta_reference provider (#1364) Just moving bits to a better place ## Test Plan ```bash torchrun $CONDA_PREFIX/bin/pytest -s -v test_text_inference.py ``` --- .../inline/inference/meta_reference/generation.py | 8 +++----- .../inline/inference/meta_reference}/llama3/args.py | 0 .../inline/inference/meta_reference}/llama3/model.py | 0 .../meta_reference}/llama3/multimodal/__init__.py | 0 .../meta_reference}/llama3/multimodal/encoder_utils.py | 0 .../meta_reference}/llama3/multimodal/image_transform.py | 0 .../inference/meta_reference}/llama3/multimodal/model.py | 0 .../inference/meta_reference}/llama3/multimodal/utils.py | 0 .../inference/meta_reference/quantization/loader.py | 4 ++-- .../quantization/scripts/quantize_checkpoint.py | 4 ++-- 10 files changed, 7 insertions(+), 9 deletions(-) rename llama_stack/{models/llama => providers/inline/inference/meta_reference}/llama3/args.py (100%) rename llama_stack/{models/llama => providers/inline/inference/meta_reference}/llama3/model.py (100%) rename llama_stack/{models/llama => providers/inline/inference/meta_reference}/llama3/multimodal/__init__.py (100%) rename llama_stack/{models/llama => providers/inline/inference/meta_reference}/llama3/multimodal/encoder_utils.py (100%) rename llama_stack/{models/llama => providers/inline/inference/meta_reference}/llama3/multimodal/image_transform.py (100%) rename llama_stack/{models/llama => providers/inline/inference/meta_reference}/llama3/multimodal/model.py (100%) rename llama_stack/{models/llama => providers/inline/inference/meta_reference}/llama3/multimodal/utils.py (100%) diff --git a/llama_stack/providers/inline/inference/meta_reference/generation.py b/llama_stack/providers/inline/inference/meta_reference/generation.py index 9020a48fe..f76b5a448 100644 --- a/llama_stack/providers/inline/inference/meta_reference/generation.py +++ b/llama_stack/providers/inline/inference/meta_reference/generation.py @@ -39,12 +39,7 @@ from llama_stack.models.llama.datatypes import ( SamplingParams, TopPSamplingStrategy, ) -from llama_stack.models.llama.llama3.args import ModelArgs from llama_stack.models.llama.llama3.chat_format import ChatFormat, LLMInput -from llama_stack.models.llama.llama3.model import Transformer -from llama_stack.models.llama.llama3.multimodal.model import ( - CrossAttentionTransformer, -) from llama_stack.models.llama.llama3.tokenizer import Tokenizer from llama_stack.models.llama.sku_list import resolve_model from llama_stack.providers.utils.inference.prompt_adapter import ( @@ -53,6 +48,9 @@ from llama_stack.providers.utils.inference.prompt_adapter import ( ) from .config import MetaReferenceInferenceConfig, MetaReferenceQuantizedInferenceConfig +from .llama3.args import ModelArgs +from .llama3.model import Transformer +from .llama3.multimodal.model import CrossAttentionTransformer log = logging.getLogger(__name__) diff --git a/llama_stack/models/llama/llama3/args.py b/llama_stack/providers/inline/inference/meta_reference/llama3/args.py similarity index 100% rename from llama_stack/models/llama/llama3/args.py rename to llama_stack/providers/inline/inference/meta_reference/llama3/args.py diff --git a/llama_stack/models/llama/llama3/model.py b/llama_stack/providers/inline/inference/meta_reference/llama3/model.py similarity index 100% rename from llama_stack/models/llama/llama3/model.py rename to llama_stack/providers/inline/inference/meta_reference/llama3/model.py diff --git a/llama_stack/models/llama/llama3/multimodal/__init__.py b/llama_stack/providers/inline/inference/meta_reference/llama3/multimodal/__init__.py similarity index 100% rename from llama_stack/models/llama/llama3/multimodal/__init__.py rename to llama_stack/providers/inline/inference/meta_reference/llama3/multimodal/__init__.py diff --git a/llama_stack/models/llama/llama3/multimodal/encoder_utils.py b/llama_stack/providers/inline/inference/meta_reference/llama3/multimodal/encoder_utils.py similarity index 100% rename from llama_stack/models/llama/llama3/multimodal/encoder_utils.py rename to llama_stack/providers/inline/inference/meta_reference/llama3/multimodal/encoder_utils.py diff --git a/llama_stack/models/llama/llama3/multimodal/image_transform.py b/llama_stack/providers/inline/inference/meta_reference/llama3/multimodal/image_transform.py similarity index 100% rename from llama_stack/models/llama/llama3/multimodal/image_transform.py rename to llama_stack/providers/inline/inference/meta_reference/llama3/multimodal/image_transform.py diff --git a/llama_stack/models/llama/llama3/multimodal/model.py b/llama_stack/providers/inline/inference/meta_reference/llama3/multimodal/model.py similarity index 100% rename from llama_stack/models/llama/llama3/multimodal/model.py rename to llama_stack/providers/inline/inference/meta_reference/llama3/multimodal/model.py diff --git a/llama_stack/models/llama/llama3/multimodal/utils.py b/llama_stack/providers/inline/inference/meta_reference/llama3/multimodal/utils.py similarity index 100% rename from llama_stack/models/llama/llama3/multimodal/utils.py rename to llama_stack/providers/inline/inference/meta_reference/llama3/multimodal/utils.py diff --git a/llama_stack/providers/inline/inference/meta_reference/quantization/loader.py b/llama_stack/providers/inline/inference/meta_reference/quantization/loader.py index 70f7670aa..8a15f688a 100644 --- a/llama_stack/providers/inline/inference/meta_reference/quantization/loader.py +++ b/llama_stack/providers/inline/inference/meta_reference/quantization/loader.py @@ -20,10 +20,10 @@ from torchao.quantization.GPTQ import Int8DynActInt4WeightLinear from llama_stack.apis.inference import QuantizationType from llama_stack.models.llama.datatypes import CheckpointQuantizationFormat -from llama_stack.models.llama.llama3.args import ModelArgs -from llama_stack.models.llama.llama3.model import Transformer, TransformerBlock from llama_stack.models.llama.sku_list import resolve_model +from ...llama3.args import ModelArgs +from ...llama3.model import Transformer, TransformerBlock from ..config import MetaReferenceQuantizedInferenceConfig log = logging.getLogger(__name__) diff --git a/llama_stack/providers/inline/inference/meta_reference/quantization/scripts/quantize_checkpoint.py b/llama_stack/providers/inline/inference/meta_reference/quantization/scripts/quantize_checkpoint.py index cf112f19e..bb2a66682 100644 --- a/llama_stack/providers/inline/inference/meta_reference/quantization/scripts/quantize_checkpoint.py +++ b/llama_stack/providers/inline/inference/meta_reference/quantization/scripts/quantize_checkpoint.py @@ -24,9 +24,9 @@ from fairscale.nn.model_parallel.initialize import ( ) from torch.nn.parameter import Parameter -from llama_stack.models.llama.llama3.args import ModelArgs -from llama_stack.models.llama.llama3.model import Transformer, TransformerBlock from llama_stack.models.llama.llama3.tokenizer import Tokenizer +from llama_stack.providers.inline.inference.meta_reference.llama3.args import ModelArgs +from llama_stack.providers.inline.inference.meta_reference.llama3.model import Transformer, TransformerBlock from llama_stack.providers.inline.inference.meta_reference.quantization.fp8_impls import ( quantize_fp8, )