refactor: move all llama code to models/llama out of meta reference (#1887)

# What does this PR do?

Move around bits. This makes the copies from llama-models _much_ easier
to maintain and ensures we don't entangle meta-reference specific
tidbits into llama-models code even by accident.

Also, kills the meta-reference-quantized-gpu distro and rolls
quantization deps into meta-reference-gpu.

## Test Plan

```
LLAMA_MODELS_DEBUG=1 \
  with-proxy llama stack run meta-reference-gpu \
  --env INFERENCE_MODEL=meta-llama/Llama-4-Scout-17B-16E-Instruct \
   --env INFERENCE_CHECKPOINT_DIR=<DIR> \
   --env MODEL_PARALLEL_SIZE=4 \
   --env QUANTIZATION_TYPE=fp8_mixed
```

Start a server with and without quantization. Point integration tests to
it using:

```
pytest -s -v  tests/integration/inference/test_text_inference.py \
   --stack-config http://localhost:8321 --text-model meta-llama/Llama-4-Scout-17B-16E-Instruct
```
This commit is contained in:
Ashwin Bharambe 2025-04-07 15:03:58 -07:00 committed by GitHub
parent c52ccc4bbd
commit 530d4bdfe1
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
85 changed files with 1267 additions and 1683 deletions

View file

@ -4,24 +4,15 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# top-level folder for each specific model found within the models/ directory at
# the top-level of this source tree.
from dataclasses import dataclass
from functools import lru_cache
from typing import List, Optional
from .datatypes import (
from .sku_types import (
CheckpointQuantizationFormat,
CoreModelId,
Model,
ModelFamily,
SamplingParams,
TopPSamplingStrategy,
)
LLAMA2_VOCAB_SIZE = 32000
@ -47,15 +38,6 @@ def all_registered_models() -> List[Model]:
)
def recommended_sampling_params() -> SamplingParams:
return SamplingParams(
strategy=TopPSamplingStrategy(
temperature=1.0,
top_p=0.9,
)
)
def llama2_family() -> List[Model]:
return [
*llama2_base_models(),
@ -150,7 +132,6 @@ def llama2_base_models() -> List[Model]:
core_model_id=CoreModelId.llama2_7b,
description="Llama 2 7b model",
huggingface_repo="meta-llama/Llama-2-7b",
recommended_sampling_params=recommended_sampling_params(),
arch_args={
"dim": 4096,
"n_layers": 32,
@ -169,7 +150,6 @@ def llama2_base_models() -> List[Model]:
core_model_id=CoreModelId.llama2_13b,
description="Llama 2 13b model",
huggingface_repo="meta-llama/Llama-2-13b",
recommended_sampling_params=recommended_sampling_params(),
arch_args={
"dim": 5120,
"n_layers": 40,
@ -188,7 +168,6 @@ def llama2_base_models() -> List[Model]:
core_model_id=CoreModelId.llama2_70b,
description="Llama 2 70b model",
huggingface_repo="meta-llama/Llama-2-70b",
recommended_sampling_params=recommended_sampling_params(),
arch_args={
"dim": 8192,
"n_layers": 80,
@ -230,7 +209,6 @@ def llama3_base_models() -> List[Model]:
core_model_id=CoreModelId.llama3_70b,
description="Llama 3 70b model",
huggingface_repo="meta-llama/Llama-3-70B",
recommended_sampling_params=recommended_sampling_params(),
arch_args={
"dim": 8192,
"n_layers": 80,
@ -254,7 +232,6 @@ def llama3_1_base_models() -> List[Model]:
core_model_id=CoreModelId.llama3_1_8b,
description="Llama 3.1 8b model",
huggingface_repo="meta-llama/Llama-3.1-8B",
recommended_sampling_params=recommended_sampling_params(),
arch_args={
"dim": 4096,
"n_layers": 32,
@ -273,7 +250,6 @@ def llama3_1_base_models() -> List[Model]:
core_model_id=CoreModelId.llama3_1_70b,
description="Llama 3.1 70b model",
huggingface_repo="meta-llama/Llama-3.1-70B",
recommended_sampling_params=recommended_sampling_params(),
arch_args={
"dim": 8192,
"n_layers": 80,
@ -293,7 +269,6 @@ def llama3_1_base_models() -> List[Model]:
variant="bf16-mp8",
description="Llama 3.1 405b model (BF16 weights)",
huggingface_repo="meta-llama/Llama-3.1-405B",
recommended_sampling_params=recommended_sampling_params(),
arch_args={
"dim": 16384,
"n_layers": 126,
@ -313,7 +288,6 @@ def llama3_1_base_models() -> List[Model]:
description="Llama 3.1 405b model (FP8 quantized)",
huggingface_repo="meta-llama/Llama-3.1-405B-FP8",
quantization_format=CheckpointQuantizationFormat.fp8_mixed,
recommended_sampling_params=recommended_sampling_params(),
arch_args={
"dim": 16384,
"n_layers": 126,
@ -333,7 +307,6 @@ def llama3_1_base_models() -> List[Model]:
variant="bf16-mp16",
description="Llama 3.1 405b model (BF16 weights for mp16)",
huggingface_repo="meta-llama/Llama-3.1-405B",
recommended_sampling_params=recommended_sampling_params(),
arch_args={
"dim": 16384,
"n_layers": 126,
@ -357,7 +330,6 @@ def llama3_2_base_models() -> List[Model]:
core_model_id=CoreModelId.llama3_2_1b,
description="Llama 3.2 1b model",
huggingface_repo="meta-llama/Llama-3.2-1B",
recommended_sampling_params=recommended_sampling_params(),
arch_args={
"dim": 2048,
"n_layers": 16,
@ -376,7 +348,6 @@ def llama3_2_base_models() -> List[Model]:
core_model_id=CoreModelId.llama3_2_3b,
description="Llama 3.2 3b model",
huggingface_repo="meta-llama/Llama-3.2-3B",
recommended_sampling_params=recommended_sampling_params(),
arch_args={
"dim": 3072,
"n_layers": 28,
@ -395,7 +366,6 @@ def llama3_2_base_models() -> List[Model]:
core_model_id=CoreModelId.llama3_2_11b_vision,
description="Llama 3.2 11b vision model",
huggingface_repo="meta-llama/Llama-3.2-11B-Vision",
recommended_sampling_params=recommended_sampling_params(),
arch_args={
"dim": 4096,
"n_layers": 32,
@ -417,7 +387,6 @@ def llama3_2_base_models() -> List[Model]:
core_model_id=CoreModelId.llama3_2_90b_vision,
description="Llama 3.2 90b vision model",
huggingface_repo="meta-llama/Llama-3.2-90B-Vision",
recommended_sampling_params=recommended_sampling_params(),
arch_args={
"dim": 8192,
"n_layers": 80,
@ -444,7 +413,6 @@ def llama2_instruct_models() -> List[Model]:
core_model_id=CoreModelId.llama2_7b_chat,
description="Llama 2 7b chat model",
huggingface_repo="meta-llama/Llama-2-7b-chat",
recommended_sampling_params=recommended_sampling_params(),
arch_args={
"dim": 4096,
"n_layers": 32,
@ -463,7 +431,6 @@ def llama2_instruct_models() -> List[Model]:
core_model_id=CoreModelId.llama2_13b_chat,
description="Llama 2 13b chat model",
huggingface_repo="meta-llama/Llama-2-13b-chat",
recommended_sampling_params=recommended_sampling_params(),
arch_args={
"dim": 5120,
"n_layers": 40,
@ -482,7 +449,6 @@ def llama2_instruct_models() -> List[Model]:
core_model_id=CoreModelId.llama2_70b_chat,
description="Llama 2 70b chat model",
huggingface_repo="meta-llama/Llama-2-70b-chat",
recommended_sampling_params=recommended_sampling_params(),
arch_args={
"dim": 8192,
"n_layers": 80,
@ -506,7 +472,6 @@ def llama3_instruct_models() -> List[Model]:
core_model_id=CoreModelId.llama3_8b_instruct,
description="Llama 3 8b instruct model",
huggingface_repo="meta-llama/Llama-3-8B-Instruct",
recommended_sampling_params=recommended_sampling_params(),
arch_args={
"dim": 4096,
"n_layers": 32,
@ -525,7 +490,6 @@ def llama3_instruct_models() -> List[Model]:
core_model_id=CoreModelId.llama3_70b_instruct,
description="Llama 3 70b instruct model",
huggingface_repo="meta-llama/Llama-3-70B-Instruct",
recommended_sampling_params=recommended_sampling_params(),
arch_args={
"dim": 8192,
"n_layers": 80,
@ -549,7 +513,6 @@ def llama3_1_instruct_models() -> List[Model]:
core_model_id=CoreModelId.llama3_1_8b_instruct,
description="Llama 3.1 8b instruct model",
huggingface_repo="meta-llama/Llama-3.1-8B-Instruct",
recommended_sampling_params=recommended_sampling_params(),
arch_args={
"dim": 4096,
"n_layers": 32,
@ -568,7 +531,6 @@ def llama3_1_instruct_models() -> List[Model]:
core_model_id=CoreModelId.llama3_1_70b_instruct,
description="Llama 3.1 70b instruct model",
huggingface_repo="meta-llama/Llama-3.1-70B-Instruct",
recommended_sampling_params=recommended_sampling_params(),
arch_args={
"dim": 8192,
"n_layers": 80,
@ -588,7 +550,6 @@ def llama3_1_instruct_models() -> List[Model]:
variant="bf16-mp8",
description="Llama 3.1 405b instruct model (BF16 weights)",
huggingface_repo="meta-llama/Llama-3.1-405B-Instruct",
recommended_sampling_params=recommended_sampling_params(),
arch_args={
"dim": 16384,
"n_layers": 126,
@ -608,7 +569,6 @@ def llama3_1_instruct_models() -> List[Model]:
description="Llama 3.1 405b instruct model (FP8 quantized)",
huggingface_repo="meta-llama/Llama-3.1-405B-Instruct-FP8",
quantization_format=CheckpointQuantizationFormat.fp8_mixed,
recommended_sampling_params=recommended_sampling_params(),
arch_args={
"dim": 16384,
"n_layers": 126,
@ -628,7 +588,6 @@ def llama3_1_instruct_models() -> List[Model]:
variant="bf16-mp16",
description="Llama 3.1 405b instruct model (BF16 weights for mp16)",
huggingface_repo="meta-llama/Llama-3.1-405B-Instruct",
recommended_sampling_params=recommended_sampling_params(),
arch_args={
"dim": 16384,
"n_layers": 126,
@ -684,7 +643,6 @@ def llama3_2_quantized_models() -> List[Model]:
quantization_format=CheckpointQuantizationFormat.int4,
description="Llama 3.2 1b INT4 quantized LoRA",
huggingface_repo="meta-llama/Llama-3.2-1B-Instruct-QLORA_INT4_EO8",
recommended_sampling_params=recommended_sampling_params(),
arch_args={
**arch_args_1b(),
"quantization_args": {
@ -703,7 +661,6 @@ def llama3_2_quantized_models() -> List[Model]:
quantization_format=CheckpointQuantizationFormat.int4,
description="Llama 3.2 1b INT4 quantized SpinQuant",
huggingface_repo="meta-llama/Llama-3.2-1B-Instruct-SpinQuant_INT4_EO8",
recommended_sampling_params=recommended_sampling_params(),
arch_args={
**arch_args_1b(),
"quantization_args": {
@ -718,7 +675,6 @@ def llama3_2_quantized_models() -> List[Model]:
quantization_format=CheckpointQuantizationFormat.int4,
description="Llama 3.2 3b INT4 quantized LoRA",
huggingface_repo="meta-llama/Llama-3.2-3B-Instruct-QLORA_INT4_EO8",
recommended_sampling_params=recommended_sampling_params(),
arch_args={
**arch_args_3b(),
"quantization_args": {
@ -737,7 +693,6 @@ def llama3_2_quantized_models() -> List[Model]:
quantization_format=CheckpointQuantizationFormat.int4,
description="Llama 3.2 3b INT4 quantized SpinQuant",
huggingface_repo="meta-llama/Llama-3.2-3B-Instruct-SpinQuant_INT4_EO8",
recommended_sampling_params=recommended_sampling_params(),
arch_args={
**arch_args_3b(),
"quantization_args": {
@ -755,7 +710,6 @@ def llama3_2_instruct_models() -> List[Model]:
core_model_id=CoreModelId.llama3_2_1b_instruct,
description="Llama 3.2 1b instruct model",
huggingface_repo="meta-llama/Llama-3.2-1B-Instruct",
recommended_sampling_params=recommended_sampling_params(),
arch_args=arch_args_1b(),
pth_file_count=1,
),
@ -763,7 +717,6 @@ def llama3_2_instruct_models() -> List[Model]:
core_model_id=CoreModelId.llama3_2_3b_instruct,
description="Llama 3.2 3b instruct model",
huggingface_repo="meta-llama/Llama-3.2-3B-Instruct",
recommended_sampling_params=recommended_sampling_params(),
arch_args=arch_args_3b(),
pth_file_count=1,
),
@ -772,7 +725,6 @@ def llama3_2_instruct_models() -> List[Model]:
core_model_id=CoreModelId.llama3_2_11b_vision_instruct,
description="Llama 3.2 11b vision instruct model",
huggingface_repo="meta-llama/Llama-3.2-11B-Vision-Instruct",
recommended_sampling_params=recommended_sampling_params(),
arch_args={
"dim": 4096,
"n_layers": 32,
@ -794,7 +746,6 @@ def llama3_2_instruct_models() -> List[Model]:
core_model_id=CoreModelId.llama3_2_90b_vision_instruct,
description="Llama 3.2 90b vision instruct model",
huggingface_repo="meta-llama/Llama-3.2-90B-Vision-Instruct",
recommended_sampling_params=recommended_sampling_params(),
arch_args={
"dim": 8192,
"n_layers": 80,
@ -821,7 +772,6 @@ def llama3_3_instruct_models() -> List[Model]:
core_model_id=CoreModelId.llama3_3_70b_instruct,
description="Llama 3.3 70b instruct",
huggingface_repo="meta-llama/Llama-3.3-70B-Instruct",
recommended_sampling_params=recommended_sampling_params(),
arch_args={
"dim": 8192,
"n_layers": 80,
@ -846,7 +796,6 @@ def safety_models() -> List[Model]:
core_model_id=CoreModelId.llama_guard_3_11b_vision,
description="Llama Guard v3 11b vision system safety model",
huggingface_repo="meta-llama/Llama-Guard-3-11B-Vision",
recommended_sampling_params=recommended_sampling_params(),
arch_args={
"dim": 4096,
"n_layers": 32,
@ -870,7 +819,6 @@ def safety_models() -> List[Model]:
description="Llama Guard v3 1b 'int4' quantized system safety model",
huggingface_repo="meta-llama/Llama-Guard-3-1B-INT4",
quantization_format=CheckpointQuantizationFormat.int4,
recommended_sampling_params=recommended_sampling_params(),
arch_args={
"dim": 2048,
"n_layers": 12,
@ -888,7 +836,6 @@ def safety_models() -> List[Model]:
core_model_id=CoreModelId.llama_guard_3_1b,
description="Llama Guard v3 1b system safety model",
huggingface_repo="meta-llama/Llama-Guard-3-1B",
recommended_sampling_params=recommended_sampling_params(),
arch_args={
"dim": 2048,
"n_layers": 16,