address comment

This commit is contained in:
Sixian Yi 2025-01-15 16:23:26 -08:00
parent 11b2fdd3d8
commit 901c10d444
5 changed files with 80 additions and 94 deletions

View file

@ -6,14 +6,15 @@
import pytest import pytest
from ..conftest import get_provider_fixture_overrides from ..conftest import (
from ..inference.fixtures import INFERENCE_FIXTURES get_provider_fixture_overrides,
from ..memory.fixtures import MEMORY_FIXTURES
from ..safety.fixtures import SAFETY_FIXTURES, safety_model_from_shield
from ..test_config_helper import (
get_provider_fixtures_from_config, get_provider_fixtures_from_config,
try_load_config_file_cached, try_load_config_file_cached,
) )
from ..inference.fixtures import INFERENCE_FIXTURES
from ..memory.fixtures import MEMORY_FIXTURES
from ..safety.fixtures import SAFETY_FIXTURES, safety_model_from_shield
from ..tools.fixtures import TOOL_RUNTIME_FIXTURES from ..tools.fixtures import TOOL_RUNTIME_FIXTURES
from .fixtures import AGENTS_FIXTURES from .fixtures import AGENTS_FIXTURES
@ -86,7 +87,7 @@ def pytest_configure(config):
def pytest_generate_tests(metafunc): def pytest_generate_tests(metafunc):
test_config = try_load_config_file_cached(metafunc.config.getoption("config")) test_config = try_load_config_file_cached(metafunc.config)
( (
config_override_inference_models, config_override_inference_models,
config_override_safety_shield, config_override_safety_shield,

View file

@ -6,13 +6,15 @@
import os import os
from collections import defaultdict from collections import defaultdict
from pathlib import Path from pathlib import Path
from typing import Any, Dict, List, Optional from typing import Any, Dict, List, Optional
import pytest import pytest
import yaml
from dotenv import load_dotenv from dotenv import load_dotenv
from pydantic import BaseModel from pydantic import BaseModel, Field
from termcolor import colored from termcolor import colored
from llama_stack.distribution.datatypes import Provider from llama_stack.distribution.datatypes import Provider
@ -20,14 +22,74 @@ from llama_stack.providers.datatypes import RemoteProviderConfig
from .env import get_env_or_fail from .env import get_env_or_fail
from .test_config_helper import try_load_config_file_cached
class ProviderFixture(BaseModel): class ProviderFixture(BaseModel):
providers: List[Provider] providers: List[Provider]
provider_data: Optional[Dict[str, Any]] = None provider_data: Optional[Dict[str, Any]] = None
class Fixtures(BaseModel):
# provider fixtures can be either a mark or a dictionary of api -> providers
provider_fixtures: List[Dict[str, str]] = Field(default_factory=list)
inference_models: List[str] = Field(default_factory=list)
safety_shield: Optional[str] = Field(default_factory=None)
embedding_model: Optional[str] = Field(default_factory=None)
class APITestConfig(BaseModel):
fixtures: Fixtures
# test name format should be <relative_path.py>::<test_name>
tests: List[str] = Field(default_factory=list)
class TestConfig(BaseModel):
inference: APITestConfig
agent: Optional[APITestConfig] = Field(default=None)
memory: Optional[APITestConfig] = Field(default=None)
CONFIG_CACHE = None
def try_load_config_file_cached(config):
config_file = config.getoption("--config")
if config_file is None:
return None
if CONFIG_CACHE is not None:
return CONFIG_CACHE
config_file_path = Path(__file__).parent / config_file
if not config_file_path.exists():
raise ValueError(
f"Test config {config_file} was specified but not found. Please make sure it exists in the llama_stack/providers/tests directory."
)
with open(config_file_path, "r") as config_file:
config = yaml.safe_load(config_file)
return TestConfig(**config)
def get_provider_fixtures_from_config(
provider_fixtures_config, default_fixture_combination
):
custom_fixtures = []
selected_default_param_id = set()
for fixture_config in provider_fixtures_config:
if "default_fixture_param_id" in fixture_config:
selected_default_param_id.add(fixture_config["default_fixture_param_id"])
else:
custom_fixtures.append(
pytest.param(fixture_config, id=fixture_config.get("inference") or "")
)
if len(selected_default_param_id) > 0:
for default_fixture in default_fixture_combination:
if default_fixture.id in selected_default_param_id:
custom_fixtures.append(default_fixture)
return custom_fixtures
def remote_stack_fixture() -> ProviderFixture: def remote_stack_fixture() -> ProviderFixture:
if url := os.getenv("REMOTE_STACK_URL", None): if url := os.getenv("REMOTE_STACK_URL", None):
config = RemoteProviderConfig.from_url(url) config = RemoteProviderConfig.from_url(url)
@ -182,7 +244,7 @@ def pytest_itemcollected(item):
def pytest_collection_modifyitems(session, config, items): def pytest_collection_modifyitems(session, config, items):
test_config = try_load_config_file_cached(config.getoption("--config")) test_config = try_load_config_file_cached(config)
if test_config is None: if test_config is None:
return return

View file

@ -6,8 +6,7 @@
import pytest import pytest
from ..conftest import get_provider_fixture_overrides from ..conftest import get_provider_fixture_overrides, try_load_config_file_cached
from ..test_config_helper import try_load_config_file_cached
from .fixtures import INFERENCE_FIXTURES from .fixtures import INFERENCE_FIXTURES
@ -43,7 +42,7 @@ VISION_MODEL_PARAMS = [
def pytest_generate_tests(metafunc): def pytest_generate_tests(metafunc):
test_config = try_load_config_file_cached(metafunc.config.getoption("config")) test_config = try_load_config_file_cached(metafunc.config)
if "inference_model" in metafunc.fixturenames: if "inference_model" in metafunc.fixturenames:
cls_name = metafunc.cls.__name__ cls_name = metafunc.cls.__name__
if test_config is not None: if test_config is not None:

View file

@ -6,13 +6,13 @@
import pytest import pytest
from ..conftest import get_provider_fixture_overrides from ..conftest import (
get_provider_fixture_overrides,
from ..inference.fixtures import INFERENCE_FIXTURES
from ..test_config_helper import (
get_provider_fixtures_from_config, get_provider_fixtures_from_config,
try_load_config_file_cached, try_load_config_file_cached,
) )
from ..inference.fixtures import INFERENCE_FIXTURES
from .fixtures import MEMORY_FIXTURES from .fixtures import MEMORY_FIXTURES
@ -69,7 +69,7 @@ def pytest_configure(config):
def pytest_generate_tests(metafunc): def pytest_generate_tests(metafunc):
test_config = try_load_config_file_cached(metafunc.config.getoption("config")) test_config = try_load_config_file_cached(metafunc.config)
provider_fixtures_config = ( provider_fixtures_config = (
test_config.memory.fixtures.provider_fixtures test_config.memory.fixtures.provider_fixtures
if test_config is not None and test_config.memory is not None if test_config is not None and test_config.memory is not None

View file

@ -1,76 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from dataclasses import dataclass
from pathlib import Path
from typing import Dict, List, Optional
import pytest
import yaml
from pydantic import BaseModel, Field
@dataclass
class APITestConfig(BaseModel):
class Fixtures(BaseModel):
# provider fixtures can be either a mark or a dictionary of api -> providers
provider_fixtures: List[Dict[str, str]] = Field(default_factory=list)
inference_models: List[str] = Field(default_factory=list)
safety_shield: Optional[str] = Field(default_factory=None)
embedding_model: Optional[str] = Field(default_factory=None)
fixtures: Fixtures
tests: List[str] = Field(default_factory=list)
# test name format should be <relative_path.py>::<test_name>
class TestConfig(BaseModel):
inference: APITestConfig
agent: Optional[APITestConfig] = Field(default=None)
memory: Optional[APITestConfig] = Field(default=None)
CONFIG_CACHE = None
def try_load_config_file_cached(config_file):
if config_file is None:
return None
if CONFIG_CACHE is not None:
return CONFIG_CACHE
config_file_path = Path(__file__).parent / config_file
if not config_file_path.exists():
raise ValueError(
f"Test config {config_file} was specified but not found. Please make sure it exists in the llama_stack/providers/tests directory."
)
with open(config_file_path, "r") as config_file:
config = yaml.safe_load(config_file)
return TestConfig(**config)
def get_provider_fixtures_from_config(
provider_fixtures_config, default_fixture_combination
):
custom_fixtures = []
selected_default_param_id = set()
for fixture_config in provider_fixtures_config:
if "default_fixture_param_id" in fixture_config:
selected_default_param_id.add(fixture_config["default_fixture_param_id"])
else:
custom_fixtures.append(
pytest.param(fixture_config, id=fixture_config.get("inference") or "")
)
if len(selected_default_param_id) > 0:
for default_fixture in default_fixture_combination:
if default_fixture.id in selected_default_param_id:
custom_fixtures.append(default_fixture)
return custom_fixtures