test: skip model registration for unsupported providers

- Updated `test_register_with_llama_model` to skip tests when using the
  Ollama provider, as it does not support custom model names.
- Delete `test_initialize_model_during_registering` since there is no
  "load_model" semantic that is exposed publicly on a provider.

These changes ensure that tests do not fail for providers with
incompatible behaviors.

Signed-off-by: Sébastien Han <seb@redhat.com>
This commit is contained in:
Sébastien Han 2025-02-10 17:40:23 +01:00
parent 00613d9014
commit 897ee1ffcb
No known key found for this signature in database

View file

@ -4,8 +4,6 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from unittest.mock import AsyncMock, patch
import pytest
# How to run this test:
@ -15,6 +13,9 @@ import pytest
class TestModelRegistration:
def provider_supports_custom_names(self, provider) -> bool:
return "remote::ollama" not in provider.__provider_spec__.provider_type
@pytest.mark.asyncio
async def test_register_unsupported_model(self, inference_stack, inference_model):
inference_impl, models_impl = inference_stack
@ -47,7 +48,12 @@ class TestModelRegistration:
)
@pytest.mark.asyncio
async def test_register_with_llama_model(self, inference_stack):
async def test_register_with_llama_model(self, inference_stack, inference_model):
inference_impl, models_impl = inference_stack
provider = inference_impl.routing_table.get_provider_impl(inference_model)
if not self.provider_supports_custom_names(provider):
pytest.skip("Provider does not support custom model names")
_, models_impl = inference_stack
_ = await models_impl.register_model(
@ -67,22 +73,6 @@ class TestModelRegistration:
provider_model_id="custom-model",
)
@pytest.mark.asyncio
async def test_initialize_model_during_registering(self, inference_stack):
_, models_impl = inference_stack
with patch(
"llama_stack.providers.inline.inference.meta_reference.inference.MetaReferenceInferenceImpl.load_model",
new_callable=AsyncMock,
) as mock_load_model:
_ = await models_impl.register_model(
model_id="Llama3.1-8B-Instruct",
metadata={
"llama_model": "meta-llama/Llama-3.1-8B-Instruct",
},
)
mock_load_model.assert_called_once()
@pytest.mark.asyncio
async def test_register_with_invalid_llama_model(self, inference_stack):
_, models_impl = inference_stack