mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-12-16 20:32:38 +00:00
Summary: Part of * https://github.com/meta-llama/llama-stack/issues/436 This change adds a minimalistic support to creating memory provider test fake. For more details about the approach, follow the issue link from above. Test Plan: Run tests using the "test_fake" mark: ``` pytest llama_stack/providers/tests/memory/test_memory.py -m "test_fake" /opt/homebrew/Caskroom/miniconda/base/envs/llama-stack/lib/python3.11/site-packages/pytest_asyncio/plugin.py:208: PytestDeprecationWarning: The configuration option "asyncio_default_fixture_loop_scope" is unset. The event loop scope for asynchronous fixtures will default to the fixture caching scope. Future versions of pytest-asyncio will default the loop scope for asynchronous fixtures to function scope. Set the default fixture loop scope explicitly in order to avoid unexpected behavior in the future. Valid fixture loop scopes are: "function", "class", "module", "package", "session" warnings.warn(PytestDeprecationWarning(_DEFAULT_FIXTURE_LOOP_SCOPE_UNSET)) ====================================================================================================== test session starts ====================================================================================================== platform darwin -- Python 3.11.10, pytest-8.3.3, pluggy-1.5.0 rootdir: /llama-stack configfile: pyproject.toml plugins: asyncio-0.24.0, anyio-4.6.2.post1 asyncio: mode=Mode.STRICT, default_loop_scope=None collected 18 items / 15 deselected / 3 selected llama_stack/providers/tests/memory/test_memory.py ... [100%] ========================================================================================= 3 passed, 15 deselected, 6 warnings in 0.03s ========================================================================================== ```
173 lines
5.3 KiB
Python
173 lines
5.3 KiB
Python
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
# All rights reserved.
|
|
#
|
|
# This source code is licensed under the terms described in the LICENSE file in
|
|
# the root directory of this source tree.
|
|
|
|
import os
|
|
from pathlib import Path
|
|
from typing import Any, Dict, List, Optional
|
|
|
|
import pytest
|
|
from dotenv import load_dotenv
|
|
from pydantic import BaseModel
|
|
from termcolor import colored
|
|
|
|
from llama_stack.distribution.datatypes import Provider
|
|
from llama_stack.providers.datatypes import RemoteProviderConfig, TestFakeProviderConfig
|
|
|
|
from .env import get_env_or_fail
|
|
|
|
|
|
class ProviderFixture(BaseModel):
|
|
providers: List[Provider]
|
|
provider_data: Optional[Dict[str, Any]] = None
|
|
|
|
|
|
# Generic test fake fixture. Use TestFakeProviderConfig to set test fakes
|
|
# that will be mapped to their corresponding APIs.
|
|
def test_fake_stack_fixture(config: TestFakeProviderConfig) -> ProviderFixture:
|
|
return ProviderFixture(
|
|
providers=[
|
|
Provider(
|
|
provider_id="test::test-fake",
|
|
provider_type="test::test-fake",
|
|
config=config.model_dump(),
|
|
)
|
|
],
|
|
)
|
|
|
|
|
|
def remote_stack_fixture() -> ProviderFixture:
|
|
if url := os.getenv("REMOTE_STACK_URL", None):
|
|
config = RemoteProviderConfig.from_url(url)
|
|
else:
|
|
config = RemoteProviderConfig(
|
|
host=get_env_or_fail("REMOTE_STACK_HOST"),
|
|
port=int(get_env_or_fail("REMOTE_STACK_PORT")),
|
|
)
|
|
return ProviderFixture(
|
|
providers=[
|
|
Provider(
|
|
provider_id="test::remote",
|
|
provider_type="test::remote",
|
|
config=config.model_dump(),
|
|
)
|
|
],
|
|
)
|
|
|
|
|
|
def pytest_configure(config):
|
|
config.option.tbstyle = "short"
|
|
config.option.disable_warnings = True
|
|
|
|
"""Load environment variables at start of test run"""
|
|
# Load from .env file if it exists
|
|
env_file = Path(__file__).parent / ".env"
|
|
if env_file.exists():
|
|
load_dotenv(env_file)
|
|
|
|
# Load any environment variables passed via --env
|
|
env_vars = config.getoption("--env") or []
|
|
for env_var in env_vars:
|
|
key, value = env_var.split("=", 1)
|
|
os.environ[key] = value
|
|
|
|
|
|
def pytest_addoption(parser):
|
|
parser.addoption(
|
|
"--providers",
|
|
default="",
|
|
help=(
|
|
"Provider configuration in format: api1=provider1,api2=provider2. "
|
|
"Example: --providers inference=ollama,safety=meta-reference"
|
|
),
|
|
)
|
|
"""Add custom command line options"""
|
|
parser.addoption(
|
|
"--env", action="append", help="Set environment variables, e.g. --env KEY=value"
|
|
)
|
|
|
|
|
|
def make_provider_id(providers: Dict[str, str]) -> str:
|
|
return ":".join(f"{api}={provider}" for api, provider in sorted(providers.items()))
|
|
|
|
|
|
def get_provider_marks(providers: Dict[str, str]) -> List[Any]:
|
|
marks = []
|
|
for provider in providers.values():
|
|
marks.append(getattr(pytest.mark, provider))
|
|
return marks
|
|
|
|
|
|
def get_provider_fixture_overrides(
|
|
config, available_fixtures: Dict[str, List[str]]
|
|
) -> Optional[List[pytest.param]]:
|
|
provider_str = config.getoption("--providers")
|
|
if not provider_str:
|
|
return None
|
|
|
|
fixture_dict = parse_fixture_string(provider_str, available_fixtures)
|
|
return [
|
|
pytest.param(
|
|
fixture_dict,
|
|
id=make_provider_id(fixture_dict),
|
|
marks=get_provider_marks(fixture_dict),
|
|
)
|
|
]
|
|
|
|
|
|
def parse_fixture_string(
|
|
provider_str: str, available_fixtures: Dict[str, List[str]]
|
|
) -> Dict[str, str]:
|
|
"""Parse provider string of format 'api1=provider1,api2=provider2'"""
|
|
if not provider_str:
|
|
return {}
|
|
|
|
fixtures = {}
|
|
pairs = provider_str.split(",")
|
|
for pair in pairs:
|
|
if "=" not in pair:
|
|
raise ValueError(
|
|
f"Invalid provider specification: {pair}. Expected format: api=provider"
|
|
)
|
|
api, fixture = pair.split("=")
|
|
if api not in available_fixtures:
|
|
raise ValueError(
|
|
f"Unknown API: {api}. Available APIs: {list(available_fixtures.keys())}"
|
|
)
|
|
if fixture not in available_fixtures[api]:
|
|
raise ValueError(
|
|
f"Unknown provider '{fixture}' for API '{api}'. "
|
|
f"Available providers: {list(available_fixtures[api])}"
|
|
)
|
|
fixtures[api] = fixture
|
|
|
|
# Check that all provided APIs are supported
|
|
for api in available_fixtures.keys():
|
|
if api not in fixtures:
|
|
raise ValueError(
|
|
f"Missing provider fixture for API '{api}'. Available providers: "
|
|
f"{list(available_fixtures[api])}"
|
|
)
|
|
return fixtures
|
|
|
|
|
|
def pytest_itemcollected(item):
|
|
# Get all markers as a list
|
|
filtered = ("asyncio", "parametrize")
|
|
marks = [mark.name for mark in item.iter_markers() if mark.name not in filtered]
|
|
if marks:
|
|
marks = colored(",".join(marks), "yellow")
|
|
item.name = f"{item.name}[{marks}]"
|
|
|
|
|
|
pytest_plugins = [
|
|
"llama_stack.providers.tests.inference.fixtures",
|
|
"llama_stack.providers.tests.safety.fixtures",
|
|
"llama_stack.providers.tests.memory.fixtures",
|
|
"llama_stack.providers.tests.agents.fixtures",
|
|
"llama_stack.providers.tests.datasetio.fixtures",
|
|
"llama_stack.providers.tests.scoring.fixtures",
|
|
"llama_stack.providers.tests.eval.fixtures",
|
|
]
|