restructure config

This commit is contained in:
Sixian Yi 2025-01-14 16:08:20 -08:00
parent 702cf2d563
commit 26d9804efd
8 changed files with 218 additions and 62 deletions

View file

@ -7,7 +7,7 @@
import pytest
from ..conftest import get_provider_fixture_overrides
from ..test_config_helper import try_load_config_file_cached
from .fixtures import INFERENCE_FIXTURES
@ -43,29 +43,43 @@ VISION_MODEL_PARAMS = [
def pytest_generate_tests(metafunc):
test_config = try_load_config_file_cached(metafunc.config.getoption("config"))
if "inference_model" in metafunc.fixturenames:
model = metafunc.config.getoption("--inference-model")
if model:
params = [pytest.param(model, id="")]
cls_name = metafunc.cls.__name__
if test_config is not None:
params = []
for model in test_config.inference.fixtures.inference_models:
if ("Vision" in cls_name and "Vision" in model) or (
"Vision" not in cls_name and "Vision" not in model
):
params.append(pytest.param(model, id=model))
else:
cls_name = metafunc.cls.__name__
if "Vision" in cls_name:
params = VISION_MODEL_PARAMS
model = metafunc.config.getoption("--inference-model")
if model:
params = [pytest.param(model, id="")]
else:
params = MODEL_PARAMS
if "Vision" in cls_name:
params = VISION_MODEL_PARAMS
else:
params = MODEL_PARAMS
metafunc.parametrize(
"inference_model",
params,
indirect=True,
)
if "inference_stack" in metafunc.fixturenames:
fixtures = INFERENCE_FIXTURES
if filtered_stacks := get_provider_fixture_overrides(
if test_config is not None:
fixtures = [
(f.get("inference") or f.get("default_fixture_param_id"))
for f in test_config.inference.fixtures.provider_fixtures
]
elif filtered_stacks := get_provider_fixture_overrides(
metafunc.config,
{
"inference": INFERENCE_FIXTURES,
},
):
fixtures = [stack.values[0]["inference"] for stack in filtered_stacks]
else:
fixtures = INFERENCE_FIXTURES
metafunc.parametrize("inference_stack", fixtures, indirect=True)

View file

@ -301,6 +301,7 @@ async def inference_stack(request, inference_model):
inference_fixture.provider_data,
models=[
ModelInput(
provider_id=inference_fixture.providers[0].provider_id,
model_id=inference_model,
model_type=model_type,
metadata=metadata,