From 11b2fdd3d8bb3b20eb67db348e473286356d0f87 Mon Sep 17 00:00:00 2001 From: Sixian Yi Date: Wed, 15 Jan 2025 13:45:55 -0800 Subject: [PATCH] support non value for test config --- llama_stack/providers/tests/agents/conftest.py | 2 +- llama_stack/providers/tests/conftest.py | 2 ++ llama_stack/providers/tests/memory/conftest.py | 10 +++++++--- llama_stack/providers/tests/test_config_helper.py | 10 +++++----- 4 files changed, 15 insertions(+), 9 deletions(-) diff --git a/llama_stack/providers/tests/agents/conftest.py b/llama_stack/providers/tests/agents/conftest.py index 347c317dc..dcb6e0f3a 100644 --- a/llama_stack/providers/tests/agents/conftest.py +++ b/llama_stack/providers/tests/agents/conftest.py @@ -92,7 +92,7 @@ def pytest_generate_tests(metafunc): config_override_safety_shield, custom_provider_fixtures, ) = (None, None, None) - if test_config is not None: + if test_config is not None and test_config.agent is not None: config_override_inference_models = test_config.agent.fixtures.inference_models config_override_safety_shield = test_config.agent.fixtures.safety_shield custom_provider_fixtures = get_provider_fixtures_from_config( diff --git a/llama_stack/providers/tests/conftest.py b/llama_stack/providers/tests/conftest.py index 14c34d14e..c75dc67a5 100644 --- a/llama_stack/providers/tests/conftest.py +++ b/llama_stack/providers/tests/conftest.py @@ -189,6 +189,8 @@ def pytest_collection_modifyitems(session, config, items): required_tests = defaultdict(set) test_configs = [test_config.inference, test_config.memory, test_config.agent] for test_config in test_configs: + if test_config is None: + continue for test in test_config.tests: arr = test.split("::") if len(arr) != 2: diff --git a/llama_stack/providers/tests/memory/conftest.py b/llama_stack/providers/tests/memory/conftest.py index fef09319a..72f9bfc55 100644 --- a/llama_stack/providers/tests/memory/conftest.py +++ b/llama_stack/providers/tests/memory/conftest.py @@ -72,11 +72,15 @@ def pytest_generate_tests(metafunc): test_config = try_load_config_file_cached(metafunc.config.getoption("config")) provider_fixtures_config = ( test_config.memory.fixtures.provider_fixtures - if test_config is not None + if test_config is not None and test_config.memory is not None else None ) - custom_fixtures = get_provider_fixtures_from_config( - provider_fixtures_config, DEFAULT_PROVIDER_COMBINATIONS + custom_fixtures = ( + get_provider_fixtures_from_config( + provider_fixtures_config, DEFAULT_PROVIDER_COMBINATIONS + ) + if provider_fixtures_config is not None + else None ) if "embedding_model" in metafunc.fixturenames: model = metafunc.config.getoption("--embedding-model") diff --git a/llama_stack/providers/tests/test_config_helper.py b/llama_stack/providers/tests/test_config_helper.py index fd51953ba..c86822f0e 100644 --- a/llama_stack/providers/tests/test_config_helper.py +++ b/llama_stack/providers/tests/test_config_helper.py @@ -18,10 +18,10 @@ class APITestConfig(BaseModel): class Fixtures(BaseModel): # provider fixtures can be either a mark or a dictionary of api -> providers - provider_fixtures: List[Dict[str, str]] + provider_fixtures: List[Dict[str, str]] = Field(default_factory=list) inference_models: List[str] = Field(default_factory=list) - safety_shield: Optional[str] - embedding_model: Optional[str] + safety_shield: Optional[str] = Field(default_factory=None) + embedding_model: Optional[str] = Field(default_factory=None) fixtures: Fixtures tests: List[str] = Field(default_factory=list) @@ -32,8 +32,8 @@ class APITestConfig(BaseModel): class TestConfig(BaseModel): inference: APITestConfig - agent: APITestConfig - memory: APITestConfig + agent: Optional[APITestConfig] = Field(default=None) + memory: Optional[APITestConfig] = Field(default=None) CONFIG_CACHE = None