From 9fcf5d58e0aefea19700344424745d45c08e1ddf Mon Sep 17 00:00:00 2001 From: Ashwin Bharambe Date: Thu, 17 Oct 2024 10:03:27 -0700 Subject: [PATCH] Allow overriding MODEL_IDS for inference test --- .../providers/tests/inference/test_inference.py | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/llama_stack/providers/tests/inference/test_inference.py b/llama_stack/providers/tests/inference/test_inference.py index 0afc894cf..581a0d428 100644 --- a/llama_stack/providers/tests/inference/test_inference.py +++ b/llama_stack/providers/tests/inference/test_inference.py @@ -5,6 +5,7 @@ # the root directory of this source tree. import itertools +import os import pytest import pytest_asyncio @@ -50,14 +51,17 @@ def get_expected_stop_reason(model: str): return StopReason.end_of_message if "Llama3.1" in model else StopReason.end_of_turn +if "MODEL_IDS" not in os.environ: + MODEL_IDS = [Llama_8B, Llama_3B] +else: + MODEL_IDS = os.environ["MODEL_IDS"].split(",") + + # This is going to create multiple Stack impls without tearing down the previous one # Fix that! @pytest_asyncio.fixture( scope="session", - params=[ - {"model": Llama_8B}, - {"model": Llama_3B}, - ], + params=[{"model": m} for m in MODEL_IDS], ids=lambda d: d["model"], ) async def inference_settings(request):