mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-08-11 04:28:02 +00:00
test: skip model registration for unsupported providers
- Updated `test_register_with_llama_model` to skip tests when using the Ollama provider, as it does not support custom model names. - Delete `test_initialize_model_during_registering` since there is no "load_model" semantic that is exposed publicly on a provider. These changes ensure that tests do not fail for providers with incompatible behaviors. Signed-off-by: Sébastien Han <seb@redhat.com>
This commit is contained in:
parent
00613d9014
commit
897ee1ffcb
1 changed files with 9 additions and 19 deletions
|
@ -4,8 +4,6 @@
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
from unittest.mock import AsyncMock, patch
|
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
# How to run this test:
|
# How to run this test:
|
||||||
|
@ -15,6 +13,9 @@ import pytest
|
||||||
|
|
||||||
|
|
||||||
class TestModelRegistration:
|
class TestModelRegistration:
|
||||||
|
def provider_supports_custom_names(self, provider) -> bool:
|
||||||
|
return "remote::ollama" not in provider.__provider_spec__.provider_type
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_register_unsupported_model(self, inference_stack, inference_model):
|
async def test_register_unsupported_model(self, inference_stack, inference_model):
|
||||||
inference_impl, models_impl = inference_stack
|
inference_impl, models_impl = inference_stack
|
||||||
|
@ -47,7 +48,12 @@ class TestModelRegistration:
|
||||||
)
|
)
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_register_with_llama_model(self, inference_stack):
|
async def test_register_with_llama_model(self, inference_stack, inference_model):
|
||||||
|
inference_impl, models_impl = inference_stack
|
||||||
|
provider = inference_impl.routing_table.get_provider_impl(inference_model)
|
||||||
|
if not self.provider_supports_custom_names(provider):
|
||||||
|
pytest.skip("Provider does not support custom model names")
|
||||||
|
|
||||||
_, models_impl = inference_stack
|
_, models_impl = inference_stack
|
||||||
|
|
||||||
_ = await models_impl.register_model(
|
_ = await models_impl.register_model(
|
||||||
|
@ -67,22 +73,6 @@ class TestModelRegistration:
|
||||||
provider_model_id="custom-model",
|
provider_model_id="custom-model",
|
||||||
)
|
)
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_initialize_model_during_registering(self, inference_stack):
|
|
||||||
_, models_impl = inference_stack
|
|
||||||
|
|
||||||
with patch(
|
|
||||||
"llama_stack.providers.inline.inference.meta_reference.inference.MetaReferenceInferenceImpl.load_model",
|
|
||||||
new_callable=AsyncMock,
|
|
||||||
) as mock_load_model:
|
|
||||||
_ = await models_impl.register_model(
|
|
||||||
model_id="Llama3.1-8B-Instruct",
|
|
||||||
metadata={
|
|
||||||
"llama_model": "meta-llama/Llama-3.1-8B-Instruct",
|
|
||||||
},
|
|
||||||
)
|
|
||||||
mock_load_model.assert_called_once()
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_register_with_invalid_llama_model(self, inference_stack):
|
async def test_register_with_invalid_llama_model(self, inference_stack):
|
||||||
_, models_impl = inference_stack
|
_, models_impl = inference_stack
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue