mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-10-03 19:57:35 +00:00
feat: allow for multiple external provider specs
when using the providers.d method of installation users could hand craft their AdapterSpec's to use overlapping code meaning one repo could contain an inline and remote impl. Currently installing a provider via module does not allow for that as each repo is only allowed to have one `get_provider_spec` method with one Spec returned add an optional way for `get_provider_spec` to return a list of `ProviderSpec` where each can be either an inline or remote impl. Note: the `adapter_type` in `get_provider_spec` MUST match the `provider_type` in the build/run yaml for this to work. resolves #3226 Signed-off-by: Charlie Doern <cdoern@redhat.com>
This commit is contained in:
parent
9f6c658f2a
commit
92d0470f74
2 changed files with 505 additions and 3 deletions
|
@ -243,6 +243,7 @@ def get_external_providers_from_module(
|
|||
spec = module.get_provider_spec()
|
||||
else:
|
||||
# pass in a partially filled out provider spec to satisfy the registry -- knowing we will be overwriting it later upon build and run
|
||||
# in the case we are building we CANNOT import this module of course because it has not been installed.
|
||||
spec = ProviderSpec(
|
||||
api=Api(provider_api),
|
||||
provider_type=provider.provider_type,
|
||||
|
@ -251,8 +252,19 @@ def get_external_providers_from_module(
|
|||
config_class="",
|
||||
)
|
||||
provider_type = provider.provider_type
|
||||
# in the case we are building we CANNOT import this module of course because it has not been installed.
|
||||
# return a partially filled out spec that the build script will populate.
|
||||
if isinstance(spec, list):
|
||||
# optionally allow people to pass inline and remote provider specs as a returned list.
|
||||
# with the old method, users could pass in directories of specs using overlapping code
|
||||
# we want to ensure we preserve that flexibility in this method.
|
||||
logger.info(
|
||||
f"Detected a list of external provider specs from {provider.module} adding all to the registry"
|
||||
)
|
||||
for provider_spec in spec:
|
||||
if provider_spec.provider_type != provider.provider_type:
|
||||
continue
|
||||
logger.info(f"Adding {provider.provider_type} to registry")
|
||||
registry[Api(provider_api)][provider.provider_type] = provider_spec
|
||||
else:
|
||||
registry[Api(provider_api)][provider_type] = spec
|
||||
except ModuleNotFoundError as exc:
|
||||
raise ValueError(
|
||||
|
|
|
@ -390,3 +390,493 @@ pip_packages:
|
|||
assert provider.is_external is True
|
||||
# config_class is empty string in partial spec
|
||||
assert provider.config_class == ""
|
||||
|
||||
|
||||
class TestGetExternalProvidersFromModule:
|
||||
"""Test suite for installing external providers from module."""
|
||||
|
||||
def test_stackrunconfig_provider_without_module(self, mock_providers):
|
||||
"""Test that providers without module attribute are skipped."""
|
||||
from llama_stack.core.datatypes import Provider, StackRunConfig
|
||||
from llama_stack.core.distribution import get_external_providers_from_module
|
||||
|
||||
import_module_side_effect = make_import_module_side_effect()
|
||||
|
||||
with patch("importlib.import_module", side_effect=import_module_side_effect):
|
||||
config = StackRunConfig(
|
||||
image_name="test_image",
|
||||
providers={
|
||||
"inference": [
|
||||
Provider(
|
||||
provider_id="no_module",
|
||||
provider_type="no_module",
|
||||
config={},
|
||||
)
|
||||
]
|
||||
},
|
||||
)
|
||||
registry = {Api.inference: {}}
|
||||
result = get_external_providers_from_module(registry, config, building=False)
|
||||
# Should not add anything to registry
|
||||
assert len(result[Api.inference]) == 0
|
||||
|
||||
def test_stackrunconfig_provider_with_none_module(self, mock_providers):
|
||||
"""Test that providers with None module value are skipped."""
|
||||
from llama_stack.core.datatypes import Provider, StackRunConfig
|
||||
from llama_stack.core.distribution import get_external_providers_from_module
|
||||
|
||||
import_module_side_effect = make_import_module_side_effect()
|
||||
|
||||
with patch("importlib.import_module", side_effect=import_module_side_effect):
|
||||
config = StackRunConfig(
|
||||
image_name="test_image",
|
||||
providers={
|
||||
"inference": [
|
||||
Provider(
|
||||
provider_id="none_module",
|
||||
provider_type="none_module",
|
||||
config={},
|
||||
module=None,
|
||||
)
|
||||
]
|
||||
},
|
||||
)
|
||||
registry = {Api.inference: {}}
|
||||
result = get_external_providers_from_module(registry, config, building=False)
|
||||
# Should not add anything to registry
|
||||
assert len(result[Api.inference]) == 0
|
||||
|
||||
def test_stackrunconfig_with_version_spec(self, mock_providers):
|
||||
"""Test provider with module containing version spec (e.g., package==1.0.0)."""
|
||||
from types import SimpleNamespace
|
||||
|
||||
from llama_stack.core.datatypes import Provider, StackRunConfig
|
||||
from llama_stack.core.distribution import get_external_providers_from_module
|
||||
from llama_stack.providers.datatypes import ProviderSpec
|
||||
|
||||
fake_spec = ProviderSpec(
|
||||
api=Api.inference,
|
||||
provider_type="versioned_test",
|
||||
config_class="versioned_test.config.VersionedTestConfig",
|
||||
module="versioned_test==1.0.0",
|
||||
)
|
||||
fake_module = SimpleNamespace(get_provider_spec=lambda: fake_spec)
|
||||
|
||||
def import_side_effect(name):
|
||||
if name == "versioned_test.provider":
|
||||
return fake_module
|
||||
raise ModuleNotFoundError(name)
|
||||
|
||||
with patch("importlib.import_module", side_effect=import_side_effect):
|
||||
config = StackRunConfig(
|
||||
image_name="test_image",
|
||||
providers={
|
||||
"inference": [
|
||||
Provider(
|
||||
provider_id="versioned",
|
||||
provider_type="versioned_test",
|
||||
config={},
|
||||
module="versioned_test==1.0.0",
|
||||
)
|
||||
]
|
||||
},
|
||||
)
|
||||
registry = {Api.inference: {}}
|
||||
result = get_external_providers_from_module(registry, config, building=False)
|
||||
assert "versioned_test" in result[Api.inference]
|
||||
assert result[Api.inference]["versioned_test"].module == "versioned_test==1.0.0"
|
||||
|
||||
def test_buildconfig_does_not_import_module(self, mock_providers):
|
||||
"""Test that BuildConfig does not import the module (building=True)."""
|
||||
from llama_stack.core.datatypes import BuildConfig, BuildProvider, DistributionSpec
|
||||
from llama_stack.core.distribution import get_external_providers_from_module
|
||||
|
||||
build_config = BuildConfig(
|
||||
version=2,
|
||||
image_type="container",
|
||||
image_name="test_image",
|
||||
distribution_spec=DistributionSpec(
|
||||
description="test",
|
||||
providers={
|
||||
"inference": [
|
||||
BuildProvider(
|
||||
provider_type="build_test",
|
||||
module="build_test==1.0.0",
|
||||
)
|
||||
]
|
||||
},
|
||||
),
|
||||
)
|
||||
|
||||
# Should not call import_module at all when building
|
||||
with patch("importlib.import_module") as mock_import:
|
||||
registry = {Api.inference: {}}
|
||||
result = get_external_providers_from_module(registry, build_config, building=True)
|
||||
|
||||
# Verify module was NOT imported
|
||||
mock_import.assert_not_called()
|
||||
|
||||
# Verify partial spec was created
|
||||
assert "build_test" in result[Api.inference]
|
||||
provider = result[Api.inference]["build_test"]
|
||||
assert provider.module == "build_test==1.0.0"
|
||||
assert provider.is_external is True
|
||||
assert provider.config_class == ""
|
||||
assert provider.api == Api.inference
|
||||
|
||||
def test_buildconfig_multiple_providers(self, mock_providers):
|
||||
"""Test BuildConfig with multiple providers for the same API."""
|
||||
from llama_stack.core.datatypes import BuildConfig, BuildProvider, DistributionSpec
|
||||
from llama_stack.core.distribution import get_external_providers_from_module
|
||||
|
||||
build_config = BuildConfig(
|
||||
version=2,
|
||||
image_type="container",
|
||||
image_name="test_image",
|
||||
distribution_spec=DistributionSpec(
|
||||
description="test",
|
||||
providers={
|
||||
"inference": [
|
||||
BuildProvider(provider_type="provider1", module="provider1"),
|
||||
BuildProvider(provider_type="provider2", module="provider2"),
|
||||
]
|
||||
},
|
||||
),
|
||||
)
|
||||
|
||||
with patch("importlib.import_module") as mock_import:
|
||||
registry = {Api.inference: {}}
|
||||
result = get_external_providers_from_module(registry, build_config, building=True)
|
||||
|
||||
mock_import.assert_not_called()
|
||||
assert "provider1" in result[Api.inference]
|
||||
assert "provider2" in result[Api.inference]
|
||||
|
||||
def test_distributionspec_does_not_import_module(self, mock_providers):
|
||||
"""Test that DistributionSpec does not import the module (building=True)."""
|
||||
from llama_stack.core.datatypes import BuildProvider, DistributionSpec
|
||||
from llama_stack.core.distribution import get_external_providers_from_module
|
||||
|
||||
dist_spec = DistributionSpec(
|
||||
description="test distribution",
|
||||
providers={
|
||||
"inference": [
|
||||
BuildProvider(
|
||||
provider_type="dist_test",
|
||||
module="dist_test==2.0.0",
|
||||
)
|
||||
]
|
||||
},
|
||||
)
|
||||
|
||||
# Should not call import_module at all when building
|
||||
with patch("importlib.import_module") as mock_import:
|
||||
registry = {Api.inference: {}}
|
||||
result = get_external_providers_from_module(registry, dist_spec, building=True)
|
||||
|
||||
# Verify module was NOT imported
|
||||
mock_import.assert_not_called()
|
||||
|
||||
# Verify partial spec was created
|
||||
assert "dist_test" in result[Api.inference]
|
||||
provider = result[Api.inference]["dist_test"]
|
||||
assert provider.module == "dist_test==2.0.0"
|
||||
assert provider.is_external is True
|
||||
assert provider.config_class == ""
|
||||
|
||||
def test_list_return_from_get_provider_spec(self, mock_providers):
|
||||
"""Test when get_provider_spec returns a list of specs."""
|
||||
from types import SimpleNamespace
|
||||
|
||||
from llama_stack.core.datatypes import Provider, StackRunConfig
|
||||
from llama_stack.core.distribution import get_external_providers_from_module
|
||||
from llama_stack.providers.datatypes import ProviderSpec
|
||||
|
||||
spec1 = ProviderSpec(
|
||||
api=Api.inference,
|
||||
provider_type="list_test",
|
||||
config_class="list_test.config.Config1",
|
||||
module="list_test",
|
||||
)
|
||||
spec2 = ProviderSpec(
|
||||
api=Api.inference,
|
||||
provider_type="list_test_remote",
|
||||
config_class="list_test.config.Config2",
|
||||
module="list_test",
|
||||
)
|
||||
|
||||
fake_module = SimpleNamespace(get_provider_spec=lambda: [spec1, spec2])
|
||||
|
||||
def import_side_effect(name):
|
||||
if name == "list_test.provider":
|
||||
return fake_module
|
||||
raise ModuleNotFoundError(name)
|
||||
|
||||
with patch("importlib.import_module", side_effect=import_side_effect):
|
||||
config = StackRunConfig(
|
||||
image_name="test_image",
|
||||
providers={
|
||||
"inference": [
|
||||
Provider(
|
||||
provider_id="list_test",
|
||||
provider_type="list_test",
|
||||
config={},
|
||||
module="list_test",
|
||||
)
|
||||
]
|
||||
},
|
||||
)
|
||||
registry = {Api.inference: {}}
|
||||
result = get_external_providers_from_module(registry, config, building=False)
|
||||
|
||||
# Only the matching provider_type should be added
|
||||
assert "list_test" in result[Api.inference]
|
||||
assert result[Api.inference]["list_test"].config_class == "list_test.config.Config1"
|
||||
|
||||
def test_list_return_filters_by_provider_type(self, mock_providers):
|
||||
"""Test that list return filters specs by provider_type."""
|
||||
from types import SimpleNamespace
|
||||
|
||||
from llama_stack.core.datatypes import Provider, StackRunConfig
|
||||
from llama_stack.core.distribution import get_external_providers_from_module
|
||||
from llama_stack.providers.datatypes import ProviderSpec
|
||||
|
||||
spec1 = ProviderSpec(
|
||||
api=Api.inference,
|
||||
provider_type="wanted",
|
||||
config_class="test.Config1",
|
||||
module="test",
|
||||
)
|
||||
spec2 = ProviderSpec(
|
||||
api=Api.inference,
|
||||
provider_type="unwanted",
|
||||
config_class="test.Config2",
|
||||
module="test",
|
||||
)
|
||||
|
||||
fake_module = SimpleNamespace(get_provider_spec=lambda: [spec1, spec2])
|
||||
|
||||
def import_side_effect(name):
|
||||
if name == "test.provider":
|
||||
return fake_module
|
||||
raise ModuleNotFoundError(name)
|
||||
|
||||
with patch("importlib.import_module", side_effect=import_side_effect):
|
||||
config = StackRunConfig(
|
||||
image_name="test_image",
|
||||
providers={
|
||||
"inference": [
|
||||
Provider(
|
||||
provider_id="wanted",
|
||||
provider_type="wanted",
|
||||
config={},
|
||||
module="test",
|
||||
)
|
||||
]
|
||||
},
|
||||
)
|
||||
registry = {Api.inference: {}}
|
||||
result = get_external_providers_from_module(registry, config, building=False)
|
||||
|
||||
# Only the matching provider_type should be added
|
||||
assert "wanted" in result[Api.inference]
|
||||
assert "unwanted" not in result[Api.inference]
|
||||
|
||||
def test_list_return_adds_multiple_provider_types(self, mock_providers):
|
||||
"""Test that list return adds multiple different provider_types when config requests them."""
|
||||
from types import SimpleNamespace
|
||||
|
||||
from llama_stack.core.datatypes import Provider, StackRunConfig
|
||||
from llama_stack.core.distribution import get_external_providers_from_module
|
||||
from llama_stack.providers.datatypes import ProviderSpec
|
||||
|
||||
# Module returns both inline and remote variants
|
||||
spec1 = ProviderSpec(
|
||||
api=Api.inference,
|
||||
provider_type="remote::ollama",
|
||||
config_class="test.RemoteConfig",
|
||||
module="test",
|
||||
)
|
||||
spec2 = ProviderSpec(
|
||||
api=Api.inference,
|
||||
provider_type="inline::ollama",
|
||||
config_class="test.InlineConfig",
|
||||
module="test",
|
||||
)
|
||||
|
||||
fake_module = SimpleNamespace(get_provider_spec=lambda: [spec1, spec2])
|
||||
|
||||
def import_side_effect(name):
|
||||
if name == "test.provider":
|
||||
return fake_module
|
||||
raise ModuleNotFoundError(name)
|
||||
|
||||
with patch("importlib.import_module", side_effect=import_side_effect):
|
||||
config = StackRunConfig(
|
||||
image_name="test_image",
|
||||
providers={
|
||||
"inference": [
|
||||
Provider(
|
||||
provider_id="remote_ollama",
|
||||
provider_type="remote::ollama",
|
||||
config={},
|
||||
module="test",
|
||||
),
|
||||
Provider(
|
||||
provider_id="inline_ollama",
|
||||
provider_type="inline::ollama",
|
||||
config={},
|
||||
module="test",
|
||||
),
|
||||
]
|
||||
},
|
||||
)
|
||||
registry = {Api.inference: {}}
|
||||
result = get_external_providers_from_module(registry, config, building=False)
|
||||
|
||||
# Both provider types should be added to registry
|
||||
assert "remote::ollama" in result[Api.inference]
|
||||
assert "inline::ollama" in result[Api.inference]
|
||||
assert result[Api.inference]["remote::ollama"].config_class == "test.RemoteConfig"
|
||||
assert result[Api.inference]["inline::ollama"].config_class == "test.InlineConfig"
|
||||
|
||||
def test_module_not_found_raises_value_error(self, mock_providers):
|
||||
"""Test that ModuleNotFoundError raises ValueError with helpful message."""
|
||||
from llama_stack.core.datatypes import Provider, StackRunConfig
|
||||
from llama_stack.core.distribution import get_external_providers_from_module
|
||||
|
||||
def import_side_effect(name):
|
||||
if name == "missing_module.provider":
|
||||
raise ModuleNotFoundError(name)
|
||||
raise ModuleNotFoundError(name)
|
||||
|
||||
with patch("importlib.import_module", side_effect=import_side_effect):
|
||||
config = StackRunConfig(
|
||||
image_name="test_image",
|
||||
providers={
|
||||
"inference": [
|
||||
Provider(
|
||||
provider_id="missing",
|
||||
provider_type="missing",
|
||||
config={},
|
||||
module="missing_module",
|
||||
)
|
||||
]
|
||||
},
|
||||
)
|
||||
registry = {Api.inference: {}}
|
||||
|
||||
with pytest.raises(ValueError) as exc_info:
|
||||
get_external_providers_from_module(registry, config, building=False)
|
||||
|
||||
assert "get_provider_spec not found" in str(exc_info.value)
|
||||
|
||||
def test_generic_exception_is_raised(self, mock_providers):
|
||||
"""Test that generic exceptions are properly raised."""
|
||||
from types import SimpleNamespace
|
||||
|
||||
from llama_stack.core.datatypes import Provider, StackRunConfig
|
||||
from llama_stack.core.distribution import get_external_providers_from_module
|
||||
|
||||
def bad_spec():
|
||||
raise RuntimeError("Something went wrong")
|
||||
|
||||
fake_module = SimpleNamespace(get_provider_spec=bad_spec)
|
||||
|
||||
def import_side_effect(name):
|
||||
if name == "error_module.provider":
|
||||
return fake_module
|
||||
raise ModuleNotFoundError(name)
|
||||
|
||||
with patch("importlib.import_module", side_effect=import_side_effect):
|
||||
config = StackRunConfig(
|
||||
image_name="test_image",
|
||||
providers={
|
||||
"inference": [
|
||||
Provider(
|
||||
provider_id="error",
|
||||
provider_type="error",
|
||||
config={},
|
||||
module="error_module",
|
||||
)
|
||||
]
|
||||
},
|
||||
)
|
||||
registry = {Api.inference: {}}
|
||||
|
||||
with pytest.raises(RuntimeError) as exc_info:
|
||||
get_external_providers_from_module(registry, config, building=False)
|
||||
|
||||
assert "Something went wrong" in str(exc_info.value)
|
||||
|
||||
def test_empty_provider_list(self, mock_providers):
|
||||
"""Test with empty provider list."""
|
||||
from llama_stack.core.datatypes import StackRunConfig
|
||||
from llama_stack.core.distribution import get_external_providers_from_module
|
||||
|
||||
config = StackRunConfig(
|
||||
image_name="test_image",
|
||||
providers={},
|
||||
)
|
||||
registry = {Api.inference: {}}
|
||||
result = get_external_providers_from_module(registry, config, building=False)
|
||||
|
||||
# Should return registry unchanged
|
||||
assert result == registry
|
||||
assert len(result[Api.inference]) == 0
|
||||
|
||||
def test_multiple_apis_with_providers(self, mock_providers):
|
||||
"""Test multiple APIs with providers."""
|
||||
from types import SimpleNamespace
|
||||
|
||||
from llama_stack.core.datatypes import Provider, StackRunConfig
|
||||
from llama_stack.core.distribution import get_external_providers_from_module
|
||||
from llama_stack.providers.datatypes import ProviderSpec
|
||||
|
||||
inference_spec = ProviderSpec(
|
||||
api=Api.inference,
|
||||
provider_type="inf_test",
|
||||
config_class="inf.Config",
|
||||
module="inf_test",
|
||||
)
|
||||
safety_spec = ProviderSpec(
|
||||
api=Api.safety,
|
||||
provider_type="safe_test",
|
||||
config_class="safe.Config",
|
||||
module="safe_test",
|
||||
)
|
||||
|
||||
def import_side_effect(name):
|
||||
if name == "inf_test.provider":
|
||||
return SimpleNamespace(get_provider_spec=lambda: inference_spec)
|
||||
elif name == "safe_test.provider":
|
||||
return SimpleNamespace(get_provider_spec=lambda: safety_spec)
|
||||
raise ModuleNotFoundError(name)
|
||||
|
||||
with patch("importlib.import_module", side_effect=import_side_effect):
|
||||
config = StackRunConfig(
|
||||
image_name="test_image",
|
||||
providers={
|
||||
"inference": [
|
||||
Provider(
|
||||
provider_id="inf",
|
||||
provider_type="inf_test",
|
||||
config={},
|
||||
module="inf_test",
|
||||
)
|
||||
],
|
||||
"safety": [
|
||||
Provider(
|
||||
provider_id="safe",
|
||||
provider_type="safe_test",
|
||||
config={},
|
||||
module="safe_test",
|
||||
)
|
||||
],
|
||||
},
|
||||
)
|
||||
registry = {Api.inference: {}, Api.safety: {}}
|
||||
result = get_external_providers_from_module(registry, config, building=False)
|
||||
|
||||
assert "inf_test" in result[Api.inference]
|
||||
assert "safe_test" in result[Api.safety]
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue