diff --git a/tests/unit/distribution/test_distribution.py b/tests/unit/distribution/test_distribution.py index ae24602d7..5aac113eb 100644 --- a/tests/unit/distribution/test_distribution.py +++ b/tests/unit/distribution/test_distribution.py @@ -106,6 +106,40 @@ def api_directories(tmp_path): return remote_inference_dir, inline_inference_dir +def make_import_module_side_effect( + builtin_provider_spec=None, + external_module=None, + raise_for_external=False, + missing_get_provider_spec=False, +): + from types import SimpleNamespace + + def import_module_side_effect(name): + if name == "llama_stack.providers.registry.inference": + mock_builtin = SimpleNamespace( + available_providers=lambda: [ + builtin_provider_spec + or ProviderSpec( + api=Api.inference, + provider_type="test_provider", + config_class="test_provider.config.TestProviderConfig", + module="test_provider", + ) + ] + ) + return mock_builtin + elif name == "external_test.provider": + if raise_for_external: + raise ModuleNotFoundError(name) + if missing_get_provider_spec: + return SimpleNamespace() + return external_module + else: + raise ModuleNotFoundError(name) + + return import_module_side_effect + + class TestProviderRegistry: """Test suite for provider registry functionality.""" @@ -221,3 +255,124 @@ pip_packages: with pytest.raises(KeyError) as exc_info: get_provider_registry(base_config) assert "config_class" in str(exc_info.value) + + def test_external_provider_from_module_success(self, mock_providers): + """Test loading an external provider from a module (success path).""" + from types import SimpleNamespace + + from llama_stack.distribution.datatypes import Provider, StackRunConfig + from llama_stack.providers.datatypes import Api, ProviderSpec + + # Simulate a provider module with get_provider_spec + fake_spec = ProviderSpec( + api=Api.inference, + provider_type="external_test", + config_class="external_test.config.ExternalTestConfig", + module="external_test", + ) + fake_module = SimpleNamespace(get_provider_spec=lambda: fake_spec) + + import_module_side_effect = make_import_module_side_effect(external_module=fake_module) + + with patch("importlib.import_module", side_effect=import_module_side_effect) as mock_import: + config = StackRunConfig( + image_name="test_image", + providers={ + "inference": [ + Provider( + provider_id="external_test", + provider_type="external_test", + config={}, + module="external_test", + ) + ] + }, + ) + registry = get_provider_registry(config) + assert Api.inference in registry + assert "external_test" in registry[Api.inference] + provider = registry[Api.inference]["external_test"] + assert provider.module == "external_test" + assert provider.config_class == "external_test.config.ExternalTestConfig" + mock_import.assert_any_call("llama_stack.providers.registry.inference") + mock_import.assert_any_call("external_test.provider") + + def test_external_provider_from_module_not_found(self, mock_providers): + """Test handling ModuleNotFoundError for missing provider module.""" + from llama_stack.distribution.datatypes import Provider, StackRunConfig + + import_module_side_effect = make_import_module_side_effect(raise_for_external=True) + + with patch("importlib.import_module", side_effect=import_module_side_effect): + config = StackRunConfig( + image_name="test_image", + providers={ + "inference": [ + Provider( + provider_id="external_test", + provider_type="external_test", + config={}, + module="external_test", + ) + ] + }, + ) + with pytest.raises(ValueError) as exc_info: + get_provider_registry(config) + assert "get_provider_spec not found" in str(exc_info.value) + + def test_external_provider_from_module_missing_get_provider_spec(self, mock_providers): + """Test handling missing get_provider_spec in provider module (should raise ValueError).""" + from llama_stack.distribution.datatypes import Provider, StackRunConfig + + import_module_side_effect = make_import_module_side_effect(missing_get_provider_spec=True) + + with patch("importlib.import_module", side_effect=import_module_side_effect): + config = StackRunConfig( + image_name="test_image", + providers={ + "inference": [ + Provider( + provider_id="external_test", + provider_type="external_test", + config={}, + module="external_test", + ) + ] + }, + ) + with pytest.raises(AttributeError): + get_provider_registry(config) + + def test_external_provider_from_module_building(self, mock_providers): + """Test loading an external provider from a module during build (building=True, partial spec).""" + from llama_stack.distribution.datatypes import BuildConfig, DistributionSpec, Provider + from llama_stack.providers.datatypes import Api + + # No importlib patch needed, should not import module when type of `config` is BuildConfig or DistributionSpec + build_config = BuildConfig( + version=2, + image_type="container", + image_name="test_image", + distribution_spec=DistributionSpec( + description="test", + providers={ + "inference": [ + Provider( + provider_id="external_test", + provider_type="external_test", + config={}, + module="external_test", + ) + ] + }, + ), + ) + registry = get_provider_registry(build_config) + assert Api.inference in registry + assert "external_test" in registry[Api.inference] + provider = registry[Api.inference]["external_test"] + assert provider.module == "external_test" + assert provider.is_external is True + # config_class is empty string in partial spec + assert provider.config_class == ""