mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-08-06 02:32:40 +00:00
restructure config
This commit is contained in:
parent
702cf2d563
commit
26d9804efd
8 changed files with 218 additions and 62 deletions
|
@ -246,7 +246,7 @@ class ModelsRoutingTable(CommonRoutingTableImpl, Models):
|
||||||
provider_id = list(self.impls_by_provider_id.keys())[0]
|
provider_id = list(self.impls_by_provider_id.keys())[0]
|
||||||
else:
|
else:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"No provider specified and multiple providers available. Please specify a provider_id. Available providers: {self.impls_by_provider_id.keys()}"
|
f"No provider specified and multiple providers available. Please specify a provider_id. Available providers: {self.impls_by_provider_id.keys()}"
|
||||||
)
|
)
|
||||||
if metadata is None:
|
if metadata is None:
|
||||||
metadata = {}
|
metadata = {}
|
||||||
|
|
|
@ -10,6 +10,10 @@ from ..conftest import get_provider_fixture_overrides
|
||||||
from ..inference.fixtures import INFERENCE_FIXTURES
|
from ..inference.fixtures import INFERENCE_FIXTURES
|
||||||
from ..memory.fixtures import MEMORY_FIXTURES
|
from ..memory.fixtures import MEMORY_FIXTURES
|
||||||
from ..safety.fixtures import SAFETY_FIXTURES, safety_model_from_shield
|
from ..safety.fixtures import SAFETY_FIXTURES, safety_model_from_shield
|
||||||
|
from ..test_config_helper import (
|
||||||
|
get_provider_fixtures_from_config,
|
||||||
|
try_load_config_file_cached,
|
||||||
|
)
|
||||||
from ..tools.fixtures import TOOL_RUNTIME_FIXTURES
|
from ..tools.fixtures import TOOL_RUNTIME_FIXTURES
|
||||||
from .fixtures import AGENTS_FIXTURES
|
from .fixtures import AGENTS_FIXTURES
|
||||||
|
|
||||||
|
@ -82,7 +86,25 @@ def pytest_configure(config):
|
||||||
|
|
||||||
|
|
||||||
def pytest_generate_tests(metafunc):
|
def pytest_generate_tests(metafunc):
|
||||||
shield_id = metafunc.config.getoption("--safety-shield")
|
test_config = try_load_config_file_cached(metafunc.config.getoption("config"))
|
||||||
|
(
|
||||||
|
config_override_inference_models,
|
||||||
|
config_override_safety_shield,
|
||||||
|
custom_provider_fixtures,
|
||||||
|
) = (None, None, None)
|
||||||
|
if test_config is not None:
|
||||||
|
config_override_inference_models = test_config.agent.fixtures.inference_models
|
||||||
|
config_override_safety_shield = test_config.agent.fixtures.safety_shield
|
||||||
|
custom_provider_fixtures = get_provider_fixtures_from_config(
|
||||||
|
test_config.agent.fixtures.provider_fixtures, DEFAULT_PROVIDER_COMBINATIONS
|
||||||
|
)
|
||||||
|
|
||||||
|
shield_id = config_override_safety_shield or metafunc.config.getoption(
|
||||||
|
"--safety-shield"
|
||||||
|
)
|
||||||
|
inference_model = config_override_inference_models or [
|
||||||
|
metafunc.config.getoption("--inference-model")
|
||||||
|
]
|
||||||
if "safety_shield" in metafunc.fixturenames:
|
if "safety_shield" in metafunc.fixturenames:
|
||||||
metafunc.parametrize(
|
metafunc.parametrize(
|
||||||
"safety_shield",
|
"safety_shield",
|
||||||
|
@ -90,8 +112,7 @@ def pytest_generate_tests(metafunc):
|
||||||
indirect=True,
|
indirect=True,
|
||||||
)
|
)
|
||||||
if "inference_model" in metafunc.fixturenames:
|
if "inference_model" in metafunc.fixturenames:
|
||||||
inference_model = metafunc.config.getoption("--inference-model")
|
models = set(inference_model)
|
||||||
models = set({inference_model})
|
|
||||||
if safety_model := safety_model_from_shield(shield_id):
|
if safety_model := safety_model_from_shield(shield_id):
|
||||||
models.add(safety_model)
|
models.add(safety_model)
|
||||||
|
|
||||||
|
@ -109,7 +130,8 @@ def pytest_generate_tests(metafunc):
|
||||||
"tool_runtime": TOOL_RUNTIME_FIXTURES,
|
"tool_runtime": TOOL_RUNTIME_FIXTURES,
|
||||||
}
|
}
|
||||||
combinations = (
|
combinations = (
|
||||||
get_provider_fixture_overrides(metafunc.config, available_fixtures)
|
custom_provider_fixtures
|
||||||
|
or get_provider_fixture_overrides(metafunc.config, available_fixtures)
|
||||||
or DEFAULT_PROVIDER_COMBINATIONS
|
or DEFAULT_PROVIDER_COMBINATIONS
|
||||||
)
|
)
|
||||||
metafunc.parametrize("agents_stack", combinations, indirect=True)
|
metafunc.parametrize("agents_stack", combinations, indirect=True)
|
||||||
|
|
|
@ -1,24 +1,59 @@
|
||||||
tests:
|
inference:
|
||||||
- path: inference/test_vision_inference.py
|
tests:
|
||||||
functions:
|
- inference/test_vision_inference.py::test_vision_chat_completion_streaming
|
||||||
- test_vision_chat_completion_streaming
|
- inference/test_vision_inference.py::test_vision_chat_completion_non_streaming
|
||||||
- test_vision_chat_completion_non_streaming
|
- inference/test_text_inference.py::test_structured_output
|
||||||
|
- inference/test_text_inference.py::test_chat_completion_streaming
|
||||||
|
- inference/test_text_inference.py::test_chat_completion_non_streaming
|
||||||
|
- inference/test_text_inference.py::test_chat_completion_with_tool_calling
|
||||||
|
- inference/test_text_inference.py::test_chat_completion_with_tool_calling_streaming
|
||||||
|
|
||||||
- path: inference/test_text_inference.py
|
fixtures:
|
||||||
functions:
|
provider_fixtures:
|
||||||
- test_structured_output
|
- inference: ollama
|
||||||
- test_chat_completion_streaming
|
- default_fixture_param_id: fireworks
|
||||||
- test_chat_completion_non_streaming
|
- inference: together
|
||||||
- test_chat_completion_with_tool_calling
|
# - inference: tgi
|
||||||
- test_chat_completion_with_tool_calling_streaming
|
# - inference: vllm_remote
|
||||||
|
inference_models:
|
||||||
|
- meta-llama/Llama-3.1-8B-Instruct
|
||||||
|
- meta-llama/Llama-3.2-11B-Vision-Instruct
|
||||||
|
|
||||||
inference_fixtures:
|
safety_shield: ~
|
||||||
- ollama
|
embedding_model: ~
|
||||||
- fireworks
|
|
||||||
- together
|
|
||||||
- tgi
|
|
||||||
- vllm_remote
|
|
||||||
|
|
||||||
test_models:
|
|
||||||
text: meta-llama/Llama-3.1-8B-Instruct
|
agent:
|
||||||
vision: meta-llama/Llama-3.2-11B-Vision-Instruct
|
tests:
|
||||||
|
- agents/test_agents.py::test_agent_turns_with_safety
|
||||||
|
- agents/test_agents.py::test_rag_agent
|
||||||
|
|
||||||
|
fixtures:
|
||||||
|
provider_fixtures:
|
||||||
|
- default_fixture_param_id: ollama
|
||||||
|
- default_fixture_param_id: together
|
||||||
|
- default_fixture_param_id: fireworks
|
||||||
|
|
||||||
|
safety_shield: ~
|
||||||
|
embedding_model: ~
|
||||||
|
|
||||||
|
inference_models:
|
||||||
|
- meta-llama/Llama-3.2-1B-Instruct
|
||||||
|
|
||||||
|
|
||||||
|
memory:
|
||||||
|
tests:
|
||||||
|
- memory/test_memory.py::test_query_documents
|
||||||
|
|
||||||
|
fixtures:
|
||||||
|
provider_fixtures:
|
||||||
|
- default_fixture_param_id: ollama
|
||||||
|
- inference: sentence_transformers
|
||||||
|
memory: faiss
|
||||||
|
- default_fixture_param_id: chroma
|
||||||
|
|
||||||
|
inference_models:
|
||||||
|
- meta-llama/Llama-3.2-1B-Instruct
|
||||||
|
|
||||||
|
safety_shield: ~
|
||||||
|
embedding_model: ~
|
||||||
|
|
|
@ -5,12 +5,12 @@
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
import os
|
import os
|
||||||
|
from collections import defaultdict
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Any, Dict, List, Optional
|
from typing import Any, Dict, List, Optional
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
import yaml
|
|
||||||
from dotenv import load_dotenv
|
from dotenv import load_dotenv
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
from termcolor import colored
|
from termcolor import colored
|
||||||
|
@ -20,6 +20,8 @@ from llama_stack.providers.datatypes import RemoteProviderConfig
|
||||||
|
|
||||||
from .env import get_env_or_fail
|
from .env import get_env_or_fail
|
||||||
|
|
||||||
|
from .test_config_helper import try_load_config_file_cached
|
||||||
|
|
||||||
|
|
||||||
class ProviderFixture(BaseModel):
|
class ProviderFixture(BaseModel):
|
||||||
providers: List[Provider]
|
providers: List[Provider]
|
||||||
|
@ -180,34 +182,26 @@ def pytest_itemcollected(item):
|
||||||
|
|
||||||
|
|
||||||
def pytest_collection_modifyitems(session, config, items):
|
def pytest_collection_modifyitems(session, config, items):
|
||||||
if config.getoption("--config") is None:
|
test_config = try_load_config_file_cached(config.getoption("--config"))
|
||||||
|
if test_config is None:
|
||||||
return
|
return
|
||||||
file_name = config.getoption("--config")
|
|
||||||
config_file_path = Path(__file__).parent / file_name
|
|
||||||
if not config_file_path.exists():
|
|
||||||
raise ValueError(
|
|
||||||
f"Test config {file_name} was specified but not found. Please make sure it exists in the llama_stack/providers/tests directory."
|
|
||||||
)
|
|
||||||
|
|
||||||
required_tests = dict()
|
required_tests = defaultdict(set)
|
||||||
inference_providers = set()
|
test_configs = [test_config.inference, test_config.memory, test_config.agent]
|
||||||
with open(config_file_path, "r") as config_file:
|
for test_config in test_configs:
|
||||||
test_config = yaml.safe_load(config_file)
|
for test in test_config.tests:
|
||||||
for test in test_config["tests"]:
|
arr = test.split("::")
|
||||||
required_tests[Path(__file__).parent / test["path"]] = set(
|
if len(arr) != 2:
|
||||||
test["functions"]
|
raise ValueError(f"Invalid format for test name {test}")
|
||||||
)
|
test_path, func_name = arr
|
||||||
inference_providers = set(test_config["inference_fixtures"])
|
required_tests[Path(__file__).parent / test_path].add(func_name)
|
||||||
|
|
||||||
new_items, deselected_items = [], []
|
new_items, deselected_items = [], []
|
||||||
for item in items:
|
for item in items:
|
||||||
if item.fspath in required_tests:
|
func_name = getattr(item, "originalname", item.name)
|
||||||
func_name = getattr(item, "originalname", item.name)
|
if func_name in required_tests[item.fspath]:
|
||||||
if func_name in required_tests[item.fspath]:
|
new_items.append(item)
|
||||||
inference = item.callspec.params.get("inference_stack")
|
continue
|
||||||
if inference in inference_providers:
|
|
||||||
new_items.append(item)
|
|
||||||
continue
|
|
||||||
deselected_items.append(item)
|
deselected_items.append(item)
|
||||||
|
|
||||||
items[:] = new_items
|
items[:] = new_items
|
||||||
|
|
|
@ -7,7 +7,7 @@
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from ..conftest import get_provider_fixture_overrides
|
from ..conftest import get_provider_fixture_overrides
|
||||||
|
from ..test_config_helper import try_load_config_file_cached
|
||||||
from .fixtures import INFERENCE_FIXTURES
|
from .fixtures import INFERENCE_FIXTURES
|
||||||
|
|
||||||
|
|
||||||
|
@ -43,29 +43,43 @@ VISION_MODEL_PARAMS = [
|
||||||
|
|
||||||
|
|
||||||
def pytest_generate_tests(metafunc):
|
def pytest_generate_tests(metafunc):
|
||||||
|
test_config = try_load_config_file_cached(metafunc.config.getoption("config"))
|
||||||
if "inference_model" in metafunc.fixturenames:
|
if "inference_model" in metafunc.fixturenames:
|
||||||
model = metafunc.config.getoption("--inference-model")
|
cls_name = metafunc.cls.__name__
|
||||||
if model:
|
if test_config is not None:
|
||||||
params = [pytest.param(model, id="")]
|
params = []
|
||||||
|
for model in test_config.inference.fixtures.inference_models:
|
||||||
|
if ("Vision" in cls_name and "Vision" in model) or (
|
||||||
|
"Vision" not in cls_name and "Vision" not in model
|
||||||
|
):
|
||||||
|
params.append(pytest.param(model, id=model))
|
||||||
else:
|
else:
|
||||||
cls_name = metafunc.cls.__name__
|
model = metafunc.config.getoption("--inference-model")
|
||||||
if "Vision" in cls_name:
|
if model:
|
||||||
params = VISION_MODEL_PARAMS
|
params = [pytest.param(model, id="")]
|
||||||
else:
|
else:
|
||||||
params = MODEL_PARAMS
|
if "Vision" in cls_name:
|
||||||
|
params = VISION_MODEL_PARAMS
|
||||||
|
else:
|
||||||
|
params = MODEL_PARAMS
|
||||||
metafunc.parametrize(
|
metafunc.parametrize(
|
||||||
"inference_model",
|
"inference_model",
|
||||||
params,
|
params,
|
||||||
indirect=True,
|
indirect=True,
|
||||||
)
|
)
|
||||||
if "inference_stack" in metafunc.fixturenames:
|
if "inference_stack" in metafunc.fixturenames:
|
||||||
fixtures = INFERENCE_FIXTURES
|
if test_config is not None:
|
||||||
if filtered_stacks := get_provider_fixture_overrides(
|
fixtures = [
|
||||||
|
(f.get("inference") or f.get("default_fixture_param_id"))
|
||||||
|
for f in test_config.inference.fixtures.provider_fixtures
|
||||||
|
]
|
||||||
|
elif filtered_stacks := get_provider_fixture_overrides(
|
||||||
metafunc.config,
|
metafunc.config,
|
||||||
{
|
{
|
||||||
"inference": INFERENCE_FIXTURES,
|
"inference": INFERENCE_FIXTURES,
|
||||||
},
|
},
|
||||||
):
|
):
|
||||||
fixtures = [stack.values[0]["inference"] for stack in filtered_stacks]
|
fixtures = [stack.values[0]["inference"] for stack in filtered_stacks]
|
||||||
|
else:
|
||||||
|
fixtures = INFERENCE_FIXTURES
|
||||||
metafunc.parametrize("inference_stack", fixtures, indirect=True)
|
metafunc.parametrize("inference_stack", fixtures, indirect=True)
|
||||||
|
|
|
@ -301,6 +301,7 @@ async def inference_stack(request, inference_model):
|
||||||
inference_fixture.provider_data,
|
inference_fixture.provider_data,
|
||||||
models=[
|
models=[
|
||||||
ModelInput(
|
ModelInput(
|
||||||
|
provider_id=inference_fixture.providers[0].provider_id,
|
||||||
model_id=inference_model,
|
model_id=inference_model,
|
||||||
model_type=model_type,
|
model_type=model_type,
|
||||||
metadata=metadata,
|
metadata=metadata,
|
||||||
|
|
|
@ -9,6 +9,10 @@ import pytest
|
||||||
from ..conftest import get_provider_fixture_overrides
|
from ..conftest import get_provider_fixture_overrides
|
||||||
|
|
||||||
from ..inference.fixtures import INFERENCE_FIXTURES
|
from ..inference.fixtures import INFERENCE_FIXTURES
|
||||||
|
from ..test_config_helper import (
|
||||||
|
get_provider_fixtures_from_config,
|
||||||
|
try_load_config_file_cached,
|
||||||
|
)
|
||||||
from .fixtures import MEMORY_FIXTURES
|
from .fixtures import MEMORY_FIXTURES
|
||||||
|
|
||||||
|
|
||||||
|
@ -65,6 +69,15 @@ def pytest_configure(config):
|
||||||
|
|
||||||
|
|
||||||
def pytest_generate_tests(metafunc):
|
def pytest_generate_tests(metafunc):
|
||||||
|
test_config = try_load_config_file_cached(metafunc.config.getoption("config"))
|
||||||
|
provider_fixtures_config = (
|
||||||
|
test_config.memory.fixtures.provider_fixtures
|
||||||
|
if test_config is not None
|
||||||
|
else None
|
||||||
|
)
|
||||||
|
custom_fixtures = get_provider_fixtures_from_config(
|
||||||
|
provider_fixtures_config, DEFAULT_PROVIDER_COMBINATIONS
|
||||||
|
)
|
||||||
if "embedding_model" in metafunc.fixturenames:
|
if "embedding_model" in metafunc.fixturenames:
|
||||||
model = metafunc.config.getoption("--embedding-model")
|
model = metafunc.config.getoption("--embedding-model")
|
||||||
if model:
|
if model:
|
||||||
|
@ -80,7 +93,8 @@ def pytest_generate_tests(metafunc):
|
||||||
"memory": MEMORY_FIXTURES,
|
"memory": MEMORY_FIXTURES,
|
||||||
}
|
}
|
||||||
combinations = (
|
combinations = (
|
||||||
get_provider_fixture_overrides(metafunc.config, available_fixtures)
|
custom_fixtures
|
||||||
|
or get_provider_fixture_overrides(metafunc.config, available_fixtures)
|
||||||
or DEFAULT_PROVIDER_COMBINATIONS
|
or DEFAULT_PROVIDER_COMBINATIONS
|
||||||
)
|
)
|
||||||
metafunc.parametrize("memory_stack", combinations, indirect=True)
|
metafunc.parametrize("memory_stack", combinations, indirect=True)
|
||||||
|
|
76
llama_stack/providers/tests/test_config_helper.py
Normal file
76
llama_stack/providers/tests/test_config_helper.py
Normal file
|
@ -0,0 +1,76 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Dict, List, Optional
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
import yaml
|
||||||
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class APITestConfig(BaseModel):
|
||||||
|
|
||||||
|
class Fixtures(BaseModel):
|
||||||
|
# provider fixtures can be either a mark or a dictionary of api -> providers
|
||||||
|
provider_fixtures: List[Dict[str, str]]
|
||||||
|
inference_models: List[str] = Field(default_factory=list)
|
||||||
|
safety_shield: Optional[str]
|
||||||
|
embedding_model: Optional[str]
|
||||||
|
|
||||||
|
fixtures: Fixtures
|
||||||
|
tests: List[str] = Field(default_factory=list)
|
||||||
|
|
||||||
|
# test name format should be <relative_path.py>::<test_name>
|
||||||
|
|
||||||
|
|
||||||
|
class TestConfig(BaseModel):
|
||||||
|
|
||||||
|
inference: APITestConfig
|
||||||
|
agent: APITestConfig
|
||||||
|
memory: APITestConfig
|
||||||
|
|
||||||
|
|
||||||
|
CONFIG_CACHE = None
|
||||||
|
|
||||||
|
|
||||||
|
def try_load_config_file_cached(config_file):
|
||||||
|
if config_file is None:
|
||||||
|
return None
|
||||||
|
if CONFIG_CACHE is not None:
|
||||||
|
return CONFIG_CACHE
|
||||||
|
|
||||||
|
config_file_path = Path(__file__).parent / config_file
|
||||||
|
if not config_file_path.exists():
|
||||||
|
raise ValueError(
|
||||||
|
f"Test config {config_file} was specified but not found. Please make sure it exists in the llama_stack/providers/tests directory."
|
||||||
|
)
|
||||||
|
with open(config_file_path, "r") as config_file:
|
||||||
|
config = yaml.safe_load(config_file)
|
||||||
|
return TestConfig(**config)
|
||||||
|
|
||||||
|
|
||||||
|
def get_provider_fixtures_from_config(
|
||||||
|
provider_fixtures_config, default_fixture_combination
|
||||||
|
):
|
||||||
|
custom_fixtures = []
|
||||||
|
selected_default_param_id = set()
|
||||||
|
for fixture_config in provider_fixtures_config:
|
||||||
|
if "default_fixture_param_id" in fixture_config:
|
||||||
|
selected_default_param_id.add(fixture_config["default_fixture_param_id"])
|
||||||
|
else:
|
||||||
|
custom_fixtures.append(
|
||||||
|
pytest.param(fixture_config, id=fixture_config.get("inference") or "")
|
||||||
|
)
|
||||||
|
|
||||||
|
if len(selected_default_param_id) > 0:
|
||||||
|
for default_fixture in default_fixture_combination:
|
||||||
|
if default_fixture.id in selected_default_param_id:
|
||||||
|
custom_fixtures.append(default_fixture)
|
||||||
|
|
||||||
|
return custom_fixtures
|
Loading…
Add table
Add a link
Reference in a new issue