mirror of
https://github.com/meta-llama/llama-stack.git
synced 2026-01-05 08:32:15 +00:00
address comments
This commit is contained in:
parent
c03f7fe9be
commit
d19526ecd7
7 changed files with 127 additions and 124 deletions
|
|
@ -28,28 +28,35 @@ class ProviderFixture(BaseModel):
|
|||
provider_data: Optional[Dict[str, Any]] = None
|
||||
|
||||
|
||||
class Fixtures(BaseModel):
|
||||
class TestScenario(BaseModel):
|
||||
# provider fixtures can be either a mark or a dictionary of api -> providers
|
||||
provider_fixtures: List[Dict[str, str]] = Field(default_factory=list)
|
||||
inference_models: List[str] = Field(default_factory=list)
|
||||
safety_shield: Optional[str] = Field(default_factory=None)
|
||||
embedding_model: Optional[str] = Field(default_factory=None)
|
||||
provider_fixtures: Dict[str, str] = Field(default_factory=dict)
|
||||
fixture_combo_id: Optional[str] = None
|
||||
|
||||
|
||||
class APITestConfig(BaseModel):
|
||||
fixtures: Fixtures
|
||||
scenarios: List[TestScenario] = Field(default_factory=list)
|
||||
inference_models: List[str] = Field(default_factory=list)
|
||||
|
||||
# test name format should be <relative_path.py>::<test_name>
|
||||
tests: List[str] = Field(default_factory=list)
|
||||
|
||||
|
||||
class MemoryApiTestConfig(APITestConfig):
|
||||
embedding_model: Optional[str] = Field(default_factory=None)
|
||||
|
||||
|
||||
class AgentsApiTestConfig(APITestConfig):
|
||||
safety_shield: Optional[str] = Field(default_factory=None)
|
||||
|
||||
|
||||
class TestConfig(BaseModel):
|
||||
inference: APITestConfig
|
||||
agent: Optional[APITestConfig] = Field(default=None)
|
||||
memory: Optional[APITestConfig] = Field(default=None)
|
||||
inference: APITestConfig = Field(default=None)
|
||||
agents: AgentsApiTestConfig = Field(default=None)
|
||||
memory: MemoryApiTestConfig = Field(default=None)
|
||||
|
||||
|
||||
def try_load_config_file_cached(config):
|
||||
def get_test_config_from_config_file(config):
|
||||
config_file = config.getoption("--config")
|
||||
if config_file is None:
|
||||
return None
|
||||
|
|
@ -64,25 +71,38 @@ def try_load_config_file_cached(config):
|
|||
return TestConfig(**config)
|
||||
|
||||
|
||||
def get_provider_fixtures_from_config(
|
||||
provider_fixtures_config, default_fixture_combination
|
||||
def get_test_config_for_api(config, api):
|
||||
test_config = get_test_config_from_config_file(config)
|
||||
if test_config is None:
|
||||
return None
|
||||
return getattr(test_config, api)
|
||||
|
||||
|
||||
def get_provider_fixture_overrides_from_test_config(
|
||||
config, api, default_provider_fixture_combination
|
||||
):
|
||||
custom_fixtures = []
|
||||
selected_default_param_id = set()
|
||||
for fixture_config in provider_fixtures_config:
|
||||
if "default_fixture_param_id" in fixture_config:
|
||||
selected_default_param_id.add(fixture_config["default_fixture_param_id"])
|
||||
api_config = get_test_config_for_api(config, api)
|
||||
if api_config is None:
|
||||
return None
|
||||
|
||||
fixture_combo_ids = set()
|
||||
custom_provider_fixture_combos = []
|
||||
for scenario in api_config.scenarios:
|
||||
if scenario.fixture_combo_id:
|
||||
fixture_combo_ids.add(scenario.fixture_combo_id)
|
||||
else:
|
||||
custom_fixtures.append(
|
||||
pytest.param(fixture_config, id=fixture_config.get("inference") or "")
|
||||
custom_provider_fixture_combos.append(
|
||||
pytest.param(
|
||||
scenario.provider_fixtures,
|
||||
id=scenario.provider_fixtures.get("inference") or "",
|
||||
)
|
||||
)
|
||||
|
||||
if len(selected_default_param_id) > 0:
|
||||
for default_fixture in default_fixture_combination:
|
||||
if default_fixture.id in selected_default_param_id:
|
||||
custom_fixtures.append(default_fixture)
|
||||
|
||||
return custom_fixtures
|
||||
if len(fixture_combo_ids) > 0:
|
||||
for default_fixture in default_provider_fixture_combination:
|
||||
if default_fixture.id in fixture_combo_ids:
|
||||
custom_provider_fixture_combos.append(default_fixture)
|
||||
return custom_provider_fixture_combos
|
||||
|
||||
|
||||
def remote_stack_fixture() -> ProviderFixture:
|
||||
|
|
@ -239,16 +259,19 @@ def pytest_itemcollected(item):
|
|||
|
||||
|
||||
def pytest_collection_modifyitems(session, config, items):
|
||||
test_config = try_load_config_file_cached(config)
|
||||
test_config = get_test_config_from_config_file(config)
|
||||
if test_config is None:
|
||||
return
|
||||
|
||||
required_tests = defaultdict(set)
|
||||
test_configs = [test_config.inference, test_config.memory, test_config.agent]
|
||||
for test_config in test_configs:
|
||||
if test_config is None:
|
||||
for api_test_config in [
|
||||
test_config.inference,
|
||||
test_config.memory,
|
||||
test_config.agents,
|
||||
]:
|
||||
if api_test_config is None:
|
||||
continue
|
||||
for test in test_config.tests:
|
||||
for test in api_test_config.tests:
|
||||
arr = test.split("::")
|
||||
if len(arr) != 2:
|
||||
raise ValueError(f"Invalid format for test name {test}")
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue