mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-10-03 19:57:35 +00:00
Merge 92d0470f74
into 188a56af5c
This commit is contained in:
commit
5c644927df
2 changed files with 505 additions and 3 deletions
|
@ -243,6 +243,7 @@ def get_external_providers_from_module(
|
||||||
spec = module.get_provider_spec()
|
spec = module.get_provider_spec()
|
||||||
else:
|
else:
|
||||||
# pass in a partially filled out provider spec to satisfy the registry -- knowing we will be overwriting it later upon build and run
|
# pass in a partially filled out provider spec to satisfy the registry -- knowing we will be overwriting it later upon build and run
|
||||||
|
# in the case we are building we CANNOT import this module of course because it has not been installed.
|
||||||
spec = ProviderSpec(
|
spec = ProviderSpec(
|
||||||
api=Api(provider_api),
|
api=Api(provider_api),
|
||||||
provider_type=provider.provider_type,
|
provider_type=provider.provider_type,
|
||||||
|
@ -251,9 +252,20 @@ def get_external_providers_from_module(
|
||||||
config_class="",
|
config_class="",
|
||||||
)
|
)
|
||||||
provider_type = provider.provider_type
|
provider_type = provider.provider_type
|
||||||
# in the case we are building we CANNOT import this module of course because it has not been installed.
|
if isinstance(spec, list):
|
||||||
# return a partially filled out spec that the build script will populate.
|
# optionally allow people to pass inline and remote provider specs as a returned list.
|
||||||
registry[Api(provider_api)][provider_type] = spec
|
# with the old method, users could pass in directories of specs using overlapping code
|
||||||
|
# we want to ensure we preserve that flexibility in this method.
|
||||||
|
logger.info(
|
||||||
|
f"Detected a list of external provider specs from {provider.module} adding all to the registry"
|
||||||
|
)
|
||||||
|
for provider_spec in spec:
|
||||||
|
if provider_spec.provider_type != provider.provider_type:
|
||||||
|
continue
|
||||||
|
logger.info(f"Adding {provider.provider_type} to registry")
|
||||||
|
registry[Api(provider_api)][provider.provider_type] = provider_spec
|
||||||
|
else:
|
||||||
|
registry[Api(provider_api)][provider_type] = spec
|
||||||
except ModuleNotFoundError as exc:
|
except ModuleNotFoundError as exc:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"get_provider_spec not found. If specifying an external provider via `module` in the Provider spec, the Provider must have the `provider.get_provider_spec` module available"
|
"get_provider_spec not found. If specifying an external provider via `module` in the Provider spec, the Provider must have the `provider.get_provider_spec` module available"
|
||||||
|
|
|
@ -390,3 +390,493 @@ pip_packages:
|
||||||
assert provider.is_external is True
|
assert provider.is_external is True
|
||||||
# config_class is empty string in partial spec
|
# config_class is empty string in partial spec
|
||||||
assert provider.config_class == ""
|
assert provider.config_class == ""
|
||||||
|
|
||||||
|
|
||||||
|
class TestGetExternalProvidersFromModule:
|
||||||
|
"""Test suite for installing external providers from module."""
|
||||||
|
|
||||||
|
def test_stackrunconfig_provider_without_module(self, mock_providers):
|
||||||
|
"""Test that providers without module attribute are skipped."""
|
||||||
|
from llama_stack.core.datatypes import Provider, StackRunConfig
|
||||||
|
from llama_stack.core.distribution import get_external_providers_from_module
|
||||||
|
|
||||||
|
import_module_side_effect = make_import_module_side_effect()
|
||||||
|
|
||||||
|
with patch("importlib.import_module", side_effect=import_module_side_effect):
|
||||||
|
config = StackRunConfig(
|
||||||
|
image_name="test_image",
|
||||||
|
providers={
|
||||||
|
"inference": [
|
||||||
|
Provider(
|
||||||
|
provider_id="no_module",
|
||||||
|
provider_type="no_module",
|
||||||
|
config={},
|
||||||
|
)
|
||||||
|
]
|
||||||
|
},
|
||||||
|
)
|
||||||
|
registry = {Api.inference: {}}
|
||||||
|
result = get_external_providers_from_module(registry, config, building=False)
|
||||||
|
# Should not add anything to registry
|
||||||
|
assert len(result[Api.inference]) == 0
|
||||||
|
|
||||||
|
def test_stackrunconfig_provider_with_none_module(self, mock_providers):
|
||||||
|
"""Test that providers with None module value are skipped."""
|
||||||
|
from llama_stack.core.datatypes import Provider, StackRunConfig
|
||||||
|
from llama_stack.core.distribution import get_external_providers_from_module
|
||||||
|
|
||||||
|
import_module_side_effect = make_import_module_side_effect()
|
||||||
|
|
||||||
|
with patch("importlib.import_module", side_effect=import_module_side_effect):
|
||||||
|
config = StackRunConfig(
|
||||||
|
image_name="test_image",
|
||||||
|
providers={
|
||||||
|
"inference": [
|
||||||
|
Provider(
|
||||||
|
provider_id="none_module",
|
||||||
|
provider_type="none_module",
|
||||||
|
config={},
|
||||||
|
module=None,
|
||||||
|
)
|
||||||
|
]
|
||||||
|
},
|
||||||
|
)
|
||||||
|
registry = {Api.inference: {}}
|
||||||
|
result = get_external_providers_from_module(registry, config, building=False)
|
||||||
|
# Should not add anything to registry
|
||||||
|
assert len(result[Api.inference]) == 0
|
||||||
|
|
||||||
|
def test_stackrunconfig_with_version_spec(self, mock_providers):
|
||||||
|
"""Test provider with module containing version spec (e.g., package==1.0.0)."""
|
||||||
|
from types import SimpleNamespace
|
||||||
|
|
||||||
|
from llama_stack.core.datatypes import Provider, StackRunConfig
|
||||||
|
from llama_stack.core.distribution import get_external_providers_from_module
|
||||||
|
from llama_stack.providers.datatypes import ProviderSpec
|
||||||
|
|
||||||
|
fake_spec = ProviderSpec(
|
||||||
|
api=Api.inference,
|
||||||
|
provider_type="versioned_test",
|
||||||
|
config_class="versioned_test.config.VersionedTestConfig",
|
||||||
|
module="versioned_test==1.0.0",
|
||||||
|
)
|
||||||
|
fake_module = SimpleNamespace(get_provider_spec=lambda: fake_spec)
|
||||||
|
|
||||||
|
def import_side_effect(name):
|
||||||
|
if name == "versioned_test.provider":
|
||||||
|
return fake_module
|
||||||
|
raise ModuleNotFoundError(name)
|
||||||
|
|
||||||
|
with patch("importlib.import_module", side_effect=import_side_effect):
|
||||||
|
config = StackRunConfig(
|
||||||
|
image_name="test_image",
|
||||||
|
providers={
|
||||||
|
"inference": [
|
||||||
|
Provider(
|
||||||
|
provider_id="versioned",
|
||||||
|
provider_type="versioned_test",
|
||||||
|
config={},
|
||||||
|
module="versioned_test==1.0.0",
|
||||||
|
)
|
||||||
|
]
|
||||||
|
},
|
||||||
|
)
|
||||||
|
registry = {Api.inference: {}}
|
||||||
|
result = get_external_providers_from_module(registry, config, building=False)
|
||||||
|
assert "versioned_test" in result[Api.inference]
|
||||||
|
assert result[Api.inference]["versioned_test"].module == "versioned_test==1.0.0"
|
||||||
|
|
||||||
|
def test_buildconfig_does_not_import_module(self, mock_providers):
|
||||||
|
"""Test that BuildConfig does not import the module (building=True)."""
|
||||||
|
from llama_stack.core.datatypes import BuildConfig, BuildProvider, DistributionSpec
|
||||||
|
from llama_stack.core.distribution import get_external_providers_from_module
|
||||||
|
|
||||||
|
build_config = BuildConfig(
|
||||||
|
version=2,
|
||||||
|
image_type="container",
|
||||||
|
image_name="test_image",
|
||||||
|
distribution_spec=DistributionSpec(
|
||||||
|
description="test",
|
||||||
|
providers={
|
||||||
|
"inference": [
|
||||||
|
BuildProvider(
|
||||||
|
provider_type="build_test",
|
||||||
|
module="build_test==1.0.0",
|
||||||
|
)
|
||||||
|
]
|
||||||
|
},
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
# Should not call import_module at all when building
|
||||||
|
with patch("importlib.import_module") as mock_import:
|
||||||
|
registry = {Api.inference: {}}
|
||||||
|
result = get_external_providers_from_module(registry, build_config, building=True)
|
||||||
|
|
||||||
|
# Verify module was NOT imported
|
||||||
|
mock_import.assert_not_called()
|
||||||
|
|
||||||
|
# Verify partial spec was created
|
||||||
|
assert "build_test" in result[Api.inference]
|
||||||
|
provider = result[Api.inference]["build_test"]
|
||||||
|
assert provider.module == "build_test==1.0.0"
|
||||||
|
assert provider.is_external is True
|
||||||
|
assert provider.config_class == ""
|
||||||
|
assert provider.api == Api.inference
|
||||||
|
|
||||||
|
def test_buildconfig_multiple_providers(self, mock_providers):
|
||||||
|
"""Test BuildConfig with multiple providers for the same API."""
|
||||||
|
from llama_stack.core.datatypes import BuildConfig, BuildProvider, DistributionSpec
|
||||||
|
from llama_stack.core.distribution import get_external_providers_from_module
|
||||||
|
|
||||||
|
build_config = BuildConfig(
|
||||||
|
version=2,
|
||||||
|
image_type="container",
|
||||||
|
image_name="test_image",
|
||||||
|
distribution_spec=DistributionSpec(
|
||||||
|
description="test",
|
||||||
|
providers={
|
||||||
|
"inference": [
|
||||||
|
BuildProvider(provider_type="provider1", module="provider1"),
|
||||||
|
BuildProvider(provider_type="provider2", module="provider2"),
|
||||||
|
]
|
||||||
|
},
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
with patch("importlib.import_module") as mock_import:
|
||||||
|
registry = {Api.inference: {}}
|
||||||
|
result = get_external_providers_from_module(registry, build_config, building=True)
|
||||||
|
|
||||||
|
mock_import.assert_not_called()
|
||||||
|
assert "provider1" in result[Api.inference]
|
||||||
|
assert "provider2" in result[Api.inference]
|
||||||
|
|
||||||
|
def test_distributionspec_does_not_import_module(self, mock_providers):
|
||||||
|
"""Test that DistributionSpec does not import the module (building=True)."""
|
||||||
|
from llama_stack.core.datatypes import BuildProvider, DistributionSpec
|
||||||
|
from llama_stack.core.distribution import get_external_providers_from_module
|
||||||
|
|
||||||
|
dist_spec = DistributionSpec(
|
||||||
|
description="test distribution",
|
||||||
|
providers={
|
||||||
|
"inference": [
|
||||||
|
BuildProvider(
|
||||||
|
provider_type="dist_test",
|
||||||
|
module="dist_test==2.0.0",
|
||||||
|
)
|
||||||
|
]
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
# Should not call import_module at all when building
|
||||||
|
with patch("importlib.import_module") as mock_import:
|
||||||
|
registry = {Api.inference: {}}
|
||||||
|
result = get_external_providers_from_module(registry, dist_spec, building=True)
|
||||||
|
|
||||||
|
# Verify module was NOT imported
|
||||||
|
mock_import.assert_not_called()
|
||||||
|
|
||||||
|
# Verify partial spec was created
|
||||||
|
assert "dist_test" in result[Api.inference]
|
||||||
|
provider = result[Api.inference]["dist_test"]
|
||||||
|
assert provider.module == "dist_test==2.0.0"
|
||||||
|
assert provider.is_external is True
|
||||||
|
assert provider.config_class == ""
|
||||||
|
|
||||||
|
def test_list_return_from_get_provider_spec(self, mock_providers):
|
||||||
|
"""Test when get_provider_spec returns a list of specs."""
|
||||||
|
from types import SimpleNamespace
|
||||||
|
|
||||||
|
from llama_stack.core.datatypes import Provider, StackRunConfig
|
||||||
|
from llama_stack.core.distribution import get_external_providers_from_module
|
||||||
|
from llama_stack.providers.datatypes import ProviderSpec
|
||||||
|
|
||||||
|
spec1 = ProviderSpec(
|
||||||
|
api=Api.inference,
|
||||||
|
provider_type="list_test",
|
||||||
|
config_class="list_test.config.Config1",
|
||||||
|
module="list_test",
|
||||||
|
)
|
||||||
|
spec2 = ProviderSpec(
|
||||||
|
api=Api.inference,
|
||||||
|
provider_type="list_test_remote",
|
||||||
|
config_class="list_test.config.Config2",
|
||||||
|
module="list_test",
|
||||||
|
)
|
||||||
|
|
||||||
|
fake_module = SimpleNamespace(get_provider_spec=lambda: [spec1, spec2])
|
||||||
|
|
||||||
|
def import_side_effect(name):
|
||||||
|
if name == "list_test.provider":
|
||||||
|
return fake_module
|
||||||
|
raise ModuleNotFoundError(name)
|
||||||
|
|
||||||
|
with patch("importlib.import_module", side_effect=import_side_effect):
|
||||||
|
config = StackRunConfig(
|
||||||
|
image_name="test_image",
|
||||||
|
providers={
|
||||||
|
"inference": [
|
||||||
|
Provider(
|
||||||
|
provider_id="list_test",
|
||||||
|
provider_type="list_test",
|
||||||
|
config={},
|
||||||
|
module="list_test",
|
||||||
|
)
|
||||||
|
]
|
||||||
|
},
|
||||||
|
)
|
||||||
|
registry = {Api.inference: {}}
|
||||||
|
result = get_external_providers_from_module(registry, config, building=False)
|
||||||
|
|
||||||
|
# Only the matching provider_type should be added
|
||||||
|
assert "list_test" in result[Api.inference]
|
||||||
|
assert result[Api.inference]["list_test"].config_class == "list_test.config.Config1"
|
||||||
|
|
||||||
|
def test_list_return_filters_by_provider_type(self, mock_providers):
|
||||||
|
"""Test that list return filters specs by provider_type."""
|
||||||
|
from types import SimpleNamespace
|
||||||
|
|
||||||
|
from llama_stack.core.datatypes import Provider, StackRunConfig
|
||||||
|
from llama_stack.core.distribution import get_external_providers_from_module
|
||||||
|
from llama_stack.providers.datatypes import ProviderSpec
|
||||||
|
|
||||||
|
spec1 = ProviderSpec(
|
||||||
|
api=Api.inference,
|
||||||
|
provider_type="wanted",
|
||||||
|
config_class="test.Config1",
|
||||||
|
module="test",
|
||||||
|
)
|
||||||
|
spec2 = ProviderSpec(
|
||||||
|
api=Api.inference,
|
||||||
|
provider_type="unwanted",
|
||||||
|
config_class="test.Config2",
|
||||||
|
module="test",
|
||||||
|
)
|
||||||
|
|
||||||
|
fake_module = SimpleNamespace(get_provider_spec=lambda: [spec1, spec2])
|
||||||
|
|
||||||
|
def import_side_effect(name):
|
||||||
|
if name == "test.provider":
|
||||||
|
return fake_module
|
||||||
|
raise ModuleNotFoundError(name)
|
||||||
|
|
||||||
|
with patch("importlib.import_module", side_effect=import_side_effect):
|
||||||
|
config = StackRunConfig(
|
||||||
|
image_name="test_image",
|
||||||
|
providers={
|
||||||
|
"inference": [
|
||||||
|
Provider(
|
||||||
|
provider_id="wanted",
|
||||||
|
provider_type="wanted",
|
||||||
|
config={},
|
||||||
|
module="test",
|
||||||
|
)
|
||||||
|
]
|
||||||
|
},
|
||||||
|
)
|
||||||
|
registry = {Api.inference: {}}
|
||||||
|
result = get_external_providers_from_module(registry, config, building=False)
|
||||||
|
|
||||||
|
# Only the matching provider_type should be added
|
||||||
|
assert "wanted" in result[Api.inference]
|
||||||
|
assert "unwanted" not in result[Api.inference]
|
||||||
|
|
||||||
|
def test_list_return_adds_multiple_provider_types(self, mock_providers):
|
||||||
|
"""Test that list return adds multiple different provider_types when config requests them."""
|
||||||
|
from types import SimpleNamespace
|
||||||
|
|
||||||
|
from llama_stack.core.datatypes import Provider, StackRunConfig
|
||||||
|
from llama_stack.core.distribution import get_external_providers_from_module
|
||||||
|
from llama_stack.providers.datatypes import ProviderSpec
|
||||||
|
|
||||||
|
# Module returns both inline and remote variants
|
||||||
|
spec1 = ProviderSpec(
|
||||||
|
api=Api.inference,
|
||||||
|
provider_type="remote::ollama",
|
||||||
|
config_class="test.RemoteConfig",
|
||||||
|
module="test",
|
||||||
|
)
|
||||||
|
spec2 = ProviderSpec(
|
||||||
|
api=Api.inference,
|
||||||
|
provider_type="inline::ollama",
|
||||||
|
config_class="test.InlineConfig",
|
||||||
|
module="test",
|
||||||
|
)
|
||||||
|
|
||||||
|
fake_module = SimpleNamespace(get_provider_spec=lambda: [spec1, spec2])
|
||||||
|
|
||||||
|
def import_side_effect(name):
|
||||||
|
if name == "test.provider":
|
||||||
|
return fake_module
|
||||||
|
raise ModuleNotFoundError(name)
|
||||||
|
|
||||||
|
with patch("importlib.import_module", side_effect=import_side_effect):
|
||||||
|
config = StackRunConfig(
|
||||||
|
image_name="test_image",
|
||||||
|
providers={
|
||||||
|
"inference": [
|
||||||
|
Provider(
|
||||||
|
provider_id="remote_ollama",
|
||||||
|
provider_type="remote::ollama",
|
||||||
|
config={},
|
||||||
|
module="test",
|
||||||
|
),
|
||||||
|
Provider(
|
||||||
|
provider_id="inline_ollama",
|
||||||
|
provider_type="inline::ollama",
|
||||||
|
config={},
|
||||||
|
module="test",
|
||||||
|
),
|
||||||
|
]
|
||||||
|
},
|
||||||
|
)
|
||||||
|
registry = {Api.inference: {}}
|
||||||
|
result = get_external_providers_from_module(registry, config, building=False)
|
||||||
|
|
||||||
|
# Both provider types should be added to registry
|
||||||
|
assert "remote::ollama" in result[Api.inference]
|
||||||
|
assert "inline::ollama" in result[Api.inference]
|
||||||
|
assert result[Api.inference]["remote::ollama"].config_class == "test.RemoteConfig"
|
||||||
|
assert result[Api.inference]["inline::ollama"].config_class == "test.InlineConfig"
|
||||||
|
|
||||||
|
def test_module_not_found_raises_value_error(self, mock_providers):
|
||||||
|
"""Test that ModuleNotFoundError raises ValueError with helpful message."""
|
||||||
|
from llama_stack.core.datatypes import Provider, StackRunConfig
|
||||||
|
from llama_stack.core.distribution import get_external_providers_from_module
|
||||||
|
|
||||||
|
def import_side_effect(name):
|
||||||
|
if name == "missing_module.provider":
|
||||||
|
raise ModuleNotFoundError(name)
|
||||||
|
raise ModuleNotFoundError(name)
|
||||||
|
|
||||||
|
with patch("importlib.import_module", side_effect=import_side_effect):
|
||||||
|
config = StackRunConfig(
|
||||||
|
image_name="test_image",
|
||||||
|
providers={
|
||||||
|
"inference": [
|
||||||
|
Provider(
|
||||||
|
provider_id="missing",
|
||||||
|
provider_type="missing",
|
||||||
|
config={},
|
||||||
|
module="missing_module",
|
||||||
|
)
|
||||||
|
]
|
||||||
|
},
|
||||||
|
)
|
||||||
|
registry = {Api.inference: {}}
|
||||||
|
|
||||||
|
with pytest.raises(ValueError) as exc_info:
|
||||||
|
get_external_providers_from_module(registry, config, building=False)
|
||||||
|
|
||||||
|
assert "get_provider_spec not found" in str(exc_info.value)
|
||||||
|
|
||||||
|
def test_generic_exception_is_raised(self, mock_providers):
|
||||||
|
"""Test that generic exceptions are properly raised."""
|
||||||
|
from types import SimpleNamespace
|
||||||
|
|
||||||
|
from llama_stack.core.datatypes import Provider, StackRunConfig
|
||||||
|
from llama_stack.core.distribution import get_external_providers_from_module
|
||||||
|
|
||||||
|
def bad_spec():
|
||||||
|
raise RuntimeError("Something went wrong")
|
||||||
|
|
||||||
|
fake_module = SimpleNamespace(get_provider_spec=bad_spec)
|
||||||
|
|
||||||
|
def import_side_effect(name):
|
||||||
|
if name == "error_module.provider":
|
||||||
|
return fake_module
|
||||||
|
raise ModuleNotFoundError(name)
|
||||||
|
|
||||||
|
with patch("importlib.import_module", side_effect=import_side_effect):
|
||||||
|
config = StackRunConfig(
|
||||||
|
image_name="test_image",
|
||||||
|
providers={
|
||||||
|
"inference": [
|
||||||
|
Provider(
|
||||||
|
provider_id="error",
|
||||||
|
provider_type="error",
|
||||||
|
config={},
|
||||||
|
module="error_module",
|
||||||
|
)
|
||||||
|
]
|
||||||
|
},
|
||||||
|
)
|
||||||
|
registry = {Api.inference: {}}
|
||||||
|
|
||||||
|
with pytest.raises(RuntimeError) as exc_info:
|
||||||
|
get_external_providers_from_module(registry, config, building=False)
|
||||||
|
|
||||||
|
assert "Something went wrong" in str(exc_info.value)
|
||||||
|
|
||||||
|
def test_empty_provider_list(self, mock_providers):
|
||||||
|
"""Test with empty provider list."""
|
||||||
|
from llama_stack.core.datatypes import StackRunConfig
|
||||||
|
from llama_stack.core.distribution import get_external_providers_from_module
|
||||||
|
|
||||||
|
config = StackRunConfig(
|
||||||
|
image_name="test_image",
|
||||||
|
providers={},
|
||||||
|
)
|
||||||
|
registry = {Api.inference: {}}
|
||||||
|
result = get_external_providers_from_module(registry, config, building=False)
|
||||||
|
|
||||||
|
# Should return registry unchanged
|
||||||
|
assert result == registry
|
||||||
|
assert len(result[Api.inference]) == 0
|
||||||
|
|
||||||
|
def test_multiple_apis_with_providers(self, mock_providers):
|
||||||
|
"""Test multiple APIs with providers."""
|
||||||
|
from types import SimpleNamespace
|
||||||
|
|
||||||
|
from llama_stack.core.datatypes import Provider, StackRunConfig
|
||||||
|
from llama_stack.core.distribution import get_external_providers_from_module
|
||||||
|
from llama_stack.providers.datatypes import ProviderSpec
|
||||||
|
|
||||||
|
inference_spec = ProviderSpec(
|
||||||
|
api=Api.inference,
|
||||||
|
provider_type="inf_test",
|
||||||
|
config_class="inf.Config",
|
||||||
|
module="inf_test",
|
||||||
|
)
|
||||||
|
safety_spec = ProviderSpec(
|
||||||
|
api=Api.safety,
|
||||||
|
provider_type="safe_test",
|
||||||
|
config_class="safe.Config",
|
||||||
|
module="safe_test",
|
||||||
|
)
|
||||||
|
|
||||||
|
def import_side_effect(name):
|
||||||
|
if name == "inf_test.provider":
|
||||||
|
return SimpleNamespace(get_provider_spec=lambda: inference_spec)
|
||||||
|
elif name == "safe_test.provider":
|
||||||
|
return SimpleNamespace(get_provider_spec=lambda: safety_spec)
|
||||||
|
raise ModuleNotFoundError(name)
|
||||||
|
|
||||||
|
with patch("importlib.import_module", side_effect=import_side_effect):
|
||||||
|
config = StackRunConfig(
|
||||||
|
image_name="test_image",
|
||||||
|
providers={
|
||||||
|
"inference": [
|
||||||
|
Provider(
|
||||||
|
provider_id="inf",
|
||||||
|
provider_type="inf_test",
|
||||||
|
config={},
|
||||||
|
module="inf_test",
|
||||||
|
)
|
||||||
|
],
|
||||||
|
"safety": [
|
||||||
|
Provider(
|
||||||
|
provider_id="safe",
|
||||||
|
provider_type="safe_test",
|
||||||
|
config={},
|
||||||
|
module="safe_test",
|
||||||
|
)
|
||||||
|
],
|
||||||
|
},
|
||||||
|
)
|
||||||
|
registry = {Api.inference: {}, Api.safety: {}}
|
||||||
|
result = get_external_providers_from_module(registry, config, building=False)
|
||||||
|
|
||||||
|
assert "inf_test" in result[Api.inference]
|
||||||
|
assert "safe_test" in result[Api.safety]
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue