mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-08-05 18:22:41 +00:00
support non value for test config
This commit is contained in:
parent
26d9804efd
commit
11b2fdd3d8
4 changed files with 15 additions and 9 deletions
|
@ -92,7 +92,7 @@ def pytest_generate_tests(metafunc):
|
|||
config_override_safety_shield,
|
||||
custom_provider_fixtures,
|
||||
) = (None, None, None)
|
||||
if test_config is not None:
|
||||
if test_config is not None and test_config.agent is not None:
|
||||
config_override_inference_models = test_config.agent.fixtures.inference_models
|
||||
config_override_safety_shield = test_config.agent.fixtures.safety_shield
|
||||
custom_provider_fixtures = get_provider_fixtures_from_config(
|
||||
|
|
|
@ -189,6 +189,8 @@ def pytest_collection_modifyitems(session, config, items):
|
|||
required_tests = defaultdict(set)
|
||||
test_configs = [test_config.inference, test_config.memory, test_config.agent]
|
||||
for test_config in test_configs:
|
||||
if test_config is None:
|
||||
continue
|
||||
for test in test_config.tests:
|
||||
arr = test.split("::")
|
||||
if len(arr) != 2:
|
||||
|
|
|
@ -72,11 +72,15 @@ def pytest_generate_tests(metafunc):
|
|||
test_config = try_load_config_file_cached(metafunc.config.getoption("config"))
|
||||
provider_fixtures_config = (
|
||||
test_config.memory.fixtures.provider_fixtures
|
||||
if test_config is not None
|
||||
if test_config is not None and test_config.memory is not None
|
||||
else None
|
||||
)
|
||||
custom_fixtures = get_provider_fixtures_from_config(
|
||||
provider_fixtures_config, DEFAULT_PROVIDER_COMBINATIONS
|
||||
custom_fixtures = (
|
||||
get_provider_fixtures_from_config(
|
||||
provider_fixtures_config, DEFAULT_PROVIDER_COMBINATIONS
|
||||
)
|
||||
if provider_fixtures_config is not None
|
||||
else None
|
||||
)
|
||||
if "embedding_model" in metafunc.fixturenames:
|
||||
model = metafunc.config.getoption("--embedding-model")
|
||||
|
|
|
@ -18,10 +18,10 @@ class APITestConfig(BaseModel):
|
|||
|
||||
class Fixtures(BaseModel):
|
||||
# provider fixtures can be either a mark or a dictionary of api -> providers
|
||||
provider_fixtures: List[Dict[str, str]]
|
||||
provider_fixtures: List[Dict[str, str]] = Field(default_factory=list)
|
||||
inference_models: List[str] = Field(default_factory=list)
|
||||
safety_shield: Optional[str]
|
||||
embedding_model: Optional[str]
|
||||
safety_shield: Optional[str] = Field(default_factory=None)
|
||||
embedding_model: Optional[str] = Field(default_factory=None)
|
||||
|
||||
fixtures: Fixtures
|
||||
tests: List[str] = Field(default_factory=list)
|
||||
|
@ -32,8 +32,8 @@ class APITestConfig(BaseModel):
|
|||
class TestConfig(BaseModel):
|
||||
|
||||
inference: APITestConfig
|
||||
agent: APITestConfig
|
||||
memory: APITestConfig
|
||||
agent: Optional[APITestConfig] = Field(default=None)
|
||||
memory: Optional[APITestConfig] = Field(default=None)
|
||||
|
||||
|
||||
CONFIG_CACHE = None
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue