mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-08-06 10:42:39 +00:00
address comments
This commit is contained in:
parent
c03f7fe9be
commit
d19526ecd7
7 changed files with 127 additions and 124 deletions
|
@ -8,8 +8,8 @@ import pytest
|
||||||
|
|
||||||
from ..conftest import (
|
from ..conftest import (
|
||||||
get_provider_fixture_overrides,
|
get_provider_fixture_overrides,
|
||||||
get_provider_fixtures_from_config,
|
get_provider_fixture_overrides_from_test_config,
|
||||||
try_load_config_file_cached,
|
get_test_config_for_api,
|
||||||
)
|
)
|
||||||
from ..inference.fixtures import INFERENCE_FIXTURES
|
from ..inference.fixtures import INFERENCE_FIXTURES
|
||||||
from ..memory.fixtures import MEMORY_FIXTURES
|
from ..memory.fixtures import MEMORY_FIXTURES
|
||||||
|
@ -87,25 +87,14 @@ def pytest_configure(config):
|
||||||
|
|
||||||
|
|
||||||
def pytest_generate_tests(metafunc):
|
def pytest_generate_tests(metafunc):
|
||||||
test_config = try_load_config_file_cached(metafunc.config)
|
test_config = get_test_config_for_api(metafunc.config, "agents")
|
||||||
(
|
shield_id = getattr(
|
||||||
config_override_inference_models,
|
test_config, "safety_shield", None
|
||||||
config_override_safety_shield,
|
) or metafunc.config.getoption("--safety-shield")
|
||||||
custom_provider_fixtures,
|
inference_models = getattr(test_config, "inference_models", None) or [
|
||||||
) = (None, None, None)
|
|
||||||
if test_config is not None and test_config.agent is not None:
|
|
||||||
config_override_inference_models = test_config.agent.fixtures.inference_models
|
|
||||||
config_override_safety_shield = test_config.agent.fixtures.safety_shield
|
|
||||||
custom_provider_fixtures = get_provider_fixtures_from_config(
|
|
||||||
test_config.agent.fixtures.provider_fixtures, DEFAULT_PROVIDER_COMBINATIONS
|
|
||||||
)
|
|
||||||
|
|
||||||
shield_id = config_override_safety_shield or metafunc.config.getoption(
|
|
||||||
"--safety-shield"
|
|
||||||
)
|
|
||||||
inference_model = config_override_inference_models or [
|
|
||||||
metafunc.config.getoption("--inference-model")
|
metafunc.config.getoption("--inference-model")
|
||||||
]
|
]
|
||||||
|
|
||||||
if "safety_shield" in metafunc.fixturenames:
|
if "safety_shield" in metafunc.fixturenames:
|
||||||
metafunc.parametrize(
|
metafunc.parametrize(
|
||||||
"safety_shield",
|
"safety_shield",
|
||||||
|
@ -113,7 +102,7 @@ def pytest_generate_tests(metafunc):
|
||||||
indirect=True,
|
indirect=True,
|
||||||
)
|
)
|
||||||
if "inference_model" in metafunc.fixturenames:
|
if "inference_model" in metafunc.fixturenames:
|
||||||
models = set(inference_model)
|
models = set(inference_models)
|
||||||
if safety_model := safety_model_from_shield(shield_id):
|
if safety_model := safety_model_from_shield(shield_id):
|
||||||
models.add(safety_model)
|
models.add(safety_model)
|
||||||
|
|
||||||
|
@ -131,7 +120,9 @@ def pytest_generate_tests(metafunc):
|
||||||
"tool_runtime": TOOL_RUNTIME_FIXTURES,
|
"tool_runtime": TOOL_RUNTIME_FIXTURES,
|
||||||
}
|
}
|
||||||
combinations = (
|
combinations = (
|
||||||
custom_provider_fixtures
|
get_provider_fixture_overrides_from_test_config(
|
||||||
|
metafunc.config, "agents", DEFAULT_PROVIDER_COMBINATIONS
|
||||||
|
)
|
||||||
or get_provider_fixture_overrides(metafunc.config, available_fixtures)
|
or get_provider_fixture_overrides(metafunc.config, available_fixtures)
|
||||||
or DEFAULT_PROVIDER_COMBINATIONS
|
or DEFAULT_PROVIDER_COMBINATIONS
|
||||||
)
|
)
|
||||||
|
|
|
@ -9,7 +9,8 @@ import pytest
|
||||||
from llama_stack.apis.agents import AgentConfig, Turn
|
from llama_stack.apis.agents import AgentConfig, Turn
|
||||||
from llama_stack.apis.inference import SamplingParams, UserMessage
|
from llama_stack.apis.inference import SamplingParams, UserMessage
|
||||||
from llama_stack.providers.datatypes import Api
|
from llama_stack.providers.datatypes import Api
|
||||||
from llama_stack.providers.utils.kvstore import kvstore_impl, SqliteKVStoreConfig
|
from llama_stack.providers.utils.kvstore import kvstore_impl
|
||||||
|
from llama_stack.providers.utils.kvstore.config import SqliteKVStoreConfig
|
||||||
|
|
||||||
from .fixtures import pick_inference_model
|
from .fixtures import pick_inference_model
|
||||||
|
|
||||||
|
|
|
@ -8,52 +8,48 @@ inference:
|
||||||
- inference/test_text_inference.py::test_chat_completion_with_tool_calling
|
- inference/test_text_inference.py::test_chat_completion_with_tool_calling
|
||||||
- inference/test_text_inference.py::test_chat_completion_with_tool_calling_streaming
|
- inference/test_text_inference.py::test_chat_completion_with_tool_calling_streaming
|
||||||
|
|
||||||
fixtures:
|
scenarios:
|
||||||
provider_fixtures:
|
- provider_fixtures:
|
||||||
- inference: ollama
|
inference: ollama
|
||||||
- default_fixture_param_id: fireworks
|
- fixture_combo_id: fireworks
|
||||||
- inference: together
|
- provider_fixtures:
|
||||||
|
inference: together
|
||||||
# - inference: tgi
|
# - inference: tgi
|
||||||
# - inference: vllm_remote
|
# - inference: vllm_remote
|
||||||
inference_models:
|
|
||||||
- meta-llama/Llama-3.1-8B-Instruct
|
|
||||||
- meta-llama/Llama-3.2-11B-Vision-Instruct
|
|
||||||
|
|
||||||
safety_shield: ~
|
inference_models:
|
||||||
embedding_model: ~
|
- meta-llama/Llama-3.1-8B-Instruct
|
||||||
|
- meta-llama/Llama-3.2-11B-Vision-Instruct
|
||||||
|
|
||||||
|
|
||||||
agent:
|
agents:
|
||||||
tests:
|
tests:
|
||||||
- agents/test_agents.py::test_agent_turns_with_safety
|
- agents/test_agents.py::test_agent_turns_with_safety
|
||||||
- agents/test_agents.py::test_rag_agent
|
- agents/test_agents.py::test_rag_agent
|
||||||
|
|
||||||
fixtures:
|
scenarios:
|
||||||
provider_fixtures:
|
- fixture_combo_id: ollama
|
||||||
- default_fixture_param_id: ollama
|
- fixture_combo_id: together
|
||||||
- default_fixture_param_id: together
|
- fixture_combo_id: fireworks
|
||||||
- default_fixture_param_id: fireworks
|
|
||||||
|
|
||||||
safety_shield: ~
|
inference_models:
|
||||||
embedding_model: ~
|
- meta-llama/Llama-3.2-1B-Instruct
|
||||||
|
|
||||||
inference_models:
|
safety_shield: meta-llama/Llama-Guard-3-1B
|
||||||
- meta-llama/Llama-3.2-1B-Instruct
|
|
||||||
|
|
||||||
|
|
||||||
memory:
|
memory:
|
||||||
tests:
|
tests:
|
||||||
- memory/test_memory.py::test_query_documents
|
- memory/test_memory.py::test_query_documents
|
||||||
|
|
||||||
fixtures:
|
scenarios:
|
||||||
provider_fixtures:
|
- fixture_combo_id: ollama
|
||||||
- default_fixture_param_id: ollama
|
- provider_fixtures:
|
||||||
- inference: sentence_transformers
|
inference: sentence_transformers
|
||||||
memory: faiss
|
memory: faiss
|
||||||
- default_fixture_param_id: chroma
|
- fixture_combo_id: chroma
|
||||||
|
|
||||||
inference_models:
|
inference_models:
|
||||||
- meta-llama/Llama-3.2-1B-Instruct
|
- meta-llama/Llama-3.2-1B-Instruct
|
||||||
|
|
||||||
safety_shield: ~
|
embedding_model: all-MiniLM-L6-v2
|
||||||
embedding_model: ~
|
|
||||||
|
|
|
@ -28,28 +28,35 @@ class ProviderFixture(BaseModel):
|
||||||
provider_data: Optional[Dict[str, Any]] = None
|
provider_data: Optional[Dict[str, Any]] = None
|
||||||
|
|
||||||
|
|
||||||
class Fixtures(BaseModel):
|
class TestScenario(BaseModel):
|
||||||
# provider fixtures can be either a mark or a dictionary of api -> providers
|
# provider fixtures can be either a mark or a dictionary of api -> providers
|
||||||
provider_fixtures: List[Dict[str, str]] = Field(default_factory=list)
|
provider_fixtures: Dict[str, str] = Field(default_factory=dict)
|
||||||
inference_models: List[str] = Field(default_factory=list)
|
fixture_combo_id: Optional[str] = None
|
||||||
safety_shield: Optional[str] = Field(default_factory=None)
|
|
||||||
embedding_model: Optional[str] = Field(default_factory=None)
|
|
||||||
|
|
||||||
|
|
||||||
class APITestConfig(BaseModel):
|
class APITestConfig(BaseModel):
|
||||||
fixtures: Fixtures
|
scenarios: List[TestScenario] = Field(default_factory=list)
|
||||||
|
inference_models: List[str] = Field(default_factory=list)
|
||||||
|
|
||||||
# test name format should be <relative_path.py>::<test_name>
|
# test name format should be <relative_path.py>::<test_name>
|
||||||
tests: List[str] = Field(default_factory=list)
|
tests: List[str] = Field(default_factory=list)
|
||||||
|
|
||||||
|
|
||||||
|
class MemoryApiTestConfig(APITestConfig):
|
||||||
|
embedding_model: Optional[str] = Field(default_factory=None)
|
||||||
|
|
||||||
|
|
||||||
|
class AgentsApiTestConfig(APITestConfig):
|
||||||
|
safety_shield: Optional[str] = Field(default_factory=None)
|
||||||
|
|
||||||
|
|
||||||
class TestConfig(BaseModel):
|
class TestConfig(BaseModel):
|
||||||
inference: APITestConfig
|
inference: APITestConfig = Field(default=None)
|
||||||
agent: Optional[APITestConfig] = Field(default=None)
|
agents: AgentsApiTestConfig = Field(default=None)
|
||||||
memory: Optional[APITestConfig] = Field(default=None)
|
memory: MemoryApiTestConfig = Field(default=None)
|
||||||
|
|
||||||
|
|
||||||
def try_load_config_file_cached(config):
|
def get_test_config_from_config_file(config):
|
||||||
config_file = config.getoption("--config")
|
config_file = config.getoption("--config")
|
||||||
if config_file is None:
|
if config_file is None:
|
||||||
return None
|
return None
|
||||||
|
@ -64,25 +71,38 @@ def try_load_config_file_cached(config):
|
||||||
return TestConfig(**config)
|
return TestConfig(**config)
|
||||||
|
|
||||||
|
|
||||||
def get_provider_fixtures_from_config(
|
def get_test_config_for_api(config, api):
|
||||||
provider_fixtures_config, default_fixture_combination
|
test_config = get_test_config_from_config_file(config)
|
||||||
|
if test_config is None:
|
||||||
|
return None
|
||||||
|
return getattr(test_config, api)
|
||||||
|
|
||||||
|
|
||||||
|
def get_provider_fixture_overrides_from_test_config(
|
||||||
|
config, api, default_provider_fixture_combination
|
||||||
):
|
):
|
||||||
custom_fixtures = []
|
api_config = get_test_config_for_api(config, api)
|
||||||
selected_default_param_id = set()
|
if api_config is None:
|
||||||
for fixture_config in provider_fixtures_config:
|
return None
|
||||||
if "default_fixture_param_id" in fixture_config:
|
|
||||||
selected_default_param_id.add(fixture_config["default_fixture_param_id"])
|
fixture_combo_ids = set()
|
||||||
|
custom_provider_fixture_combos = []
|
||||||
|
for scenario in api_config.scenarios:
|
||||||
|
if scenario.fixture_combo_id:
|
||||||
|
fixture_combo_ids.add(scenario.fixture_combo_id)
|
||||||
else:
|
else:
|
||||||
custom_fixtures.append(
|
custom_provider_fixture_combos.append(
|
||||||
pytest.param(fixture_config, id=fixture_config.get("inference") or "")
|
pytest.param(
|
||||||
|
scenario.provider_fixtures,
|
||||||
|
id=scenario.provider_fixtures.get("inference") or "",
|
||||||
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
if len(selected_default_param_id) > 0:
|
if len(fixture_combo_ids) > 0:
|
||||||
for default_fixture in default_fixture_combination:
|
for default_fixture in default_provider_fixture_combination:
|
||||||
if default_fixture.id in selected_default_param_id:
|
if default_fixture.id in fixture_combo_ids:
|
||||||
custom_fixtures.append(default_fixture)
|
custom_provider_fixture_combos.append(default_fixture)
|
||||||
|
return custom_provider_fixture_combos
|
||||||
return custom_fixtures
|
|
||||||
|
|
||||||
|
|
||||||
def remote_stack_fixture() -> ProviderFixture:
|
def remote_stack_fixture() -> ProviderFixture:
|
||||||
|
@ -239,16 +259,19 @@ def pytest_itemcollected(item):
|
||||||
|
|
||||||
|
|
||||||
def pytest_collection_modifyitems(session, config, items):
|
def pytest_collection_modifyitems(session, config, items):
|
||||||
test_config = try_load_config_file_cached(config)
|
test_config = get_test_config_from_config_file(config)
|
||||||
if test_config is None:
|
if test_config is None:
|
||||||
return
|
return
|
||||||
|
|
||||||
required_tests = defaultdict(set)
|
required_tests = defaultdict(set)
|
||||||
test_configs = [test_config.inference, test_config.memory, test_config.agent]
|
for api_test_config in [
|
||||||
for test_config in test_configs:
|
test_config.inference,
|
||||||
if test_config is None:
|
test_config.memory,
|
||||||
|
test_config.agents,
|
||||||
|
]:
|
||||||
|
if api_test_config is None:
|
||||||
continue
|
continue
|
||||||
for test in test_config.tests:
|
for test in api_test_config.tests:
|
||||||
arr = test.split("::")
|
arr = test.split("::")
|
||||||
if len(arr) != 2:
|
if len(arr) != 2:
|
||||||
raise ValueError(f"Invalid format for test name {test}")
|
raise ValueError(f"Invalid format for test name {test}")
|
||||||
|
|
|
@ -6,7 +6,7 @@
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from ..conftest import get_provider_fixture_overrides, try_load_config_file_cached
|
from ..conftest import get_provider_fixture_overrides, get_test_config_for_api
|
||||||
from .fixtures import INFERENCE_FIXTURES
|
from .fixtures import INFERENCE_FIXTURES
|
||||||
|
|
||||||
|
|
||||||
|
@ -42,43 +42,43 @@ VISION_MODEL_PARAMS = [
|
||||||
|
|
||||||
|
|
||||||
def pytest_generate_tests(metafunc):
|
def pytest_generate_tests(metafunc):
|
||||||
test_config = try_load_config_file_cached(metafunc.config)
|
test_config = get_test_config_for_api(metafunc.config, "inference")
|
||||||
|
|
||||||
if "inference_model" in metafunc.fixturenames:
|
if "inference_model" in metafunc.fixturenames:
|
||||||
cls_name = metafunc.cls.__name__
|
cls_name = metafunc.cls.__name__
|
||||||
if test_config is not None:
|
params = []
|
||||||
params = []
|
inference_models = getattr(test_config, "inference_models", [])
|
||||||
for model in test_config.inference.fixtures.inference_models:
|
for model in inference_models:
|
||||||
if ("Vision" in cls_name and "Vision" in model) or (
|
if ("Vision" in cls_name and "Vision" in model) or (
|
||||||
"Vision" not in cls_name and "Vision" not in model
|
"Vision" not in cls_name and "Vision" not in model
|
||||||
):
|
):
|
||||||
params.append(pytest.param(model, id=model))
|
params.append(pytest.param(model, id=model))
|
||||||
else:
|
|
||||||
|
if not params:
|
||||||
model = metafunc.config.getoption("--inference-model")
|
model = metafunc.config.getoption("--inference-model")
|
||||||
if model:
|
params = [pytest.param(model, id="")]
|
||||||
params = [pytest.param(model, id="")]
|
|
||||||
else:
|
|
||||||
if "Vision" in cls_name:
|
|
||||||
params = VISION_MODEL_PARAMS
|
|
||||||
else:
|
|
||||||
params = MODEL_PARAMS
|
|
||||||
metafunc.parametrize(
|
metafunc.parametrize(
|
||||||
"inference_model",
|
"inference_model",
|
||||||
params,
|
params,
|
||||||
indirect=True,
|
indirect=True,
|
||||||
)
|
)
|
||||||
if "inference_stack" in metafunc.fixturenames:
|
if "inference_stack" in metafunc.fixturenames:
|
||||||
if test_config is not None:
|
fixtures = INFERENCE_FIXTURES
|
||||||
fixtures = [
|
if filtered_stacks := get_provider_fixture_overrides(
|
||||||
(f.get("inference") or f.get("default_fixture_param_id"))
|
|
||||||
for f in test_config.inference.fixtures.provider_fixtures
|
|
||||||
]
|
|
||||||
elif filtered_stacks := get_provider_fixture_overrides(
|
|
||||||
metafunc.config,
|
metafunc.config,
|
||||||
{
|
{
|
||||||
"inference": INFERENCE_FIXTURES,
|
"inference": INFERENCE_FIXTURES,
|
||||||
},
|
},
|
||||||
):
|
):
|
||||||
fixtures = [stack.values[0]["inference"] for stack in filtered_stacks]
|
fixtures = [stack.values[0]["inference"] for stack in filtered_stacks]
|
||||||
else:
|
if test_config:
|
||||||
fixtures = INFERENCE_FIXTURES
|
if custom_fixtures := [
|
||||||
|
(
|
||||||
|
scenario.fixture_combo_id
|
||||||
|
or scenario.provider_fixtures.get("inference")
|
||||||
|
)
|
||||||
|
for scenario in test_config.scenarios
|
||||||
|
]:
|
||||||
|
fixtures = custom_fixtures
|
||||||
metafunc.parametrize("inference_stack", fixtures, indirect=True)
|
metafunc.parametrize("inference_stack", fixtures, indirect=True)
|
||||||
|
|
|
@ -8,8 +8,8 @@ import pytest
|
||||||
|
|
||||||
from ..conftest import (
|
from ..conftest import (
|
||||||
get_provider_fixture_overrides,
|
get_provider_fixture_overrides,
|
||||||
get_provider_fixtures_from_config,
|
get_provider_fixture_overrides_from_test_config,
|
||||||
try_load_config_file_cached,
|
get_test_config_for_api,
|
||||||
)
|
)
|
||||||
|
|
||||||
from ..inference.fixtures import INFERENCE_FIXTURES
|
from ..inference.fixtures import INFERENCE_FIXTURES
|
||||||
|
@ -69,21 +69,11 @@ def pytest_configure(config):
|
||||||
|
|
||||||
|
|
||||||
def pytest_generate_tests(metafunc):
|
def pytest_generate_tests(metafunc):
|
||||||
test_config = try_load_config_file_cached(metafunc.config)
|
test_config = get_test_config_for_api(metafunc.config, "memory")
|
||||||
provider_fixtures_config = (
|
|
||||||
test_config.memory.fixtures.provider_fixtures
|
|
||||||
if test_config is not None and test_config.memory is not None
|
|
||||||
else None
|
|
||||||
)
|
|
||||||
custom_fixtures = (
|
|
||||||
get_provider_fixtures_from_config(
|
|
||||||
provider_fixtures_config, DEFAULT_PROVIDER_COMBINATIONS
|
|
||||||
)
|
|
||||||
if provider_fixtures_config is not None
|
|
||||||
else None
|
|
||||||
)
|
|
||||||
if "embedding_model" in metafunc.fixturenames:
|
if "embedding_model" in metafunc.fixturenames:
|
||||||
model = metafunc.config.getoption("--embedding-model")
|
model = getattr(test_config, "embedding_model", None)
|
||||||
|
# Fall back to the default if not specified by the config file
|
||||||
|
model = model or metafunc.config.getoption("--embedding-model")
|
||||||
if model:
|
if model:
|
||||||
params = [pytest.param(model, id="")]
|
params = [pytest.param(model, id="")]
|
||||||
else:
|
else:
|
||||||
|
@ -97,7 +87,9 @@ def pytest_generate_tests(metafunc):
|
||||||
"memory": MEMORY_FIXTURES,
|
"memory": MEMORY_FIXTURES,
|
||||||
}
|
}
|
||||||
combinations = (
|
combinations = (
|
||||||
custom_fixtures
|
get_provider_fixture_overrides_from_test_config(
|
||||||
|
metafunc.config, "memory", DEFAULT_PROVIDER_COMBINATIONS
|
||||||
|
)
|
||||||
or get_provider_fixture_overrides(metafunc.config, available_fixtures)
|
or get_provider_fixture_overrides(metafunc.config, available_fixtures)
|
||||||
or DEFAULT_PROVIDER_COMBINATIONS
|
or DEFAULT_PROVIDER_COMBINATIONS
|
||||||
)
|
)
|
||||||
|
|
|
@ -5,7 +5,7 @@
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from llama_stack.apis.common.type_system import JobStatus
|
from llama_stack.apis.common.job_types import JobStatus
|
||||||
from llama_stack.apis.post_training import (
|
from llama_stack.apis.post_training import (
|
||||||
Checkpoint,
|
Checkpoint,
|
||||||
DataConfig,
|
DataConfig,
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue