mirror of
https://github.com/meta-llama/llama-stack.git
synced 2026-01-01 12:50:00 +00:00
restructure config
This commit is contained in:
parent
702cf2d563
commit
26d9804efd
8 changed files with 218 additions and 62 deletions
|
|
@ -7,7 +7,7 @@
|
|||
import pytest
|
||||
|
||||
from ..conftest import get_provider_fixture_overrides
|
||||
|
||||
from ..test_config_helper import try_load_config_file_cached
|
||||
from .fixtures import INFERENCE_FIXTURES
|
||||
|
||||
|
||||
|
|
@ -43,29 +43,43 @@ VISION_MODEL_PARAMS = [
|
|||
|
||||
|
||||
def pytest_generate_tests(metafunc):
|
||||
test_config = try_load_config_file_cached(metafunc.config.getoption("config"))
|
||||
if "inference_model" in metafunc.fixturenames:
|
||||
model = metafunc.config.getoption("--inference-model")
|
||||
if model:
|
||||
params = [pytest.param(model, id="")]
|
||||
cls_name = metafunc.cls.__name__
|
||||
if test_config is not None:
|
||||
params = []
|
||||
for model in test_config.inference.fixtures.inference_models:
|
||||
if ("Vision" in cls_name and "Vision" in model) or (
|
||||
"Vision" not in cls_name and "Vision" not in model
|
||||
):
|
||||
params.append(pytest.param(model, id=model))
|
||||
else:
|
||||
cls_name = metafunc.cls.__name__
|
||||
if "Vision" in cls_name:
|
||||
params = VISION_MODEL_PARAMS
|
||||
model = metafunc.config.getoption("--inference-model")
|
||||
if model:
|
||||
params = [pytest.param(model, id="")]
|
||||
else:
|
||||
params = MODEL_PARAMS
|
||||
|
||||
if "Vision" in cls_name:
|
||||
params = VISION_MODEL_PARAMS
|
||||
else:
|
||||
params = MODEL_PARAMS
|
||||
metafunc.parametrize(
|
||||
"inference_model",
|
||||
params,
|
||||
indirect=True,
|
||||
)
|
||||
if "inference_stack" in metafunc.fixturenames:
|
||||
fixtures = INFERENCE_FIXTURES
|
||||
if filtered_stacks := get_provider_fixture_overrides(
|
||||
if test_config is not None:
|
||||
fixtures = [
|
||||
(f.get("inference") or f.get("default_fixture_param_id"))
|
||||
for f in test_config.inference.fixtures.provider_fixtures
|
||||
]
|
||||
elif filtered_stacks := get_provider_fixture_overrides(
|
||||
metafunc.config,
|
||||
{
|
||||
"inference": INFERENCE_FIXTURES,
|
||||
},
|
||||
):
|
||||
fixtures = [stack.values[0]["inference"] for stack in filtered_stacks]
|
||||
else:
|
||||
fixtures = INFERENCE_FIXTURES
|
||||
metafunc.parametrize("inference_stack", fixtures, indirect=True)
|
||||
|
|
|
|||
|
|
@ -301,6 +301,7 @@ async def inference_stack(request, inference_model):
|
|||
inference_fixture.provider_data,
|
||||
models=[
|
||||
ModelInput(
|
||||
provider_id=inference_fixture.providers[0].provider_id,
|
||||
model_id=inference_model,
|
||||
model_type=model_type,
|
||||
metadata=metadata,
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue