Script for runnning required autotests

This commit is contained in:
Sixian Yi 2025-01-05 22:33:26 -08:00
parent 8e5b336792
commit ab47d4a71e

View file

@ -0,0 +1,79 @@
import os
import re
import signal
import subprocess
import time
import pytest
# Inference provider and the required environment arg for running integration tests
INFERENCE_PROVIDER_ENV_KEY = {
"ollama": None,
"fireworks": "FIREWORKS_API_KEY",
"together": "TOGETHER_API_KEY",
}
# Model category and the keywords of the corresponding functionality tests
CATEGORY_FUNCTIONALITY_TESTS = {
"text": ["streaming", "tool_calling", "structured_output"],
"vision": [
"streaming",
],
}
def generate_pytest_args(category, provider, env_key):
test_path = (
"./llama_stack/providers/tests/inference/test_{model_type}_inference.py".format(
model_type=category
)
)
pytest_args = [
test_path,
"-v",
"-s",
"-k",
provider,
]
if env_key is not None:
pytest_args.extend(
[
"--env",
"{key_name}={key_value}".format(
key_name=env_key, key_value=os.getenv(env_key)
),
]
)
return pytest_args
def main():
test_result = []
for model_category, functionality_tests in CATEGORY_FUNCTIONALITY_TESTS.items():
for provider, env_key in INFERENCE_PROVIDER_ENV_KEY.items():
if provider == "ollama":
proc = subprocess.Popen(
[
"ollama",
"run",
(
"llama3.1:8b-instruct-fp16"
if model_category == "text"
else "llama3.2-vision"
),
]
)
retcode = pytest.main(
generate_pytest_args(model_category, provider, env_key)
)
proc.terminate()
else:
retcode = pytest.main(
generate_pytest_args(model_category, provider, env_key)
)
if __name__ == "__main__":
main()