# Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. # # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. import os from pathlib import Path from typing import Any, Dict, List, Optional import pytest import yaml from dotenv import load_dotenv from pydantic import BaseModel from termcolor import colored from llama_stack.distribution.datatypes import Provider from llama_stack.providers.datatypes import RemoteProviderConfig from .env import get_env_or_fail class ProviderFixture(BaseModel): providers: List[Provider] provider_data: Optional[Dict[str, Any]] = None def remote_stack_fixture() -> ProviderFixture: if url := os.getenv("REMOTE_STACK_URL", None): config = RemoteProviderConfig.from_url(url) else: config = RemoteProviderConfig( host=get_env_or_fail("REMOTE_STACK_HOST"), port=int(get_env_or_fail("REMOTE_STACK_PORT")), ) return ProviderFixture( providers=[ Provider( provider_id="test::remote", provider_type="test::remote", config=config.model_dump(), ) ], ) def pytest_configure(config): config.option.tbstyle = "short" config.option.disable_warnings = True """Load environment variables at start of test run""" # Load from .env file if it exists env_file = Path(__file__).parent / ".env" if env_file.exists(): load_dotenv(env_file) # Load any environment variables passed via --env env_vars = config.getoption("--env") or [] for env_var in env_vars: key, value = env_var.split("=", 1) os.environ[key] = value def pytest_addoption(parser): parser.addoption( "--providers", default="", help=( "Provider configuration in format: api1=provider1,api2=provider2. " "Example: --providers inference=ollama,safety=meta-reference" ), ) parser.addoption( "--config", action="store", help="Set test config file (supported format: YAML), e.g. --config=test_config.yml", ) """Add custom command line options""" parser.addoption( "--env", action="append", help="Set environment variables, e.g. --env KEY=value" ) parser.addoption( "--inference-model", action="store", default="meta-llama/Llama-3.2-3B-Instruct", help="Specify the inference model to use for testing", ) parser.addoption( "--safety-shield", action="store", default="meta-llama/Llama-Guard-3-1B", help="Specify the safety shield to use for testing", ) parser.addoption( "--embedding-model", action="store", default=None, help="Specify the embedding model to use for testing", ) parser.addoption( "--judge-model", action="store", default="meta-llama/Llama-3.1-8B-Instruct", help="Specify the judge model to use for testing", ) def make_provider_id(providers: Dict[str, str]) -> str: return ":".join(f"{api}={provider}" for api, provider in sorted(providers.items())) def get_provider_marks(providers: Dict[str, str]) -> List[Any]: marks = [] for provider in providers.values(): marks.append(getattr(pytest.mark, provider)) return marks def get_provider_fixture_overrides( config, available_fixtures: Dict[str, List[str]] ) -> Optional[List[pytest.param]]: provider_str = config.getoption("--providers") if not provider_str: return None fixture_dict = parse_fixture_string(provider_str, available_fixtures) return [ pytest.param( fixture_dict, id=make_provider_id(fixture_dict), marks=get_provider_marks(fixture_dict), ) ] def parse_fixture_string( provider_str: str, available_fixtures: Dict[str, List[str]] ) -> Dict[str, str]: """Parse provider string of format 'api1=provider1,api2=provider2'""" if not provider_str: return {} fixtures = {} pairs = provider_str.split(",") for pair in pairs: if "=" not in pair: raise ValueError( f"Invalid provider specification: {pair}. Expected format: api=provider" ) api, fixture = pair.split("=") if api not in available_fixtures: raise ValueError( f"Unknown API: {api}. Available APIs: {list(available_fixtures.keys())}" ) if fixture not in available_fixtures[api]: raise ValueError( f"Unknown provider '{fixture}' for API '{api}'. " f"Available providers: {list(available_fixtures[api])}" ) fixtures[api] = fixture # Check that all provided APIs are supported for api in available_fixtures.keys(): if api not in fixtures: raise ValueError( f"Missing provider fixture for API '{api}'. Available providers: " f"{list(available_fixtures[api])}" ) return fixtures def pytest_itemcollected(item): # Get all markers as a list filtered = ("asyncio", "parametrize") marks = [mark.name for mark in item.iter_markers() if mark.name not in filtered] if marks: marks = colored(",".join(marks), "yellow") item.name = f"{item.name}[{marks}]" def pytest_collection_modifyitems(session, config, items): if config.getoption("--config") is None: return file_name = config.getoption("--config") config_file_path = Path(__file__).parent / file_name if not config_file_path.exists(): raise ValueError( f"Test config {file_name} was specified but not found. Please make sure it exists in the llama_stack/providers/tests directory." ) required_tests = dict() inference_providers = set() with open(config_file_path, "r") as config_file: test_config = yaml.safe_load(config_file) for test in test_config["tests"]: required_tests[Path(__file__).parent / test["path"]] = set( test["functions"] ) inference_providers = set(test_config["inference_fixtures"]) new_items, deselected_items = [], [] for item in items: if item.fspath in required_tests: func_name = getattr(item, "originalname", item.name) if func_name in required_tests[item.fspath]: inference = item.callspec.params.get("inference_stack") if inference in inference_providers: new_items.append(item) continue deselected_items.append(item) items[:] = new_items config.hook.pytest_deselected(items=deselected_items) pytest_plugins = [ "llama_stack.providers.tests.inference.fixtures", "llama_stack.providers.tests.safety.fixtures", "llama_stack.providers.tests.memory.fixtures", "llama_stack.providers.tests.agents.fixtures", "llama_stack.providers.tests.datasetio.fixtures", "llama_stack.providers.tests.scoring.fixtures", "llama_stack.providers.tests.eval.fixtures", "llama_stack.providers.tests.post_training.fixtures", "llama_stack.providers.tests.tools.fixtures", ]