diff --git a/llama_stack/providers/tests/scoring/fixtures.py b/llama_stack/providers/tests/scoring/fixtures.py index 14095b526..ee6999043 100644 --- a/llama_stack/providers/tests/scoring/fixtures.py +++ b/llama_stack/providers/tests/scoring/fixtures.py @@ -7,6 +7,8 @@ import pytest import pytest_asyncio +from llama_stack.apis.models import ModelInput + from llama_stack.distribution.datatypes import Api, Provider from llama_stack.providers.tests.resolver import resolve_impls_for_test_v2 @@ -76,20 +78,14 @@ async def scoring_stack(request, inference_model): [Api.scoring, Api.datasetio, Api.inference], providers, provider_data, - ) - - provider_id = providers["inference"][0].provider_id - await impls[Api.models].register_model( - model_id=inference_model, - provider_id=provider_id, - ) - await impls[Api.models].register_model( - model_id="Llama3.1-405B-Instruct", - provider_id=provider_id, - ) - await impls[Api.models].register_model( - model_id="Llama3.1-8B-Instruct", - provider_id=provider_id, + models=[ + ModelInput(model_id=model) + for model in [ + inference_model, + "Llama3.1-405B-Instruct", + "Llama3.1-8B-Instruct", + ] + ], ) return impls