Refactor test run to support shorthand model names

This commit is contained in:
Connor Hack 2024-11-22 12:05:05 -08:00
parent 9c07e0189a
commit 7f5e0dd3db

View file

@ -31,28 +31,45 @@ on:
required: true
default: "0"
model_ids:
description: 'Comma separated list of models to test'
required: true
default: "Llama3.2-3B-Instruct"
provider_id:
description: 'ID of your provider'
required: true
default: "meta_reference"
model_id:
description: 'Shorthand name for model ID (llama_3b or llama_8b)'
required: true
default: "llama_3b"
model_override_3b:
description: 'Specify manual override for the <llama_3b> shorthand model'
required: false
default: "Llama-3.2-3B-Instruct"
model_override_8b:
description: 'Specify manual override for the <llama_8b> shorthand model'
required: false
default: "Llama-3.1-8B-Instruct"
env:
# ID used for each test's provider config
PROVIDER_ID: "${{ inputs.provider_id || 'meta_reference' }}"
# Path to model checkpoints within EFS volume
MODEL_CHECKPOINT_DIR: "/data/llama/Llama3.2-3B-Instruct"
MODEL_CHECKPOINT_DIR: "/data/llama/"
# Path to directory to run tests from
TESTS_PATH: "${{ github.workspace }}/llama_stack/providers/tests"
# List of models that are to be tested
MODEL_IDS: "${{ inputs.model_ids || 'Llama3.2-3B-Instruct' }}"
# Keep track of a list of model IDs that are valid to use within pytest fixture marks
AVAILABLE_MODEL_IDs: "llama_3b llama_8b"
# ID used for each test's provider config
PROVIDER_ID: "${{ inputs.provider_id || 'meta_reference' }}"
# Shorthand name for model ID, used in pytest fixture marks
MODEL_ID: "${{ inputs.model_id || 'llama_3b' }}"
# Set the llama 3b / 8b override for models if desired, else use the default.
LLAMA_3B_OVERRIDE: "${{ inputs.model_override_3b || 'Llama-3.2-3B-Instruct' }}"
LLAMA_8B_OVERRIDE: "${{ inputs.model_override_8b || 'Llama-3.1-8B-Instruct' }}"
# Defines which directories in TESTS_PATH to exclude from the test loop
EXCLUDED_DIRS: "__pycache__"
@ -83,26 +100,41 @@ jobs:
echo "========= Content of the EFS mount ============="
ls -la ${{ env.MODEL_CHECKPOINT_DIR }}
- name: "Check if models exist in EFS volume"
id: check_if_models_exist
run: |
#for model_id in ${MODEL_IDS//,/ }; do
# model_path="${MODEL_CHECKPOINT_DIR}/${model_id}"
# if [ ! -d "${model_path}" ]; then
# echo "Model '${model_id}' does not exist in mounted EFS volume, Terminating workflow."
# exit 1
# else
# echo "Content of '${model_id}' model"
# ls -la "${model_path}"
# fi
#done
- name: "[DEBUG] Get runner container OS information"
id: debug_os_info
if: ${{ inputs.debug == 'true' }}
run: |
cat /etc/os-release
############################
#### MODEL INPUT CHECKS ####
############################
- name: "Check if env.model_id is valid"
id: check_model_id
run: |
if [[ " ${AVAILABLE_MODEL_IDs[@]} " =~ " ${MODEL_ID} " ]]; then
echo "Model ID ${MODEL_ID} is valid"
else
echo "Model ID ${MODEL_ID} is invalid, Terminating workflow."
exit 1
fi
- name: "Check if models exist in EFS volume"
id: check_if_models_exist
run: |
MODEL_IDS="${LLAMA_3B_OVERRIDE},${LLAMA_8B_OVERRIDE}"
for model_id in ${MODEL_IDS//,/ }; do
model_path="${MODEL_CHECKPOINT_DIR}/${model_id}"
if [ ! -d "${model_path}" ]; then
echo "Model '${model_id}' does not exist in mounted EFS volume, Terminating workflow."
exit 1
else
echo "Content of '${model_id}' model"
ls -la "${model_path}"
fi
done
#######################
#### CODE CHECKOUT ####
#######################
@ -158,6 +190,21 @@ jobs:
echo "PATH=$PATH"
echo "GITHUB_PATH=$GITHUB_PATH"
#####################################
#### UPDATE CHECKPOINT DIRECTORY ####
#####################################
- name: "Update checkpoint directory"
id: checkpoint_update
run: |
if ${MODEL_ID} == "llama_3b" ; then
echo "MODEL_CHECKPOINT_DIR=${MODEL_CHECKPOINT_DIR}/$LLAMA_3B_OVERRIDE" >> "$GITHUB_ENV"
elif ${MODEL_ID} == "llama_8b" ; then
echo "MODEL_CHECKPOINT_DIR=${MODEL_CHECKPOINT_DIR}/$LLAMA_8B_OVERRIDE" >> "$GITHUB_ENV"
else
echo "MODEL_ID is not valid, Terminating workflow."
exit 1
fi
##################################
#### DEPENDENCY INSTALLATIONS ####
##################################
@ -225,11 +272,11 @@ jobs:
for file in "$dir"/test_*.py; do
test_name=$(basename "$file")
new_file="result-${dir_name}-${test_name}.xml"
if torchrun $(which pytest) -s -v ${TESTS_PATH}/${dir_name}/${test_name} -m "${PROVIDER_ID} and llama_3b" \
if torchrun $(which pytest) -s -v ${TESTS_PATH}/${dir_name}/${test_name} -m "${PROVIDER_ID} and ${MODEL_ID}" \
--junitxml="${{ github.workspace }}/${new_file}"; then
echo "Test passed"
echo "Ran test: ${test_name}"
else
echo "Test failed"
echo "Did NOT run test: ${test_name}"
fi
pattern+="${new_file} "
done