mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-08-01 16:24:44 +00:00
Refactor test run to support shorthand model names
This commit is contained in:
parent
9c07e0189a
commit
7f5e0dd3db
1 changed files with 74 additions and 27 deletions
101
.github/workflows/gha_workflow_llama_stack_tests.yml
vendored
101
.github/workflows/gha_workflow_llama_stack_tests.yml
vendored
|
@ -31,28 +31,45 @@ on:
|
|||
required: true
|
||||
default: "0"
|
||||
|
||||
model_ids:
|
||||
description: 'Comma separated list of models to test'
|
||||
required: true
|
||||
default: "Llama3.2-3B-Instruct"
|
||||
|
||||
provider_id:
|
||||
description: 'ID of your provider'
|
||||
required: true
|
||||
default: "meta_reference"
|
||||
|
||||
model_id:
|
||||
description: 'Shorthand name for model ID (llama_3b or llama_8b)'
|
||||
required: true
|
||||
default: "llama_3b"
|
||||
|
||||
model_override_3b:
|
||||
description: 'Specify manual override for the <llama_3b> shorthand model'
|
||||
required: false
|
||||
default: "Llama-3.2-3B-Instruct"
|
||||
|
||||
model_override_8b:
|
||||
description: 'Specify manual override for the <llama_8b> shorthand model'
|
||||
required: false
|
||||
default: "Llama-3.1-8B-Instruct"
|
||||
|
||||
env:
|
||||
# ID used for each test's provider config
|
||||
PROVIDER_ID: "${{ inputs.provider_id || 'meta_reference' }}"
|
||||
|
||||
# Path to model checkpoints within EFS volume
|
||||
MODEL_CHECKPOINT_DIR: "/data/llama/Llama3.2-3B-Instruct"
|
||||
MODEL_CHECKPOINT_DIR: "/data/llama/"
|
||||
|
||||
# Path to directory to run tests from
|
||||
TESTS_PATH: "${{ github.workspace }}/llama_stack/providers/tests"
|
||||
|
||||
# List of models that are to be tested
|
||||
MODEL_IDS: "${{ inputs.model_ids || 'Llama3.2-3B-Instruct' }}"
|
||||
# Keep track of a list of model IDs that are valid to use within pytest fixture marks
|
||||
AVAILABLE_MODEL_IDs: "llama_3b llama_8b"
|
||||
|
||||
# ID used for each test's provider config
|
||||
PROVIDER_ID: "${{ inputs.provider_id || 'meta_reference' }}"
|
||||
# Shorthand name for model ID, used in pytest fixture marks
|
||||
MODEL_ID: "${{ inputs.model_id || 'llama_3b' }}"
|
||||
|
||||
# Set the llama 3b / 8b override for models if desired, else use the default.
|
||||
LLAMA_3B_OVERRIDE: "${{ inputs.model_override_3b || 'Llama-3.2-3B-Instruct' }}"
|
||||
LLAMA_8B_OVERRIDE: "${{ inputs.model_override_8b || 'Llama-3.1-8B-Instruct' }}"
|
||||
|
||||
# Defines which directories in TESTS_PATH to exclude from the test loop
|
||||
EXCLUDED_DIRS: "__pycache__"
|
||||
|
@ -83,26 +100,41 @@ jobs:
|
|||
echo "========= Content of the EFS mount ============="
|
||||
ls -la ${{ env.MODEL_CHECKPOINT_DIR }}
|
||||
|
||||
- name: "Check if models exist in EFS volume"
|
||||
id: check_if_models_exist
|
||||
run: |
|
||||
#for model_id in ${MODEL_IDS//,/ }; do
|
||||
# model_path="${MODEL_CHECKPOINT_DIR}/${model_id}"
|
||||
# if [ ! -d "${model_path}" ]; then
|
||||
# echo "Model '${model_id}' does not exist in mounted EFS volume, Terminating workflow."
|
||||
# exit 1
|
||||
# else
|
||||
# echo "Content of '${model_id}' model"
|
||||
# ls -la "${model_path}"
|
||||
# fi
|
||||
#done
|
||||
|
||||
- name: "[DEBUG] Get runner container OS information"
|
||||
id: debug_os_info
|
||||
if: ${{ inputs.debug == 'true' }}
|
||||
run: |
|
||||
cat /etc/os-release
|
||||
|
||||
############################
|
||||
#### MODEL INPUT CHECKS ####
|
||||
############################
|
||||
|
||||
- name: "Check if env.model_id is valid"
|
||||
id: check_model_id
|
||||
run: |
|
||||
if [[ " ${AVAILABLE_MODEL_IDs[@]} " =~ " ${MODEL_ID} " ]]; then
|
||||
echo "Model ID ${MODEL_ID} is valid"
|
||||
else
|
||||
echo "Model ID ${MODEL_ID} is invalid, Terminating workflow."
|
||||
exit 1
|
||||
fi
|
||||
|
||||
- name: "Check if models exist in EFS volume"
|
||||
id: check_if_models_exist
|
||||
run: |
|
||||
MODEL_IDS="${LLAMA_3B_OVERRIDE},${LLAMA_8B_OVERRIDE}"
|
||||
for model_id in ${MODEL_IDS//,/ }; do
|
||||
model_path="${MODEL_CHECKPOINT_DIR}/${model_id}"
|
||||
if [ ! -d "${model_path}" ]; then
|
||||
echo "Model '${model_id}' does not exist in mounted EFS volume, Terminating workflow."
|
||||
exit 1
|
||||
else
|
||||
echo "Content of '${model_id}' model"
|
||||
ls -la "${model_path}"
|
||||
fi
|
||||
done
|
||||
|
||||
#######################
|
||||
#### CODE CHECKOUT ####
|
||||
#######################
|
||||
|
@ -158,6 +190,21 @@ jobs:
|
|||
echo "PATH=$PATH"
|
||||
echo "GITHUB_PATH=$GITHUB_PATH"
|
||||
|
||||
#####################################
|
||||
#### UPDATE CHECKPOINT DIRECTORY ####
|
||||
#####################################
|
||||
- name: "Update checkpoint directory"
|
||||
id: checkpoint_update
|
||||
run: |
|
||||
if ${MODEL_ID} == "llama_3b" ; then
|
||||
echo "MODEL_CHECKPOINT_DIR=${MODEL_CHECKPOINT_DIR}/$LLAMA_3B_OVERRIDE" >> "$GITHUB_ENV"
|
||||
elif ${MODEL_ID} == "llama_8b" ; then
|
||||
echo "MODEL_CHECKPOINT_DIR=${MODEL_CHECKPOINT_DIR}/$LLAMA_8B_OVERRIDE" >> "$GITHUB_ENV"
|
||||
else
|
||||
echo "MODEL_ID is not valid, Terminating workflow."
|
||||
exit 1
|
||||
fi
|
||||
|
||||
##################################
|
||||
#### DEPENDENCY INSTALLATIONS ####
|
||||
##################################
|
||||
|
@ -225,11 +272,11 @@ jobs:
|
|||
for file in "$dir"/test_*.py; do
|
||||
test_name=$(basename "$file")
|
||||
new_file="result-${dir_name}-${test_name}.xml"
|
||||
if torchrun $(which pytest) -s -v ${TESTS_PATH}/${dir_name}/${test_name} -m "${PROVIDER_ID} and llama_3b" \
|
||||
if torchrun $(which pytest) -s -v ${TESTS_PATH}/${dir_name}/${test_name} -m "${PROVIDER_ID} and ${MODEL_ID}" \
|
||||
--junitxml="${{ github.workspace }}/${new_file}"; then
|
||||
echo "Test passed"
|
||||
echo "Ran test: ${test_name}"
|
||||
else
|
||||
echo "Test failed"
|
||||
echo "Did NOT run test: ${test_name}"
|
||||
fi
|
||||
pattern+="${new_file} "
|
||||
done
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue