This commit is contained in:
Sixian Yi 2025-01-14 16:33:29 -08:00
parent b22cc3e1fa
commit 9b709005fa
5 changed files with 46 additions and 92 deletions

View file

@ -6,6 +6,7 @@
import os
from collections import defaultdict
from pathlib import Path
from typing import Any, Dict, List, Optional
@ -13,15 +14,13 @@ import pytest
import yaml
from dotenv import load_dotenv
from llama_stack.distribution.datatypes import Provider
from llama_stack.providers.datatypes import RemoteProviderConfig
from pydantic import BaseModel, Field
from termcolor import colored
from .env import get_env_or_fail
from llama_stack.distribution.datatypes import Provider
from llama_stack.providers.datatypes import RemoteProviderConfig
from .test_config_helper import try_load_config_file_cached
from .env import get_env_or_fail
from .report import Report
@ -142,8 +141,8 @@ def pytest_configure(config):
key, value = env_var.split("=", 1)
os.environ[key] = value
if config.getoption("--config") is not None:
config.pluginmanager.register(Report(config))
if config.getoption("--output") is not None:
config.pluginmanager.register(Report(config.getoption("--output")))
def pytest_addoption(parser):
@ -160,6 +159,11 @@ def pytest_addoption(parser):
action="store",
help="Set test config file (supported format: YAML), e.g. --config=test_config.yml",
)
parser.addoption(
"--output",
action="store",
help="Set output file for test report, e.g. --output=pytest_report.md",
)
"""Add custom command line options"""
parser.addoption(
"--env", action="append", help="Set environment variables, e.g. --env KEY=value"
@ -269,9 +273,14 @@ def pytest_collection_modifyitems(session, config, items):
return
required_tests = defaultdict(set)
test_configs = [test_config.inference, test_config.memory, test_config.agent]
for test_config in test_configs:
for test in test_config.tests:
for api_test_config in [
test_config.inference,
test_config.memory,
test_config.agents,
]:
if api_test_config is None:
continue
for test in api_test_config.tests:
arr = test.split("::")
if len(arr) != 2:
raise ValueError(f"Invalid format for test name {test}")