mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-06-28 19:04:19 +00:00
# Context For test automation, the end goal is to run a single pytest command from root test directory (llama_stack/providers/tests/.) such that we execute push-blocking tests The work plan: 1) trigger pytest from llama_stack/providers/tests/. 2) use config file to determine what tests and parametrization we want to run # What does this PR do? 1) consolidates the "inference-models" / "embedding-model" / "judge-model" ... options in root conftest.py. Without this change, we will hit into error when trying to run `pytest /Users/sxyi/llama-stack/llama_stack/providers/tests/.` because of duplicated `addoptions` definitions across child conftest files. 2) Add a `config` option to specify test config in YAML. (see [`ci_test_config.yaml`](https://gist.github.com/sixianyi0721/5b37fbce4069139445c2f06f6e42f87e) for example config file) For provider_fixtures, we allow users to use either a default fixture combination or define their own {api:provider} combinations. ``` memory: .... fixtures: provider_fixtures: - default_fixture_param_id: ollama // use default fixture combination with param_id="ollama" in [providers/tests/memory/conftest.py](https://fburl.com/mtjzwsmk) - inference: sentence_transformers memory: faiss - default_fixture_param_id: chroma ``` 3) generate tests according to the config. Logic lives in two places: a) in `{api}/conftest.py::pytest_generate_tests`, we read from config to do parametrization. b) after test collection, in `pytest_collection_modifyitems`, we filter the tests to include only functions listed in config. ## Test Plan 1) `pytest /Users/sxyi/llama-stack/llama_stack/providers/tests/. --collect-only --config=ci_test_config.yaml` Using `--collect-only` tag to print the pytests listed in the config file (`ci_test_config.yaml`). output: [gist](https://gist.github.com/sixianyi0721/05145e60d4d085c17cfb304beeb1e60e) 2) sanity check on `--inference-model` option ``` pytest -v -s -k "ollama" --inference-model="meta-llama/Llama-3.1-8B-Instruct" ./llama_stack/providers/tests/inference/test_text_inference.py ``` ## Before submitting - [ ] This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case). - [x] Ran pre-commit to handle lint / formatting issues. - [x] Read the [contributor guideline](https://github.com/meta-llama/llama-stack/blob/main/CONTRIBUTING.md), Pull Request section? - [ ] Updated relevant documentation. - [ ] Wrote necessary unit or integration tests.
316 lines
9.3 KiB
Python
316 lines
9.3 KiB
Python
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
# All rights reserved.
|
|
#
|
|
# This source code is licensed under the terms described in the LICENSE file in
|
|
# the root directory of this source tree.
|
|
|
|
import os
|
|
|
|
import pytest
|
|
import pytest_asyncio
|
|
|
|
from llama_stack.apis.models import ModelInput, ModelType
|
|
from llama_stack.distribution.datatypes import Api, Provider
|
|
|
|
from llama_stack.providers.inline.inference.meta_reference import (
|
|
MetaReferenceInferenceConfig,
|
|
)
|
|
from llama_stack.providers.inline.inference.vllm import VLLMConfig
|
|
from llama_stack.providers.remote.inference.bedrock import BedrockConfig
|
|
|
|
from llama_stack.providers.remote.inference.cerebras import CerebrasImplConfig
|
|
from llama_stack.providers.remote.inference.fireworks import FireworksImplConfig
|
|
from llama_stack.providers.remote.inference.groq import GroqConfig
|
|
from llama_stack.providers.remote.inference.nvidia import NVIDIAConfig
|
|
from llama_stack.providers.remote.inference.ollama import OllamaImplConfig
|
|
from llama_stack.providers.remote.inference.tgi import TGIImplConfig
|
|
from llama_stack.providers.remote.inference.together import TogetherImplConfig
|
|
from llama_stack.providers.remote.inference.vllm import VLLMInferenceAdapterConfig
|
|
from llama_stack.providers.tests.resolver import construct_stack_for_test
|
|
|
|
from ..conftest import ProviderFixture, remote_stack_fixture
|
|
from ..env import get_env_or_fail
|
|
|
|
|
|
@pytest.fixture(scope="session")
|
|
def inference_model(request):
|
|
if hasattr(request, "param"):
|
|
return request.param
|
|
return request.config.getoption("--inference-model", None)
|
|
|
|
|
|
@pytest.fixture(scope="session")
|
|
def inference_remote() -> ProviderFixture:
|
|
return remote_stack_fixture()
|
|
|
|
|
|
@pytest.fixture(scope="session")
|
|
def inference_meta_reference(inference_model) -> ProviderFixture:
|
|
inference_model = (
|
|
[inference_model] if isinstance(inference_model, str) else inference_model
|
|
)
|
|
# If embedding dimension is set, use the 8B model for testing
|
|
if os.getenv("EMBEDDING_DIMENSION"):
|
|
inference_model = ["meta-llama/Llama-3.1-8B-Instruct"]
|
|
|
|
return ProviderFixture(
|
|
providers=[
|
|
Provider(
|
|
provider_id=f"meta-reference-{i}",
|
|
provider_type="inline::meta-reference",
|
|
config=MetaReferenceInferenceConfig(
|
|
model=m,
|
|
max_seq_len=4096,
|
|
create_distributed_process_group=False,
|
|
checkpoint_dir=os.getenv("MODEL_CHECKPOINT_DIR", None),
|
|
).model_dump(),
|
|
)
|
|
for i, m in enumerate(inference_model)
|
|
]
|
|
)
|
|
|
|
|
|
@pytest.fixture(scope="session")
|
|
def inference_cerebras() -> ProviderFixture:
|
|
return ProviderFixture(
|
|
providers=[
|
|
Provider(
|
|
provider_id="cerebras",
|
|
provider_type="remote::cerebras",
|
|
config=CerebrasImplConfig(
|
|
api_key=get_env_or_fail("CEREBRAS_API_KEY"),
|
|
).model_dump(),
|
|
)
|
|
],
|
|
)
|
|
|
|
|
|
@pytest.fixture(scope="session")
|
|
def inference_ollama(inference_model) -> ProviderFixture:
|
|
inference_model = (
|
|
[inference_model] if isinstance(inference_model, str) else inference_model
|
|
)
|
|
if inference_model and "Llama3.1-8B-Instruct" in inference_model:
|
|
pytest.skip("Ollama only supports Llama3.2-3B-Instruct for testing")
|
|
|
|
return ProviderFixture(
|
|
providers=[
|
|
Provider(
|
|
provider_id="ollama",
|
|
provider_type="remote::ollama",
|
|
config=OllamaImplConfig(
|
|
host="localhost", port=os.getenv("OLLAMA_PORT", 11434)
|
|
).model_dump(),
|
|
)
|
|
],
|
|
)
|
|
|
|
|
|
@pytest_asyncio.fixture(scope="session")
|
|
def inference_vllm(inference_model) -> ProviderFixture:
|
|
inference_model = (
|
|
[inference_model] if isinstance(inference_model, str) else inference_model
|
|
)
|
|
return ProviderFixture(
|
|
providers=[
|
|
Provider(
|
|
provider_id=f"vllm-{i}",
|
|
provider_type="inline::vllm",
|
|
config=VLLMConfig(
|
|
model=m,
|
|
enforce_eager=True, # Make test run faster
|
|
).model_dump(),
|
|
)
|
|
for i, m in enumerate(inference_model)
|
|
]
|
|
)
|
|
|
|
|
|
@pytest.fixture(scope="session")
|
|
def inference_vllm_remote() -> ProviderFixture:
|
|
return ProviderFixture(
|
|
providers=[
|
|
Provider(
|
|
provider_id="remote::vllm",
|
|
provider_type="remote::vllm",
|
|
config=VLLMInferenceAdapterConfig(
|
|
url=get_env_or_fail("VLLM_URL"),
|
|
max_tokens=int(os.getenv("VLLM_MAX_TOKENS", 2048)),
|
|
).model_dump(),
|
|
)
|
|
],
|
|
)
|
|
|
|
|
|
@pytest.fixture(scope="session")
|
|
def inference_fireworks() -> ProviderFixture:
|
|
return ProviderFixture(
|
|
providers=[
|
|
Provider(
|
|
provider_id="fireworks",
|
|
provider_type="remote::fireworks",
|
|
config=FireworksImplConfig(
|
|
api_key=get_env_or_fail("FIREWORKS_API_KEY"),
|
|
).model_dump(),
|
|
)
|
|
],
|
|
)
|
|
|
|
|
|
@pytest.fixture(scope="session")
|
|
def inference_together() -> ProviderFixture:
|
|
return ProviderFixture(
|
|
providers=[
|
|
Provider(
|
|
provider_id="together",
|
|
provider_type="remote::together",
|
|
config=TogetherImplConfig().model_dump(),
|
|
)
|
|
],
|
|
provider_data=dict(
|
|
together_api_key=get_env_or_fail("TOGETHER_API_KEY"),
|
|
),
|
|
)
|
|
|
|
|
|
@pytest.fixture(scope="session")
|
|
def inference_groq() -> ProviderFixture:
|
|
return ProviderFixture(
|
|
providers=[
|
|
Provider(
|
|
provider_id="groq",
|
|
provider_type="remote::groq",
|
|
config=GroqConfig().model_dump(),
|
|
)
|
|
],
|
|
provider_data=dict(
|
|
groq_api_key=get_env_or_fail("GROQ_API_KEY"),
|
|
),
|
|
)
|
|
|
|
|
|
@pytest.fixture(scope="session")
|
|
def inference_bedrock() -> ProviderFixture:
|
|
return ProviderFixture(
|
|
providers=[
|
|
Provider(
|
|
provider_id="bedrock",
|
|
provider_type="remote::bedrock",
|
|
config=BedrockConfig().model_dump(),
|
|
)
|
|
],
|
|
)
|
|
|
|
|
|
@pytest.fixture(scope="session")
|
|
def inference_nvidia() -> ProviderFixture:
|
|
return ProviderFixture(
|
|
providers=[
|
|
Provider(
|
|
provider_id="nvidia",
|
|
provider_type="remote::nvidia",
|
|
config=NVIDIAConfig().model_dump(),
|
|
)
|
|
],
|
|
)
|
|
|
|
|
|
@pytest.fixture(scope="session")
|
|
def inference_tgi() -> ProviderFixture:
|
|
return ProviderFixture(
|
|
providers=[
|
|
Provider(
|
|
provider_id="tgi",
|
|
provider_type="remote::tgi",
|
|
config=TGIImplConfig(
|
|
url=get_env_or_fail("TGI_URL"),
|
|
api_token=os.getenv("TGI_API_TOKEN", None),
|
|
).model_dump(),
|
|
)
|
|
],
|
|
)
|
|
|
|
|
|
@pytest.fixture(scope="session")
|
|
def inference_sentence_transformers() -> ProviderFixture:
|
|
return ProviderFixture(
|
|
providers=[
|
|
Provider(
|
|
provider_id="sentence_transformers",
|
|
provider_type="inline::sentence-transformers",
|
|
config={},
|
|
)
|
|
]
|
|
)
|
|
|
|
|
|
def get_model_short_name(model_name: str) -> str:
|
|
"""Convert model name to a short test identifier.
|
|
|
|
Args:
|
|
model_name: Full model name like "Llama3.1-8B-Instruct"
|
|
|
|
Returns:
|
|
Short name like "llama_8b" suitable for test markers
|
|
"""
|
|
model_name = model_name.lower()
|
|
if "vision" in model_name:
|
|
return "llama_vision"
|
|
elif "3b" in model_name:
|
|
return "llama_3b"
|
|
elif "8b" in model_name:
|
|
return "llama_8b"
|
|
else:
|
|
return model_name.replace(".", "_").replace("-", "_")
|
|
|
|
|
|
@pytest.fixture(scope="session")
|
|
def model_id(inference_model) -> str:
|
|
return get_model_short_name(inference_model)
|
|
|
|
|
|
INFERENCE_FIXTURES = [
|
|
"meta_reference",
|
|
"ollama",
|
|
"fireworks",
|
|
"together",
|
|
"vllm",
|
|
"groq",
|
|
"vllm_remote",
|
|
"remote",
|
|
"bedrock",
|
|
"cerebras",
|
|
"nvidia",
|
|
"tgi",
|
|
]
|
|
|
|
|
|
@pytest_asyncio.fixture(scope="session")
|
|
async def inference_stack(request, inference_model):
|
|
fixture_name = request.param
|
|
inference_fixture = request.getfixturevalue(f"inference_{fixture_name}")
|
|
model_type = ModelType.llm
|
|
metadata = {}
|
|
if os.getenv("EMBEDDING_DIMENSION"):
|
|
model_type = ModelType.embedding
|
|
metadata["embedding_dimension"] = get_env_or_fail("EMBEDDING_DIMENSION")
|
|
|
|
test_stack = await construct_stack_for_test(
|
|
[Api.inference],
|
|
{"inference": inference_fixture.providers},
|
|
inference_fixture.provider_data,
|
|
models=[
|
|
ModelInput(
|
|
provider_id=inference_fixture.providers[0].provider_id,
|
|
model_id=inference_model,
|
|
model_type=model_type,
|
|
metadata=metadata,
|
|
)
|
|
],
|
|
)
|
|
|
|
# Pytest yield fixture; see https://docs.pytest.org/en/stable/how-to/fixtures.html#yield-fixtures-recommended
|
|
yield test_stack.impls[Api.inference], test_stack.impls[Api.models]
|
|
|
|
# Cleanup code that runs after test case completion
|
|
await test_stack.impls[Api.inference].shutdown()
|