llama-stack-mirror/llama_stack/providers/tests/conftest.py
Sixian Yi bf610adf9b Temporary Commit at 1/6/2025, 8:13:09 PM
Summary:
### THIS DIFF

### PLAN

### CONTEXT

### DESIGN

Test Plan:
# Test Setup

**Type checker and check that the build compiles**

**Unit Tests**

**E2E Tests**

// Screenshots and videos
| Before | After |
|--|
| … | … |

# Monitoring Plan.
2025-01-06 20:16:00 -08:00

221 lines
7 KiB
Python

# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import os
from pathlib import Path
from typing import Any, Dict, List, Optional
import yaml
import json
import pytest
from dotenv import load_dotenv
from pydantic import BaseModel
from termcolor import colored
from llama_stack.distribution.datatypes import Provider
from llama_stack.providers.datatypes import RemoteProviderConfig
from .env import get_env_or_fail
class ProviderFixture(BaseModel):
providers: List[Provider]
provider_data: Optional[Dict[str, Any]] = None
def remote_stack_fixture() -> ProviderFixture:
if url := os.getenv("REMOTE_STACK_URL", None):
config = RemoteProviderConfig.from_url(url)
else:
config = RemoteProviderConfig(
host=get_env_or_fail("REMOTE_STACK_HOST"),
port=int(get_env_or_fail("REMOTE_STACK_PORT")),
)
return ProviderFixture(
providers=[
Provider(
provider_id="test::remote",
provider_type="test::remote",
config=config.model_dump(),
)
],
)
def pytest_configure(config):
config.option.tbstyle = "short"
config.option.disable_warnings = True
"""Load environment variables at start of test run"""
# Load from .env file if it exists
env_file = Path(__file__).parent / ".env"
if env_file.exists():
load_dotenv(env_file)
# Load any environment variables passed via --env
env_vars = config.getoption("--env") or []
for env_var in env_vars:
key, value = env_var.split("=", 1)
os.environ[key] = value
def pytest_addoption(parser):
parser.addoption(
"--providers",
default="",
help=(
"Provider configuration in format: api1=provider1,api2=provider2. "
"Example: --providers inference=ollama,safety=meta-reference"
),
)
"""Add custom command line options"""
parser.addoption(
"--env", action="append", help="Set environment variables, e.g. --env KEY=value"
)
parser.addoption(
"--config", action="append", help="Set environment variables, e.g. --env KEY=value"
)
def make_provider_id(providers: Dict[str, str]) -> str:
return ":".join(f"{api}={provider}" for api, provider in sorted(providers.items()))
def get_provider_marks(providers: Dict[str, str]) -> List[Any]:
marks = []
for provider in providers.values():
marks.append(getattr(pytest.mark, provider))
return marks
def get_provider_fixture_overrides(
config, available_fixtures: Dict[str, List[str]]
) -> Optional[List[pytest.param]]:
provider_str = config.getoption("--providers")
if not provider_str:
return None
fixture_dict = parse_fixture_string(provider_str, available_fixtures)
return [
pytest.param(
fixture_dict,
id=make_provider_id(fixture_dict),
marks=get_provider_marks(fixture_dict),
)
]
def parse_fixture_string(
provider_str: str, available_fixtures: Dict[str, List[str]]
) -> Dict[str, str]:
"""Parse provider string of format 'api1=provider1,api2=provider2'"""
if not provider_str:
return {}
fixtures = {}
pairs = provider_str.split(",")
for pair in pairs:
if "=" not in pair:
raise ValueError(
f"Invalid provider specification: {pair}. Expected format: api=provider"
)
api, fixture = pair.split("=")
if api not in available_fixtures:
raise ValueError(
f"Unknown API: {api}. Available APIs: {list(available_fixtures.keys())}"
)
if fixture not in available_fixtures[api]:
raise ValueError(
f"Unknown provider '{fixture}' for API '{api}'. "
f"Available providers: {list(available_fixtures[api])}"
)
fixtures[api] = fixture
# Check that all provided APIs are supported
for api in available_fixtures.keys():
if api not in fixtures:
raise ValueError(
f"Missing provider fixture for API '{api}'. Available providers: "
f"{list(available_fixtures[api])}"
)
return fixtures
def pytest_itemcollected(item):
# Get all markers as a list
filtered = ("asyncio", "parametrize")
marks = [mark.name for mark in item.iter_markers() if mark.name not in filtered]
if marks:
marks = colored(",".join(marks), "yellow")
item.name = f"{item.name}[{marks}]"
# def pytest_ignore_collect(path, config):
# # print("SIXIAN ignore", path)
# # if config.getoption("--config") is None:
# # return False
# manifest_path = Path(__file__).parent / "manifest.json"
# files = json.loads(manifest_path.read_text(encoding="utf-8"))["files"]
# file_path = set((Path(__file__).parent / f["path"]) for f in files)
# # print("SIXIAN ", path.__class__.__name__)
# # print("SIXIAN files", file_path)
# if Path(path) not in file_path:
# # print("SIXIAN ignoring", path, file_path)
# return True
# return False
def pytest_collection_modifyitems(config, items):
manifest_path = Path(__file__).parent / "manifest.json"
manifest = json.loads(manifest_path.read_text(encoding="utf-8"))
mapping = dict()
for f in manifest["files"]:
mapping[f["basepath"]] = f["functions"]
ret_items = []
deselected = []
print("SIXIAN test_file", items)
for item in items:
test_file = item.fspath.basename
# print("SIXIAN test_file", test_file, "mapping", mapping)
if test_file in mapping.keys():
print("SIXIAN appending", item)
ret_items.append(item)
else:
deselected.append(item)
print("SIXIAN ret_items", ret_items)
config.hook.pytest_deselected(items=[deselected])
# class ManifestDirectory(pytest.Directory):
# def collect(self):
# manifest_path = self.path / "manifest.json"
# manifest = json.loads(manifest_path.read_text(encoding="utf-8"))
# ihook = self.ihook
# for file in manifest["files"]:
# yield from ihook.pytest_collect_file(
# file_path=self.path / file, parent=self
# )
# def pytest_collect_directory(path, parent):
# if config.getoption("--env") is None:
# return None
# if path.joinpath("manifest.json").is_file():
# return ManifestDirectory.from_parent(parent=parent, path=path)
# return None
pytest_plugins = [
"llama_stack.providers.tests.inference.fixtures",
"llama_stack.providers.tests.safety.fixtures",
"llama_stack.providers.tests.memory.fixtures",
"llama_stack.providers.tests.agents.fixtures",
"llama_stack.providers.tests.datasetio.fixtures",
"llama_stack.providers.tests.scoring.fixtures",
"llama_stack.providers.tests.eval.fixtures",
"llama_stack.providers.tests.post_training.fixtures",
]