diff --git a/llama_stack/core/distribution.py b/llama_stack/core/distribution.py index f44967aaf..0ebb847af 100644 --- a/llama_stack/core/distribution.py +++ b/llama_stack/core/distribution.py @@ -243,6 +243,7 @@ def get_external_providers_from_module( spec = module.get_provider_spec() else: # pass in a partially filled out provider spec to satisfy the registry -- knowing we will be overwriting it later upon build and run + # in the case we are building we CANNOT import this module of course because it has not been installed. spec = ProviderSpec( api=Api(provider_api), provider_type=provider.provider_type, @@ -251,9 +252,20 @@ def get_external_providers_from_module( config_class="", ) provider_type = provider.provider_type - # in the case we are building we CANNOT import this module of course because it has not been installed. - # return a partially filled out spec that the build script will populate. - registry[Api(provider_api)][provider_type] = spec + if isinstance(spec, list): + # optionally allow people to pass inline and remote provider specs as a returned list. + # with the old method, users could pass in directories of specs using overlapping code + # we want to ensure we preserve that flexibility in this method. + logger.info( + f"Detected a list of external provider specs from {provider.module} adding all to the registry" + ) + for provider_spec in spec: + if provider_spec.provider_type != provider.provider_type: + continue + logger.info(f"Adding {provider.provider_type} to registry") + registry[Api(provider_api)][provider.provider_type] = provider_spec + else: + registry[Api(provider_api)][provider_type] = spec except ModuleNotFoundError as exc: raise ValueError( "get_provider_spec not found. If specifying an external provider via `module` in the Provider spec, the Provider must have the `provider.get_provider_spec` module available" diff --git a/tests/unit/distribution/test_distribution.py b/tests/unit/distribution/test_distribution.py index f24de0644..8d75a2312 100644 --- a/tests/unit/distribution/test_distribution.py +++ b/tests/unit/distribution/test_distribution.py @@ -390,3 +390,493 @@ pip_packages: assert provider.is_external is True # config_class is empty string in partial spec assert provider.config_class == "" + + +class TestGetExternalProvidersFromModule: + """Test suite for installing external providers from module.""" + + def test_stackrunconfig_provider_without_module(self, mock_providers): + """Test that providers without module attribute are skipped.""" + from llama_stack.core.datatypes import Provider, StackRunConfig + from llama_stack.core.distribution import get_external_providers_from_module + + import_module_side_effect = make_import_module_side_effect() + + with patch("importlib.import_module", side_effect=import_module_side_effect): + config = StackRunConfig( + image_name="test_image", + providers={ + "inference": [ + Provider( + provider_id="no_module", + provider_type="no_module", + config={}, + ) + ] + }, + ) + registry = {Api.inference: {}} + result = get_external_providers_from_module(registry, config, building=False) + # Should not add anything to registry + assert len(result[Api.inference]) == 0 + + def test_stackrunconfig_provider_with_none_module(self, mock_providers): + """Test that providers with None module value are skipped.""" + from llama_stack.core.datatypes import Provider, StackRunConfig + from llama_stack.core.distribution import get_external_providers_from_module + + import_module_side_effect = make_import_module_side_effect() + + with patch("importlib.import_module", side_effect=import_module_side_effect): + config = StackRunConfig( + image_name="test_image", + providers={ + "inference": [ + Provider( + provider_id="none_module", + provider_type="none_module", + config={}, + module=None, + ) + ] + }, + ) + registry = {Api.inference: {}} + result = get_external_providers_from_module(registry, config, building=False) + # Should not add anything to registry + assert len(result[Api.inference]) == 0 + + def test_stackrunconfig_with_version_spec(self, mock_providers): + """Test provider with module containing version spec (e.g., package==1.0.0).""" + from types import SimpleNamespace + + from llama_stack.core.datatypes import Provider, StackRunConfig + from llama_stack.core.distribution import get_external_providers_from_module + from llama_stack.providers.datatypes import ProviderSpec + + fake_spec = ProviderSpec( + api=Api.inference, + provider_type="versioned_test", + config_class="versioned_test.config.VersionedTestConfig", + module="versioned_test==1.0.0", + ) + fake_module = SimpleNamespace(get_provider_spec=lambda: fake_spec) + + def import_side_effect(name): + if name == "versioned_test.provider": + return fake_module + raise ModuleNotFoundError(name) + + with patch("importlib.import_module", side_effect=import_side_effect): + config = StackRunConfig( + image_name="test_image", + providers={ + "inference": [ + Provider( + provider_id="versioned", + provider_type="versioned_test", + config={}, + module="versioned_test==1.0.0", + ) + ] + }, + ) + registry = {Api.inference: {}} + result = get_external_providers_from_module(registry, config, building=False) + assert "versioned_test" in result[Api.inference] + assert result[Api.inference]["versioned_test"].module == "versioned_test==1.0.0" + + def test_buildconfig_does_not_import_module(self, mock_providers): + """Test that BuildConfig does not import the module (building=True).""" + from llama_stack.core.datatypes import BuildConfig, BuildProvider, DistributionSpec + from llama_stack.core.distribution import get_external_providers_from_module + + build_config = BuildConfig( + version=2, + image_type="container", + image_name="test_image", + distribution_spec=DistributionSpec( + description="test", + providers={ + "inference": [ + BuildProvider( + provider_type="build_test", + module="build_test==1.0.0", + ) + ] + }, + ), + ) + + # Should not call import_module at all when building + with patch("importlib.import_module") as mock_import: + registry = {Api.inference: {}} + result = get_external_providers_from_module(registry, build_config, building=True) + + # Verify module was NOT imported + mock_import.assert_not_called() + + # Verify partial spec was created + assert "build_test" in result[Api.inference] + provider = result[Api.inference]["build_test"] + assert provider.module == "build_test==1.0.0" + assert provider.is_external is True + assert provider.config_class == "" + assert provider.api == Api.inference + + def test_buildconfig_multiple_providers(self, mock_providers): + """Test BuildConfig with multiple providers for the same API.""" + from llama_stack.core.datatypes import BuildConfig, BuildProvider, DistributionSpec + from llama_stack.core.distribution import get_external_providers_from_module + + build_config = BuildConfig( + version=2, + image_type="container", + image_name="test_image", + distribution_spec=DistributionSpec( + description="test", + providers={ + "inference": [ + BuildProvider(provider_type="provider1", module="provider1"), + BuildProvider(provider_type="provider2", module="provider2"), + ] + }, + ), + ) + + with patch("importlib.import_module") as mock_import: + registry = {Api.inference: {}} + result = get_external_providers_from_module(registry, build_config, building=True) + + mock_import.assert_not_called() + assert "provider1" in result[Api.inference] + assert "provider2" in result[Api.inference] + + def test_distributionspec_does_not_import_module(self, mock_providers): + """Test that DistributionSpec does not import the module (building=True).""" + from llama_stack.core.datatypes import BuildProvider, DistributionSpec + from llama_stack.core.distribution import get_external_providers_from_module + + dist_spec = DistributionSpec( + description="test distribution", + providers={ + "inference": [ + BuildProvider( + provider_type="dist_test", + module="dist_test==2.0.0", + ) + ] + }, + ) + + # Should not call import_module at all when building + with patch("importlib.import_module") as mock_import: + registry = {Api.inference: {}} + result = get_external_providers_from_module(registry, dist_spec, building=True) + + # Verify module was NOT imported + mock_import.assert_not_called() + + # Verify partial spec was created + assert "dist_test" in result[Api.inference] + provider = result[Api.inference]["dist_test"] + assert provider.module == "dist_test==2.0.0" + assert provider.is_external is True + assert provider.config_class == "" + + def test_list_return_from_get_provider_spec(self, mock_providers): + """Test when get_provider_spec returns a list of specs.""" + from types import SimpleNamespace + + from llama_stack.core.datatypes import Provider, StackRunConfig + from llama_stack.core.distribution import get_external_providers_from_module + from llama_stack.providers.datatypes import ProviderSpec + + spec1 = ProviderSpec( + api=Api.inference, + provider_type="list_test", + config_class="list_test.config.Config1", + module="list_test", + ) + spec2 = ProviderSpec( + api=Api.inference, + provider_type="list_test_remote", + config_class="list_test.config.Config2", + module="list_test", + ) + + fake_module = SimpleNamespace(get_provider_spec=lambda: [spec1, spec2]) + + def import_side_effect(name): + if name == "list_test.provider": + return fake_module + raise ModuleNotFoundError(name) + + with patch("importlib.import_module", side_effect=import_side_effect): + config = StackRunConfig( + image_name="test_image", + providers={ + "inference": [ + Provider( + provider_id="list_test", + provider_type="list_test", + config={}, + module="list_test", + ) + ] + }, + ) + registry = {Api.inference: {}} + result = get_external_providers_from_module(registry, config, building=False) + + # Only the matching provider_type should be added + assert "list_test" in result[Api.inference] + assert result[Api.inference]["list_test"].config_class == "list_test.config.Config1" + + def test_list_return_filters_by_provider_type(self, mock_providers): + """Test that list return filters specs by provider_type.""" + from types import SimpleNamespace + + from llama_stack.core.datatypes import Provider, StackRunConfig + from llama_stack.core.distribution import get_external_providers_from_module + from llama_stack.providers.datatypes import ProviderSpec + + spec1 = ProviderSpec( + api=Api.inference, + provider_type="wanted", + config_class="test.Config1", + module="test", + ) + spec2 = ProviderSpec( + api=Api.inference, + provider_type="unwanted", + config_class="test.Config2", + module="test", + ) + + fake_module = SimpleNamespace(get_provider_spec=lambda: [spec1, spec2]) + + def import_side_effect(name): + if name == "test.provider": + return fake_module + raise ModuleNotFoundError(name) + + with patch("importlib.import_module", side_effect=import_side_effect): + config = StackRunConfig( + image_name="test_image", + providers={ + "inference": [ + Provider( + provider_id="wanted", + provider_type="wanted", + config={}, + module="test", + ) + ] + }, + ) + registry = {Api.inference: {}} + result = get_external_providers_from_module(registry, config, building=False) + + # Only the matching provider_type should be added + assert "wanted" in result[Api.inference] + assert "unwanted" not in result[Api.inference] + + def test_list_return_adds_multiple_provider_types(self, mock_providers): + """Test that list return adds multiple different provider_types when config requests them.""" + from types import SimpleNamespace + + from llama_stack.core.datatypes import Provider, StackRunConfig + from llama_stack.core.distribution import get_external_providers_from_module + from llama_stack.providers.datatypes import ProviderSpec + + # Module returns both inline and remote variants + spec1 = ProviderSpec( + api=Api.inference, + provider_type="remote::ollama", + config_class="test.RemoteConfig", + module="test", + ) + spec2 = ProviderSpec( + api=Api.inference, + provider_type="inline::ollama", + config_class="test.InlineConfig", + module="test", + ) + + fake_module = SimpleNamespace(get_provider_spec=lambda: [spec1, spec2]) + + def import_side_effect(name): + if name == "test.provider": + return fake_module + raise ModuleNotFoundError(name) + + with patch("importlib.import_module", side_effect=import_side_effect): + config = StackRunConfig( + image_name="test_image", + providers={ + "inference": [ + Provider( + provider_id="remote_ollama", + provider_type="remote::ollama", + config={}, + module="test", + ), + Provider( + provider_id="inline_ollama", + provider_type="inline::ollama", + config={}, + module="test", + ), + ] + }, + ) + registry = {Api.inference: {}} + result = get_external_providers_from_module(registry, config, building=False) + + # Both provider types should be added to registry + assert "remote::ollama" in result[Api.inference] + assert "inline::ollama" in result[Api.inference] + assert result[Api.inference]["remote::ollama"].config_class == "test.RemoteConfig" + assert result[Api.inference]["inline::ollama"].config_class == "test.InlineConfig" + + def test_module_not_found_raises_value_error(self, mock_providers): + """Test that ModuleNotFoundError raises ValueError with helpful message.""" + from llama_stack.core.datatypes import Provider, StackRunConfig + from llama_stack.core.distribution import get_external_providers_from_module + + def import_side_effect(name): + if name == "missing_module.provider": + raise ModuleNotFoundError(name) + raise ModuleNotFoundError(name) + + with patch("importlib.import_module", side_effect=import_side_effect): + config = StackRunConfig( + image_name="test_image", + providers={ + "inference": [ + Provider( + provider_id="missing", + provider_type="missing", + config={}, + module="missing_module", + ) + ] + }, + ) + registry = {Api.inference: {}} + + with pytest.raises(ValueError) as exc_info: + get_external_providers_from_module(registry, config, building=False) + + assert "get_provider_spec not found" in str(exc_info.value) + + def test_generic_exception_is_raised(self, mock_providers): + """Test that generic exceptions are properly raised.""" + from types import SimpleNamespace + + from llama_stack.core.datatypes import Provider, StackRunConfig + from llama_stack.core.distribution import get_external_providers_from_module + + def bad_spec(): + raise RuntimeError("Something went wrong") + + fake_module = SimpleNamespace(get_provider_spec=bad_spec) + + def import_side_effect(name): + if name == "error_module.provider": + return fake_module + raise ModuleNotFoundError(name) + + with patch("importlib.import_module", side_effect=import_side_effect): + config = StackRunConfig( + image_name="test_image", + providers={ + "inference": [ + Provider( + provider_id="error", + provider_type="error", + config={}, + module="error_module", + ) + ] + }, + ) + registry = {Api.inference: {}} + + with pytest.raises(RuntimeError) as exc_info: + get_external_providers_from_module(registry, config, building=False) + + assert "Something went wrong" in str(exc_info.value) + + def test_empty_provider_list(self, mock_providers): + """Test with empty provider list.""" + from llama_stack.core.datatypes import StackRunConfig + from llama_stack.core.distribution import get_external_providers_from_module + + config = StackRunConfig( + image_name="test_image", + providers={}, + ) + registry = {Api.inference: {}} + result = get_external_providers_from_module(registry, config, building=False) + + # Should return registry unchanged + assert result == registry + assert len(result[Api.inference]) == 0 + + def test_multiple_apis_with_providers(self, mock_providers): + """Test multiple APIs with providers.""" + from types import SimpleNamespace + + from llama_stack.core.datatypes import Provider, StackRunConfig + from llama_stack.core.distribution import get_external_providers_from_module + from llama_stack.providers.datatypes import ProviderSpec + + inference_spec = ProviderSpec( + api=Api.inference, + provider_type="inf_test", + config_class="inf.Config", + module="inf_test", + ) + safety_spec = ProviderSpec( + api=Api.safety, + provider_type="safe_test", + config_class="safe.Config", + module="safe_test", + ) + + def import_side_effect(name): + if name == "inf_test.provider": + return SimpleNamespace(get_provider_spec=lambda: inference_spec) + elif name == "safe_test.provider": + return SimpleNamespace(get_provider_spec=lambda: safety_spec) + raise ModuleNotFoundError(name) + + with patch("importlib.import_module", side_effect=import_side_effect): + config = StackRunConfig( + image_name="test_image", + providers={ + "inference": [ + Provider( + provider_id="inf", + provider_type="inf_test", + config={}, + module="inf_test", + ) + ], + "safety": [ + Provider( + provider_id="safe", + provider_type="safe_test", + config={}, + module="safe_test", + ) + ], + }, + ) + registry = {Api.inference: {}, Api.safety: {}} + result = get_external_providers_from_module(registry, config, building=False) + + assert "inf_test" in result[Api.inference] + assert "safe_test" in result[Api.safety]