diff --git a/llama_stack/providers/tests/agents/fixtures.py b/llama_stack/providers/tests/agents/fixtures.py index 13c250439..9f8e7a12b 100644 --- a/llama_stack/providers/tests/agents/fixtures.py +++ b/llama_stack/providers/tests/agents/fixtures.py @@ -81,14 +81,28 @@ async def agents_stack(request, inference_model, safety_shield): inference_models = ( inference_model if isinstance(inference_model, list) else [inference_model] ) - models = [ - ModelInput( - model_id=model, - model_type=ModelType.llm, - provider_id=providers["inference"][0].provider_id, + + # NOTE: meta-reference provider needs 1 provider per model, lookup provider_id from provider config + model_to_provider_id = {} + for provider in providers["inference"]: + if "model" in provider.config: + model_to_provider_id[provider.config["model"]] = provider.provider_id + + models = [] + for model in inference_models: + if model in model_to_provider_id: + provider_id = model_to_provider_id[model] + else: + provider_id = providers["inference"][0].provider_id + + models.append( + ModelInput( + model_id=model, + model_type=ModelType.llm, + provider_id=provider_id, + ) ) - for model in inference_models - ] + models.append( ModelInput( model_id="all-MiniLM-L6-v2",