fix report to at least not barf

This commit is contained in:
Ashwin Bharambe 2025-03-05 15:55:25 -08:00
parent 4f82d361a8
commit f1e4588b0a
2 changed files with 17 additions and 17 deletions

View file

@ -19,21 +19,13 @@ def pytest_configure(config):
load_dotenv()
# Load any environment variables passed via --env
env_vars = config.getoption("--env") or []
for env_var in env_vars:
key, value = env_var.split("=", 1)
os.environ[key] = value
# Note:
# if report_path is not provided (aka no option --report in the pytest command),
# it will be set to False
# if --report will give None ( in this case we infer report_path)
# if --report /a/b is provided, it will be set to the path provided
# We want to handle all these cases and hence explicitly check for False
report_path = config.getoption("--report")
if report_path is not False:
config.pluginmanager.register(Report(report_path))
if config.getoption("--report"):
config.pluginmanager.register(Report(config))
def pytest_addoption(parser):

View file

@ -6,7 +6,6 @@
from collections import defaultdict
from typing import Optional
import pytest
from pytest import CollectReport
@ -63,7 +62,17 @@ SUPPORTED_MODELS = {
class Report:
def __init__(self, report_path: Optional[str] = None):
def __init__(self, config):
self.distro_name = None
self.config = config
stack_config = self.config.getoption("--stack-config")
if stack_config:
is_url = stack_config.startswith("http") or "//" in stack_config
is_yaml = stack_config.endswith(".yaml")
if not is_url and not is_yaml:
self.distro_name = stack_config
self.report_data = defaultdict(dict)
# test function -> test nodeid
self.test_data = dict()
@ -83,8 +92,8 @@ class Report:
self.test_data[report.nodeid] = outcome
def pytest_sessionfinish(self, session):
# disabled
return
if not self.client:
return
report = []
report.append(f"# Report for {self.distro_name} distribution")
@ -181,9 +190,8 @@ class Report:
vision_model = model_id.split("/")[1]
self.vision_model_id = self.vision_model_id or vision_model
if self.client is None and "llama_stack_client" in item.funcargs:
self.client = item.funcargs["llama_stack_client"]
self.distro_name = self.distro_name or self.client.async_client.config.image_name
if not self.client:
self.client = item.funcargs.get("llama_stack_client")
def _print_result_icon(self, result):
if result == "Passed":