forked from phoenix-oss/llama-stack-mirror
feat: introduce llama4 support (#1877)
As title says. Details in README, elsewhere.
This commit is contained in:
parent
23a99a4b22
commit
b8f1561956
61 changed files with 205222 additions and 6439 deletions
|
@ -12,14 +12,26 @@
|
|||
# top-level folder for each specific model found within the models/ directory at
|
||||
# the top-level of this source tree.
|
||||
|
||||
# Run this script:
|
||||
# torchrun --nproc_per_node=8 scripts/generate_prompt_format.py meta-llama/Llama-4-17B-Omni-Instruct-BF16-16E ~/.llama/checkpoints/Llama-4-17B-Omni-Instruct-BF16-16E/ llama_stack.models.llama.llama4.prompts llama_stack/models/llama/llama4/prompt_format.md
|
||||
|
||||
|
||||
import importlib
|
||||
import os
|
||||
from pathlib import Path
|
||||
|
||||
import fire
|
||||
|
||||
from llama_stack.models.llama.sku_list import resolve_model
|
||||
from llama_stack.providers.inline.inference.meta_reference.config import MetaReferenceInferenceConfig
|
||||
from llama_stack.providers.inline.inference.meta_reference.llama3.generation import Llama3
|
||||
from llama_stack.providers.inline.inference.meta_reference.config import (
|
||||
MetaReferenceInferenceConfig,
|
||||
)
|
||||
from llama_stack.providers.inline.inference.meta_reference.llama3.generation import (
|
||||
Llama3,
|
||||
)
|
||||
from llama_stack.providers.inline.inference.meta_reference.llama4.generation import (
|
||||
Llama4,
|
||||
)
|
||||
|
||||
THIS_DIR = Path(__file__).parent.resolve()
|
||||
|
||||
|
@ -29,24 +41,33 @@ def run_main(
|
|||
checkpoint_dir: str,
|
||||
module_name: str,
|
||||
output_path: str,
|
||||
llama4: bool = True,
|
||||
):
|
||||
module = importlib.import_module(module_name)
|
||||
assert hasattr(module, "usecases"), f"Module {module_name} missing usecases function"
|
||||
|
||||
config = MetaReferenceInferenceConfig(
|
||||
model=model_id,
|
||||
max_seq_len=512,
|
||||
max_batch_size=1,
|
||||
checkpoint_dir=checkpoint_dir,
|
||||
)
|
||||
llama_model = resolve_model(model_id)
|
||||
if not llama_model:
|
||||
raise ValueError(f"Model {model_id} not found")
|
||||
generator = Llama3.build(
|
||||
config=config,
|
||||
model_id=model_id,
|
||||
llama_model=llama_model,
|
||||
)
|
||||
|
||||
if not llama4:
|
||||
config = MetaReferenceInferenceConfig(
|
||||
model=model_id,
|
||||
max_seq_len=4096,
|
||||
max_batch_size=1,
|
||||
checkpoint_dir=checkpoint_dir,
|
||||
)
|
||||
generator = Llama3.build(
|
||||
config=config,
|
||||
model_id=model_id,
|
||||
llama_model=llama_model,
|
||||
)
|
||||
else:
|
||||
generator = Llama4.build(
|
||||
ckpt_dir=checkpoint_dir,
|
||||
max_seq_len=4096,
|
||||
max_batch_size=1,
|
||||
)
|
||||
|
||||
use_cases = module.usecases()
|
||||
text = ""
|
||||
|
@ -59,8 +80,7 @@ def run_main(
|
|||
text += use_case_text
|
||||
print(use_case_text)
|
||||
|
||||
text += "Thank You!\n"
|
||||
|
||||
os.makedirs(os.path.dirname(output_path), exist_ok=True)
|
||||
with open(output_path, "w") as f:
|
||||
f.write(text)
|
||||
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue