diff --git a/llama_stack/providers/tests/agents/fixtures.py b/llama_stack/providers/tests/agents/fixtures.py index a90e5afc1..da075f47f 100644 --- a/llama_stack/providers/tests/agents/fixtures.py +++ b/llama_stack/providers/tests/agents/fixtures.py @@ -82,18 +82,17 @@ async def agents_stack(request, inference_model, safety_shield): inference_model if isinstance(inference_model, list) else [inference_model] ) - inference_provider = providers["inference"][0] - provider_id = inference_provider.provider_id - if inference_provider.config and "model" in inference_provider.config: - model_to_provider_id = { - provider.config["model"]: provider.provider_id - for provider in providers["inference"] - } + model_to_provider_id = {} + for provider in providers["inference"]: + if provider.config and "model" in provider.config: + model_to_provider_id[provider.config["model"]] = provider.provider_id models = [] for model in inference_models: if model in model_to_provider_id: provider_id = model_to_provider_id[model] + else: + provider_id = providers["inference"][0].provider_id models.append( ModelInput(