add support for remote providers in tests

This commit is contained in:
Ashwin Bharambe 2024-11-04 19:57:40 -08:00
parent 0763a0b85f
commit 7cf4c905f3
11 changed files with 79 additions and 15 deletions

View file

@ -37,11 +37,19 @@ DEFAULT_PROVIDER_COMBINATIONS = [
id="together",
marks=pytest.mark.together,
),
pytest.param(
{
"inference": "remote",
"safety": "remote",
},
id="remote",
marks=pytest.mark.remote,
),
]
def pytest_configure(config):
for mark in ["meta_reference", "ollama", "together"]:
for mark in ["meta_reference", "ollama", "together", "remote"]:
config.addinivalue_line(
"markers",
f"{mark}: marks tests as {mark} specific",

View file

@ -16,10 +16,15 @@ from llama_stack.providers.impls.meta_reference.safety import (
from llama_stack.providers.tests.resolver import resolve_impls_for_test_v2
from ..conftest import ProviderFixture
from ..conftest import ProviderFixture, remote_stack_fixture
from ..env import get_env_or_fail
@pytest.fixture(scope="session")
def safety_remote() -> ProviderFixture:
return remote_stack_fixture()
@pytest.fixture(scope="session")
def safety_model(request):
if hasattr(request, "param"):
@ -60,7 +65,7 @@ def safety_together() -> ProviderFixture:
)
SAFETY_FIXTURES = ["meta_reference", "together"]
SAFETY_FIXTURES = ["meta_reference", "together", "remote"]
@pytest_asyncio.fixture(scope="session")

View file

@ -27,7 +27,7 @@ class TestSafety:
for shield in response:
assert isinstance(shield, ShieldDefWithProvider)
assert shield.type in [v.value for v in ShieldType]
assert shield.shield_type in [v.value for v in ShieldType]
@pytest.mark.asyncio
async def test_run_shield(self, safety_stack):