forked from phoenix-oss/llama-stack-mirror
# What does this PR do? - Updated `test_register_with_llama_model` to skip tests when using the Ollama provider, as it does not support custom model names. - Delete `test_initialize_model_during_registering` since there is no "load_model" semantic that is exposed publicly on a provider. These changes ensure that tests do not fail for providers with incompatible behaviors. Signed-off-by: Sébastien Han <seb@redhat.com> [//]: # (If resolving an issue, uncomment and update the line below) [//]: # (Closes #[issue-number]) ## Test Plan Run Ollama: ``` uv run pytest -v -s -k "ollama" llama_stack/providers/tests/inference/test_model_registration.py /Users/leseb/Documents/AI/llama-stack/.venv/lib/python3.13/site-packages/pytest_asyncio/plugin.py:207: PytestDeprecationWarning: The configuration option "asyncio_default_fixture_loop_scope" is unset. The event loop scope for asynchronous fixtures will default to the fixture caching scope. Future versions of pytest-asyncio will default the loop scope for asynchronous fixtures to function scope. Set the default fixture loop scope explicitly in order to avoid unexpected behavior in the future. Valid fixture loop scopes are: "function", "class", "module", "package", "session" warnings.warn(PytestDeprecationWarning(_DEFAULT_FIXTURE_LOOP_SCOPE_UNSET)) ========================================== test session starts ========================================== platform darwin -- Python 3.13.1, pytest-8.3.4, pluggy-1.5.0 -- /Users/leseb/Documents/AI/llama-stack/.venv/bin/python3 cachedir: .pytest_cache metadata: {'Python': '3.13.1', 'Platform': 'macOS-15.3-arm64-arm-64bit-Mach-O', 'Packages': {'pytest': '8.3.4', 'pluggy': '1.5.0'}, 'Plugins': {'html': '4.1.1', 'metadata': '3.1.1', 'asyncio': '0.25.3', 'anyio': '4.8.0', 'nbval': '0.11.0'}} rootdir: /Users/leseb/Documents/AI/llama-stack configfile: pyproject.toml plugins: html-4.1.1, metadata-3.1.1, asyncio-0.25.3, anyio-4.8.0, nbval-0.11.0 asyncio: mode=Mode.STRICT, asyncio_default_fixture_loop_scope=None collected 65 items / 60 deselected / 5 selected llama_stack/providers/tests/inference/test_model_registration.py::TestModelRegistration::test_register_unsupported_model[-ollama] PASSED llama_stack/providers/tests/inference/test_model_registration.py::TestModelRegistration::test_register_nonexistent_model[-ollama] PASSED llama_stack/providers/tests/inference/test_model_registration.py::TestModelRegistration::test_register_with_llama_model[-ollama] SKIPPED llama_stack/providers/tests/inference/test_model_registration.py::TestModelRegistration::test_register_with_invalid_llama_model[-ollama] PASSED ======================== 3 passed, 1 skipped, 60 deselected, 2 warnings in 0.22s ======================== ``` [//]: # (## Documentation) [//]: # (- [ ] Added a Changelog entry if the change is significant) Signed-off-by: Sébastien Han <seb@redhat.com>
84 lines
3 KiB
Python
84 lines
3 KiB
Python
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
# All rights reserved.
|
|
#
|
|
# This source code is licensed under the terms described in the LICENSE file in
|
|
# the root directory of this source tree.
|
|
|
|
import pytest
|
|
|
|
# How to run this test:
|
|
#
|
|
# torchrun $CONDA_PREFIX/bin/pytest -v -s -k "meta_reference" --inference-model="Llama3.1-8B-Instruct"
|
|
# ./llama_stack/providers/tests/inference/test_model_registration.py
|
|
|
|
|
|
class TestModelRegistration:
|
|
def provider_supports_custom_names(self, provider) -> bool:
|
|
return "remote::ollama" not in provider.__provider_spec__.provider_type
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_register_unsupported_model(self, inference_stack, inference_model):
|
|
inference_impl, models_impl = inference_stack
|
|
|
|
provider = inference_impl.routing_table.get_provider_impl(inference_model)
|
|
if provider.__provider_spec__.provider_type not in (
|
|
"meta-reference",
|
|
"remote::ollama",
|
|
"remote::vllm",
|
|
"remote::tgi",
|
|
):
|
|
pytest.skip(
|
|
"Skipping test for remote inference providers since they can handle large models like 70B instruct"
|
|
)
|
|
|
|
# Try to register a model that's too large for local inference
|
|
with pytest.raises(ValueError):
|
|
await models_impl.register_model(
|
|
model_id="Llama3.1-70B-Instruct",
|
|
)
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_register_nonexistent_model(self, inference_stack):
|
|
_, models_impl = inference_stack
|
|
|
|
# Try to register a non-existent model
|
|
with pytest.raises(ValueError):
|
|
await models_impl.register_model(
|
|
model_id="Llama3-NonExistent-Model",
|
|
)
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_register_with_llama_model(self, inference_stack, inference_model):
|
|
inference_impl, models_impl = inference_stack
|
|
provider = inference_impl.routing_table.get_provider_impl(inference_model)
|
|
if not self.provider_supports_custom_names(provider):
|
|
pytest.skip("Provider does not support custom model names")
|
|
|
|
_, models_impl = inference_stack
|
|
|
|
_ = await models_impl.register_model(
|
|
model_id="custom-model",
|
|
metadata={
|
|
"llama_model": "meta-llama/Llama-2-7b",
|
|
"skip_load": True,
|
|
},
|
|
)
|
|
|
|
with pytest.raises(ValueError):
|
|
await models_impl.register_model(
|
|
model_id="custom-model-2",
|
|
metadata={
|
|
"llama_model": "meta-llama/Llama-2-7b",
|
|
},
|
|
provider_model_id="custom-model",
|
|
)
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_register_with_invalid_llama_model(self, inference_stack):
|
|
_, models_impl = inference_stack
|
|
|
|
with pytest.raises(ValueError):
|
|
await models_impl.register_model(
|
|
model_id="custom-model-2",
|
|
metadata={"llama_model": "invalid-llama-model"},
|
|
)
|