yet another refactor to make this more general

now it accepts --inference-model, --safety-model options also
This commit is contained in:
Ashwin Bharambe 2024-11-04 14:16:35 -08:00 committed by Ashwin Bharambe
parent 2ed0267fbb
commit 60800bc09b
13 changed files with 127 additions and 61 deletions

View file

@ -126,8 +126,13 @@ class CommonRoutingTableImpl(RoutingTable):
if routing_key not in self.registry:
apiname, objname = apiname_object()
provider_ids = list(self.impls_by_provider_id.keys())
if len(provider_ids) > 1:
provider_ids_str = f"any of the providers: {', '.join(provider_ids)}"
else:
provider_ids_str = f"provider: `{provider_ids[0]}`"
raise ValueError(
f"`{routing_key}` not registered. Make sure there is an {apiname} provider serving this {objname}."
f"`{routing_key}` not served by {provider_ids_str}. Make sure there is an {apiname} provider serving this {objname}."
)
objs = self.registry[routing_key]

View file

@ -57,7 +57,34 @@ def pytest_configure(config):
)
def pytest_addoption(parser):
parser.addoption(
"--inference-model",
action="store",
default="Llama3.1-8B-Instruct",
help="Specify the inference model to use for testing",
)
parser.addoption(
"--safety-model",
action="store",
default="Llama-Guard-3-8B",
help="Specify the safety model to use for testing",
)
def pytest_generate_tests(metafunc):
if "inference_model" in metafunc.fixturenames:
metafunc.parametrize(
"inference_model",
[pytest.param(metafunc.config.getoption("--inference-model"), id="")],
indirect=True,
)
if "safety_model" in metafunc.fixturenames:
metafunc.parametrize(
"safety_model",
[pytest.param(metafunc.config.getoption("--safety-model"), id="")],
indirect=True,
)
if "agents_stack" in metafunc.fixturenames:
available_fixtures = {
"inference": INFERENCE_FIXTURES,

View file

@ -22,7 +22,7 @@ from ..conftest import ProviderFixture
@pytest.fixture(scope="session")
def agents_meta_reference(inference_model, safety_model) -> ProviderFixture:
def agents_meta_reference() -> ProviderFixture:
sqlite_file = tempfile.NamedTemporaryFile(delete=False, suffix=".db")
return ProviderFixture(
provider=Provider(
@ -42,12 +42,12 @@ AGENTS_FIXTURES = ["meta_reference"]
@pytest_asyncio.fixture(scope="session")
async def agents_stack(inference_model, safety_model, request):
async def agents_stack(request):
fixture_dict = request.param
providers = {}
provider_data = {}
for key in ["agents", "inference", "safety", "memory"]:
for key in ["inference", "safety", "memory", "agents"]:
fixture = request.getfixturevalue(f"{key}_{fixture_dict[key]}")
providers[key] = [fixture.provider]
if fixture.provider_data:

View file

@ -75,12 +75,6 @@ async def create_agent_session(agents_impl, agent_config):
return agent_id, session_id
@pytest.mark.parametrize(
"inference_model", [pytest.param("Llama3.1-8B-Instruct", id="")], indirect=True
)
@pytest.mark.parametrize(
"safety_model", [pytest.param("Llama-Guard-3-8B", id="")], indirect=True
)
class TestAgents:
@pytest.mark.asyncio
async def test_agent_turns_with_safety(self, agents_stack, common_params):

View file

@ -4,9 +4,20 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import pytest
from .fixtures import INFERENCE_FIXTURES
def pytest_addoption(parser):
parser.addoption(
"--inference-model",
action="store",
default=None,
help="Specify the inference model to use for testing",
)
def pytest_configure(config):
config.addinivalue_line(
"markers", "llama_8b: mark test to run only with the given model"
@ -19,3 +30,33 @@ def pytest_configure(config):
"markers",
f"{fixture_name}: marks tests as {fixture_name} specific",
)
MODEL_PARAMS = [
pytest.param("Llama3.1-8B-Instruct", marks=pytest.mark.llama_8b, id="llama_8b"),
pytest.param("Llama3.2-3B-Instruct", marks=pytest.mark.llama_3b, id="llama_3b"),
]
def pytest_generate_tests(metafunc):
if "inference_model" in metafunc.fixturenames:
model = metafunc.config.getoption("--inference-model")
if model:
params = [pytest.param(model, id="")]
else:
params = MODEL_PARAMS
metafunc.parametrize(
"inference_model",
params,
indirect=True,
)
if "inference_stack" in metafunc.fixturenames:
metafunc.parametrize(
"inference_stack",
[
pytest.param(fixture_name, marks=getattr(pytest.mark, fixture_name))
for fixture_name in INFERENCE_FIXTURES
],
indirect=True,
)

View file

@ -22,19 +22,14 @@ from ..conftest import ProviderFixture
from ..env import get_env_or_fail
MODEL_PARAMS = [
pytest.param("Llama3.1-8B-Instruct", marks=pytest.mark.llama_8b, id="llama_8b"),
pytest.param("Llama3.2-3B-Instruct", marks=pytest.mark.llama_3b, id="llama_3b"),
]
@pytest.fixture(scope="session", params=MODEL_PARAMS)
@pytest.fixture(scope="session")
def inference_model(request):
return request.param
return request.config.getoption("--inference-model", None) or request.param
@pytest.fixture(scope="session")
def inference_meta_reference(inference_model) -> ProviderFixture:
# TODO: make this work with multiple models
return ProviderFixture(
provider=Provider(
provider_id="meta-reference",
@ -66,7 +61,7 @@ def inference_ollama(inference_model) -> ProviderFixture:
@pytest.fixture(scope="session")
def inference_fireworks(inference_model) -> ProviderFixture:
def inference_fireworks() -> ProviderFixture:
return ProviderFixture(
provider=Provider(
provider_id="fireworks",
@ -79,7 +74,7 @@ def inference_fireworks(inference_model) -> ProviderFixture:
@pytest.fixture(scope="session")
def inference_together(inference_model) -> ProviderFixture:
def inference_together() -> ProviderFixture:
return ProviderFixture(
provider=Provider(
provider_id="together",
@ -95,7 +90,7 @@ def inference_together(inference_model) -> ProviderFixture:
INFERENCE_FIXTURES = ["meta_reference", "ollama", "fireworks", "together"]
@pytest_asyncio.fixture(scope="session", params=INFERENCE_FIXTURES)
@pytest_asyncio.fixture(scope="session")
async def inference_stack(request):
fixture_name = request.param
inference_fixture = request.getfixturevalue(f"inference_{fixture_name}")

View file

@ -14,7 +14,6 @@ from llama_models.llama3.api.datatypes import * # noqa: F403
from llama_stack.apis.inference import * # noqa: F403
from llama_stack.distribution.datatypes import * # noqa: F403
from .fixtures import INFERENCE_FIXTURES, MODEL_PARAMS
# How to run this test:
#
@ -70,15 +69,6 @@ def sample_tool_definition():
)
@pytest.mark.parametrize("inference_model", MODEL_PARAMS, indirect=True)
@pytest.mark.parametrize(
"inference_stack",
[
pytest.param(fixture_name, marks=getattr(pytest.mark, fixture_name))
for fixture_name in INFERENCE_FIXTURES
],
indirect=True,
)
class TestInference:
@pytest.mark.asyncio
async def test_model_list(self, inference_model, inference_stack):

View file

@ -4,6 +4,8 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import pytest
from .fixtures import MEMORY_FIXTURES
@ -13,3 +15,15 @@ def pytest_configure(config):
"markers",
f"{fixture_name}: marks tests as {fixture_name} specific",
)
def pytest_generate_tests(metafunc):
if "memory_stack" in metafunc.fixturenames:
metafunc.parametrize(
"memory_stack",
[
pytest.param(fixture_name, marks=getattr(pytest.mark, fixture_name))
for fixture_name in MEMORY_FIXTURES
],
indirect=True,
)

View file

@ -64,16 +64,8 @@ def memory_weaviate() -> ProviderFixture:
MEMORY_FIXTURES = ["meta_reference", "pgvector", "weaviate"]
PROVIDER_PARAMS = [
pytest.param(fixture_name, marks=getattr(pytest.mark, fixture_name))
for fixture_name in MEMORY_FIXTURES
]
@pytest_asyncio.fixture(
scope="session",
params=PROVIDER_PARAMS,
)
@pytest_asyncio.fixture(scope="session")
async def memory_stack(request):
fixture_name = request.param
fixture = request.getfixturevalue(f"memory_{fixture_name}")

View file

@ -8,7 +8,6 @@ import pytest
from llama_stack.apis.memory import * # noqa: F403
from llama_stack.distribution.datatypes import * # noqa: F403
from .fixtures import PROVIDER_PARAMS
# How to run this test:
#
@ -54,11 +53,6 @@ async def register_memory_bank(banks_impl: MemoryBanks):
await banks_impl.register_memory_bank(bank)
@pytest.mark.parametrize(
"memory_stack",
PROVIDER_PARAMS,
indirect=True,
)
class TestMemory:
@pytest.mark.asyncio
async def test_banks_list(self, memory_stack):

View file

@ -48,11 +48,38 @@ def pytest_configure(config):
)
def pytest_addoption(parser):
parser.addoption(
"--safety-model",
action="store",
default=None,
help="Specify the safety model to use for testing",
)
SAFETY_MODEL_PARAMS = [
pytest.param("Llama-Guard-3-1B", marks=pytest.mark.guard_1b, id="guard_1b"),
]
def pytest_generate_tests(metafunc):
# We use this method to make sure we have built-in simple combos for safety tests
# But a user can also pass in a custom combination via the CLI by doing
# `--providers inference=together,safety=meta_reference`
if "safety_model" in metafunc.fixturenames:
model = metafunc.config.getoption("--safety-model")
if model:
params = [pytest.param(model, id="")]
else:
params = SAFETY_MODEL_PARAMS
for fixture in ["inference_model", "safety_model"]:
metafunc.parametrize(
fixture,
params,
indirect=True,
)
if "safety_stack" in metafunc.fixturenames:
# print(f"metafunc.fixturenames: {metafunc.fixturenames}, {metafunc}")
available_fixtures = {

View file

@ -20,14 +20,9 @@ from ..conftest import ProviderFixture
from ..env import get_env_or_fail
SAFETY_MODEL_PARAMS = [
pytest.param("Llama-Guard-3-1B", marks=pytest.mark.guard_1b, id="guard_1b"),
]
@pytest.fixture(scope="session", params=SAFETY_MODEL_PARAMS)
@pytest.fixture(scope="session")
def safety_model(request):
return request.param
return request.config.getoption("--safety-model", None) or request.param
@pytest.fixture(scope="session")

View file

@ -17,14 +17,6 @@ from llama_stack.distribution.datatypes import * # noqa: F403
# -m "ollama"
@pytest.mark.parametrize(
"inference_model", [pytest.param("Llama-Guard-3-1B", id="")], indirect=True
)
@pytest.mark.parametrize(
"safety_model",
[pytest.param("Llama-Guard-3-1B", id="guard_3_1b")],
indirect=True,
)
class TestSafety:
@pytest.mark.asyncio
async def test_shield_list(self, safety_stack):