This commit is contained in:
Charlie Doern 2025-10-03 15:18:59 -04:00 committed by GitHub
commit 5c644927df
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
2 changed files with 505 additions and 3 deletions

View file

@ -243,6 +243,7 @@ def get_external_providers_from_module(
spec = module.get_provider_spec()
else:
# pass in a partially filled out provider spec to satisfy the registry -- knowing we will be overwriting it later upon build and run
# in the case we are building we CANNOT import this module of course because it has not been installed.
spec = ProviderSpec(
api=Api(provider_api),
provider_type=provider.provider_type,
@ -251,8 +252,19 @@ def get_external_providers_from_module(
config_class="",
)
provider_type = provider.provider_type
# in the case we are building we CANNOT import this module of course because it has not been installed.
# return a partially filled out spec that the build script will populate.
if isinstance(spec, list):
# optionally allow people to pass inline and remote provider specs as a returned list.
# with the old method, users could pass in directories of specs using overlapping code
# we want to ensure we preserve that flexibility in this method.
logger.info(
f"Detected a list of external provider specs from {provider.module} adding all to the registry"
)
for provider_spec in spec:
if provider_spec.provider_type != provider.provider_type:
continue
logger.info(f"Adding {provider.provider_type} to registry")
registry[Api(provider_api)][provider.provider_type] = provider_spec
else:
registry[Api(provider_api)][provider_type] = spec
except ModuleNotFoundError as exc:
raise ValueError(

View file

@ -390,3 +390,493 @@ pip_packages:
assert provider.is_external is True
# config_class is empty string in partial spec
assert provider.config_class == ""
class TestGetExternalProvidersFromModule:
"""Test suite for installing external providers from module."""
def test_stackrunconfig_provider_without_module(self, mock_providers):
"""Test that providers without module attribute are skipped."""
from llama_stack.core.datatypes import Provider, StackRunConfig
from llama_stack.core.distribution import get_external_providers_from_module
import_module_side_effect = make_import_module_side_effect()
with patch("importlib.import_module", side_effect=import_module_side_effect):
config = StackRunConfig(
image_name="test_image",
providers={
"inference": [
Provider(
provider_id="no_module",
provider_type="no_module",
config={},
)
]
},
)
registry = {Api.inference: {}}
result = get_external_providers_from_module(registry, config, building=False)
# Should not add anything to registry
assert len(result[Api.inference]) == 0
def test_stackrunconfig_provider_with_none_module(self, mock_providers):
"""Test that providers with None module value are skipped."""
from llama_stack.core.datatypes import Provider, StackRunConfig
from llama_stack.core.distribution import get_external_providers_from_module
import_module_side_effect = make_import_module_side_effect()
with patch("importlib.import_module", side_effect=import_module_side_effect):
config = StackRunConfig(
image_name="test_image",
providers={
"inference": [
Provider(
provider_id="none_module",
provider_type="none_module",
config={},
module=None,
)
]
},
)
registry = {Api.inference: {}}
result = get_external_providers_from_module(registry, config, building=False)
# Should not add anything to registry
assert len(result[Api.inference]) == 0
def test_stackrunconfig_with_version_spec(self, mock_providers):
"""Test provider with module containing version spec (e.g., package==1.0.0)."""
from types import SimpleNamespace
from llama_stack.core.datatypes import Provider, StackRunConfig
from llama_stack.core.distribution import get_external_providers_from_module
from llama_stack.providers.datatypes import ProviderSpec
fake_spec = ProviderSpec(
api=Api.inference,
provider_type="versioned_test",
config_class="versioned_test.config.VersionedTestConfig",
module="versioned_test==1.0.0",
)
fake_module = SimpleNamespace(get_provider_spec=lambda: fake_spec)
def import_side_effect(name):
if name == "versioned_test.provider":
return fake_module
raise ModuleNotFoundError(name)
with patch("importlib.import_module", side_effect=import_side_effect):
config = StackRunConfig(
image_name="test_image",
providers={
"inference": [
Provider(
provider_id="versioned",
provider_type="versioned_test",
config={},
module="versioned_test==1.0.0",
)
]
},
)
registry = {Api.inference: {}}
result = get_external_providers_from_module(registry, config, building=False)
assert "versioned_test" in result[Api.inference]
assert result[Api.inference]["versioned_test"].module == "versioned_test==1.0.0"
def test_buildconfig_does_not_import_module(self, mock_providers):
"""Test that BuildConfig does not import the module (building=True)."""
from llama_stack.core.datatypes import BuildConfig, BuildProvider, DistributionSpec
from llama_stack.core.distribution import get_external_providers_from_module
build_config = BuildConfig(
version=2,
image_type="container",
image_name="test_image",
distribution_spec=DistributionSpec(
description="test",
providers={
"inference": [
BuildProvider(
provider_type="build_test",
module="build_test==1.0.0",
)
]
},
),
)
# Should not call import_module at all when building
with patch("importlib.import_module") as mock_import:
registry = {Api.inference: {}}
result = get_external_providers_from_module(registry, build_config, building=True)
# Verify module was NOT imported
mock_import.assert_not_called()
# Verify partial spec was created
assert "build_test" in result[Api.inference]
provider = result[Api.inference]["build_test"]
assert provider.module == "build_test==1.0.0"
assert provider.is_external is True
assert provider.config_class == ""
assert provider.api == Api.inference
def test_buildconfig_multiple_providers(self, mock_providers):
"""Test BuildConfig with multiple providers for the same API."""
from llama_stack.core.datatypes import BuildConfig, BuildProvider, DistributionSpec
from llama_stack.core.distribution import get_external_providers_from_module
build_config = BuildConfig(
version=2,
image_type="container",
image_name="test_image",
distribution_spec=DistributionSpec(
description="test",
providers={
"inference": [
BuildProvider(provider_type="provider1", module="provider1"),
BuildProvider(provider_type="provider2", module="provider2"),
]
},
),
)
with patch("importlib.import_module") as mock_import:
registry = {Api.inference: {}}
result = get_external_providers_from_module(registry, build_config, building=True)
mock_import.assert_not_called()
assert "provider1" in result[Api.inference]
assert "provider2" in result[Api.inference]
def test_distributionspec_does_not_import_module(self, mock_providers):
"""Test that DistributionSpec does not import the module (building=True)."""
from llama_stack.core.datatypes import BuildProvider, DistributionSpec
from llama_stack.core.distribution import get_external_providers_from_module
dist_spec = DistributionSpec(
description="test distribution",
providers={
"inference": [
BuildProvider(
provider_type="dist_test",
module="dist_test==2.0.0",
)
]
},
)
# Should not call import_module at all when building
with patch("importlib.import_module") as mock_import:
registry = {Api.inference: {}}
result = get_external_providers_from_module(registry, dist_spec, building=True)
# Verify module was NOT imported
mock_import.assert_not_called()
# Verify partial spec was created
assert "dist_test" in result[Api.inference]
provider = result[Api.inference]["dist_test"]
assert provider.module == "dist_test==2.0.0"
assert provider.is_external is True
assert provider.config_class == ""
def test_list_return_from_get_provider_spec(self, mock_providers):
"""Test when get_provider_spec returns a list of specs."""
from types import SimpleNamespace
from llama_stack.core.datatypes import Provider, StackRunConfig
from llama_stack.core.distribution import get_external_providers_from_module
from llama_stack.providers.datatypes import ProviderSpec
spec1 = ProviderSpec(
api=Api.inference,
provider_type="list_test",
config_class="list_test.config.Config1",
module="list_test",
)
spec2 = ProviderSpec(
api=Api.inference,
provider_type="list_test_remote",
config_class="list_test.config.Config2",
module="list_test",
)
fake_module = SimpleNamespace(get_provider_spec=lambda: [spec1, spec2])
def import_side_effect(name):
if name == "list_test.provider":
return fake_module
raise ModuleNotFoundError(name)
with patch("importlib.import_module", side_effect=import_side_effect):
config = StackRunConfig(
image_name="test_image",
providers={
"inference": [
Provider(
provider_id="list_test",
provider_type="list_test",
config={},
module="list_test",
)
]
},
)
registry = {Api.inference: {}}
result = get_external_providers_from_module(registry, config, building=False)
# Only the matching provider_type should be added
assert "list_test" in result[Api.inference]
assert result[Api.inference]["list_test"].config_class == "list_test.config.Config1"
def test_list_return_filters_by_provider_type(self, mock_providers):
"""Test that list return filters specs by provider_type."""
from types import SimpleNamespace
from llama_stack.core.datatypes import Provider, StackRunConfig
from llama_stack.core.distribution import get_external_providers_from_module
from llama_stack.providers.datatypes import ProviderSpec
spec1 = ProviderSpec(
api=Api.inference,
provider_type="wanted",
config_class="test.Config1",
module="test",
)
spec2 = ProviderSpec(
api=Api.inference,
provider_type="unwanted",
config_class="test.Config2",
module="test",
)
fake_module = SimpleNamespace(get_provider_spec=lambda: [spec1, spec2])
def import_side_effect(name):
if name == "test.provider":
return fake_module
raise ModuleNotFoundError(name)
with patch("importlib.import_module", side_effect=import_side_effect):
config = StackRunConfig(
image_name="test_image",
providers={
"inference": [
Provider(
provider_id="wanted",
provider_type="wanted",
config={},
module="test",
)
]
},
)
registry = {Api.inference: {}}
result = get_external_providers_from_module(registry, config, building=False)
# Only the matching provider_type should be added
assert "wanted" in result[Api.inference]
assert "unwanted" not in result[Api.inference]
def test_list_return_adds_multiple_provider_types(self, mock_providers):
"""Test that list return adds multiple different provider_types when config requests them."""
from types import SimpleNamespace
from llama_stack.core.datatypes import Provider, StackRunConfig
from llama_stack.core.distribution import get_external_providers_from_module
from llama_stack.providers.datatypes import ProviderSpec
# Module returns both inline and remote variants
spec1 = ProviderSpec(
api=Api.inference,
provider_type="remote::ollama",
config_class="test.RemoteConfig",
module="test",
)
spec2 = ProviderSpec(
api=Api.inference,
provider_type="inline::ollama",
config_class="test.InlineConfig",
module="test",
)
fake_module = SimpleNamespace(get_provider_spec=lambda: [spec1, spec2])
def import_side_effect(name):
if name == "test.provider":
return fake_module
raise ModuleNotFoundError(name)
with patch("importlib.import_module", side_effect=import_side_effect):
config = StackRunConfig(
image_name="test_image",
providers={
"inference": [
Provider(
provider_id="remote_ollama",
provider_type="remote::ollama",
config={},
module="test",
),
Provider(
provider_id="inline_ollama",
provider_type="inline::ollama",
config={},
module="test",
),
]
},
)
registry = {Api.inference: {}}
result = get_external_providers_from_module(registry, config, building=False)
# Both provider types should be added to registry
assert "remote::ollama" in result[Api.inference]
assert "inline::ollama" in result[Api.inference]
assert result[Api.inference]["remote::ollama"].config_class == "test.RemoteConfig"
assert result[Api.inference]["inline::ollama"].config_class == "test.InlineConfig"
def test_module_not_found_raises_value_error(self, mock_providers):
"""Test that ModuleNotFoundError raises ValueError with helpful message."""
from llama_stack.core.datatypes import Provider, StackRunConfig
from llama_stack.core.distribution import get_external_providers_from_module
def import_side_effect(name):
if name == "missing_module.provider":
raise ModuleNotFoundError(name)
raise ModuleNotFoundError(name)
with patch("importlib.import_module", side_effect=import_side_effect):
config = StackRunConfig(
image_name="test_image",
providers={
"inference": [
Provider(
provider_id="missing",
provider_type="missing",
config={},
module="missing_module",
)
]
},
)
registry = {Api.inference: {}}
with pytest.raises(ValueError) as exc_info:
get_external_providers_from_module(registry, config, building=False)
assert "get_provider_spec not found" in str(exc_info.value)
def test_generic_exception_is_raised(self, mock_providers):
"""Test that generic exceptions are properly raised."""
from types import SimpleNamespace
from llama_stack.core.datatypes import Provider, StackRunConfig
from llama_stack.core.distribution import get_external_providers_from_module
def bad_spec():
raise RuntimeError("Something went wrong")
fake_module = SimpleNamespace(get_provider_spec=bad_spec)
def import_side_effect(name):
if name == "error_module.provider":
return fake_module
raise ModuleNotFoundError(name)
with patch("importlib.import_module", side_effect=import_side_effect):
config = StackRunConfig(
image_name="test_image",
providers={
"inference": [
Provider(
provider_id="error",
provider_type="error",
config={},
module="error_module",
)
]
},
)
registry = {Api.inference: {}}
with pytest.raises(RuntimeError) as exc_info:
get_external_providers_from_module(registry, config, building=False)
assert "Something went wrong" in str(exc_info.value)
def test_empty_provider_list(self, mock_providers):
"""Test with empty provider list."""
from llama_stack.core.datatypes import StackRunConfig
from llama_stack.core.distribution import get_external_providers_from_module
config = StackRunConfig(
image_name="test_image",
providers={},
)
registry = {Api.inference: {}}
result = get_external_providers_from_module(registry, config, building=False)
# Should return registry unchanged
assert result == registry
assert len(result[Api.inference]) == 0
def test_multiple_apis_with_providers(self, mock_providers):
"""Test multiple APIs with providers."""
from types import SimpleNamespace
from llama_stack.core.datatypes import Provider, StackRunConfig
from llama_stack.core.distribution import get_external_providers_from_module
from llama_stack.providers.datatypes import ProviderSpec
inference_spec = ProviderSpec(
api=Api.inference,
provider_type="inf_test",
config_class="inf.Config",
module="inf_test",
)
safety_spec = ProviderSpec(
api=Api.safety,
provider_type="safe_test",
config_class="safe.Config",
module="safe_test",
)
def import_side_effect(name):
if name == "inf_test.provider":
return SimpleNamespace(get_provider_spec=lambda: inference_spec)
elif name == "safe_test.provider":
return SimpleNamespace(get_provider_spec=lambda: safety_spec)
raise ModuleNotFoundError(name)
with patch("importlib.import_module", side_effect=import_side_effect):
config = StackRunConfig(
image_name="test_image",
providers={
"inference": [
Provider(
provider_id="inf",
provider_type="inf_test",
config={},
module="inf_test",
)
],
"safety": [
Provider(
provider_id="safe",
provider_type="safe_test",
config={},
module="safe_test",
)
],
},
)
registry = {Api.inference: {}, Api.safety: {}}
result = get_external_providers_from_module(registry, config, building=False)
assert "inf_test" in result[Api.inference]
assert "safe_test" in result[Api.safety]