mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-07-30 07:39:38 +00:00
yet another refactor to make this more general
now it accepts --inference-model, --safety-model options also
This commit is contained in:
parent
2ed0267fbb
commit
60800bc09b
13 changed files with 127 additions and 61 deletions
|
@ -126,8 +126,13 @@ class CommonRoutingTableImpl(RoutingTable):
|
|||
|
||||
if routing_key not in self.registry:
|
||||
apiname, objname = apiname_object()
|
||||
provider_ids = list(self.impls_by_provider_id.keys())
|
||||
if len(provider_ids) > 1:
|
||||
provider_ids_str = f"any of the providers: {', '.join(provider_ids)}"
|
||||
else:
|
||||
provider_ids_str = f"provider: `{provider_ids[0]}`"
|
||||
raise ValueError(
|
||||
f"`{routing_key}` not registered. Make sure there is an {apiname} provider serving this {objname}."
|
||||
f"`{routing_key}` not served by {provider_ids_str}. Make sure there is an {apiname} provider serving this {objname}."
|
||||
)
|
||||
|
||||
objs = self.registry[routing_key]
|
||||
|
|
|
@ -57,7 +57,34 @@ def pytest_configure(config):
|
|||
)
|
||||
|
||||
|
||||
def pytest_addoption(parser):
|
||||
parser.addoption(
|
||||
"--inference-model",
|
||||
action="store",
|
||||
default="Llama3.1-8B-Instruct",
|
||||
help="Specify the inference model to use for testing",
|
||||
)
|
||||
parser.addoption(
|
||||
"--safety-model",
|
||||
action="store",
|
||||
default="Llama-Guard-3-8B",
|
||||
help="Specify the safety model to use for testing",
|
||||
)
|
||||
|
||||
|
||||
def pytest_generate_tests(metafunc):
|
||||
if "inference_model" in metafunc.fixturenames:
|
||||
metafunc.parametrize(
|
||||
"inference_model",
|
||||
[pytest.param(metafunc.config.getoption("--inference-model"), id="")],
|
||||
indirect=True,
|
||||
)
|
||||
if "safety_model" in metafunc.fixturenames:
|
||||
metafunc.parametrize(
|
||||
"safety_model",
|
||||
[pytest.param(metafunc.config.getoption("--safety-model"), id="")],
|
||||
indirect=True,
|
||||
)
|
||||
if "agents_stack" in metafunc.fixturenames:
|
||||
available_fixtures = {
|
||||
"inference": INFERENCE_FIXTURES,
|
||||
|
|
|
@ -22,7 +22,7 @@ from ..conftest import ProviderFixture
|
|||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def agents_meta_reference(inference_model, safety_model) -> ProviderFixture:
|
||||
def agents_meta_reference() -> ProviderFixture:
|
||||
sqlite_file = tempfile.NamedTemporaryFile(delete=False, suffix=".db")
|
||||
return ProviderFixture(
|
||||
provider=Provider(
|
||||
|
@ -42,12 +42,12 @@ AGENTS_FIXTURES = ["meta_reference"]
|
|||
|
||||
|
||||
@pytest_asyncio.fixture(scope="session")
|
||||
async def agents_stack(inference_model, safety_model, request):
|
||||
async def agents_stack(request):
|
||||
fixture_dict = request.param
|
||||
|
||||
providers = {}
|
||||
provider_data = {}
|
||||
for key in ["agents", "inference", "safety", "memory"]:
|
||||
for key in ["inference", "safety", "memory", "agents"]:
|
||||
fixture = request.getfixturevalue(f"{key}_{fixture_dict[key]}")
|
||||
providers[key] = [fixture.provider]
|
||||
if fixture.provider_data:
|
||||
|
|
|
@ -75,12 +75,6 @@ async def create_agent_session(agents_impl, agent_config):
|
|||
return agent_id, session_id
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"inference_model", [pytest.param("Llama3.1-8B-Instruct", id="")], indirect=True
|
||||
)
|
||||
@pytest.mark.parametrize(
|
||||
"safety_model", [pytest.param("Llama-Guard-3-8B", id="")], indirect=True
|
||||
)
|
||||
class TestAgents:
|
||||
@pytest.mark.asyncio
|
||||
async def test_agent_turns_with_safety(self, agents_stack, common_params):
|
||||
|
|
|
@ -4,9 +4,20 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import pytest
|
||||
|
||||
from .fixtures import INFERENCE_FIXTURES
|
||||
|
||||
|
||||
def pytest_addoption(parser):
|
||||
parser.addoption(
|
||||
"--inference-model",
|
||||
action="store",
|
||||
default=None,
|
||||
help="Specify the inference model to use for testing",
|
||||
)
|
||||
|
||||
|
||||
def pytest_configure(config):
|
||||
config.addinivalue_line(
|
||||
"markers", "llama_8b: mark test to run only with the given model"
|
||||
|
@ -19,3 +30,33 @@ def pytest_configure(config):
|
|||
"markers",
|
||||
f"{fixture_name}: marks tests as {fixture_name} specific",
|
||||
)
|
||||
|
||||
|
||||
MODEL_PARAMS = [
|
||||
pytest.param("Llama3.1-8B-Instruct", marks=pytest.mark.llama_8b, id="llama_8b"),
|
||||
pytest.param("Llama3.2-3B-Instruct", marks=pytest.mark.llama_3b, id="llama_3b"),
|
||||
]
|
||||
|
||||
|
||||
def pytest_generate_tests(metafunc):
|
||||
if "inference_model" in metafunc.fixturenames:
|
||||
model = metafunc.config.getoption("--inference-model")
|
||||
if model:
|
||||
params = [pytest.param(model, id="")]
|
||||
else:
|
||||
params = MODEL_PARAMS
|
||||
|
||||
metafunc.parametrize(
|
||||
"inference_model",
|
||||
params,
|
||||
indirect=True,
|
||||
)
|
||||
if "inference_stack" in metafunc.fixturenames:
|
||||
metafunc.parametrize(
|
||||
"inference_stack",
|
||||
[
|
||||
pytest.param(fixture_name, marks=getattr(pytest.mark, fixture_name))
|
||||
for fixture_name in INFERENCE_FIXTURES
|
||||
],
|
||||
indirect=True,
|
||||
)
|
||||
|
|
|
@ -22,19 +22,14 @@ from ..conftest import ProviderFixture
|
|||
from ..env import get_env_or_fail
|
||||
|
||||
|
||||
MODEL_PARAMS = [
|
||||
pytest.param("Llama3.1-8B-Instruct", marks=pytest.mark.llama_8b, id="llama_8b"),
|
||||
pytest.param("Llama3.2-3B-Instruct", marks=pytest.mark.llama_3b, id="llama_3b"),
|
||||
]
|
||||
|
||||
|
||||
@pytest.fixture(scope="session", params=MODEL_PARAMS)
|
||||
@pytest.fixture(scope="session")
|
||||
def inference_model(request):
|
||||
return request.param
|
||||
return request.config.getoption("--inference-model", None) or request.param
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def inference_meta_reference(inference_model) -> ProviderFixture:
|
||||
# TODO: make this work with multiple models
|
||||
return ProviderFixture(
|
||||
provider=Provider(
|
||||
provider_id="meta-reference",
|
||||
|
@ -66,7 +61,7 @@ def inference_ollama(inference_model) -> ProviderFixture:
|
|||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def inference_fireworks(inference_model) -> ProviderFixture:
|
||||
def inference_fireworks() -> ProviderFixture:
|
||||
return ProviderFixture(
|
||||
provider=Provider(
|
||||
provider_id="fireworks",
|
||||
|
@ -79,7 +74,7 @@ def inference_fireworks(inference_model) -> ProviderFixture:
|
|||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def inference_together(inference_model) -> ProviderFixture:
|
||||
def inference_together() -> ProviderFixture:
|
||||
return ProviderFixture(
|
||||
provider=Provider(
|
||||
provider_id="together",
|
||||
|
@ -95,7 +90,7 @@ def inference_together(inference_model) -> ProviderFixture:
|
|||
INFERENCE_FIXTURES = ["meta_reference", "ollama", "fireworks", "together"]
|
||||
|
||||
|
||||
@pytest_asyncio.fixture(scope="session", params=INFERENCE_FIXTURES)
|
||||
@pytest_asyncio.fixture(scope="session")
|
||||
async def inference_stack(request):
|
||||
fixture_name = request.param
|
||||
inference_fixture = request.getfixturevalue(f"inference_{fixture_name}")
|
||||
|
|
|
@ -14,7 +14,6 @@ from llama_models.llama3.api.datatypes import * # noqa: F403
|
|||
from llama_stack.apis.inference import * # noqa: F403
|
||||
|
||||
from llama_stack.distribution.datatypes import * # noqa: F403
|
||||
from .fixtures import INFERENCE_FIXTURES, MODEL_PARAMS
|
||||
|
||||
# How to run this test:
|
||||
#
|
||||
|
@ -70,15 +69,6 @@ def sample_tool_definition():
|
|||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("inference_model", MODEL_PARAMS, indirect=True)
|
||||
@pytest.mark.parametrize(
|
||||
"inference_stack",
|
||||
[
|
||||
pytest.param(fixture_name, marks=getattr(pytest.mark, fixture_name))
|
||||
for fixture_name in INFERENCE_FIXTURES
|
||||
],
|
||||
indirect=True,
|
||||
)
|
||||
class TestInference:
|
||||
@pytest.mark.asyncio
|
||||
async def test_model_list(self, inference_model, inference_stack):
|
||||
|
|
|
@ -4,6 +4,8 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import pytest
|
||||
|
||||
from .fixtures import MEMORY_FIXTURES
|
||||
|
||||
|
||||
|
@ -13,3 +15,15 @@ def pytest_configure(config):
|
|||
"markers",
|
||||
f"{fixture_name}: marks tests as {fixture_name} specific",
|
||||
)
|
||||
|
||||
|
||||
def pytest_generate_tests(metafunc):
|
||||
if "memory_stack" in metafunc.fixturenames:
|
||||
metafunc.parametrize(
|
||||
"memory_stack",
|
||||
[
|
||||
pytest.param(fixture_name, marks=getattr(pytest.mark, fixture_name))
|
||||
for fixture_name in MEMORY_FIXTURES
|
||||
],
|
||||
indirect=True,
|
||||
)
|
||||
|
|
|
@ -64,16 +64,8 @@ def memory_weaviate() -> ProviderFixture:
|
|||
|
||||
MEMORY_FIXTURES = ["meta_reference", "pgvector", "weaviate"]
|
||||
|
||||
PROVIDER_PARAMS = [
|
||||
pytest.param(fixture_name, marks=getattr(pytest.mark, fixture_name))
|
||||
for fixture_name in MEMORY_FIXTURES
|
||||
]
|
||||
|
||||
|
||||
@pytest_asyncio.fixture(
|
||||
scope="session",
|
||||
params=PROVIDER_PARAMS,
|
||||
)
|
||||
@pytest_asyncio.fixture(scope="session")
|
||||
async def memory_stack(request):
|
||||
fixture_name = request.param
|
||||
fixture = request.getfixturevalue(f"memory_{fixture_name}")
|
||||
|
|
|
@ -8,7 +8,6 @@ import pytest
|
|||
|
||||
from llama_stack.apis.memory import * # noqa: F403
|
||||
from llama_stack.distribution.datatypes import * # noqa: F403
|
||||
from .fixtures import PROVIDER_PARAMS
|
||||
|
||||
# How to run this test:
|
||||
#
|
||||
|
@ -54,11 +53,6 @@ async def register_memory_bank(banks_impl: MemoryBanks):
|
|||
await banks_impl.register_memory_bank(bank)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"memory_stack",
|
||||
PROVIDER_PARAMS,
|
||||
indirect=True,
|
||||
)
|
||||
class TestMemory:
|
||||
@pytest.mark.asyncio
|
||||
async def test_banks_list(self, memory_stack):
|
||||
|
|
|
@ -48,11 +48,38 @@ def pytest_configure(config):
|
|||
)
|
||||
|
||||
|
||||
def pytest_addoption(parser):
|
||||
parser.addoption(
|
||||
"--safety-model",
|
||||
action="store",
|
||||
default=None,
|
||||
help="Specify the safety model to use for testing",
|
||||
)
|
||||
|
||||
|
||||
SAFETY_MODEL_PARAMS = [
|
||||
pytest.param("Llama-Guard-3-1B", marks=pytest.mark.guard_1b, id="guard_1b"),
|
||||
]
|
||||
|
||||
|
||||
def pytest_generate_tests(metafunc):
|
||||
# We use this method to make sure we have built-in simple combos for safety tests
|
||||
# But a user can also pass in a custom combination via the CLI by doing
|
||||
# `--providers inference=together,safety=meta_reference`
|
||||
|
||||
if "safety_model" in metafunc.fixturenames:
|
||||
model = metafunc.config.getoption("--safety-model")
|
||||
if model:
|
||||
params = [pytest.param(model, id="")]
|
||||
else:
|
||||
params = SAFETY_MODEL_PARAMS
|
||||
for fixture in ["inference_model", "safety_model"]:
|
||||
metafunc.parametrize(
|
||||
fixture,
|
||||
params,
|
||||
indirect=True,
|
||||
)
|
||||
|
||||
if "safety_stack" in metafunc.fixturenames:
|
||||
# print(f"metafunc.fixturenames: {metafunc.fixturenames}, {metafunc}")
|
||||
available_fixtures = {
|
||||
|
|
|
@ -20,14 +20,9 @@ from ..conftest import ProviderFixture
|
|||
from ..env import get_env_or_fail
|
||||
|
||||
|
||||
SAFETY_MODEL_PARAMS = [
|
||||
pytest.param("Llama-Guard-3-1B", marks=pytest.mark.guard_1b, id="guard_1b"),
|
||||
]
|
||||
|
||||
|
||||
@pytest.fixture(scope="session", params=SAFETY_MODEL_PARAMS)
|
||||
@pytest.fixture(scope="session")
|
||||
def safety_model(request):
|
||||
return request.param
|
||||
return request.config.getoption("--safety-model", None) or request.param
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
|
|
|
@ -17,14 +17,6 @@ from llama_stack.distribution.datatypes import * # noqa: F403
|
|||
# -m "ollama"
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"inference_model", [pytest.param("Llama-Guard-3-1B", id="")], indirect=True
|
||||
)
|
||||
@pytest.mark.parametrize(
|
||||
"safety_model",
|
||||
[pytest.param("Llama-Guard-3-1B", id="guard_3_1b")],
|
||||
indirect=True,
|
||||
)
|
||||
class TestSafety:
|
||||
@pytest.mark.asyncio
|
||||
async def test_shield_list(self, safety_stack):
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue