mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-10-03 19:57:35 +00:00
Merge 92d0470f74
into 188a56af5c
This commit is contained in:
commit
5c644927df
2 changed files with 505 additions and 3 deletions
|
@ -243,6 +243,7 @@ def get_external_providers_from_module(
|
|||
spec = module.get_provider_spec()
|
||||
else:
|
||||
# pass in a partially filled out provider spec to satisfy the registry -- knowing we will be overwriting it later upon build and run
|
||||
# in the case we are building we CANNOT import this module of course because it has not been installed.
|
||||
spec = ProviderSpec(
|
||||
api=Api(provider_api),
|
||||
provider_type=provider.provider_type,
|
||||
|
@ -251,9 +252,20 @@ def get_external_providers_from_module(
|
|||
config_class="",
|
||||
)
|
||||
provider_type = provider.provider_type
|
||||
# in the case we are building we CANNOT import this module of course because it has not been installed.
|
||||
# return a partially filled out spec that the build script will populate.
|
||||
registry[Api(provider_api)][provider_type] = spec
|
||||
if isinstance(spec, list):
|
||||
# optionally allow people to pass inline and remote provider specs as a returned list.
|
||||
# with the old method, users could pass in directories of specs using overlapping code
|
||||
# we want to ensure we preserve that flexibility in this method.
|
||||
logger.info(
|
||||
f"Detected a list of external provider specs from {provider.module} adding all to the registry"
|
||||
)
|
||||
for provider_spec in spec:
|
||||
if provider_spec.provider_type != provider.provider_type:
|
||||
continue
|
||||
logger.info(f"Adding {provider.provider_type} to registry")
|
||||
registry[Api(provider_api)][provider.provider_type] = provider_spec
|
||||
else:
|
||||
registry[Api(provider_api)][provider_type] = spec
|
||||
except ModuleNotFoundError as exc:
|
||||
raise ValueError(
|
||||
"get_provider_spec not found. If specifying an external provider via `module` in the Provider spec, the Provider must have the `provider.get_provider_spec` module available"
|
||||
|
|
|
@ -390,3 +390,493 @@ pip_packages:
|
|||
assert provider.is_external is True
|
||||
# config_class is empty string in partial spec
|
||||
assert provider.config_class == ""
|
||||
|
||||
|
||||
class TestGetExternalProvidersFromModule:
|
||||
"""Test suite for installing external providers from module."""
|
||||
|
||||
def test_stackrunconfig_provider_without_module(self, mock_providers):
|
||||
"""Test that providers without module attribute are skipped."""
|
||||
from llama_stack.core.datatypes import Provider, StackRunConfig
|
||||
from llama_stack.core.distribution import get_external_providers_from_module
|
||||
|
||||
import_module_side_effect = make_import_module_side_effect()
|
||||
|
||||
with patch("importlib.import_module", side_effect=import_module_side_effect):
|
||||
config = StackRunConfig(
|
||||
image_name="test_image",
|
||||
providers={
|
||||
"inference": [
|
||||
Provider(
|
||||
provider_id="no_module",
|
||||
provider_type="no_module",
|
||||
config={},
|
||||
)
|
||||
]
|
||||
},
|
||||
)
|
||||
registry = {Api.inference: {}}
|
||||
result = get_external_providers_from_module(registry, config, building=False)
|
||||
# Should not add anything to registry
|
||||
assert len(result[Api.inference]) == 0
|
||||
|
||||
def test_stackrunconfig_provider_with_none_module(self, mock_providers):
|
||||
"""Test that providers with None module value are skipped."""
|
||||
from llama_stack.core.datatypes import Provider, StackRunConfig
|
||||
from llama_stack.core.distribution import get_external_providers_from_module
|
||||
|
||||
import_module_side_effect = make_import_module_side_effect()
|
||||
|
||||
with patch("importlib.import_module", side_effect=import_module_side_effect):
|
||||
config = StackRunConfig(
|
||||
image_name="test_image",
|
||||
providers={
|
||||
"inference": [
|
||||
Provider(
|
||||
provider_id="none_module",
|
||||
provider_type="none_module",
|
||||
config={},
|
||||
module=None,
|
||||
)
|
||||
]
|
||||
},
|
||||
)
|
||||
registry = {Api.inference: {}}
|
||||
result = get_external_providers_from_module(registry, config, building=False)
|
||||
# Should not add anything to registry
|
||||
assert len(result[Api.inference]) == 0
|
||||
|
||||
def test_stackrunconfig_with_version_spec(self, mock_providers):
|
||||
"""Test provider with module containing version spec (e.g., package==1.0.0)."""
|
||||
from types import SimpleNamespace
|
||||
|
||||
from llama_stack.core.datatypes import Provider, StackRunConfig
|
||||
from llama_stack.core.distribution import get_external_providers_from_module
|
||||
from llama_stack.providers.datatypes import ProviderSpec
|
||||
|
||||
fake_spec = ProviderSpec(
|
||||
api=Api.inference,
|
||||
provider_type="versioned_test",
|
||||
config_class="versioned_test.config.VersionedTestConfig",
|
||||
module="versioned_test==1.0.0",
|
||||
)
|
||||
fake_module = SimpleNamespace(get_provider_spec=lambda: fake_spec)
|
||||
|
||||
def import_side_effect(name):
|
||||
if name == "versioned_test.provider":
|
||||
return fake_module
|
||||
raise ModuleNotFoundError(name)
|
||||
|
||||
with patch("importlib.import_module", side_effect=import_side_effect):
|
||||
config = StackRunConfig(
|
||||
image_name="test_image",
|
||||
providers={
|
||||
"inference": [
|
||||
Provider(
|
||||
provider_id="versioned",
|
||||
provider_type="versioned_test",
|
||||
config={},
|
||||
module="versioned_test==1.0.0",
|
||||
)
|
||||
]
|
||||
},
|
||||
)
|
||||
registry = {Api.inference: {}}
|
||||
result = get_external_providers_from_module(registry, config, building=False)
|
||||
assert "versioned_test" in result[Api.inference]
|
||||
assert result[Api.inference]["versioned_test"].module == "versioned_test==1.0.0"
|
||||
|
||||
def test_buildconfig_does_not_import_module(self, mock_providers):
|
||||
"""Test that BuildConfig does not import the module (building=True)."""
|
||||
from llama_stack.core.datatypes import BuildConfig, BuildProvider, DistributionSpec
|
||||
from llama_stack.core.distribution import get_external_providers_from_module
|
||||
|
||||
build_config = BuildConfig(
|
||||
version=2,
|
||||
image_type="container",
|
||||
image_name="test_image",
|
||||
distribution_spec=DistributionSpec(
|
||||
description="test",
|
||||
providers={
|
||||
"inference": [
|
||||
BuildProvider(
|
||||
provider_type="build_test",
|
||||
module="build_test==1.0.0",
|
||||
)
|
||||
]
|
||||
},
|
||||
),
|
||||
)
|
||||
|
||||
# Should not call import_module at all when building
|
||||
with patch("importlib.import_module") as mock_import:
|
||||
registry = {Api.inference: {}}
|
||||
result = get_external_providers_from_module(registry, build_config, building=True)
|
||||
|
||||
# Verify module was NOT imported
|
||||
mock_import.assert_not_called()
|
||||
|
||||
# Verify partial spec was created
|
||||
assert "build_test" in result[Api.inference]
|
||||
provider = result[Api.inference]["build_test"]
|
||||
assert provider.module == "build_test==1.0.0"
|
||||
assert provider.is_external is True
|
||||
assert provider.config_class == ""
|
||||
assert provider.api == Api.inference
|
||||
|
||||
def test_buildconfig_multiple_providers(self, mock_providers):
|
||||
"""Test BuildConfig with multiple providers for the same API."""
|
||||
from llama_stack.core.datatypes import BuildConfig, BuildProvider, DistributionSpec
|
||||
from llama_stack.core.distribution import get_external_providers_from_module
|
||||
|
||||
build_config = BuildConfig(
|
||||
version=2,
|
||||
image_type="container",
|
||||
image_name="test_image",
|
||||
distribution_spec=DistributionSpec(
|
||||
description="test",
|
||||
providers={
|
||||
"inference": [
|
||||
BuildProvider(provider_type="provider1", module="provider1"),
|
||||
BuildProvider(provider_type="provider2", module="provider2"),
|
||||
]
|
||||
},
|
||||
),
|
||||
)
|
||||
|
||||
with patch("importlib.import_module") as mock_import:
|
||||
registry = {Api.inference: {}}
|
||||
result = get_external_providers_from_module(registry, build_config, building=True)
|
||||
|
||||
mock_import.assert_not_called()
|
||||
assert "provider1" in result[Api.inference]
|
||||
assert "provider2" in result[Api.inference]
|
||||
|
||||
def test_distributionspec_does_not_import_module(self, mock_providers):
|
||||
"""Test that DistributionSpec does not import the module (building=True)."""
|
||||
from llama_stack.core.datatypes import BuildProvider, DistributionSpec
|
||||
from llama_stack.core.distribution import get_external_providers_from_module
|
||||
|
||||
dist_spec = DistributionSpec(
|
||||
description="test distribution",
|
||||
providers={
|
||||
"inference": [
|
||||
BuildProvider(
|
||||
provider_type="dist_test",
|
||||
module="dist_test==2.0.0",
|
||||
)
|
||||
]
|
||||
},
|
||||
)
|
||||
|
||||
# Should not call import_module at all when building
|
||||
with patch("importlib.import_module") as mock_import:
|
||||
registry = {Api.inference: {}}
|
||||
result = get_external_providers_from_module(registry, dist_spec, building=True)
|
||||
|
||||
# Verify module was NOT imported
|
||||
mock_import.assert_not_called()
|
||||
|
||||
# Verify partial spec was created
|
||||
assert "dist_test" in result[Api.inference]
|
||||
provider = result[Api.inference]["dist_test"]
|
||||
assert provider.module == "dist_test==2.0.0"
|
||||
assert provider.is_external is True
|
||||
assert provider.config_class == ""
|
||||
|
||||
def test_list_return_from_get_provider_spec(self, mock_providers):
|
||||
"""Test when get_provider_spec returns a list of specs."""
|
||||
from types import SimpleNamespace
|
||||
|
||||
from llama_stack.core.datatypes import Provider, StackRunConfig
|
||||
from llama_stack.core.distribution import get_external_providers_from_module
|
||||
from llama_stack.providers.datatypes import ProviderSpec
|
||||
|
||||
spec1 = ProviderSpec(
|
||||
api=Api.inference,
|
||||
provider_type="list_test",
|
||||
config_class="list_test.config.Config1",
|
||||
module="list_test",
|
||||
)
|
||||
spec2 = ProviderSpec(
|
||||
api=Api.inference,
|
||||
provider_type="list_test_remote",
|
||||
config_class="list_test.config.Config2",
|
||||
module="list_test",
|
||||
)
|
||||
|
||||
fake_module = SimpleNamespace(get_provider_spec=lambda: [spec1, spec2])
|
||||
|
||||
def import_side_effect(name):
|
||||
if name == "list_test.provider":
|
||||
return fake_module
|
||||
raise ModuleNotFoundError(name)
|
||||
|
||||
with patch("importlib.import_module", side_effect=import_side_effect):
|
||||
config = StackRunConfig(
|
||||
image_name="test_image",
|
||||
providers={
|
||||
"inference": [
|
||||
Provider(
|
||||
provider_id="list_test",
|
||||
provider_type="list_test",
|
||||
config={},
|
||||
module="list_test",
|
||||
)
|
||||
]
|
||||
},
|
||||
)
|
||||
registry = {Api.inference: {}}
|
||||
result = get_external_providers_from_module(registry, config, building=False)
|
||||
|
||||
# Only the matching provider_type should be added
|
||||
assert "list_test" in result[Api.inference]
|
||||
assert result[Api.inference]["list_test"].config_class == "list_test.config.Config1"
|
||||
|
||||
def test_list_return_filters_by_provider_type(self, mock_providers):
|
||||
"""Test that list return filters specs by provider_type."""
|
||||
from types import SimpleNamespace
|
||||
|
||||
from llama_stack.core.datatypes import Provider, StackRunConfig
|
||||
from llama_stack.core.distribution import get_external_providers_from_module
|
||||
from llama_stack.providers.datatypes import ProviderSpec
|
||||
|
||||
spec1 = ProviderSpec(
|
||||
api=Api.inference,
|
||||
provider_type="wanted",
|
||||
config_class="test.Config1",
|
||||
module="test",
|
||||
)
|
||||
spec2 = ProviderSpec(
|
||||
api=Api.inference,
|
||||
provider_type="unwanted",
|
||||
config_class="test.Config2",
|
||||
module="test",
|
||||
)
|
||||
|
||||
fake_module = SimpleNamespace(get_provider_spec=lambda: [spec1, spec2])
|
||||
|
||||
def import_side_effect(name):
|
||||
if name == "test.provider":
|
||||
return fake_module
|
||||
raise ModuleNotFoundError(name)
|
||||
|
||||
with patch("importlib.import_module", side_effect=import_side_effect):
|
||||
config = StackRunConfig(
|
||||
image_name="test_image",
|
||||
providers={
|
||||
"inference": [
|
||||
Provider(
|
||||
provider_id="wanted",
|
||||
provider_type="wanted",
|
||||
config={},
|
||||
module="test",
|
||||
)
|
||||
]
|
||||
},
|
||||
)
|
||||
registry = {Api.inference: {}}
|
||||
result = get_external_providers_from_module(registry, config, building=False)
|
||||
|
||||
# Only the matching provider_type should be added
|
||||
assert "wanted" in result[Api.inference]
|
||||
assert "unwanted" not in result[Api.inference]
|
||||
|
||||
def test_list_return_adds_multiple_provider_types(self, mock_providers):
|
||||
"""Test that list return adds multiple different provider_types when config requests them."""
|
||||
from types import SimpleNamespace
|
||||
|
||||
from llama_stack.core.datatypes import Provider, StackRunConfig
|
||||
from llama_stack.core.distribution import get_external_providers_from_module
|
||||
from llama_stack.providers.datatypes import ProviderSpec
|
||||
|
||||
# Module returns both inline and remote variants
|
||||
spec1 = ProviderSpec(
|
||||
api=Api.inference,
|
||||
provider_type="remote::ollama",
|
||||
config_class="test.RemoteConfig",
|
||||
module="test",
|
||||
)
|
||||
spec2 = ProviderSpec(
|
||||
api=Api.inference,
|
||||
provider_type="inline::ollama",
|
||||
config_class="test.InlineConfig",
|
||||
module="test",
|
||||
)
|
||||
|
||||
fake_module = SimpleNamespace(get_provider_spec=lambda: [spec1, spec2])
|
||||
|
||||
def import_side_effect(name):
|
||||
if name == "test.provider":
|
||||
return fake_module
|
||||
raise ModuleNotFoundError(name)
|
||||
|
||||
with patch("importlib.import_module", side_effect=import_side_effect):
|
||||
config = StackRunConfig(
|
||||
image_name="test_image",
|
||||
providers={
|
||||
"inference": [
|
||||
Provider(
|
||||
provider_id="remote_ollama",
|
||||
provider_type="remote::ollama",
|
||||
config={},
|
||||
module="test",
|
||||
),
|
||||
Provider(
|
||||
provider_id="inline_ollama",
|
||||
provider_type="inline::ollama",
|
||||
config={},
|
||||
module="test",
|
||||
),
|
||||
]
|
||||
},
|
||||
)
|
||||
registry = {Api.inference: {}}
|
||||
result = get_external_providers_from_module(registry, config, building=False)
|
||||
|
||||
# Both provider types should be added to registry
|
||||
assert "remote::ollama" in result[Api.inference]
|
||||
assert "inline::ollama" in result[Api.inference]
|
||||
assert result[Api.inference]["remote::ollama"].config_class == "test.RemoteConfig"
|
||||
assert result[Api.inference]["inline::ollama"].config_class == "test.InlineConfig"
|
||||
|
||||
def test_module_not_found_raises_value_error(self, mock_providers):
|
||||
"""Test that ModuleNotFoundError raises ValueError with helpful message."""
|
||||
from llama_stack.core.datatypes import Provider, StackRunConfig
|
||||
from llama_stack.core.distribution import get_external_providers_from_module
|
||||
|
||||
def import_side_effect(name):
|
||||
if name == "missing_module.provider":
|
||||
raise ModuleNotFoundError(name)
|
||||
raise ModuleNotFoundError(name)
|
||||
|
||||
with patch("importlib.import_module", side_effect=import_side_effect):
|
||||
config = StackRunConfig(
|
||||
image_name="test_image",
|
||||
providers={
|
||||
"inference": [
|
||||
Provider(
|
||||
provider_id="missing",
|
||||
provider_type="missing",
|
||||
config={},
|
||||
module="missing_module",
|
||||
)
|
||||
]
|
||||
},
|
||||
)
|
||||
registry = {Api.inference: {}}
|
||||
|
||||
with pytest.raises(ValueError) as exc_info:
|
||||
get_external_providers_from_module(registry, config, building=False)
|
||||
|
||||
assert "get_provider_spec not found" in str(exc_info.value)
|
||||
|
||||
def test_generic_exception_is_raised(self, mock_providers):
|
||||
"""Test that generic exceptions are properly raised."""
|
||||
from types import SimpleNamespace
|
||||
|
||||
from llama_stack.core.datatypes import Provider, StackRunConfig
|
||||
from llama_stack.core.distribution import get_external_providers_from_module
|
||||
|
||||
def bad_spec():
|
||||
raise RuntimeError("Something went wrong")
|
||||
|
||||
fake_module = SimpleNamespace(get_provider_spec=bad_spec)
|
||||
|
||||
def import_side_effect(name):
|
||||
if name == "error_module.provider":
|
||||
return fake_module
|
||||
raise ModuleNotFoundError(name)
|
||||
|
||||
with patch("importlib.import_module", side_effect=import_side_effect):
|
||||
config = StackRunConfig(
|
||||
image_name="test_image",
|
||||
providers={
|
||||
"inference": [
|
||||
Provider(
|
||||
provider_id="error",
|
||||
provider_type="error",
|
||||
config={},
|
||||
module="error_module",
|
||||
)
|
||||
]
|
||||
},
|
||||
)
|
||||
registry = {Api.inference: {}}
|
||||
|
||||
with pytest.raises(RuntimeError) as exc_info:
|
||||
get_external_providers_from_module(registry, config, building=False)
|
||||
|
||||
assert "Something went wrong" in str(exc_info.value)
|
||||
|
||||
def test_empty_provider_list(self, mock_providers):
|
||||
"""Test with empty provider list."""
|
||||
from llama_stack.core.datatypes import StackRunConfig
|
||||
from llama_stack.core.distribution import get_external_providers_from_module
|
||||
|
||||
config = StackRunConfig(
|
||||
image_name="test_image",
|
||||
providers={},
|
||||
)
|
||||
registry = {Api.inference: {}}
|
||||
result = get_external_providers_from_module(registry, config, building=False)
|
||||
|
||||
# Should return registry unchanged
|
||||
assert result == registry
|
||||
assert len(result[Api.inference]) == 0
|
||||
|
||||
def test_multiple_apis_with_providers(self, mock_providers):
|
||||
"""Test multiple APIs with providers."""
|
||||
from types import SimpleNamespace
|
||||
|
||||
from llama_stack.core.datatypes import Provider, StackRunConfig
|
||||
from llama_stack.core.distribution import get_external_providers_from_module
|
||||
from llama_stack.providers.datatypes import ProviderSpec
|
||||
|
||||
inference_spec = ProviderSpec(
|
||||
api=Api.inference,
|
||||
provider_type="inf_test",
|
||||
config_class="inf.Config",
|
||||
module="inf_test",
|
||||
)
|
||||
safety_spec = ProviderSpec(
|
||||
api=Api.safety,
|
||||
provider_type="safe_test",
|
||||
config_class="safe.Config",
|
||||
module="safe_test",
|
||||
)
|
||||
|
||||
def import_side_effect(name):
|
||||
if name == "inf_test.provider":
|
||||
return SimpleNamespace(get_provider_spec=lambda: inference_spec)
|
||||
elif name == "safe_test.provider":
|
||||
return SimpleNamespace(get_provider_spec=lambda: safety_spec)
|
||||
raise ModuleNotFoundError(name)
|
||||
|
||||
with patch("importlib.import_module", side_effect=import_side_effect):
|
||||
config = StackRunConfig(
|
||||
image_name="test_image",
|
||||
providers={
|
||||
"inference": [
|
||||
Provider(
|
||||
provider_id="inf",
|
||||
provider_type="inf_test",
|
||||
config={},
|
||||
module="inf_test",
|
||||
)
|
||||
],
|
||||
"safety": [
|
||||
Provider(
|
||||
provider_id="safe",
|
||||
provider_type="safe_test",
|
||||
config={},
|
||||
module="safe_test",
|
||||
)
|
||||
],
|
||||
},
|
||||
)
|
||||
registry = {Api.inference: {}, Api.safety: {}}
|
||||
result = get_external_providers_from_module(registry, config, building=False)
|
||||
|
||||
assert "inf_test" in result[Api.inference]
|
||||
assert "safe_test" in result[Api.safety]
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue