From ef1dd433517a499231a65f9eaf0dd121f0a451b7 Mon Sep 17 00:00:00 2001 From: Xi Yan Date: Mon, 30 Dec 2024 15:02:59 -0800 Subject: [PATCH] fix provider lookup --- .../providers/tests/agents/fixtures.py | 29 +++++++++++++------ 1 file changed, 20 insertions(+), 9 deletions(-) diff --git a/llama_stack/providers/tests/agents/fixtures.py b/llama_stack/providers/tests/agents/fixtures.py index 23175e544..6d92bcf6c 100644 --- a/llama_stack/providers/tests/agents/fixtures.py +++ b/llama_stack/providers/tests/agents/fixtures.py @@ -82,17 +82,28 @@ async def agents_stack(request, inference_model, safety_shield): inference_models = ( inference_model if isinstance(inference_model, list) else [inference_model] ) - print(providers) - print(inference_models, safety_shield) - models = [ - ModelInput( - model_id=model, - model_type=ModelType.llm, - provider_id=providers["inference"][i].provider_id, + inference_provider = providers["inference"][0] + provider_id = inference_provider.provider_id + if inference_provider.config and "model" in inference_provider.config: + model_to_provider_id = { + provider.config.model: provider.provider_id + for provider in providers["inference"] + } + + models = [] + for model in inference_models: + if model in model_to_provider_id: + provider_id = model_to_provider_id[model] + + models.append( + ModelInput( + model_id=model, + model_type=ModelType.llm, + provider_id=provider_id, + ) ) - for i, model in enumerate(inference_models) - ] + models.append( ModelInput( model_id="all-MiniLM-L6-v2",