diff --git a/llama_stack/providers/tests/inference/test_model_registration.py b/llama_stack/providers/tests/inference/test_model_registration.py index 7c41b07ef..4a5c6a259 100644 --- a/llama_stack/providers/tests/inference/test_model_registration.py +++ b/llama_stack/providers/tests/inference/test_model_registration.py @@ -4,8 +4,6 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -from unittest.mock import AsyncMock, patch - import pytest # How to run this test: @@ -15,6 +13,9 @@ import pytest class TestModelRegistration: + def provider_supports_custom_names(self, provider) -> bool: + return "remote::ollama" not in provider.__provider_spec__.provider_type + @pytest.mark.asyncio async def test_register_unsupported_model(self, inference_stack, inference_model): inference_impl, models_impl = inference_stack @@ -47,7 +48,12 @@ class TestModelRegistration: ) @pytest.mark.asyncio - async def test_register_with_llama_model(self, inference_stack): + async def test_register_with_llama_model(self, inference_stack, inference_model): + inference_impl, models_impl = inference_stack + provider = inference_impl.routing_table.get_provider_impl(inference_model) + if not self.provider_supports_custom_names(provider): + pytest.skip("Provider does not support custom model names") + _, models_impl = inference_stack _ = await models_impl.register_model( @@ -67,22 +73,6 @@ class TestModelRegistration: provider_model_id="custom-model", ) - @pytest.mark.asyncio - async def test_initialize_model_during_registering(self, inference_stack): - _, models_impl = inference_stack - - with patch( - "llama_stack.providers.inline.inference.meta_reference.inference.MetaReferenceInferenceImpl.load_model", - new_callable=AsyncMock, - ) as mock_load_model: - _ = await models_impl.register_model( - model_id="Llama3.1-8B-Instruct", - metadata={ - "llama_model": "meta-llama/Llama-3.1-8B-Instruct", - }, - ) - mock_load_model.assert_called_once() - @pytest.mark.asyncio async def test_register_with_invalid_llama_model(self, inference_stack): _, models_impl = inference_stack