diff --git a/tests/unit/distribution/test_distribution.py b/tests/unit/distribution/test_distribution.py index ae24602d7..781a7b6a9 100644 --- a/tests/unit/distribution/test_distribution.py +++ b/tests/unit/distribution/test_distribution.py @@ -221,3 +221,182 @@ pip_packages: with pytest.raises(KeyError) as exc_info: get_provider_registry(base_config) assert "config_class" in str(exc_info.value) + + def test_external_provider_from_module_success(self, mock_providers): + """Test loading an external provider from a module (success path).""" + from types import SimpleNamespace + + from llama_stack.distribution.datatypes import Provider, StackRunConfig + from llama_stack.providers.datatypes import Api, ProviderSpec + + # Simulate a provider module with get_provider_spec + fake_spec = ProviderSpec( + api=Api.inference, + provider_type="external_test", + config_class="external_test.config.ExternalTestConfig", + module="external_test", + ) + fake_module = SimpleNamespace(get_provider_spec=lambda: fake_spec) + + def import_module_side_effect(name): + if name == "llama_stack.providers.registry.inference": + # Return a mock with available_providers for built-in providers + mock_builtin = SimpleNamespace( + available_providers=lambda: [ + ProviderSpec( + api=Api.inference, + provider_type="test_provider", + config_class="test_provider.config.TestProviderConfig", + module="test_provider", + ) + ] + ) + return mock_builtin + elif name == "external_test.provider": + return fake_module + else: + raise ModuleNotFoundError(name) + + with patch("importlib.import_module", side_effect=import_module_side_effect) as mock_import: + config = StackRunConfig( + image_name="test_image", + providers={ + "inference": [ + Provider( + provider_id="external_test", + provider_type="external_test", + config={}, + module="external_test", + ) + ] + }, + ) + registry = get_provider_registry(config) + assert Api.inference in registry + assert "external_test" in registry[Api.inference] + provider = registry[Api.inference]["external_test"] + assert provider.module == "external_test" + assert provider.config_class == "external_test.config.ExternalTestConfig" + mock_import.assert_any_call("llama_stack.providers.registry.inference") + mock_import.assert_any_call("external_test.provider") + + def test_external_provider_from_module_not_found(self, mock_providers): + """Test handling ModuleNotFoundError for missing provider module.""" + from types import SimpleNamespace + + from llama_stack.distribution.datatypes import Provider, StackRunConfig + from llama_stack.providers.datatypes import Api + + def import_module_side_effect(name): + if name == "llama_stack.providers.registry.inference": + mock_builtin = SimpleNamespace( + available_providers=lambda: [ + ProviderSpec( + api=Api.inference, + provider_type="test_provider", + config_class="test_provider.config.TestProviderConfig", + module="test_provider", + ) + ] + ) + return mock_builtin + elif name == "external_test.provider": + raise ModuleNotFoundError(name) + else: + raise ModuleNotFoundError(name) + + with patch("importlib.import_module", side_effect=import_module_side_effect): + config = StackRunConfig( + image_name="test_image", + providers={ + "inference": [ + Provider( + provider_id="external_test", + provider_type="external_test", + config={}, + module="external_test", + ) + ] + }, + ) + with pytest.raises(ValueError) as exc_info: + get_provider_registry(config) + assert "get_provider_spec not found" in str(exc_info.value) + + def test_external_provider_from_module_missing_get_provider_spec(self, mock_providers): + """Test handling missing get_provider_spec in provider module (should raise ValueError).""" + from types import SimpleNamespace + + from llama_stack.distribution.datatypes import Provider, StackRunConfig + from llama_stack.providers.datatypes import Api + + # Simulate a provider module without get_provider_spec + fake_module = SimpleNamespace() + + def import_module_side_effect(name): + if name == "llama_stack.providers.registry.inference": + mock_builtin = SimpleNamespace( + available_providers=lambda: [ + ProviderSpec( + api=Api.inference, + provider_type="test_provider", + config_class="test_provider.config.TestProviderConfig", + module="test_provider", + ) + ] + ) + return mock_builtin + elif name == "external_test.provider": + return fake_module + else: + raise ModuleNotFoundError(name) + + with patch("importlib.import_module", side_effect=import_module_side_effect): + config = StackRunConfig( + image_name="test_image", + providers={ + "inference": [ + Provider( + provider_id="external_test", + provider_type="external_test", + config={}, + module="external_test", + ) + ] + }, + ) + with pytest.raises(AttributeError): + get_provider_registry(config) + + def test_external_provider_from_module_building(self, mock_providers): + """Test loading an external provider from a module during build (building=True, partial spec).""" + from llama_stack.distribution.datatypes import BuildConfig, DistributionSpec, Provider + from llama_stack.providers.datatypes import Api + + # No importlib patch needed, should not import module when building=True + build_config = BuildConfig( + version=2, + image_type="container", + image_name="test_image", + distribution_spec=DistributionSpec( + description="test", + providers={ + "inference": [ + Provider( + provider_id="external_test", + provider_type="external_test", + config={}, + module="external_test", + ) + ] + }, + ), + ) + registry = get_provider_registry(build_config) + assert Api.inference in registry + assert "external_test" in registry[Api.inference] + provider = registry[Api.inference]["external_test"] + assert provider.module == "external_test" + assert provider.is_external is True + # config_class is empty string in partial spec + assert provider.config_class == ""