mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-07-30 23:51:00 +00:00
yet another refactor to make this more general
now it accepts --inference-model, --safety-model options also
This commit is contained in:
parent
2ed0267fbb
commit
60800bc09b
13 changed files with 127 additions and 61 deletions
|
@ -126,8 +126,13 @@ class CommonRoutingTableImpl(RoutingTable):
|
||||||
|
|
||||||
if routing_key not in self.registry:
|
if routing_key not in self.registry:
|
||||||
apiname, objname = apiname_object()
|
apiname, objname = apiname_object()
|
||||||
|
provider_ids = list(self.impls_by_provider_id.keys())
|
||||||
|
if len(provider_ids) > 1:
|
||||||
|
provider_ids_str = f"any of the providers: {', '.join(provider_ids)}"
|
||||||
|
else:
|
||||||
|
provider_ids_str = f"provider: `{provider_ids[0]}`"
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"`{routing_key}` not registered. Make sure there is an {apiname} provider serving this {objname}."
|
f"`{routing_key}` not served by {provider_ids_str}. Make sure there is an {apiname} provider serving this {objname}."
|
||||||
)
|
)
|
||||||
|
|
||||||
objs = self.registry[routing_key]
|
objs = self.registry[routing_key]
|
||||||
|
|
|
@ -57,7 +57,34 @@ def pytest_configure(config):
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def pytest_addoption(parser):
|
||||||
|
parser.addoption(
|
||||||
|
"--inference-model",
|
||||||
|
action="store",
|
||||||
|
default="Llama3.1-8B-Instruct",
|
||||||
|
help="Specify the inference model to use for testing",
|
||||||
|
)
|
||||||
|
parser.addoption(
|
||||||
|
"--safety-model",
|
||||||
|
action="store",
|
||||||
|
default="Llama-Guard-3-8B",
|
||||||
|
help="Specify the safety model to use for testing",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def pytest_generate_tests(metafunc):
|
def pytest_generate_tests(metafunc):
|
||||||
|
if "inference_model" in metafunc.fixturenames:
|
||||||
|
metafunc.parametrize(
|
||||||
|
"inference_model",
|
||||||
|
[pytest.param(metafunc.config.getoption("--inference-model"), id="")],
|
||||||
|
indirect=True,
|
||||||
|
)
|
||||||
|
if "safety_model" in metafunc.fixturenames:
|
||||||
|
metafunc.parametrize(
|
||||||
|
"safety_model",
|
||||||
|
[pytest.param(metafunc.config.getoption("--safety-model"), id="")],
|
||||||
|
indirect=True,
|
||||||
|
)
|
||||||
if "agents_stack" in metafunc.fixturenames:
|
if "agents_stack" in metafunc.fixturenames:
|
||||||
available_fixtures = {
|
available_fixtures = {
|
||||||
"inference": INFERENCE_FIXTURES,
|
"inference": INFERENCE_FIXTURES,
|
||||||
|
|
|
@ -22,7 +22,7 @@ from ..conftest import ProviderFixture
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="session")
|
@pytest.fixture(scope="session")
|
||||||
def agents_meta_reference(inference_model, safety_model) -> ProviderFixture:
|
def agents_meta_reference() -> ProviderFixture:
|
||||||
sqlite_file = tempfile.NamedTemporaryFile(delete=False, suffix=".db")
|
sqlite_file = tempfile.NamedTemporaryFile(delete=False, suffix=".db")
|
||||||
return ProviderFixture(
|
return ProviderFixture(
|
||||||
provider=Provider(
|
provider=Provider(
|
||||||
|
@ -42,12 +42,12 @@ AGENTS_FIXTURES = ["meta_reference"]
|
||||||
|
|
||||||
|
|
||||||
@pytest_asyncio.fixture(scope="session")
|
@pytest_asyncio.fixture(scope="session")
|
||||||
async def agents_stack(inference_model, safety_model, request):
|
async def agents_stack(request):
|
||||||
fixture_dict = request.param
|
fixture_dict = request.param
|
||||||
|
|
||||||
providers = {}
|
providers = {}
|
||||||
provider_data = {}
|
provider_data = {}
|
||||||
for key in ["agents", "inference", "safety", "memory"]:
|
for key in ["inference", "safety", "memory", "agents"]:
|
||||||
fixture = request.getfixturevalue(f"{key}_{fixture_dict[key]}")
|
fixture = request.getfixturevalue(f"{key}_{fixture_dict[key]}")
|
||||||
providers[key] = [fixture.provider]
|
providers[key] = [fixture.provider]
|
||||||
if fixture.provider_data:
|
if fixture.provider_data:
|
||||||
|
|
|
@ -75,12 +75,6 @@ async def create_agent_session(agents_impl, agent_config):
|
||||||
return agent_id, session_id
|
return agent_id, session_id
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
|
||||||
"inference_model", [pytest.param("Llama3.1-8B-Instruct", id="")], indirect=True
|
|
||||||
)
|
|
||||||
@pytest.mark.parametrize(
|
|
||||||
"safety_model", [pytest.param("Llama-Guard-3-8B", id="")], indirect=True
|
|
||||||
)
|
|
||||||
class TestAgents:
|
class TestAgents:
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_agent_turns_with_safety(self, agents_stack, common_params):
|
async def test_agent_turns_with_safety(self, agents_stack, common_params):
|
||||||
|
|
|
@ -4,9 +4,20 @@
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
from .fixtures import INFERENCE_FIXTURES
|
from .fixtures import INFERENCE_FIXTURES
|
||||||
|
|
||||||
|
|
||||||
|
def pytest_addoption(parser):
|
||||||
|
parser.addoption(
|
||||||
|
"--inference-model",
|
||||||
|
action="store",
|
||||||
|
default=None,
|
||||||
|
help="Specify the inference model to use for testing",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def pytest_configure(config):
|
def pytest_configure(config):
|
||||||
config.addinivalue_line(
|
config.addinivalue_line(
|
||||||
"markers", "llama_8b: mark test to run only with the given model"
|
"markers", "llama_8b: mark test to run only with the given model"
|
||||||
|
@ -19,3 +30,33 @@ def pytest_configure(config):
|
||||||
"markers",
|
"markers",
|
||||||
f"{fixture_name}: marks tests as {fixture_name} specific",
|
f"{fixture_name}: marks tests as {fixture_name} specific",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
MODEL_PARAMS = [
|
||||||
|
pytest.param("Llama3.1-8B-Instruct", marks=pytest.mark.llama_8b, id="llama_8b"),
|
||||||
|
pytest.param("Llama3.2-3B-Instruct", marks=pytest.mark.llama_3b, id="llama_3b"),
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def pytest_generate_tests(metafunc):
|
||||||
|
if "inference_model" in metafunc.fixturenames:
|
||||||
|
model = metafunc.config.getoption("--inference-model")
|
||||||
|
if model:
|
||||||
|
params = [pytest.param(model, id="")]
|
||||||
|
else:
|
||||||
|
params = MODEL_PARAMS
|
||||||
|
|
||||||
|
metafunc.parametrize(
|
||||||
|
"inference_model",
|
||||||
|
params,
|
||||||
|
indirect=True,
|
||||||
|
)
|
||||||
|
if "inference_stack" in metafunc.fixturenames:
|
||||||
|
metafunc.parametrize(
|
||||||
|
"inference_stack",
|
||||||
|
[
|
||||||
|
pytest.param(fixture_name, marks=getattr(pytest.mark, fixture_name))
|
||||||
|
for fixture_name in INFERENCE_FIXTURES
|
||||||
|
],
|
||||||
|
indirect=True,
|
||||||
|
)
|
||||||
|
|
|
@ -22,19 +22,14 @@ from ..conftest import ProviderFixture
|
||||||
from ..env import get_env_or_fail
|
from ..env import get_env_or_fail
|
||||||
|
|
||||||
|
|
||||||
MODEL_PARAMS = [
|
@pytest.fixture(scope="session")
|
||||||
pytest.param("Llama3.1-8B-Instruct", marks=pytest.mark.llama_8b, id="llama_8b"),
|
|
||||||
pytest.param("Llama3.2-3B-Instruct", marks=pytest.mark.llama_3b, id="llama_3b"),
|
|
||||||
]
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="session", params=MODEL_PARAMS)
|
|
||||||
def inference_model(request):
|
def inference_model(request):
|
||||||
return request.param
|
return request.config.getoption("--inference-model", None) or request.param
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="session")
|
@pytest.fixture(scope="session")
|
||||||
def inference_meta_reference(inference_model) -> ProviderFixture:
|
def inference_meta_reference(inference_model) -> ProviderFixture:
|
||||||
|
# TODO: make this work with multiple models
|
||||||
return ProviderFixture(
|
return ProviderFixture(
|
||||||
provider=Provider(
|
provider=Provider(
|
||||||
provider_id="meta-reference",
|
provider_id="meta-reference",
|
||||||
|
@ -66,7 +61,7 @@ def inference_ollama(inference_model) -> ProviderFixture:
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="session")
|
@pytest.fixture(scope="session")
|
||||||
def inference_fireworks(inference_model) -> ProviderFixture:
|
def inference_fireworks() -> ProviderFixture:
|
||||||
return ProviderFixture(
|
return ProviderFixture(
|
||||||
provider=Provider(
|
provider=Provider(
|
||||||
provider_id="fireworks",
|
provider_id="fireworks",
|
||||||
|
@ -79,7 +74,7 @@ def inference_fireworks(inference_model) -> ProviderFixture:
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="session")
|
@pytest.fixture(scope="session")
|
||||||
def inference_together(inference_model) -> ProviderFixture:
|
def inference_together() -> ProviderFixture:
|
||||||
return ProviderFixture(
|
return ProviderFixture(
|
||||||
provider=Provider(
|
provider=Provider(
|
||||||
provider_id="together",
|
provider_id="together",
|
||||||
|
@ -95,7 +90,7 @@ def inference_together(inference_model) -> ProviderFixture:
|
||||||
INFERENCE_FIXTURES = ["meta_reference", "ollama", "fireworks", "together"]
|
INFERENCE_FIXTURES = ["meta_reference", "ollama", "fireworks", "together"]
|
||||||
|
|
||||||
|
|
||||||
@pytest_asyncio.fixture(scope="session", params=INFERENCE_FIXTURES)
|
@pytest_asyncio.fixture(scope="session")
|
||||||
async def inference_stack(request):
|
async def inference_stack(request):
|
||||||
fixture_name = request.param
|
fixture_name = request.param
|
||||||
inference_fixture = request.getfixturevalue(f"inference_{fixture_name}")
|
inference_fixture = request.getfixturevalue(f"inference_{fixture_name}")
|
||||||
|
|
|
@ -14,7 +14,6 @@ from llama_models.llama3.api.datatypes import * # noqa: F403
|
||||||
from llama_stack.apis.inference import * # noqa: F403
|
from llama_stack.apis.inference import * # noqa: F403
|
||||||
|
|
||||||
from llama_stack.distribution.datatypes import * # noqa: F403
|
from llama_stack.distribution.datatypes import * # noqa: F403
|
||||||
from .fixtures import INFERENCE_FIXTURES, MODEL_PARAMS
|
|
||||||
|
|
||||||
# How to run this test:
|
# How to run this test:
|
||||||
#
|
#
|
||||||
|
@ -70,15 +69,6 @@ def sample_tool_definition():
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("inference_model", MODEL_PARAMS, indirect=True)
|
|
||||||
@pytest.mark.parametrize(
|
|
||||||
"inference_stack",
|
|
||||||
[
|
|
||||||
pytest.param(fixture_name, marks=getattr(pytest.mark, fixture_name))
|
|
||||||
for fixture_name in INFERENCE_FIXTURES
|
|
||||||
],
|
|
||||||
indirect=True,
|
|
||||||
)
|
|
||||||
class TestInference:
|
class TestInference:
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_model_list(self, inference_model, inference_stack):
|
async def test_model_list(self, inference_model, inference_stack):
|
||||||
|
|
|
@ -4,6 +4,8 @@
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
from .fixtures import MEMORY_FIXTURES
|
from .fixtures import MEMORY_FIXTURES
|
||||||
|
|
||||||
|
|
||||||
|
@ -13,3 +15,15 @@ def pytest_configure(config):
|
||||||
"markers",
|
"markers",
|
||||||
f"{fixture_name}: marks tests as {fixture_name} specific",
|
f"{fixture_name}: marks tests as {fixture_name} specific",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def pytest_generate_tests(metafunc):
|
||||||
|
if "memory_stack" in metafunc.fixturenames:
|
||||||
|
metafunc.parametrize(
|
||||||
|
"memory_stack",
|
||||||
|
[
|
||||||
|
pytest.param(fixture_name, marks=getattr(pytest.mark, fixture_name))
|
||||||
|
for fixture_name in MEMORY_FIXTURES
|
||||||
|
],
|
||||||
|
indirect=True,
|
||||||
|
)
|
||||||
|
|
|
@ -64,16 +64,8 @@ def memory_weaviate() -> ProviderFixture:
|
||||||
|
|
||||||
MEMORY_FIXTURES = ["meta_reference", "pgvector", "weaviate"]
|
MEMORY_FIXTURES = ["meta_reference", "pgvector", "weaviate"]
|
||||||
|
|
||||||
PROVIDER_PARAMS = [
|
|
||||||
pytest.param(fixture_name, marks=getattr(pytest.mark, fixture_name))
|
|
||||||
for fixture_name in MEMORY_FIXTURES
|
|
||||||
]
|
|
||||||
|
|
||||||
|
@pytest_asyncio.fixture(scope="session")
|
||||||
@pytest_asyncio.fixture(
|
|
||||||
scope="session",
|
|
||||||
params=PROVIDER_PARAMS,
|
|
||||||
)
|
|
||||||
async def memory_stack(request):
|
async def memory_stack(request):
|
||||||
fixture_name = request.param
|
fixture_name = request.param
|
||||||
fixture = request.getfixturevalue(f"memory_{fixture_name}")
|
fixture = request.getfixturevalue(f"memory_{fixture_name}")
|
||||||
|
|
|
@ -8,7 +8,6 @@ import pytest
|
||||||
|
|
||||||
from llama_stack.apis.memory import * # noqa: F403
|
from llama_stack.apis.memory import * # noqa: F403
|
||||||
from llama_stack.distribution.datatypes import * # noqa: F403
|
from llama_stack.distribution.datatypes import * # noqa: F403
|
||||||
from .fixtures import PROVIDER_PARAMS
|
|
||||||
|
|
||||||
# How to run this test:
|
# How to run this test:
|
||||||
#
|
#
|
||||||
|
@ -54,11 +53,6 @@ async def register_memory_bank(banks_impl: MemoryBanks):
|
||||||
await banks_impl.register_memory_bank(bank)
|
await banks_impl.register_memory_bank(bank)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
|
||||||
"memory_stack",
|
|
||||||
PROVIDER_PARAMS,
|
|
||||||
indirect=True,
|
|
||||||
)
|
|
||||||
class TestMemory:
|
class TestMemory:
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_banks_list(self, memory_stack):
|
async def test_banks_list(self, memory_stack):
|
||||||
|
|
|
@ -48,11 +48,38 @@ def pytest_configure(config):
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def pytest_addoption(parser):
|
||||||
|
parser.addoption(
|
||||||
|
"--safety-model",
|
||||||
|
action="store",
|
||||||
|
default=None,
|
||||||
|
help="Specify the safety model to use for testing",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
SAFETY_MODEL_PARAMS = [
|
||||||
|
pytest.param("Llama-Guard-3-1B", marks=pytest.mark.guard_1b, id="guard_1b"),
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
def pytest_generate_tests(metafunc):
|
def pytest_generate_tests(metafunc):
|
||||||
# We use this method to make sure we have built-in simple combos for safety tests
|
# We use this method to make sure we have built-in simple combos for safety tests
|
||||||
# But a user can also pass in a custom combination via the CLI by doing
|
# But a user can also pass in a custom combination via the CLI by doing
|
||||||
# `--providers inference=together,safety=meta_reference`
|
# `--providers inference=together,safety=meta_reference`
|
||||||
|
|
||||||
|
if "safety_model" in metafunc.fixturenames:
|
||||||
|
model = metafunc.config.getoption("--safety-model")
|
||||||
|
if model:
|
||||||
|
params = [pytest.param(model, id="")]
|
||||||
|
else:
|
||||||
|
params = SAFETY_MODEL_PARAMS
|
||||||
|
for fixture in ["inference_model", "safety_model"]:
|
||||||
|
metafunc.parametrize(
|
||||||
|
fixture,
|
||||||
|
params,
|
||||||
|
indirect=True,
|
||||||
|
)
|
||||||
|
|
||||||
if "safety_stack" in metafunc.fixturenames:
|
if "safety_stack" in metafunc.fixturenames:
|
||||||
# print(f"metafunc.fixturenames: {metafunc.fixturenames}, {metafunc}")
|
# print(f"metafunc.fixturenames: {metafunc.fixturenames}, {metafunc}")
|
||||||
available_fixtures = {
|
available_fixtures = {
|
||||||
|
|
|
@ -20,14 +20,9 @@ from ..conftest import ProviderFixture
|
||||||
from ..env import get_env_or_fail
|
from ..env import get_env_or_fail
|
||||||
|
|
||||||
|
|
||||||
SAFETY_MODEL_PARAMS = [
|
@pytest.fixture(scope="session")
|
||||||
pytest.param("Llama-Guard-3-1B", marks=pytest.mark.guard_1b, id="guard_1b"),
|
|
||||||
]
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="session", params=SAFETY_MODEL_PARAMS)
|
|
||||||
def safety_model(request):
|
def safety_model(request):
|
||||||
return request.param
|
return request.config.getoption("--safety-model", None) or request.param
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="session")
|
@pytest.fixture(scope="session")
|
||||||
|
|
|
@ -17,14 +17,6 @@ from llama_stack.distribution.datatypes import * # noqa: F403
|
||||||
# -m "ollama"
|
# -m "ollama"
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
|
||||||
"inference_model", [pytest.param("Llama-Guard-3-1B", id="")], indirect=True
|
|
||||||
)
|
|
||||||
@pytest.mark.parametrize(
|
|
||||||
"safety_model",
|
|
||||||
[pytest.param("Llama-Guard-3-1B", id="guard_3_1b")],
|
|
||||||
indirect=True,
|
|
||||||
)
|
|
||||||
class TestSafety:
|
class TestSafety:
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_shield_list(self, safety_stack):
|
async def test_shield_list(self, safety_stack):
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue