diff --git a/llama_stack/providers/tests/inference/test_inference.py b/llama_stack/providers/tests/inference/test_inference.py index 0afc894cf..581a0d428 100644 --- a/llama_stack/providers/tests/inference/test_inference.py +++ b/llama_stack/providers/tests/inference/test_inference.py @@ -5,6 +5,7 @@ # the root directory of this source tree. import itertools +import os import pytest import pytest_asyncio @@ -50,14 +51,17 @@ def get_expected_stop_reason(model: str): return StopReason.end_of_message if "Llama3.1" in model else StopReason.end_of_turn +if "MODEL_IDS" not in os.environ: + MODEL_IDS = [Llama_8B, Llama_3B] +else: + MODEL_IDS = os.environ["MODEL_IDS"].split(",") + + # This is going to create multiple Stack impls without tearing down the previous one # Fix that! @pytest_asyncio.fixture( scope="session", - params=[ - {"model": Llama_8B}, - {"model": Llama_3B}, - ], + params=[{"model": m} for m in MODEL_IDS], ids=lambda d: d["model"], ) async def inference_settings(request):