update generate prompt format

This commit is contained in:
Ashwin Bharambe 2025-04-07 14:54:43 -07:00
parent 76004eacb4
commit 35aac86997
2 changed files with 36 additions and 62 deletions

View file

@ -5,13 +5,6 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# top-level folder for each specific model found within the models/ directory at
# the top-level of this source tree.
# Run this script:
# torchrun --nproc_per_node=8 scripts/generate_prompt_format.py meta-llama/Llama-4-17B-Omni-Instruct-BF16-16E ~/.llama/checkpoints/Llama-4-17B-Omni-Instruct-BF16-16E/ llama_stack.models.llama.llama4.prompts llama_stack/models/llama/llama4/prompt_format.md
@ -22,16 +15,9 @@ from pathlib import Path
import fire
from llama_stack.models.llama.llama3.generation import Llama3
from llama_stack.models.llama.llama4.generation import Llama4
from llama_stack.models.llama.sku_list import resolve_model
from llama_stack.providers.inline.inference.meta_reference.config import (
MetaReferenceInferenceConfig,
)
from llama_stack.providers.inline.inference.meta_reference.llama3.generation import (
Llama3,
)
from llama_stack.providers.inline.inference.meta_reference.llama4.generation import (
Llama4,
)
THIS_DIR = Path(__file__).parent.resolve()
@ -50,24 +36,12 @@ def run_main(
if not llama_model:
raise ValueError(f"Model {model_id} not found")
if not llama4:
config = MetaReferenceInferenceConfig(
model=model_id,
max_seq_len=4096,
max_batch_size=1,
checkpoint_dir=checkpoint_dir,
)
generator = Llama3.build(
config=config,
model_id=model_id,
llama_model=llama_model,
)
else:
generator = Llama4.build(
ckpt_dir=checkpoint_dir,
max_seq_len=4096,
max_batch_size=1,
)
cls = Llama4 if llama4 else Llama3
generator = cls.build(
ckpt_dir=checkpoint_dir,
max_seq_len=4096,
max_batch_size=1,
)
use_cases = module.usecases()
text = ""