diff --git a/tests/verifications/conftest.py b/tests/verifications/conftest.py index 0b4a6feb7..030efcde9 100644 --- a/tests/verifications/conftest.py +++ b/tests/verifications/conftest.py @@ -25,6 +25,11 @@ def pytest_addoption(parser): action="store", help="Provider to use for testing", ) + parser.addoption( + "--model", + action="store", + help="Model to use for testing", + ) pytest_plugins = [ diff --git a/tests/verifications/openai_api/conftest.py b/tests/verifications/openai_api/conftest.py index e4f7f27a0..b55a5d11a 100644 --- a/tests/verifications/openai_api/conftest.py +++ b/tests/verifications/openai_api/conftest.py @@ -16,6 +16,11 @@ def pytest_generate_tests(metafunc): metafunc.parametrize("model", []) return + model = metafunc.config.getoption("model") + if model: + metafunc.parametrize("model", [model]) + return + try: config_data = _load_all_verification_configs() except (OSError, FileNotFoundError) as e: diff --git a/tests/verifications/openai_api/fixtures/fixtures.py b/tests/verifications/openai_api/fixtures/fixtures.py index a7328e5f6..a3be7e402 100644 --- a/tests/verifications/openai_api/fixtures/fixtures.py +++ b/tests/verifications/openai_api/fixtures/fixtures.py @@ -12,6 +12,8 @@ import pytest import yaml from openai import OpenAI +from llama_stack import LlamaStackAsLibraryClient + # --- Helper Functions --- @@ -81,7 +83,7 @@ def verification_config(): pytest.fail(str(e)) # Fail test collection if config loading fails -@pytest.fixture +@pytest.fixture(scope="session") def provider(request, verification_config): provider = request.config.getoption("--provider") base_url = request.config.getoption("--base-url") @@ -100,12 +102,14 @@ def provider(request, verification_config): return provider -@pytest.fixture +@pytest.fixture(scope="session") def base_url(request, provider, verification_config): - return request.config.getoption("--base-url") or verification_config["providers"][provider]["base_url"] + return request.config.getoption("--base-url") or verification_config.get("providers", {}).get(provider, {}).get( + "base_url" + ) -@pytest.fixture +@pytest.fixture(scope="session") def api_key(request, provider, verification_config): provider_conf = verification_config.get("providers", {}).get(provider, {}) api_key_env_var = provider_conf.get("api_key_var") @@ -122,11 +126,21 @@ def model_mapping(provider, providers_model_mapping): return providers_model_mapping[provider] -@pytest.fixture -def openai_client(base_url, api_key): +@pytest.fixture(scope="session") +def openai_client(base_url, api_key, provider): # Simplify running against a local Llama Stack - if "localhost" in base_url and not api_key: + if base_url and "localhost" in base_url and not api_key: api_key = "empty" + if provider.startswith("stack:"): + parts = provider.split(":") + if len(parts) != 2: + raise ValueError(f"Invalid config for Llama Stack: {provider}, it must be of the form 'stack:'") + config = parts[1] + client = LlamaStackAsLibraryClient(config, skip_logger_removal=True) + if not client.initialize(): + raise RuntimeError("Initialization failed") + return client + return OpenAI( base_url=base_url, api_key=api_key,