From 4694780d23e7b873ba2d519a371d6b3d44f437b5 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?S=C3=A9bastien=20Han?= Date: Thu, 20 Feb 2025 07:39:13 +0100 Subject: [PATCH] test: skip model registration for unsupported providers (#1030) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit # What does this PR do? - Updated `test_register_with_llama_model` to skip tests when using the Ollama provider, as it does not support custom model names. - Delete `test_initialize_model_during_registering` since there is no "load_model" semantic that is exposed publicly on a provider. These changes ensure that tests do not fail for providers with incompatible behaviors. Signed-off-by: Sébastien Han [//]: # (If resolving an issue, uncomment and update the line below) [//]: # (Closes #[issue-number]) ## Test Plan Run Ollama: ``` uv run pytest -v -s -k "ollama" llama_stack/providers/tests/inference/test_model_registration.py /Users/leseb/Documents/AI/llama-stack/.venv/lib/python3.13/site-packages/pytest_asyncio/plugin.py:207: PytestDeprecationWarning: The configuration option "asyncio_default_fixture_loop_scope" is unset. The event loop scope for asynchronous fixtures will default to the fixture caching scope. Future versions of pytest-asyncio will default the loop scope for asynchronous fixtures to function scope. Set the default fixture loop scope explicitly in order to avoid unexpected behavior in the future. Valid fixture loop scopes are: "function", "class", "module", "package", "session" warnings.warn(PytestDeprecationWarning(_DEFAULT_FIXTURE_LOOP_SCOPE_UNSET)) ========================================== test session starts ========================================== platform darwin -- Python 3.13.1, pytest-8.3.4, pluggy-1.5.0 -- /Users/leseb/Documents/AI/llama-stack/.venv/bin/python3 cachedir: .pytest_cache metadata: {'Python': '3.13.1', 'Platform': 'macOS-15.3-arm64-arm-64bit-Mach-O', 'Packages': {'pytest': '8.3.4', 'pluggy': '1.5.0'}, 'Plugins': {'html': '4.1.1', 'metadata': '3.1.1', 'asyncio': '0.25.3', 'anyio': '4.8.0', 'nbval': '0.11.0'}} rootdir: /Users/leseb/Documents/AI/llama-stack configfile: pyproject.toml plugins: html-4.1.1, metadata-3.1.1, asyncio-0.25.3, anyio-4.8.0, nbval-0.11.0 asyncio: mode=Mode.STRICT, asyncio_default_fixture_loop_scope=None collected 65 items / 60 deselected / 5 selected llama_stack/providers/tests/inference/test_model_registration.py::TestModelRegistration::test_register_unsupported_model[-ollama] PASSED llama_stack/providers/tests/inference/test_model_registration.py::TestModelRegistration::test_register_nonexistent_model[-ollama] PASSED llama_stack/providers/tests/inference/test_model_registration.py::TestModelRegistration::test_register_with_llama_model[-ollama] SKIPPED llama_stack/providers/tests/inference/test_model_registration.py::TestModelRegistration::test_register_with_invalid_llama_model[-ollama] PASSED ======================== 3 passed, 1 skipped, 60 deselected, 2 warnings in 0.22s ======================== ``` [//]: # (## Documentation) [//]: # (- [ ] Added a Changelog entry if the change is significant) Signed-off-by: Sébastien Han --- .../inference/test_model_registration.py | 28 ++++++------------- 1 file changed, 9 insertions(+), 19 deletions(-) diff --git a/llama_stack/providers/tests/inference/test_model_registration.py b/llama_stack/providers/tests/inference/test_model_registration.py index 7c41b07ef..4a5c6a259 100644 --- a/llama_stack/providers/tests/inference/test_model_registration.py +++ b/llama_stack/providers/tests/inference/test_model_registration.py @@ -4,8 +4,6 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -from unittest.mock import AsyncMock, patch - import pytest # How to run this test: @@ -15,6 +13,9 @@ import pytest class TestModelRegistration: + def provider_supports_custom_names(self, provider) -> bool: + return "remote::ollama" not in provider.__provider_spec__.provider_type + @pytest.mark.asyncio async def test_register_unsupported_model(self, inference_stack, inference_model): inference_impl, models_impl = inference_stack @@ -47,7 +48,12 @@ class TestModelRegistration: ) @pytest.mark.asyncio - async def test_register_with_llama_model(self, inference_stack): + async def test_register_with_llama_model(self, inference_stack, inference_model): + inference_impl, models_impl = inference_stack + provider = inference_impl.routing_table.get_provider_impl(inference_model) + if not self.provider_supports_custom_names(provider): + pytest.skip("Provider does not support custom model names") + _, models_impl = inference_stack _ = await models_impl.register_model( @@ -67,22 +73,6 @@ class TestModelRegistration: provider_model_id="custom-model", ) - @pytest.mark.asyncio - async def test_initialize_model_during_registering(self, inference_stack): - _, models_impl = inference_stack - - with patch( - "llama_stack.providers.inline.inference.meta_reference.inference.MetaReferenceInferenceImpl.load_model", - new_callable=AsyncMock, - ) as mock_load_model: - _ = await models_impl.register_model( - model_id="Llama3.1-8B-Instruct", - metadata={ - "llama_model": "meta-llama/Llama-3.1-8B-Instruct", - }, - ) - mock_load_model.assert_called_once() - @pytest.mark.asyncio async def test_register_with_invalid_llama_model(self, inference_stack): _, models_impl = inference_stack