mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-08-06 10:42:39 +00:00
address comment
This commit is contained in:
parent
11b2fdd3d8
commit
901c10d444
5 changed files with 80 additions and 94 deletions
|
@ -6,14 +6,15 @@
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from ..conftest import get_provider_fixture_overrides
|
from ..conftest import (
|
||||||
from ..inference.fixtures import INFERENCE_FIXTURES
|
get_provider_fixture_overrides,
|
||||||
from ..memory.fixtures import MEMORY_FIXTURES
|
|
||||||
from ..safety.fixtures import SAFETY_FIXTURES, safety_model_from_shield
|
|
||||||
from ..test_config_helper import (
|
|
||||||
get_provider_fixtures_from_config,
|
get_provider_fixtures_from_config,
|
||||||
try_load_config_file_cached,
|
try_load_config_file_cached,
|
||||||
)
|
)
|
||||||
|
from ..inference.fixtures import INFERENCE_FIXTURES
|
||||||
|
from ..memory.fixtures import MEMORY_FIXTURES
|
||||||
|
from ..safety.fixtures import SAFETY_FIXTURES, safety_model_from_shield
|
||||||
|
|
||||||
from ..tools.fixtures import TOOL_RUNTIME_FIXTURES
|
from ..tools.fixtures import TOOL_RUNTIME_FIXTURES
|
||||||
from .fixtures import AGENTS_FIXTURES
|
from .fixtures import AGENTS_FIXTURES
|
||||||
|
|
||||||
|
@ -86,7 +87,7 @@ def pytest_configure(config):
|
||||||
|
|
||||||
|
|
||||||
def pytest_generate_tests(metafunc):
|
def pytest_generate_tests(metafunc):
|
||||||
test_config = try_load_config_file_cached(metafunc.config.getoption("config"))
|
test_config = try_load_config_file_cached(metafunc.config)
|
||||||
(
|
(
|
||||||
config_override_inference_models,
|
config_override_inference_models,
|
||||||
config_override_safety_shield,
|
config_override_safety_shield,
|
||||||
|
|
|
@ -6,13 +6,15 @@
|
||||||
|
|
||||||
import os
|
import os
|
||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
|
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Any, Dict, List, Optional
|
from typing import Any, Dict, List, Optional
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
import yaml
|
||||||
|
|
||||||
from dotenv import load_dotenv
|
from dotenv import load_dotenv
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel, Field
|
||||||
from termcolor import colored
|
from termcolor import colored
|
||||||
|
|
||||||
from llama_stack.distribution.datatypes import Provider
|
from llama_stack.distribution.datatypes import Provider
|
||||||
|
@ -20,14 +22,74 @@ from llama_stack.providers.datatypes import RemoteProviderConfig
|
||||||
|
|
||||||
from .env import get_env_or_fail
|
from .env import get_env_or_fail
|
||||||
|
|
||||||
from .test_config_helper import try_load_config_file_cached
|
|
||||||
|
|
||||||
|
|
||||||
class ProviderFixture(BaseModel):
|
class ProviderFixture(BaseModel):
|
||||||
providers: List[Provider]
|
providers: List[Provider]
|
||||||
provider_data: Optional[Dict[str, Any]] = None
|
provider_data: Optional[Dict[str, Any]] = None
|
||||||
|
|
||||||
|
|
||||||
|
class Fixtures(BaseModel):
|
||||||
|
# provider fixtures can be either a mark or a dictionary of api -> providers
|
||||||
|
provider_fixtures: List[Dict[str, str]] = Field(default_factory=list)
|
||||||
|
inference_models: List[str] = Field(default_factory=list)
|
||||||
|
safety_shield: Optional[str] = Field(default_factory=None)
|
||||||
|
embedding_model: Optional[str] = Field(default_factory=None)
|
||||||
|
|
||||||
|
|
||||||
|
class APITestConfig(BaseModel):
|
||||||
|
fixtures: Fixtures
|
||||||
|
|
||||||
|
# test name format should be <relative_path.py>::<test_name>
|
||||||
|
tests: List[str] = Field(default_factory=list)
|
||||||
|
|
||||||
|
|
||||||
|
class TestConfig(BaseModel):
|
||||||
|
inference: APITestConfig
|
||||||
|
agent: Optional[APITestConfig] = Field(default=None)
|
||||||
|
memory: Optional[APITestConfig] = Field(default=None)
|
||||||
|
|
||||||
|
|
||||||
|
CONFIG_CACHE = None
|
||||||
|
|
||||||
|
|
||||||
|
def try_load_config_file_cached(config):
|
||||||
|
config_file = config.getoption("--config")
|
||||||
|
if config_file is None:
|
||||||
|
return None
|
||||||
|
if CONFIG_CACHE is not None:
|
||||||
|
return CONFIG_CACHE
|
||||||
|
|
||||||
|
config_file_path = Path(__file__).parent / config_file
|
||||||
|
if not config_file_path.exists():
|
||||||
|
raise ValueError(
|
||||||
|
f"Test config {config_file} was specified but not found. Please make sure it exists in the llama_stack/providers/tests directory."
|
||||||
|
)
|
||||||
|
with open(config_file_path, "r") as config_file:
|
||||||
|
config = yaml.safe_load(config_file)
|
||||||
|
return TestConfig(**config)
|
||||||
|
|
||||||
|
|
||||||
|
def get_provider_fixtures_from_config(
|
||||||
|
provider_fixtures_config, default_fixture_combination
|
||||||
|
):
|
||||||
|
custom_fixtures = []
|
||||||
|
selected_default_param_id = set()
|
||||||
|
for fixture_config in provider_fixtures_config:
|
||||||
|
if "default_fixture_param_id" in fixture_config:
|
||||||
|
selected_default_param_id.add(fixture_config["default_fixture_param_id"])
|
||||||
|
else:
|
||||||
|
custom_fixtures.append(
|
||||||
|
pytest.param(fixture_config, id=fixture_config.get("inference") or "")
|
||||||
|
)
|
||||||
|
|
||||||
|
if len(selected_default_param_id) > 0:
|
||||||
|
for default_fixture in default_fixture_combination:
|
||||||
|
if default_fixture.id in selected_default_param_id:
|
||||||
|
custom_fixtures.append(default_fixture)
|
||||||
|
|
||||||
|
return custom_fixtures
|
||||||
|
|
||||||
|
|
||||||
def remote_stack_fixture() -> ProviderFixture:
|
def remote_stack_fixture() -> ProviderFixture:
|
||||||
if url := os.getenv("REMOTE_STACK_URL", None):
|
if url := os.getenv("REMOTE_STACK_URL", None):
|
||||||
config = RemoteProviderConfig.from_url(url)
|
config = RemoteProviderConfig.from_url(url)
|
||||||
|
@ -182,7 +244,7 @@ def pytest_itemcollected(item):
|
||||||
|
|
||||||
|
|
||||||
def pytest_collection_modifyitems(session, config, items):
|
def pytest_collection_modifyitems(session, config, items):
|
||||||
test_config = try_load_config_file_cached(config.getoption("--config"))
|
test_config = try_load_config_file_cached(config)
|
||||||
if test_config is None:
|
if test_config is None:
|
||||||
return
|
return
|
||||||
|
|
||||||
|
|
|
@ -6,8 +6,7 @@
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from ..conftest import get_provider_fixture_overrides
|
from ..conftest import get_provider_fixture_overrides, try_load_config_file_cached
|
||||||
from ..test_config_helper import try_load_config_file_cached
|
|
||||||
from .fixtures import INFERENCE_FIXTURES
|
from .fixtures import INFERENCE_FIXTURES
|
||||||
|
|
||||||
|
|
||||||
|
@ -43,7 +42,7 @@ VISION_MODEL_PARAMS = [
|
||||||
|
|
||||||
|
|
||||||
def pytest_generate_tests(metafunc):
|
def pytest_generate_tests(metafunc):
|
||||||
test_config = try_load_config_file_cached(metafunc.config.getoption("config"))
|
test_config = try_load_config_file_cached(metafunc.config)
|
||||||
if "inference_model" in metafunc.fixturenames:
|
if "inference_model" in metafunc.fixturenames:
|
||||||
cls_name = metafunc.cls.__name__
|
cls_name = metafunc.cls.__name__
|
||||||
if test_config is not None:
|
if test_config is not None:
|
||||||
|
|
|
@ -6,13 +6,13 @@
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from ..conftest import get_provider_fixture_overrides
|
from ..conftest import (
|
||||||
|
get_provider_fixture_overrides,
|
||||||
from ..inference.fixtures import INFERENCE_FIXTURES
|
|
||||||
from ..test_config_helper import (
|
|
||||||
get_provider_fixtures_from_config,
|
get_provider_fixtures_from_config,
|
||||||
try_load_config_file_cached,
|
try_load_config_file_cached,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
from ..inference.fixtures import INFERENCE_FIXTURES
|
||||||
from .fixtures import MEMORY_FIXTURES
|
from .fixtures import MEMORY_FIXTURES
|
||||||
|
|
||||||
|
|
||||||
|
@ -69,7 +69,7 @@ def pytest_configure(config):
|
||||||
|
|
||||||
|
|
||||||
def pytest_generate_tests(metafunc):
|
def pytest_generate_tests(metafunc):
|
||||||
test_config = try_load_config_file_cached(metafunc.config.getoption("config"))
|
test_config = try_load_config_file_cached(metafunc.config)
|
||||||
provider_fixtures_config = (
|
provider_fixtures_config = (
|
||||||
test_config.memory.fixtures.provider_fixtures
|
test_config.memory.fixtures.provider_fixtures
|
||||||
if test_config is not None and test_config.memory is not None
|
if test_config is not None and test_config.memory is not None
|
||||||
|
|
|
@ -1,76 +0,0 @@
|
||||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
||||||
# All rights reserved.
|
|
||||||
#
|
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
|
||||||
# the root directory of this source tree.
|
|
||||||
|
|
||||||
from dataclasses import dataclass
|
|
||||||
from pathlib import Path
|
|
||||||
from typing import Dict, List, Optional
|
|
||||||
|
|
||||||
import pytest
|
|
||||||
import yaml
|
|
||||||
from pydantic import BaseModel, Field
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class APITestConfig(BaseModel):
|
|
||||||
|
|
||||||
class Fixtures(BaseModel):
|
|
||||||
# provider fixtures can be either a mark or a dictionary of api -> providers
|
|
||||||
provider_fixtures: List[Dict[str, str]] = Field(default_factory=list)
|
|
||||||
inference_models: List[str] = Field(default_factory=list)
|
|
||||||
safety_shield: Optional[str] = Field(default_factory=None)
|
|
||||||
embedding_model: Optional[str] = Field(default_factory=None)
|
|
||||||
|
|
||||||
fixtures: Fixtures
|
|
||||||
tests: List[str] = Field(default_factory=list)
|
|
||||||
|
|
||||||
# test name format should be <relative_path.py>::<test_name>
|
|
||||||
|
|
||||||
|
|
||||||
class TestConfig(BaseModel):
|
|
||||||
|
|
||||||
inference: APITestConfig
|
|
||||||
agent: Optional[APITestConfig] = Field(default=None)
|
|
||||||
memory: Optional[APITestConfig] = Field(default=None)
|
|
||||||
|
|
||||||
|
|
||||||
CONFIG_CACHE = None
|
|
||||||
|
|
||||||
|
|
||||||
def try_load_config_file_cached(config_file):
|
|
||||||
if config_file is None:
|
|
||||||
return None
|
|
||||||
if CONFIG_CACHE is not None:
|
|
||||||
return CONFIG_CACHE
|
|
||||||
|
|
||||||
config_file_path = Path(__file__).parent / config_file
|
|
||||||
if not config_file_path.exists():
|
|
||||||
raise ValueError(
|
|
||||||
f"Test config {config_file} was specified but not found. Please make sure it exists in the llama_stack/providers/tests directory."
|
|
||||||
)
|
|
||||||
with open(config_file_path, "r") as config_file:
|
|
||||||
config = yaml.safe_load(config_file)
|
|
||||||
return TestConfig(**config)
|
|
||||||
|
|
||||||
|
|
||||||
def get_provider_fixtures_from_config(
|
|
||||||
provider_fixtures_config, default_fixture_combination
|
|
||||||
):
|
|
||||||
custom_fixtures = []
|
|
||||||
selected_default_param_id = set()
|
|
||||||
for fixture_config in provider_fixtures_config:
|
|
||||||
if "default_fixture_param_id" in fixture_config:
|
|
||||||
selected_default_param_id.add(fixture_config["default_fixture_param_id"])
|
|
||||||
else:
|
|
||||||
custom_fixtures.append(
|
|
||||||
pytest.param(fixture_config, id=fixture_config.get("inference") or "")
|
|
||||||
)
|
|
||||||
|
|
||||||
if len(selected_default_param_id) > 0:
|
|
||||||
for default_fixture in default_fixture_combination:
|
|
||||||
if default_fixture.id in selected_default_param_id:
|
|
||||||
custom_fixtures.append(default_fixture)
|
|
||||||
|
|
||||||
return custom_fixtures
|
|
Loading…
Add table
Add a link
Reference in a new issue