mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-08-05 18:22:41 +00:00
update
This commit is contained in:
parent
b22cc3e1fa
commit
9b709005fa
5 changed files with 46 additions and 92 deletions
|
@ -9,7 +9,9 @@ import pytest
|
|||
from llama_stack.apis.agents import AgentConfig, Turn
|
||||
from llama_stack.apis.inference import SamplingParams, UserMessage
|
||||
from llama_stack.providers.datatypes import Api
|
||||
from llama_stack.providers.utils.kvstore import kvstore_impl, SqliteKVStoreConfig
|
||||
from llama_stack.providers.utils.kvstore import kvstore_impl
|
||||
from llama_stack.providers.utils.kvstore.config import SqliteKVStoreConfig
|
||||
|
||||
from .fixtures import pick_inference_model
|
||||
|
||||
from .utils import create_agent_session
|
||||
|
|
|
@ -6,6 +6,7 @@
|
|||
|
||||
import os
|
||||
from collections import defaultdict
|
||||
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
|
@ -13,15 +14,13 @@ import pytest
|
|||
import yaml
|
||||
|
||||
from dotenv import load_dotenv
|
||||
|
||||
from llama_stack.distribution.datatypes import Provider
|
||||
from llama_stack.providers.datatypes import RemoteProviderConfig
|
||||
from pydantic import BaseModel, Field
|
||||
from termcolor import colored
|
||||
|
||||
from .env import get_env_or_fail
|
||||
from llama_stack.distribution.datatypes import Provider
|
||||
from llama_stack.providers.datatypes import RemoteProviderConfig
|
||||
|
||||
from .test_config_helper import try_load_config_file_cached
|
||||
from .env import get_env_or_fail
|
||||
from .report import Report
|
||||
|
||||
|
||||
|
@ -142,8 +141,8 @@ def pytest_configure(config):
|
|||
key, value = env_var.split("=", 1)
|
||||
os.environ[key] = value
|
||||
|
||||
if config.getoption("--config") is not None:
|
||||
config.pluginmanager.register(Report(config))
|
||||
if config.getoption("--output") is not None:
|
||||
config.pluginmanager.register(Report(config.getoption("--output")))
|
||||
|
||||
|
||||
def pytest_addoption(parser):
|
||||
|
@ -160,6 +159,11 @@ def pytest_addoption(parser):
|
|||
action="store",
|
||||
help="Set test config file (supported format: YAML), e.g. --config=test_config.yml",
|
||||
)
|
||||
parser.addoption(
|
||||
"--output",
|
||||
action="store",
|
||||
help="Set output file for test report, e.g. --output=pytest_report.md",
|
||||
)
|
||||
"""Add custom command line options"""
|
||||
parser.addoption(
|
||||
"--env", action="append", help="Set environment variables, e.g. --env KEY=value"
|
||||
|
@ -269,9 +273,14 @@ def pytest_collection_modifyitems(session, config, items):
|
|||
return
|
||||
|
||||
required_tests = defaultdict(set)
|
||||
test_configs = [test_config.inference, test_config.memory, test_config.agent]
|
||||
for test_config in test_configs:
|
||||
for test in test_config.tests:
|
||||
for api_test_config in [
|
||||
test_config.inference,
|
||||
test_config.memory,
|
||||
test_config.agents,
|
||||
]:
|
||||
if api_test_config is None:
|
||||
continue
|
||||
for test in api_test_config.tests:
|
||||
arr = test.split("::")
|
||||
if len(arr) != 2:
|
||||
raise ValueError(f"Invalid format for test name {test}")
|
||||
|
|
|
@ -5,7 +5,7 @@
|
|||
# the root directory of this source tree.
|
||||
import pytest
|
||||
|
||||
from llama_stack.apis.common.type_system import JobStatus
|
||||
from llama_stack.apis.common.job_types import JobStatus
|
||||
from llama_stack.apis.post_training import (
|
||||
Checkpoint,
|
||||
DataConfig,
|
||||
|
|
|
@ -11,6 +11,7 @@ from pathlib import Path
|
|||
import pytest
|
||||
from llama_models.datatypes import CoreModelId
|
||||
from llama_models.sku_list import all_registered_models
|
||||
from pytest import ExitCode
|
||||
|
||||
from pytest_html.basereport import _process_outcome
|
||||
|
||||
|
@ -71,11 +72,22 @@ SUPPORTED_MODELS = {
|
|||
|
||||
class Report:
|
||||
|
||||
def __init__(self, _config):
|
||||
def __init__(self, output_path):
|
||||
|
||||
valid_file_format = (
|
||||
output_path.split(".")[1] in ["md", "markdown"]
|
||||
if len(output_path.split(".")) == 2
|
||||
else False
|
||||
)
|
||||
if not valid_file_format:
|
||||
raise ValueError(
|
||||
f"Invalid output file {output_path}. Markdown file is required"
|
||||
)
|
||||
self.output_path = output_path
|
||||
self.test_data = defaultdict(dict)
|
||||
self.inference_tests = defaultdict(dict)
|
||||
|
||||
@pytest.hookimpl(tryfirst=True)
|
||||
@pytest.hookimpl
|
||||
def pytest_runtest_logreport(self, report):
|
||||
# This hook is called in several phases, including setup, call and teardown
|
||||
# The test is considered failed / error if any of the outcomes is not "Passed"
|
||||
|
@ -91,7 +103,9 @@ class Report:
|
|||
self.test_data[report.nodeid] = data
|
||||
|
||||
@pytest.hookimpl
|
||||
def pytest_sessionfinish(self, session):
|
||||
def pytest_sessionfinish(self, session, exitstatus):
|
||||
if exitstatus <= ExitCode.INTERRUPTED:
|
||||
return
|
||||
report = []
|
||||
report.append("# Llama Stack Integration Test Results Report")
|
||||
report.append("\n## Summary")
|
||||
|
@ -108,6 +122,11 @@ class Report:
|
|||
|
||||
rows = []
|
||||
for model in all_registered_models():
|
||||
if (
|
||||
"Instruct" not in model.core_model_id.value
|
||||
and "Guard" not in model.core_model_id.value
|
||||
):
|
||||
continue
|
||||
row = f"| {model.core_model_id.value} |"
|
||||
for k in SUPPORTED_MODELS.keys():
|
||||
if model.core_model_id.value in SUPPORTED_MODELS[k]:
|
||||
|
@ -149,7 +168,7 @@ class Report:
|
|||
report.extend(test_table)
|
||||
report.append("\n")
|
||||
|
||||
output_file = Path("pytest_report.md")
|
||||
output_file = Path(self.output_path)
|
||||
output_file.write_text("\n".join(report))
|
||||
print(f"\n Report generated: {output_file.absolute()}")
|
||||
|
||||
|
|
|
@ -1,76 +0,0 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
from typing import Dict, List, Optional
|
||||
|
||||
import pytest
|
||||
import yaml
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
@dataclass
|
||||
class APITestConfig(BaseModel):
|
||||
|
||||
class Fixtures(BaseModel):
|
||||
# provider fixtures can be either a mark or a dictionary of api -> providers
|
||||
provider_fixtures: List[Dict[str, str]] = Field(default_factory=list)
|
||||
inference_models: List[str] = Field(default_factory=list)
|
||||
safety_shield: Optional[str] = Field(default_factory=None)
|
||||
embedding_model: Optional[str] = Field(default_factory=None)
|
||||
|
||||
fixtures: Fixtures
|
||||
tests: List[str] = Field(default_factory=list)
|
||||
|
||||
# test name format should be <relative_path.py>::<test_name>
|
||||
|
||||
|
||||
class TestConfig(BaseModel):
|
||||
|
||||
inference: APITestConfig
|
||||
agent: Optional[APITestConfig] = Field(default=None)
|
||||
memory: Optional[APITestConfig] = Field(default=None)
|
||||
|
||||
|
||||
CONFIG_CACHE = None
|
||||
|
||||
|
||||
def try_load_config_file_cached(config_file):
|
||||
if config_file is None:
|
||||
return None
|
||||
if CONFIG_CACHE is not None:
|
||||
return CONFIG_CACHE
|
||||
|
||||
config_file_path = Path(__file__).parent / config_file
|
||||
if not config_file_path.exists():
|
||||
raise ValueError(
|
||||
f"Test config {config_file} was specified but not found. Please make sure it exists in the llama_stack/providers/tests directory."
|
||||
)
|
||||
with open(config_file_path, "r") as config_file:
|
||||
config = yaml.safe_load(config_file)
|
||||
return TestConfig(**config)
|
||||
|
||||
|
||||
def get_provider_fixtures_from_config(
|
||||
provider_fixtures_config, default_fixture_combination
|
||||
):
|
||||
custom_fixtures = []
|
||||
selected_default_param_id = set()
|
||||
for fixture_config in provider_fixtures_config:
|
||||
if "default_fixture_param_id" in fixture_config:
|
||||
selected_default_param_id.add(fixture_config["default_fixture_param_id"])
|
||||
else:
|
||||
custom_fixtures.append(
|
||||
pytest.param(fixture_config, id=fixture_config.get("inference") or "")
|
||||
)
|
||||
|
||||
if len(selected_default_param_id) > 0:
|
||||
for default_fixture in default_fixture_combination:
|
||||
if default_fixture.id in selected_default_param_id:
|
||||
custom_fixtures.append(default_fixture)
|
||||
|
||||
return custom_fixtures
|
Loading…
Add table
Add a link
Reference in a new issue