restructure config

This commit is contained in:
Sixian Yi 2025-01-14 16:08:20 -08:00
parent 702cf2d563
commit 26d9804efd
8 changed files with 218 additions and 62 deletions

View file

@ -246,7 +246,7 @@ class ModelsRoutingTable(CommonRoutingTableImpl, Models):
provider_id = list(self.impls_by_provider_id.keys())[0]
else:
raise ValueError(
"No provider specified and multiple providers available. Please specify a provider_id. Available providers: {self.impls_by_provider_id.keys()}"
f"No provider specified and multiple providers available. Please specify a provider_id. Available providers: {self.impls_by_provider_id.keys()}"
)
if metadata is None:
metadata = {}

View file

@ -10,6 +10,10 @@ from ..conftest import get_provider_fixture_overrides
from ..inference.fixtures import INFERENCE_FIXTURES
from ..memory.fixtures import MEMORY_FIXTURES
from ..safety.fixtures import SAFETY_FIXTURES, safety_model_from_shield
from ..test_config_helper import (
get_provider_fixtures_from_config,
try_load_config_file_cached,
)
from ..tools.fixtures import TOOL_RUNTIME_FIXTURES
from .fixtures import AGENTS_FIXTURES
@ -82,7 +86,25 @@ def pytest_configure(config):
def pytest_generate_tests(metafunc):
shield_id = metafunc.config.getoption("--safety-shield")
test_config = try_load_config_file_cached(metafunc.config.getoption("config"))
(
config_override_inference_models,
config_override_safety_shield,
custom_provider_fixtures,
) = (None, None, None)
if test_config is not None:
config_override_inference_models = test_config.agent.fixtures.inference_models
config_override_safety_shield = test_config.agent.fixtures.safety_shield
custom_provider_fixtures = get_provider_fixtures_from_config(
test_config.agent.fixtures.provider_fixtures, DEFAULT_PROVIDER_COMBINATIONS
)
shield_id = config_override_safety_shield or metafunc.config.getoption(
"--safety-shield"
)
inference_model = config_override_inference_models or [
metafunc.config.getoption("--inference-model")
]
if "safety_shield" in metafunc.fixturenames:
metafunc.parametrize(
"safety_shield",
@ -90,8 +112,7 @@ def pytest_generate_tests(metafunc):
indirect=True,
)
if "inference_model" in metafunc.fixturenames:
inference_model = metafunc.config.getoption("--inference-model")
models = set({inference_model})
models = set(inference_model)
if safety_model := safety_model_from_shield(shield_id):
models.add(safety_model)
@ -109,7 +130,8 @@ def pytest_generate_tests(metafunc):
"tool_runtime": TOOL_RUNTIME_FIXTURES,
}
combinations = (
get_provider_fixture_overrides(metafunc.config, available_fixtures)
custom_provider_fixtures
or get_provider_fixture_overrides(metafunc.config, available_fixtures)
or DEFAULT_PROVIDER_COMBINATIONS
)
metafunc.parametrize("agents_stack", combinations, indirect=True)

View file

@ -1,24 +1,59 @@
tests:
- path: inference/test_vision_inference.py
functions:
- test_vision_chat_completion_streaming
- test_vision_chat_completion_non_streaming
inference:
tests:
- inference/test_vision_inference.py::test_vision_chat_completion_streaming
- inference/test_vision_inference.py::test_vision_chat_completion_non_streaming
- inference/test_text_inference.py::test_structured_output
- inference/test_text_inference.py::test_chat_completion_streaming
- inference/test_text_inference.py::test_chat_completion_non_streaming
- inference/test_text_inference.py::test_chat_completion_with_tool_calling
- inference/test_text_inference.py::test_chat_completion_with_tool_calling_streaming
- path: inference/test_text_inference.py
functions:
- test_structured_output
- test_chat_completion_streaming
- test_chat_completion_non_streaming
- test_chat_completion_with_tool_calling
- test_chat_completion_with_tool_calling_streaming
fixtures:
provider_fixtures:
- inference: ollama
- default_fixture_param_id: fireworks
- inference: together
# - inference: tgi
# - inference: vllm_remote
inference_models:
- meta-llama/Llama-3.1-8B-Instruct
- meta-llama/Llama-3.2-11B-Vision-Instruct
inference_fixtures:
- ollama
- fireworks
- together
- tgi
- vllm_remote
safety_shield: ~
embedding_model: ~
test_models:
text: meta-llama/Llama-3.1-8B-Instruct
vision: meta-llama/Llama-3.2-11B-Vision-Instruct
agent:
tests:
- agents/test_agents.py::test_agent_turns_with_safety
- agents/test_agents.py::test_rag_agent
fixtures:
provider_fixtures:
- default_fixture_param_id: ollama
- default_fixture_param_id: together
- default_fixture_param_id: fireworks
safety_shield: ~
embedding_model: ~
inference_models:
- meta-llama/Llama-3.2-1B-Instruct
memory:
tests:
- memory/test_memory.py::test_query_documents
fixtures:
provider_fixtures:
- default_fixture_param_id: ollama
- inference: sentence_transformers
memory: faiss
- default_fixture_param_id: chroma
inference_models:
- meta-llama/Llama-3.2-1B-Instruct
safety_shield: ~
embedding_model: ~

View file

@ -5,12 +5,12 @@
# the root directory of this source tree.
import os
from collections import defaultdict
from pathlib import Path
from typing import Any, Dict, List, Optional
import pytest
import yaml
from dotenv import load_dotenv
from pydantic import BaseModel
from termcolor import colored
@ -20,6 +20,8 @@ from llama_stack.providers.datatypes import RemoteProviderConfig
from .env import get_env_or_fail
from .test_config_helper import try_load_config_file_cached
class ProviderFixture(BaseModel):
providers: List[Provider]
@ -180,34 +182,26 @@ def pytest_itemcollected(item):
def pytest_collection_modifyitems(session, config, items):
if config.getoption("--config") is None:
test_config = try_load_config_file_cached(config.getoption("--config"))
if test_config is None:
return
file_name = config.getoption("--config")
config_file_path = Path(__file__).parent / file_name
if not config_file_path.exists():
raise ValueError(
f"Test config {file_name} was specified but not found. Please make sure it exists in the llama_stack/providers/tests directory."
)
required_tests = dict()
inference_providers = set()
with open(config_file_path, "r") as config_file:
test_config = yaml.safe_load(config_file)
for test in test_config["tests"]:
required_tests[Path(__file__).parent / test["path"]] = set(
test["functions"]
)
inference_providers = set(test_config["inference_fixtures"])
required_tests = defaultdict(set)
test_configs = [test_config.inference, test_config.memory, test_config.agent]
for test_config in test_configs:
for test in test_config.tests:
arr = test.split("::")
if len(arr) != 2:
raise ValueError(f"Invalid format for test name {test}")
test_path, func_name = arr
required_tests[Path(__file__).parent / test_path].add(func_name)
new_items, deselected_items = [], []
for item in items:
if item.fspath in required_tests:
func_name = getattr(item, "originalname", item.name)
if func_name in required_tests[item.fspath]:
inference = item.callspec.params.get("inference_stack")
if inference in inference_providers:
new_items.append(item)
continue
func_name = getattr(item, "originalname", item.name)
if func_name in required_tests[item.fspath]:
new_items.append(item)
continue
deselected_items.append(item)
items[:] = new_items

View file

@ -7,7 +7,7 @@
import pytest
from ..conftest import get_provider_fixture_overrides
from ..test_config_helper import try_load_config_file_cached
from .fixtures import INFERENCE_FIXTURES
@ -43,29 +43,43 @@ VISION_MODEL_PARAMS = [
def pytest_generate_tests(metafunc):
test_config = try_load_config_file_cached(metafunc.config.getoption("config"))
if "inference_model" in metafunc.fixturenames:
model = metafunc.config.getoption("--inference-model")
if model:
params = [pytest.param(model, id="")]
cls_name = metafunc.cls.__name__
if test_config is not None:
params = []
for model in test_config.inference.fixtures.inference_models:
if ("Vision" in cls_name and "Vision" in model) or (
"Vision" not in cls_name and "Vision" not in model
):
params.append(pytest.param(model, id=model))
else:
cls_name = metafunc.cls.__name__
if "Vision" in cls_name:
params = VISION_MODEL_PARAMS
model = metafunc.config.getoption("--inference-model")
if model:
params = [pytest.param(model, id="")]
else:
params = MODEL_PARAMS
if "Vision" in cls_name:
params = VISION_MODEL_PARAMS
else:
params = MODEL_PARAMS
metafunc.parametrize(
"inference_model",
params,
indirect=True,
)
if "inference_stack" in metafunc.fixturenames:
fixtures = INFERENCE_FIXTURES
if filtered_stacks := get_provider_fixture_overrides(
if test_config is not None:
fixtures = [
(f.get("inference") or f.get("default_fixture_param_id"))
for f in test_config.inference.fixtures.provider_fixtures
]
elif filtered_stacks := get_provider_fixture_overrides(
metafunc.config,
{
"inference": INFERENCE_FIXTURES,
},
):
fixtures = [stack.values[0]["inference"] for stack in filtered_stacks]
else:
fixtures = INFERENCE_FIXTURES
metafunc.parametrize("inference_stack", fixtures, indirect=True)

View file

@ -301,6 +301,7 @@ async def inference_stack(request, inference_model):
inference_fixture.provider_data,
models=[
ModelInput(
provider_id=inference_fixture.providers[0].provider_id,
model_id=inference_model,
model_type=model_type,
metadata=metadata,

View file

@ -9,6 +9,10 @@ import pytest
from ..conftest import get_provider_fixture_overrides
from ..inference.fixtures import INFERENCE_FIXTURES
from ..test_config_helper import (
get_provider_fixtures_from_config,
try_load_config_file_cached,
)
from .fixtures import MEMORY_FIXTURES
@ -65,6 +69,15 @@ def pytest_configure(config):
def pytest_generate_tests(metafunc):
test_config = try_load_config_file_cached(metafunc.config.getoption("config"))
provider_fixtures_config = (
test_config.memory.fixtures.provider_fixtures
if test_config is not None
else None
)
custom_fixtures = get_provider_fixtures_from_config(
provider_fixtures_config, DEFAULT_PROVIDER_COMBINATIONS
)
if "embedding_model" in metafunc.fixturenames:
model = metafunc.config.getoption("--embedding-model")
if model:
@ -80,7 +93,8 @@ def pytest_generate_tests(metafunc):
"memory": MEMORY_FIXTURES,
}
combinations = (
get_provider_fixture_overrides(metafunc.config, available_fixtures)
custom_fixtures
or get_provider_fixture_overrides(metafunc.config, available_fixtures)
or DEFAULT_PROVIDER_COMBINATIONS
)
metafunc.parametrize("memory_stack", combinations, indirect=True)

View file

@ -0,0 +1,76 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from dataclasses import dataclass
from pathlib import Path
from typing import Dict, List, Optional
import pytest
import yaml
from pydantic import BaseModel, Field
@dataclass
class APITestConfig(BaseModel):
class Fixtures(BaseModel):
# provider fixtures can be either a mark or a dictionary of api -> providers
provider_fixtures: List[Dict[str, str]]
inference_models: List[str] = Field(default_factory=list)
safety_shield: Optional[str]
embedding_model: Optional[str]
fixtures: Fixtures
tests: List[str] = Field(default_factory=list)
# test name format should be <relative_path.py>::<test_name>
class TestConfig(BaseModel):
inference: APITestConfig
agent: APITestConfig
memory: APITestConfig
CONFIG_CACHE = None
def try_load_config_file_cached(config_file):
if config_file is None:
return None
if CONFIG_CACHE is not None:
return CONFIG_CACHE
config_file_path = Path(__file__).parent / config_file
if not config_file_path.exists():
raise ValueError(
f"Test config {config_file} was specified but not found. Please make sure it exists in the llama_stack/providers/tests directory."
)
with open(config_file_path, "r") as config_file:
config = yaml.safe_load(config_file)
return TestConfig(**config)
def get_provider_fixtures_from_config(
provider_fixtures_config, default_fixture_combination
):
custom_fixtures = []
selected_default_param_id = set()
for fixture_config in provider_fixtures_config:
if "default_fixture_param_id" in fixture_config:
selected_default_param_id.add(fixture_config["default_fixture_param_id"])
else:
custom_fixtures.append(
pytest.param(fixture_config, id=fixture_config.get("inference") or "")
)
if len(selected_default_param_id) > 0:
for default_fixture in default_fixture_combination:
if default_fixture.id in selected_default_param_id:
custom_fixtures.append(default_fixture)
return custom_fixtures