forked from phoenix-oss/llama-stack-mirror
refactor: move all llama code to models/llama out of meta reference (#1887)
# What does this PR do? Move around bits. This makes the copies from llama-models _much_ easier to maintain and ensures we don't entangle meta-reference specific tidbits into llama-models code even by accident. Also, kills the meta-reference-quantized-gpu distro and rolls quantization deps into meta-reference-gpu. ## Test Plan ``` LLAMA_MODELS_DEBUG=1 \ with-proxy llama stack run meta-reference-gpu \ --env INFERENCE_MODEL=meta-llama/Llama-4-Scout-17B-16E-Instruct \ --env INFERENCE_CHECKPOINT_DIR=<DIR> \ --env MODEL_PARALLEL_SIZE=4 \ --env QUANTIZATION_TYPE=fp8_mixed ``` Start a server with and without quantization. Point integration tests to it using: ``` pytest -s -v tests/integration/inference/test_text_inference.py \ --stack-config http://localhost:8321 --text-model meta-llama/Llama-4-Scout-17B-16E-Instruct ```
This commit is contained in:
parent
c52ccc4bbd
commit
530d4bdfe1
85 changed files with 1267 additions and 1683 deletions
|
@ -5,13 +5,6 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# top-level folder for each specific model found within the models/ directory at
|
||||
# the top-level of this source tree.
|
||||
|
||||
# Run this script:
|
||||
# torchrun --nproc_per_node=8 scripts/generate_prompt_format.py meta-llama/Llama-4-17B-Omni-Instruct-BF16-16E ~/.llama/checkpoints/Llama-4-17B-Omni-Instruct-BF16-16E/ llama_stack.models.llama.llama4.prompts llama_stack/models/llama/llama4/prompt_format.md
|
||||
|
||||
|
@ -22,16 +15,9 @@ from pathlib import Path
|
|||
|
||||
import fire
|
||||
|
||||
from llama_stack.models.llama.llama3.generation import Llama3
|
||||
from llama_stack.models.llama.llama4.generation import Llama4
|
||||
from llama_stack.models.llama.sku_list import resolve_model
|
||||
from llama_stack.providers.inline.inference.meta_reference.config import (
|
||||
MetaReferenceInferenceConfig,
|
||||
)
|
||||
from llama_stack.providers.inline.inference.meta_reference.llama3.generation import (
|
||||
Llama3,
|
||||
)
|
||||
from llama_stack.providers.inline.inference.meta_reference.llama4.generation import (
|
||||
Llama4,
|
||||
)
|
||||
|
||||
THIS_DIR = Path(__file__).parent.resolve()
|
||||
|
||||
|
@ -50,24 +36,12 @@ def run_main(
|
|||
if not llama_model:
|
||||
raise ValueError(f"Model {model_id} not found")
|
||||
|
||||
if not llama4:
|
||||
config = MetaReferenceInferenceConfig(
|
||||
model=model_id,
|
||||
max_seq_len=4096,
|
||||
max_batch_size=1,
|
||||
checkpoint_dir=checkpoint_dir,
|
||||
)
|
||||
generator = Llama3.build(
|
||||
config=config,
|
||||
model_id=model_id,
|
||||
llama_model=llama_model,
|
||||
)
|
||||
else:
|
||||
generator = Llama4.build(
|
||||
ckpt_dir=checkpoint_dir,
|
||||
max_seq_len=4096,
|
||||
max_batch_size=1,
|
||||
)
|
||||
cls = Llama4 if llama4 else Llama3
|
||||
generator = cls.build(
|
||||
ckpt_dir=checkpoint_dir,
|
||||
max_seq_len=4096,
|
||||
max_batch_size=1,
|
||||
)
|
||||
|
||||
use_cases = module.usecases()
|
||||
text = ""
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue