forked from phoenix-oss/llama-stack-mirror
Significantly simpler and malleable test setup (#360)
* Significantly simpler and malleable test setup * convert memory tests * refactor fixtures and add support for composable fixtures * Fix memory to use the newer fixture organization * Get agents tests working * Safety tests work * yet another refactor to make this more general now it accepts --inference-model, --safety-model options also * get multiple providers working for meta-reference (for inference + safety) * Add README.md --------- Co-authored-by: Ashwin Bharambe <ashwin@meta.com>
This commit is contained in:
parent
663883cc29
commit
ffedb81c11
25 changed files with 1491 additions and 790 deletions
134
llama_stack/providers/tests/conftest.py
Normal file
134
llama_stack/providers/tests/conftest.py
Normal file
|
@ -0,0 +1,134 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import os
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
import pytest
|
||||
from dotenv import load_dotenv
|
||||
from pydantic import BaseModel
|
||||
from termcolor import colored
|
||||
|
||||
from llama_stack.distribution.datatypes import Provider
|
||||
|
||||
|
||||
class ProviderFixture(BaseModel):
|
||||
providers: List[Provider]
|
||||
provider_data: Optional[Dict[str, Any]] = None
|
||||
|
||||
|
||||
def pytest_configure(config):
|
||||
config.option.tbstyle = "short"
|
||||
config.option.disable_warnings = True
|
||||
|
||||
"""Load environment variables at start of test run"""
|
||||
# Load from .env file if it exists
|
||||
env_file = Path(__file__).parent / ".env"
|
||||
if env_file.exists():
|
||||
load_dotenv(env_file)
|
||||
|
||||
# Load any environment variables passed via --env
|
||||
env_vars = config.getoption("--env") or []
|
||||
for env_var in env_vars:
|
||||
key, value = env_var.split("=", 1)
|
||||
os.environ[key] = value
|
||||
|
||||
|
||||
def pytest_addoption(parser):
|
||||
parser.addoption(
|
||||
"--providers",
|
||||
default="",
|
||||
help=(
|
||||
"Provider configuration in format: api1=provider1,api2=provider2. "
|
||||
"Example: --providers inference=ollama,safety=meta-reference"
|
||||
),
|
||||
)
|
||||
"""Add custom command line options"""
|
||||
parser.addoption(
|
||||
"--env", action="append", help="Set environment variables, e.g. --env KEY=value"
|
||||
)
|
||||
|
||||
|
||||
def make_provider_id(providers: Dict[str, str]) -> str:
|
||||
return ":".join(f"{api}={provider}" for api, provider in sorted(providers.items()))
|
||||
|
||||
|
||||
def get_provider_marks(providers: Dict[str, str]) -> List[Any]:
|
||||
marks = []
|
||||
for provider in providers.values():
|
||||
marks.append(getattr(pytest.mark, provider))
|
||||
return marks
|
||||
|
||||
|
||||
def get_provider_fixture_overrides(
|
||||
config, available_fixtures: Dict[str, List[str]]
|
||||
) -> Optional[List[pytest.param]]:
|
||||
provider_str = config.getoption("--providers")
|
||||
if not provider_str:
|
||||
return None
|
||||
|
||||
fixture_dict = parse_fixture_string(provider_str, available_fixtures)
|
||||
return [
|
||||
pytest.param(
|
||||
fixture_dict,
|
||||
id=make_provider_id(fixture_dict),
|
||||
marks=get_provider_marks(fixture_dict),
|
||||
)
|
||||
]
|
||||
|
||||
|
||||
def parse_fixture_string(
|
||||
provider_str: str, available_fixtures: Dict[str, List[str]]
|
||||
) -> Dict[str, str]:
|
||||
"""Parse provider string of format 'api1=provider1,api2=provider2'"""
|
||||
if not provider_str:
|
||||
return {}
|
||||
|
||||
fixtures = {}
|
||||
pairs = provider_str.split(",")
|
||||
for pair in pairs:
|
||||
if "=" not in pair:
|
||||
raise ValueError(
|
||||
f"Invalid provider specification: {pair}. Expected format: api=provider"
|
||||
)
|
||||
api, fixture = pair.split("=")
|
||||
if api not in available_fixtures:
|
||||
raise ValueError(
|
||||
f"Unknown API: {api}. Available APIs: {list(available_fixtures.keys())}"
|
||||
)
|
||||
if fixture not in available_fixtures[api]:
|
||||
raise ValueError(
|
||||
f"Unknown provider '{fixture}' for API '{api}'. "
|
||||
f"Available providers: {list(available_fixtures[api])}"
|
||||
)
|
||||
fixtures[api] = fixture
|
||||
|
||||
# Check that all provided APIs are supported
|
||||
for api in available_fixtures.keys():
|
||||
if api not in fixtures:
|
||||
raise ValueError(
|
||||
f"Missing provider fixture for API '{api}'. Available providers: "
|
||||
f"{list(available_fixtures[api])}"
|
||||
)
|
||||
return fixtures
|
||||
|
||||
|
||||
def pytest_itemcollected(item):
|
||||
# Get all markers as a list
|
||||
filtered = ("asyncio", "parametrize")
|
||||
marks = [mark.name for mark in item.iter_markers() if mark.name not in filtered]
|
||||
if marks:
|
||||
marks = colored(",".join(marks), "yellow")
|
||||
item.name = f"{item.name}[{marks}]"
|
||||
|
||||
|
||||
pytest_plugins = [
|
||||
"llama_stack.providers.tests.inference.fixtures",
|
||||
"llama_stack.providers.tests.safety.fixtures",
|
||||
"llama_stack.providers.tests.memory.fixtures",
|
||||
"llama_stack.providers.tests.agents.fixtures",
|
||||
]
|
Loading…
Add table
Add a link
Reference in a new issue