refactor: move llama3 impl to meta_reference provider (#1364)

Just moving bits to a better place

## Test Plan

```bash
torchrun $CONDA_PREFIX/bin/pytest -s -v test_text_inference.py
```
This commit is contained in:
Ashwin Bharambe 2025-03-03 13:22:57 -08:00 committed by GitHub
parent af396e3809
commit 725423c95c
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
10 changed files with 7 additions and 9 deletions

View file

@ -39,12 +39,7 @@ from llama_stack.models.llama.datatypes import (
SamplingParams,
TopPSamplingStrategy,
)
from llama_stack.models.llama.llama3.args import ModelArgs
from llama_stack.models.llama.llama3.chat_format import ChatFormat, LLMInput
from llama_stack.models.llama.llama3.model import Transformer
from llama_stack.models.llama.llama3.multimodal.model import (
CrossAttentionTransformer,
)
from llama_stack.models.llama.llama3.tokenizer import Tokenizer
from llama_stack.models.llama.sku_list import resolve_model
from llama_stack.providers.utils.inference.prompt_adapter import (
@ -53,6 +48,9 @@ from llama_stack.providers.utils.inference.prompt_adapter import (
)
from .config import MetaReferenceInferenceConfig, MetaReferenceQuantizedInferenceConfig
from .llama3.args import ModelArgs
from .llama3.model import Transformer
from .llama3.multimodal.model import CrossAttentionTransformer
log = logging.getLogger(__name__)

View file

@ -20,10 +20,10 @@ from torchao.quantization.GPTQ import Int8DynActInt4WeightLinear
from llama_stack.apis.inference import QuantizationType
from llama_stack.models.llama.datatypes import CheckpointQuantizationFormat
from llama_stack.models.llama.llama3.args import ModelArgs
from llama_stack.models.llama.llama3.model import Transformer, TransformerBlock
from llama_stack.models.llama.sku_list import resolve_model
from ...llama3.args import ModelArgs
from ...llama3.model import Transformer, TransformerBlock
from ..config import MetaReferenceQuantizedInferenceConfig
log = logging.getLogger(__name__)

View file

@ -24,9 +24,9 @@ from fairscale.nn.model_parallel.initialize import (
)
from torch.nn.parameter import Parameter
from llama_stack.models.llama.llama3.args import ModelArgs
from llama_stack.models.llama.llama3.model import Transformer, TransformerBlock
from llama_stack.models.llama.llama3.tokenizer import Tokenizer
from llama_stack.providers.inline.inference.meta_reference.llama3.args import ModelArgs
from llama_stack.providers.inline.inference.meta_reference.llama3.model import Transformer, TransformerBlock
from llama_stack.providers.inline.inference.meta_reference.quantization.fp8_impls import (
quantize_fp8,
)