diff --git a/llama_stack/providers/tests/inference/fixtures.py b/llama_stack/providers/tests/inference/fixtures.py index e0d36124f..ed0b0302d 100644 --- a/llama_stack/providers/tests/inference/fixtures.py +++ b/llama_stack/providers/tests/inference/fixtures.py @@ -47,6 +47,9 @@ def inference_meta_reference(inference_model) -> ProviderFixture: inference_model = ( [inference_model] if isinstance(inference_model, str) else inference_model ) + # If embedding dimension is set, use the 8B model for testing + if os.getenv("EMBEDDING_DIMENSION"): + inference_model = ["meta-llama/Llama-3.1-8B-Instruct"] return ProviderFixture( providers=[