mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-07-26 22:19:49 +00:00
test: add more coverage for get_provider_registry
since the scope of what get_provider_registry does is expanding, add tests to cover providers with `module` Signed-off-by: Charlie Doern <cdoern@redhat.com>
This commit is contained in:
parent
09472b63c9
commit
434cb10fbb
1 changed files with 179 additions and 0 deletions
|
@ -221,3 +221,182 @@ pip_packages:
|
||||||
with pytest.raises(KeyError) as exc_info:
|
with pytest.raises(KeyError) as exc_info:
|
||||||
get_provider_registry(base_config)
|
get_provider_registry(base_config)
|
||||||
assert "config_class" in str(exc_info.value)
|
assert "config_class" in str(exc_info.value)
|
||||||
|
|
||||||
|
def test_external_provider_from_module_success(self, mock_providers):
|
||||||
|
"""Test loading an external provider from a module (success path)."""
|
||||||
|
from types import SimpleNamespace
|
||||||
|
|
||||||
|
from llama_stack.distribution.datatypes import Provider, StackRunConfig
|
||||||
|
from llama_stack.providers.datatypes import Api, ProviderSpec
|
||||||
|
|
||||||
|
# Simulate a provider module with get_provider_spec
|
||||||
|
fake_spec = ProviderSpec(
|
||||||
|
api=Api.inference,
|
||||||
|
provider_type="external_test",
|
||||||
|
config_class="external_test.config.ExternalTestConfig",
|
||||||
|
module="external_test",
|
||||||
|
)
|
||||||
|
fake_module = SimpleNamespace(get_provider_spec=lambda: fake_spec)
|
||||||
|
|
||||||
|
def import_module_side_effect(name):
|
||||||
|
if name == "llama_stack.providers.registry.inference":
|
||||||
|
# Return a mock with available_providers for built-in providers
|
||||||
|
mock_builtin = SimpleNamespace(
|
||||||
|
available_providers=lambda: [
|
||||||
|
ProviderSpec(
|
||||||
|
api=Api.inference,
|
||||||
|
provider_type="test_provider",
|
||||||
|
config_class="test_provider.config.TestProviderConfig",
|
||||||
|
module="test_provider",
|
||||||
|
)
|
||||||
|
]
|
||||||
|
)
|
||||||
|
return mock_builtin
|
||||||
|
elif name == "external_test.provider":
|
||||||
|
return fake_module
|
||||||
|
else:
|
||||||
|
raise ModuleNotFoundError(name)
|
||||||
|
|
||||||
|
with patch("importlib.import_module", side_effect=import_module_side_effect) as mock_import:
|
||||||
|
config = StackRunConfig(
|
||||||
|
image_name="test_image",
|
||||||
|
providers={
|
||||||
|
"inference": [
|
||||||
|
Provider(
|
||||||
|
provider_id="external_test",
|
||||||
|
provider_type="external_test",
|
||||||
|
config={},
|
||||||
|
module="external_test",
|
||||||
|
)
|
||||||
|
]
|
||||||
|
},
|
||||||
|
)
|
||||||
|
registry = get_provider_registry(config)
|
||||||
|
assert Api.inference in registry
|
||||||
|
assert "external_test" in registry[Api.inference]
|
||||||
|
provider = registry[Api.inference]["external_test"]
|
||||||
|
assert provider.module == "external_test"
|
||||||
|
assert provider.config_class == "external_test.config.ExternalTestConfig"
|
||||||
|
mock_import.assert_any_call("llama_stack.providers.registry.inference")
|
||||||
|
mock_import.assert_any_call("external_test.provider")
|
||||||
|
|
||||||
|
def test_external_provider_from_module_not_found(self, mock_providers):
|
||||||
|
"""Test handling ModuleNotFoundError for missing provider module."""
|
||||||
|
from types import SimpleNamespace
|
||||||
|
|
||||||
|
from llama_stack.distribution.datatypes import Provider, StackRunConfig
|
||||||
|
from llama_stack.providers.datatypes import Api
|
||||||
|
|
||||||
|
def import_module_side_effect(name):
|
||||||
|
if name == "llama_stack.providers.registry.inference":
|
||||||
|
mock_builtin = SimpleNamespace(
|
||||||
|
available_providers=lambda: [
|
||||||
|
ProviderSpec(
|
||||||
|
api=Api.inference,
|
||||||
|
provider_type="test_provider",
|
||||||
|
config_class="test_provider.config.TestProviderConfig",
|
||||||
|
module="test_provider",
|
||||||
|
)
|
||||||
|
]
|
||||||
|
)
|
||||||
|
return mock_builtin
|
||||||
|
elif name == "external_test.provider":
|
||||||
|
raise ModuleNotFoundError(name)
|
||||||
|
else:
|
||||||
|
raise ModuleNotFoundError(name)
|
||||||
|
|
||||||
|
with patch("importlib.import_module", side_effect=import_module_side_effect):
|
||||||
|
config = StackRunConfig(
|
||||||
|
image_name="test_image",
|
||||||
|
providers={
|
||||||
|
"inference": [
|
||||||
|
Provider(
|
||||||
|
provider_id="external_test",
|
||||||
|
provider_type="external_test",
|
||||||
|
config={},
|
||||||
|
module="external_test",
|
||||||
|
)
|
||||||
|
]
|
||||||
|
},
|
||||||
|
)
|
||||||
|
with pytest.raises(ValueError) as exc_info:
|
||||||
|
get_provider_registry(config)
|
||||||
|
assert "get_provider_spec not found" in str(exc_info.value)
|
||||||
|
|
||||||
|
def test_external_provider_from_module_missing_get_provider_spec(self, mock_providers):
|
||||||
|
"""Test handling missing get_provider_spec in provider module (should raise ValueError)."""
|
||||||
|
from types import SimpleNamespace
|
||||||
|
|
||||||
|
from llama_stack.distribution.datatypes import Provider, StackRunConfig
|
||||||
|
from llama_stack.providers.datatypes import Api
|
||||||
|
|
||||||
|
# Simulate a provider module without get_provider_spec
|
||||||
|
fake_module = SimpleNamespace()
|
||||||
|
|
||||||
|
def import_module_side_effect(name):
|
||||||
|
if name == "llama_stack.providers.registry.inference":
|
||||||
|
mock_builtin = SimpleNamespace(
|
||||||
|
available_providers=lambda: [
|
||||||
|
ProviderSpec(
|
||||||
|
api=Api.inference,
|
||||||
|
provider_type="test_provider",
|
||||||
|
config_class="test_provider.config.TestProviderConfig",
|
||||||
|
module="test_provider",
|
||||||
|
)
|
||||||
|
]
|
||||||
|
)
|
||||||
|
return mock_builtin
|
||||||
|
elif name == "external_test.provider":
|
||||||
|
return fake_module
|
||||||
|
else:
|
||||||
|
raise ModuleNotFoundError(name)
|
||||||
|
|
||||||
|
with patch("importlib.import_module", side_effect=import_module_side_effect):
|
||||||
|
config = StackRunConfig(
|
||||||
|
image_name="test_image",
|
||||||
|
providers={
|
||||||
|
"inference": [
|
||||||
|
Provider(
|
||||||
|
provider_id="external_test",
|
||||||
|
provider_type="external_test",
|
||||||
|
config={},
|
||||||
|
module="external_test",
|
||||||
|
)
|
||||||
|
]
|
||||||
|
},
|
||||||
|
)
|
||||||
|
with pytest.raises(AttributeError):
|
||||||
|
get_provider_registry(config)
|
||||||
|
|
||||||
|
def test_external_provider_from_module_building(self, mock_providers):
|
||||||
|
"""Test loading an external provider from a module during build (building=True, partial spec)."""
|
||||||
|
from llama_stack.distribution.datatypes import BuildConfig, DistributionSpec, Provider
|
||||||
|
from llama_stack.providers.datatypes import Api
|
||||||
|
|
||||||
|
# No importlib patch needed, should not import module when building=True
|
||||||
|
build_config = BuildConfig(
|
||||||
|
version=2,
|
||||||
|
image_type="container",
|
||||||
|
image_name="test_image",
|
||||||
|
distribution_spec=DistributionSpec(
|
||||||
|
description="test",
|
||||||
|
providers={
|
||||||
|
"inference": [
|
||||||
|
Provider(
|
||||||
|
provider_id="external_test",
|
||||||
|
provider_type="external_test",
|
||||||
|
config={},
|
||||||
|
module="external_test",
|
||||||
|
)
|
||||||
|
]
|
||||||
|
},
|
||||||
|
),
|
||||||
|
)
|
||||||
|
registry = get_provider_registry(build_config)
|
||||||
|
assert Api.inference in registry
|
||||||
|
assert "external_test" in registry[Api.inference]
|
||||||
|
provider = registry[Api.inference]["external_test"]
|
||||||
|
assert provider.module == "external_test"
|
||||||
|
assert provider.is_external is True
|
||||||
|
# config_class is empty string in partial spec
|
||||||
|
assert provider.config_class == ""
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue