restructure config

This commit is contained in:
Sixian Yi 2025-01-14 16:08:20 -08:00
parent 702cf2d563
commit 26d9804efd
8 changed files with 218 additions and 62 deletions

View file

@ -246,7 +246,7 @@ class ModelsRoutingTable(CommonRoutingTableImpl, Models):
provider_id = list(self.impls_by_provider_id.keys())[0] provider_id = list(self.impls_by_provider_id.keys())[0]
else: else:
raise ValueError( raise ValueError(
"No provider specified and multiple providers available. Please specify a provider_id. Available providers: {self.impls_by_provider_id.keys()}" f"No provider specified and multiple providers available. Please specify a provider_id. Available providers: {self.impls_by_provider_id.keys()}"
) )
if metadata is None: if metadata is None:
metadata = {} metadata = {}

View file

@ -10,6 +10,10 @@ from ..conftest import get_provider_fixture_overrides
from ..inference.fixtures import INFERENCE_FIXTURES from ..inference.fixtures import INFERENCE_FIXTURES
from ..memory.fixtures import MEMORY_FIXTURES from ..memory.fixtures import MEMORY_FIXTURES
from ..safety.fixtures import SAFETY_FIXTURES, safety_model_from_shield from ..safety.fixtures import SAFETY_FIXTURES, safety_model_from_shield
from ..test_config_helper import (
get_provider_fixtures_from_config,
try_load_config_file_cached,
)
from ..tools.fixtures import TOOL_RUNTIME_FIXTURES from ..tools.fixtures import TOOL_RUNTIME_FIXTURES
from .fixtures import AGENTS_FIXTURES from .fixtures import AGENTS_FIXTURES
@ -82,7 +86,25 @@ def pytest_configure(config):
def pytest_generate_tests(metafunc): def pytest_generate_tests(metafunc):
shield_id = metafunc.config.getoption("--safety-shield") test_config = try_load_config_file_cached(metafunc.config.getoption("config"))
(
config_override_inference_models,
config_override_safety_shield,
custom_provider_fixtures,
) = (None, None, None)
if test_config is not None:
config_override_inference_models = test_config.agent.fixtures.inference_models
config_override_safety_shield = test_config.agent.fixtures.safety_shield
custom_provider_fixtures = get_provider_fixtures_from_config(
test_config.agent.fixtures.provider_fixtures, DEFAULT_PROVIDER_COMBINATIONS
)
shield_id = config_override_safety_shield or metafunc.config.getoption(
"--safety-shield"
)
inference_model = config_override_inference_models or [
metafunc.config.getoption("--inference-model")
]
if "safety_shield" in metafunc.fixturenames: if "safety_shield" in metafunc.fixturenames:
metafunc.parametrize( metafunc.parametrize(
"safety_shield", "safety_shield",
@ -90,8 +112,7 @@ def pytest_generate_tests(metafunc):
indirect=True, indirect=True,
) )
if "inference_model" in metafunc.fixturenames: if "inference_model" in metafunc.fixturenames:
inference_model = metafunc.config.getoption("--inference-model") models = set(inference_model)
models = set({inference_model})
if safety_model := safety_model_from_shield(shield_id): if safety_model := safety_model_from_shield(shield_id):
models.add(safety_model) models.add(safety_model)
@ -109,7 +130,8 @@ def pytest_generate_tests(metafunc):
"tool_runtime": TOOL_RUNTIME_FIXTURES, "tool_runtime": TOOL_RUNTIME_FIXTURES,
} }
combinations = ( combinations = (
get_provider_fixture_overrides(metafunc.config, available_fixtures) custom_provider_fixtures
or get_provider_fixture_overrides(metafunc.config, available_fixtures)
or DEFAULT_PROVIDER_COMBINATIONS or DEFAULT_PROVIDER_COMBINATIONS
) )
metafunc.parametrize("agents_stack", combinations, indirect=True) metafunc.parametrize("agents_stack", combinations, indirect=True)

View file

@ -1,24 +1,59 @@
tests: inference:
- path: inference/test_vision_inference.py tests:
functions: - inference/test_vision_inference.py::test_vision_chat_completion_streaming
- test_vision_chat_completion_streaming - inference/test_vision_inference.py::test_vision_chat_completion_non_streaming
- test_vision_chat_completion_non_streaming - inference/test_text_inference.py::test_structured_output
- inference/test_text_inference.py::test_chat_completion_streaming
- inference/test_text_inference.py::test_chat_completion_non_streaming
- inference/test_text_inference.py::test_chat_completion_with_tool_calling
- inference/test_text_inference.py::test_chat_completion_with_tool_calling_streaming
- path: inference/test_text_inference.py fixtures:
functions: provider_fixtures:
- test_structured_output - inference: ollama
- test_chat_completion_streaming - default_fixture_param_id: fireworks
- test_chat_completion_non_streaming - inference: together
- test_chat_completion_with_tool_calling # - inference: tgi
- test_chat_completion_with_tool_calling_streaming # - inference: vllm_remote
inference_models:
- meta-llama/Llama-3.1-8B-Instruct
- meta-llama/Llama-3.2-11B-Vision-Instruct
inference_fixtures: safety_shield: ~
- ollama embedding_model: ~
- fireworks
- together
- tgi
- vllm_remote
test_models:
text: meta-llama/Llama-3.1-8B-Instruct agent:
vision: meta-llama/Llama-3.2-11B-Vision-Instruct tests:
- agents/test_agents.py::test_agent_turns_with_safety
- agents/test_agents.py::test_rag_agent
fixtures:
provider_fixtures:
- default_fixture_param_id: ollama
- default_fixture_param_id: together
- default_fixture_param_id: fireworks
safety_shield: ~
embedding_model: ~
inference_models:
- meta-llama/Llama-3.2-1B-Instruct
memory:
tests:
- memory/test_memory.py::test_query_documents
fixtures:
provider_fixtures:
- default_fixture_param_id: ollama
- inference: sentence_transformers
memory: faiss
- default_fixture_param_id: chroma
inference_models:
- meta-llama/Llama-3.2-1B-Instruct
safety_shield: ~
embedding_model: ~

View file

@ -5,12 +5,12 @@
# the root directory of this source tree. # the root directory of this source tree.
import os import os
from collections import defaultdict
from pathlib import Path from pathlib import Path
from typing import Any, Dict, List, Optional from typing import Any, Dict, List, Optional
import pytest import pytest
import yaml
from dotenv import load_dotenv from dotenv import load_dotenv
from pydantic import BaseModel from pydantic import BaseModel
from termcolor import colored from termcolor import colored
@ -20,6 +20,8 @@ from llama_stack.providers.datatypes import RemoteProviderConfig
from .env import get_env_or_fail from .env import get_env_or_fail
from .test_config_helper import try_load_config_file_cached
class ProviderFixture(BaseModel): class ProviderFixture(BaseModel):
providers: List[Provider] providers: List[Provider]
@ -180,34 +182,26 @@ def pytest_itemcollected(item):
def pytest_collection_modifyitems(session, config, items): def pytest_collection_modifyitems(session, config, items):
if config.getoption("--config") is None: test_config = try_load_config_file_cached(config.getoption("--config"))
if test_config is None:
return return
file_name = config.getoption("--config")
config_file_path = Path(__file__).parent / file_name
if not config_file_path.exists():
raise ValueError(
f"Test config {file_name} was specified but not found. Please make sure it exists in the llama_stack/providers/tests directory."
)
required_tests = dict() required_tests = defaultdict(set)
inference_providers = set() test_configs = [test_config.inference, test_config.memory, test_config.agent]
with open(config_file_path, "r") as config_file: for test_config in test_configs:
test_config = yaml.safe_load(config_file) for test in test_config.tests:
for test in test_config["tests"]: arr = test.split("::")
required_tests[Path(__file__).parent / test["path"]] = set( if len(arr) != 2:
test["functions"] raise ValueError(f"Invalid format for test name {test}")
) test_path, func_name = arr
inference_providers = set(test_config["inference_fixtures"]) required_tests[Path(__file__).parent / test_path].add(func_name)
new_items, deselected_items = [], [] new_items, deselected_items = [], []
for item in items: for item in items:
if item.fspath in required_tests: func_name = getattr(item, "originalname", item.name)
func_name = getattr(item, "originalname", item.name) if func_name in required_tests[item.fspath]:
if func_name in required_tests[item.fspath]: new_items.append(item)
inference = item.callspec.params.get("inference_stack") continue
if inference in inference_providers:
new_items.append(item)
continue
deselected_items.append(item) deselected_items.append(item)
items[:] = new_items items[:] = new_items

View file

@ -7,7 +7,7 @@
import pytest import pytest
from ..conftest import get_provider_fixture_overrides from ..conftest import get_provider_fixture_overrides
from ..test_config_helper import try_load_config_file_cached
from .fixtures import INFERENCE_FIXTURES from .fixtures import INFERENCE_FIXTURES
@ -43,29 +43,43 @@ VISION_MODEL_PARAMS = [
def pytest_generate_tests(metafunc): def pytest_generate_tests(metafunc):
test_config = try_load_config_file_cached(metafunc.config.getoption("config"))
if "inference_model" in metafunc.fixturenames: if "inference_model" in metafunc.fixturenames:
model = metafunc.config.getoption("--inference-model") cls_name = metafunc.cls.__name__
if model: if test_config is not None:
params = [pytest.param(model, id="")] params = []
for model in test_config.inference.fixtures.inference_models:
if ("Vision" in cls_name and "Vision" in model) or (
"Vision" not in cls_name and "Vision" not in model
):
params.append(pytest.param(model, id=model))
else: else:
cls_name = metafunc.cls.__name__ model = metafunc.config.getoption("--inference-model")
if "Vision" in cls_name: if model:
params = VISION_MODEL_PARAMS params = [pytest.param(model, id="")]
else: else:
params = MODEL_PARAMS if "Vision" in cls_name:
params = VISION_MODEL_PARAMS
else:
params = MODEL_PARAMS
metafunc.parametrize( metafunc.parametrize(
"inference_model", "inference_model",
params, params,
indirect=True, indirect=True,
) )
if "inference_stack" in metafunc.fixturenames: if "inference_stack" in metafunc.fixturenames:
fixtures = INFERENCE_FIXTURES if test_config is not None:
if filtered_stacks := get_provider_fixture_overrides( fixtures = [
(f.get("inference") or f.get("default_fixture_param_id"))
for f in test_config.inference.fixtures.provider_fixtures
]
elif filtered_stacks := get_provider_fixture_overrides(
metafunc.config, metafunc.config,
{ {
"inference": INFERENCE_FIXTURES, "inference": INFERENCE_FIXTURES,
}, },
): ):
fixtures = [stack.values[0]["inference"] for stack in filtered_stacks] fixtures = [stack.values[0]["inference"] for stack in filtered_stacks]
else:
fixtures = INFERENCE_FIXTURES
metafunc.parametrize("inference_stack", fixtures, indirect=True) metafunc.parametrize("inference_stack", fixtures, indirect=True)

View file

@ -301,6 +301,7 @@ async def inference_stack(request, inference_model):
inference_fixture.provider_data, inference_fixture.provider_data,
models=[ models=[
ModelInput( ModelInput(
provider_id=inference_fixture.providers[0].provider_id,
model_id=inference_model, model_id=inference_model,
model_type=model_type, model_type=model_type,
metadata=metadata, metadata=metadata,

View file

@ -9,6 +9,10 @@ import pytest
from ..conftest import get_provider_fixture_overrides from ..conftest import get_provider_fixture_overrides
from ..inference.fixtures import INFERENCE_FIXTURES from ..inference.fixtures import INFERENCE_FIXTURES
from ..test_config_helper import (
get_provider_fixtures_from_config,
try_load_config_file_cached,
)
from .fixtures import MEMORY_FIXTURES from .fixtures import MEMORY_FIXTURES
@ -65,6 +69,15 @@ def pytest_configure(config):
def pytest_generate_tests(metafunc): def pytest_generate_tests(metafunc):
test_config = try_load_config_file_cached(metafunc.config.getoption("config"))
provider_fixtures_config = (
test_config.memory.fixtures.provider_fixtures
if test_config is not None
else None
)
custom_fixtures = get_provider_fixtures_from_config(
provider_fixtures_config, DEFAULT_PROVIDER_COMBINATIONS
)
if "embedding_model" in metafunc.fixturenames: if "embedding_model" in metafunc.fixturenames:
model = metafunc.config.getoption("--embedding-model") model = metafunc.config.getoption("--embedding-model")
if model: if model:
@ -80,7 +93,8 @@ def pytest_generate_tests(metafunc):
"memory": MEMORY_FIXTURES, "memory": MEMORY_FIXTURES,
} }
combinations = ( combinations = (
get_provider_fixture_overrides(metafunc.config, available_fixtures) custom_fixtures
or get_provider_fixture_overrides(metafunc.config, available_fixtures)
or DEFAULT_PROVIDER_COMBINATIONS or DEFAULT_PROVIDER_COMBINATIONS
) )
metafunc.parametrize("memory_stack", combinations, indirect=True) metafunc.parametrize("memory_stack", combinations, indirect=True)

View file

@ -0,0 +1,76 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from dataclasses import dataclass
from pathlib import Path
from typing import Dict, List, Optional
import pytest
import yaml
from pydantic import BaseModel, Field
@dataclass
class APITestConfig(BaseModel):
class Fixtures(BaseModel):
# provider fixtures can be either a mark or a dictionary of api -> providers
provider_fixtures: List[Dict[str, str]]
inference_models: List[str] = Field(default_factory=list)
safety_shield: Optional[str]
embedding_model: Optional[str]
fixtures: Fixtures
tests: List[str] = Field(default_factory=list)
# test name format should be <relative_path.py>::<test_name>
class TestConfig(BaseModel):
inference: APITestConfig
agent: APITestConfig
memory: APITestConfig
CONFIG_CACHE = None
def try_load_config_file_cached(config_file):
if config_file is None:
return None
if CONFIG_CACHE is not None:
return CONFIG_CACHE
config_file_path = Path(__file__).parent / config_file
if not config_file_path.exists():
raise ValueError(
f"Test config {config_file} was specified but not found. Please make sure it exists in the llama_stack/providers/tests directory."
)
with open(config_file_path, "r") as config_file:
config = yaml.safe_load(config_file)
return TestConfig(**config)
def get_provider_fixtures_from_config(
provider_fixtures_config, default_fixture_combination
):
custom_fixtures = []
selected_default_param_id = set()
for fixture_config in provider_fixtures_config:
if "default_fixture_param_id" in fixture_config:
selected_default_param_id.add(fixture_config["default_fixture_param_id"])
else:
custom_fixtures.append(
pytest.param(fixture_config, id=fixture_config.get("inference") or "")
)
if len(selected_default_param_id) > 0:
for default_fixture in default_fixture_combination:
if default_fixture.id in selected_default_param_id:
custom_fixtures.append(default_fixture)
return custom_fixtures