From 897ee1ffcb90e3ed5f45ed43e6d31f18cd589540 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?S=C3=A9bastien=20Han?= Date: Mon, 10 Feb 2025 17:40:23 +0100 Subject: [PATCH] test: skip model registration for unsupported providers MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - Updated `test_register_with_llama_model` to skip tests when using the Ollama provider, as it does not support custom model names. - Delete `test_initialize_model_during_registering` since there is no "load_model" semantic that is exposed publicly on a provider. These changes ensure that tests do not fail for providers with incompatible behaviors. Signed-off-by: Sébastien Han --- .../inference/test_model_registration.py | 28 ++++++------------- 1 file changed, 9 insertions(+), 19 deletions(-) diff --git a/llama_stack/providers/tests/inference/test_model_registration.py b/llama_stack/providers/tests/inference/test_model_registration.py index 7c41b07ef..4a5c6a259 100644 --- a/llama_stack/providers/tests/inference/test_model_registration.py +++ b/llama_stack/providers/tests/inference/test_model_registration.py @@ -4,8 +4,6 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -from unittest.mock import AsyncMock, patch - import pytest # How to run this test: @@ -15,6 +13,9 @@ import pytest class TestModelRegistration: + def provider_supports_custom_names(self, provider) -> bool: + return "remote::ollama" not in provider.__provider_spec__.provider_type + @pytest.mark.asyncio async def test_register_unsupported_model(self, inference_stack, inference_model): inference_impl, models_impl = inference_stack @@ -47,7 +48,12 @@ class TestModelRegistration: ) @pytest.mark.asyncio - async def test_register_with_llama_model(self, inference_stack): + async def test_register_with_llama_model(self, inference_stack, inference_model): + inference_impl, models_impl = inference_stack + provider = inference_impl.routing_table.get_provider_impl(inference_model) + if not self.provider_supports_custom_names(provider): + pytest.skip("Provider does not support custom model names") + _, models_impl = inference_stack _ = await models_impl.register_model( @@ -67,22 +73,6 @@ class TestModelRegistration: provider_model_id="custom-model", ) - @pytest.mark.asyncio - async def test_initialize_model_during_registering(self, inference_stack): - _, models_impl = inference_stack - - with patch( - "llama_stack.providers.inline.inference.meta_reference.inference.MetaReferenceInferenceImpl.load_model", - new_callable=AsyncMock, - ) as mock_load_model: - _ = await models_impl.register_model( - model_id="Llama3.1-8B-Instruct", - metadata={ - "llama_model": "meta-llama/Llama-3.1-8B-Instruct", - }, - ) - mock_load_model.assert_called_once() - @pytest.mark.asyncio async def test_register_with_invalid_llama_model(self, inference_stack): _, models_impl = inference_stack