test: add more coverage for get_provider_registry

since the scope of what get_provider_registry does is expanding, add tests to cover providers with `module`

Signed-off-by: Charlie Doern <cdoern@redhat.com>
This commit is contained in:
Charlie Doern 2025-07-24 10:51:18 -04:00
parent 09472b63c9
commit 434cb10fbb

View file

@ -221,3 +221,182 @@ pip_packages:
with pytest.raises(KeyError) as exc_info:
get_provider_registry(base_config)
assert "config_class" in str(exc_info.value)
def test_external_provider_from_module_success(self, mock_providers):
"""Test loading an external provider from a module (success path)."""
from types import SimpleNamespace
from llama_stack.distribution.datatypes import Provider, StackRunConfig
from llama_stack.providers.datatypes import Api, ProviderSpec
# Simulate a provider module with get_provider_spec
fake_spec = ProviderSpec(
api=Api.inference,
provider_type="external_test",
config_class="external_test.config.ExternalTestConfig",
module="external_test",
)
fake_module = SimpleNamespace(get_provider_spec=lambda: fake_spec)
def import_module_side_effect(name):
if name == "llama_stack.providers.registry.inference":
# Return a mock with available_providers for built-in providers
mock_builtin = SimpleNamespace(
available_providers=lambda: [
ProviderSpec(
api=Api.inference,
provider_type="test_provider",
config_class="test_provider.config.TestProviderConfig",
module="test_provider",
)
]
)
return mock_builtin
elif name == "external_test.provider":
return fake_module
else:
raise ModuleNotFoundError(name)
with patch("importlib.import_module", side_effect=import_module_side_effect) as mock_import:
config = StackRunConfig(
image_name="test_image",
providers={
"inference": [
Provider(
provider_id="external_test",
provider_type="external_test",
config={},
module="external_test",
)
]
},
)
registry = get_provider_registry(config)
assert Api.inference in registry
assert "external_test" in registry[Api.inference]
provider = registry[Api.inference]["external_test"]
assert provider.module == "external_test"
assert provider.config_class == "external_test.config.ExternalTestConfig"
mock_import.assert_any_call("llama_stack.providers.registry.inference")
mock_import.assert_any_call("external_test.provider")
def test_external_provider_from_module_not_found(self, mock_providers):
"""Test handling ModuleNotFoundError for missing provider module."""
from types import SimpleNamespace
from llama_stack.distribution.datatypes import Provider, StackRunConfig
from llama_stack.providers.datatypes import Api
def import_module_side_effect(name):
if name == "llama_stack.providers.registry.inference":
mock_builtin = SimpleNamespace(
available_providers=lambda: [
ProviderSpec(
api=Api.inference,
provider_type="test_provider",
config_class="test_provider.config.TestProviderConfig",
module="test_provider",
)
]
)
return mock_builtin
elif name == "external_test.provider":
raise ModuleNotFoundError(name)
else:
raise ModuleNotFoundError(name)
with patch("importlib.import_module", side_effect=import_module_side_effect):
config = StackRunConfig(
image_name="test_image",
providers={
"inference": [
Provider(
provider_id="external_test",
provider_type="external_test",
config={},
module="external_test",
)
]
},
)
with pytest.raises(ValueError) as exc_info:
get_provider_registry(config)
assert "get_provider_spec not found" in str(exc_info.value)
def test_external_provider_from_module_missing_get_provider_spec(self, mock_providers):
"""Test handling missing get_provider_spec in provider module (should raise ValueError)."""
from types import SimpleNamespace
from llama_stack.distribution.datatypes import Provider, StackRunConfig
from llama_stack.providers.datatypes import Api
# Simulate a provider module without get_provider_spec
fake_module = SimpleNamespace()
def import_module_side_effect(name):
if name == "llama_stack.providers.registry.inference":
mock_builtin = SimpleNamespace(
available_providers=lambda: [
ProviderSpec(
api=Api.inference,
provider_type="test_provider",
config_class="test_provider.config.TestProviderConfig",
module="test_provider",
)
]
)
return mock_builtin
elif name == "external_test.provider":
return fake_module
else:
raise ModuleNotFoundError(name)
with patch("importlib.import_module", side_effect=import_module_side_effect):
config = StackRunConfig(
image_name="test_image",
providers={
"inference": [
Provider(
provider_id="external_test",
provider_type="external_test",
config={},
module="external_test",
)
]
},
)
with pytest.raises(AttributeError):
get_provider_registry(config)
def test_external_provider_from_module_building(self, mock_providers):
"""Test loading an external provider from a module during build (building=True, partial spec)."""
from llama_stack.distribution.datatypes import BuildConfig, DistributionSpec, Provider
from llama_stack.providers.datatypes import Api
# No importlib patch needed, should not import module when building=True
build_config = BuildConfig(
version=2,
image_type="container",
image_name="test_image",
distribution_spec=DistributionSpec(
description="test",
providers={
"inference": [
Provider(
provider_id="external_test",
provider_type="external_test",
config={},
module="external_test",
)
]
},
),
)
registry = get_provider_registry(build_config)
assert Api.inference in registry
assert "external_test" in registry[Api.inference]
provider = registry[Api.inference]["external_test"]
assert provider.module == "external_test"
assert provider.is_external is True
# config_class is empty string in partial spec
assert provider.config_class == ""