Test new provider name

This commit is contained in:
Connor Hack 2024-11-22 10:22:12 -08:00
parent 377896a4c5
commit 1481a67365
2 changed files with 32 additions and 6 deletions

View file

@ -41,6 +41,11 @@ on:
required: true
default: "meta-reference"
api_key:
description: 'Provider API key'
required: false
default: "---"
env:
# Path to model checkpoints within EFS volume
MODEL_CHECKPOINT_DIR: "/data/llama/Llama3.2-3B-Instruct"
@ -52,7 +57,13 @@ env:
MODEL_IDS: "${{ inputs.model_ids || 'Llama3.2-3B-Instruct' }}"
# ID used for each test's provider config
PROVIDER_ID: "${{ inputs.provider_id || 'meta-reference' }}"
#PROVIDER_ID: "${{ inputs.provider_id || 'meta-reference' }}"
# Defined dynamically when each test is run below
#PROVIDER_CONFIG: ""
# (Unused) API key that can be manually defined for workflow dispatch
API_KEY: "${{ inputs.api_key || '' }}"
# Defines which directories in TESTS_PATH to exclude from the test loop
EXCLUDED_DIRS: "__pycache__"
@ -67,7 +78,7 @@ jobs:
pull-requests: write
defaults:
run:
shell: bash
shell: bash # default shell to run all steps for a given job.
runs-on: ${{ inputs.runner != '' && inputs.runner || 'llama-stack-gha-runner-gpu' }}
if: always()
steps:
@ -134,6 +145,14 @@ jobs:
############################
#### UPDATE SYSTEM PATH ####
############################
- name: "[DEBUG] Update path: before"
id: path_update_before
if: ${{ inputs.debug == 'true' }}
run: |
echo "System path before update:"
echo "PATH=$PATH"
echo "GITHUB_PATH=$GITHUB_PATH"
- name: "Update path: execute"
id: path_update_exec
run: |
@ -142,6 +161,14 @@ jobs:
mkdir -p ${HOME}/.local/bin
echo "${HOME}/.local/bin" >> "$GITHUB_PATH"
- name: "[DEBUG] Update path: after"
id: path_update_after
if: ${{ inputs.debug == 'true' }}
run: |
echo "System path after update:"
echo "PATH=$PATH"
echo "GITHUB_PATH=$GITHUB_PATH"
##################################
#### DEPENDENCY INSTALLATIONS ####
##################################
@ -202,7 +229,6 @@ jobs:
working-directory: "${{ github.workspace }}"
run: |
pattern=""
echo "PROVIDER_ID = ${PROVIDER_ID}"
for dir in llama_stack/providers/tests/*; do
if [ -d "$dir" ]; then
dir_name=$(basename "$dir")
@ -210,7 +236,7 @@ jobs:
for file in "$dir"/test_*.py; do
test_name=$(basename "$file")
new_file="result-${dir_name}-${test_name}.xml"
if torchrun $(which pytest) -s -v ${TESTS_PATH}/${dir_name}/${test_name} -m "meta-reference and llama_3b" \
if torchrun $(which pytest) -s -v ${TESTS_PATH}/${dir_name}/${test_name} -m "meta_reference and llama_3b" \
--junitxml="${{ github.workspace }}/${new_file}"; then
echo "Test passed: $test_name"
else

View file

@ -36,6 +36,8 @@ class MetaReferenceInferenceImpl(Inference, ModelRegistryHelper, ModelsProtocolP
def __init__(self, config: MetaReferenceInferenceConfig) -> None:
self.config = config
model = resolve_model(config.model)
if model is None:
raise RuntimeError(f"Unknown model: {config.model}, Run `llama model list`")
ModelRegistryHelper.__init__(
self,
[
@ -45,8 +47,6 @@ class MetaReferenceInferenceImpl(Inference, ModelRegistryHelper, ModelsProtocolP
)
],
)
if model is None:
raise RuntimeError(f"Unknown model: {config.model}, Run `llama model list`")
self.model = model
# verify that the checkpoint actually is for this model lol