llama-stack/llama_stack/providers/tests/inference/test_model_registration.py
Sébastien Han 840344975d
test: rm unused exception alias in pytest.raises (#991)
# What does this PR do?

Refactored tests by removing unused exception alias (as exc_info) in
pytest.raises, improving code clarity and reducing lint warnings.
exc_info was never used.

Signed-off-by: Sébastien Han <seb@redhat.com>

## Test Plan

Please describe:
 - tests you ran to verify your changes with result summaries.
 - provide instructions so it can be reproduced.


## Sources

Please link relevant resources if necessary.


## Before submitting

- [ ] This PR fixes a typo or improves the docs (you can dismiss the
other checks if that's the case).
- [ ] Ran pre-commit to handle lint / formatting issues.
- [ ] Read the [contributor
guideline](https://github.com/meta-llama/llama-stack/blob/main/CONTRIBUTING.md),
      Pull Request section?
- [ ] Updated relevant documentation.
- [ ] Wrote necessary unit or integration tests.

Signed-off-by: Sébastien Han <seb@redhat.com>
2025-02-07 08:04:25 -08:00

95 lines
3.2 KiB
Python

# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from unittest.mock import AsyncMock, patch
import pytest
# How to run this test:
#
# torchrun $CONDA_PREFIX/bin/pytest -v -s -k "meta_reference" --inference-model="Llama3.1-8B-Instruct"
# ./llama_stack/providers/tests/inference/test_model_registration.py
class TestModelRegistration:
@pytest.mark.asyncio
async def test_register_unsupported_model(self, inference_stack, inference_model):
inference_impl, models_impl = inference_stack
provider = inference_impl.routing_table.get_provider_impl(inference_model)
if provider.__provider_spec__.provider_type not in (
"meta-reference",
"remote::ollama",
"remote::vllm",
"remote::tgi",
):
pytest.skip(
"Skipping test for remote inference providers since they can handle large models like 70B instruct"
)
# Try to register a model that's too large for local inference
with pytest.raises(ValueError):
await models_impl.register_model(
model_id="Llama3.1-70B-Instruct",
)
@pytest.mark.asyncio
async def test_register_nonexistent_model(self, inference_stack):
_, models_impl = inference_stack
# Try to register a non-existent model
with pytest.raises(ValueError):
await models_impl.register_model(
model_id="Llama3-NonExistent-Model",
)
@pytest.mark.asyncio
async def test_register_with_llama_model(self, inference_stack):
_, models_impl = inference_stack
_ = await models_impl.register_model(
model_id="custom-model",
metadata={
"llama_model": "meta-llama/Llama-2-7b",
"skip_load": True,
},
)
with pytest.raises(ValueError):
await models_impl.register_model(
model_id="custom-model-2",
metadata={
"llama_model": "meta-llama/Llama-2-7b",
},
provider_model_id="custom-model",
)
@pytest.mark.asyncio
async def test_initialize_model_during_registering(self, inference_stack):
_, models_impl = inference_stack
with patch(
"llama_stack.providers.inline.inference.meta_reference.inference.MetaReferenceInferenceImpl.load_model",
new_callable=AsyncMock,
) as mock_load_model:
_ = await models_impl.register_model(
model_id="Llama3.1-8B-Instruct",
metadata={
"llama_model": "meta-llama/Llama-3.1-8B-Instruct",
},
)
mock_load_model.assert_called_once()
@pytest.mark.asyncio
async def test_register_with_invalid_llama_model(self, inference_stack):
_, models_impl = inference_stack
with pytest.raises(ValueError):
await models_impl.register_model(
model_id="custom-model-2",
metadata={"llama_model": "invalid-llama-model"},
)