From ab47d4a71edd7539983f1529f4f84dd48668f8b5 Mon Sep 17 00:00:00 2001 From: Sixian Yi Date: Sun, 5 Jan 2025 22:33:26 -0800 Subject: [PATCH] Script for runnning required autotests --- .../providers/tests/inference/ci_test.py | 79 +++++++++++++++++++ 1 file changed, 79 insertions(+) create mode 100644 llama_stack/providers/tests/inference/ci_test.py diff --git a/llama_stack/providers/tests/inference/ci_test.py b/llama_stack/providers/tests/inference/ci_test.py new file mode 100644 index 000000000..a9bbba3a1 --- /dev/null +++ b/llama_stack/providers/tests/inference/ci_test.py @@ -0,0 +1,79 @@ +import os +import re +import signal +import subprocess +import time + +import pytest + + +# Inference provider and the required environment arg for running integration tests +INFERENCE_PROVIDER_ENV_KEY = { + "ollama": None, + "fireworks": "FIREWORKS_API_KEY", + "together": "TOGETHER_API_KEY", +} + +# Model category and the keywords of the corresponding functionality tests +CATEGORY_FUNCTIONALITY_TESTS = { + "text": ["streaming", "tool_calling", "structured_output"], + "vision": [ + "streaming", + ], +} + + +def generate_pytest_args(category, provider, env_key): + test_path = ( + "./llama_stack/providers/tests/inference/test_{model_type}_inference.py".format( + model_type=category + ) + ) + pytest_args = [ + test_path, + "-v", + "-s", + "-k", + provider, + ] + if env_key is not None: + pytest_args.extend( + [ + "--env", + "{key_name}={key_value}".format( + key_name=env_key, key_value=os.getenv(env_key) + ), + ] + ) + return pytest_args + + +def main(): + test_result = [] + + for model_category, functionality_tests in CATEGORY_FUNCTIONALITY_TESTS.items(): + for provider, env_key in INFERENCE_PROVIDER_ENV_KEY.items(): + if provider == "ollama": + proc = subprocess.Popen( + [ + "ollama", + "run", + ( + "llama3.1:8b-instruct-fp16" + if model_category == "text" + else "llama3.2-vision" + ), + ] + ) + retcode = pytest.main( + generate_pytest_args(model_category, provider, env_key) + ) + proc.terminate() + else: + retcode = pytest.main( + generate_pytest_args(model_category, provider, env_key) + ) + + +if __name__ == "__main__": + main()