refactor: install external provider via module

using `module` in the provider class and the fact that `build` and `run` configs BOTH use the `class Provider` now, enables us to point to an external provider via a `module`.

For example, say this is in your build config:

```
- provider_id: ramalama
  provider_type: remote::ramalama
  module: ramalama_stack
```

during build (in the various scripts), additionally to installing any pip dependencies we will also install this module
and use the `get_provider_spec` method to retreive the ProviderSpec that is currently specified using `providers.d`.

Most (if not all) external providers today have a `get_provider_spec` method that sits unused. Utilizing this method rather than the providers.d route allows for a much easier installation process for external providers and limits the amount of extra configuration
a regular user has to do to get their stack off the ground.

In production so far, providing instructions for installing external providers for users has been difficult: they need to install the module as a pre-req, create the providers.d directory, copy in the provider spec, and also copy in the necessary build/run yaml files.

Using the module is a more seamless discovery method

Signed-off-by: Charlie Doern <cdoern@redhat.com>
This commit is contained in:
Charlie Doern 2025-07-06 20:00:58 -04:00
parent 233f8c81bf
commit dcc6b1eee9
6 changed files with 508 additions and 232 deletions

View file

@ -12,6 +12,7 @@ from typing import Any
import yaml
from pydantic import BaseModel
from llama_stack.distribution.datatypes import BuildConfig, DistributionSpec
from llama_stack.distribution.external import load_external_apis
from llama_stack.log import get_logger
from llama_stack.providers.datatypes import (
@ -97,12 +98,10 @@ def _load_inline_provider_spec(spec_data: dict[str, Any], api: Api, provider_nam
return spec
def get_provider_registry(
config=None,
) -> dict[Api, dict[str, ProviderSpec]]:
def get_provider_registry(config=None) -> dict[Api, dict[str, ProviderSpec]]:
"""Get the provider registry, optionally including external providers.
This function loads both built-in providers and external providers from YAML files.
This function loads both built-in providers and external providers from YAML files or from their provided modules.
External providers are loaded from a directory structure like:
providers.d/
@ -123,8 +122,13 @@ def get_provider_registry(
safety/
llama-guard.yaml
This method is overloaded in that it can be called from a variety of places: during build, during run, during stack construction.
So when building external providers from a module, there are scenarios where the pip package required to import the module might not be available yet.
There is special handling for all of the potential cases this method can be called from.
Args:
config: Optional object containing the external providers directory path
building: Optional bool delineating whether or not this is being called from a build process
Returns:
A dictionary mapping APIs to their available providers
@ -162,46 +166,112 @@ def get_provider_registry(
"Install the API package to load any in-tree providers for this API."
)
# Check if config has the external_providers_dir attribute
if config and hasattr(config, "external_providers_dir") and config.external_providers_dir:
external_providers_dir = os.path.abspath(os.path.expanduser(config.external_providers_dir))
if not os.path.exists(external_providers_dir):
raise FileNotFoundError(f"External providers directory not found: {external_providers_dir}")
logger.info(f"Loading external providers from {external_providers_dir}")
# Check if config has external providers
if config:
if hasattr(config, "external_providers_dir") and config.external_providers_dir:
registry = get_external_providers_from_dir(registry, config)
# else lets check for modules in each provider
registry = get_external_providers_from_module(
registry=registry,
config=config,
building=(isinstance(config, BuildConfig) or isinstance(config, DistributionSpec)),
)
for api in providable_apis():
api_name = api.name.lower()
# Process both remote and inline providers
for provider_type in ["remote", "inline"]:
api_dir = os.path.join(external_providers_dir, provider_type, api_name)
if not os.path.exists(api_dir):
logger.debug(f"No {provider_type} provider directory found for {api_name}")
continue
# Look for provider spec files in the API directory
for spec_path in glob.glob(os.path.join(api_dir, "*.yaml")):
provider_name = os.path.splitext(os.path.basename(spec_path))[0]
logger.info(f"Loading {provider_type} provider spec from {spec_path}")
try:
with open(spec_path) as f:
spec_data = yaml.safe_load(f)
if provider_type == "remote":
spec = _load_remote_provider_spec(spec_data, api)
provider_type_key = f"remote::{provider_name}"
else:
spec = _load_inline_provider_spec(spec_data, api, provider_name)
provider_type_key = f"inline::{provider_name}"
if provider_type_key in registry[api]:
logger.warning(f"Overriding already registered provider {provider_type_key} for {api.name}")
registry[api][provider_type_key] = spec
logger.info(f"Successfully loaded external provider {provider_type_key}")
except yaml.YAMLError as yaml_err:
logger.error(f"Failed to parse YAML file {spec_path}: {yaml_err}")
raise yaml_err
except Exception as e:
logger.error(f"Failed to load provider spec from {spec_path}: {e}")
raise e
return registry
def get_external_providers_from_dir(
registry: dict[Api, dict[str, ProviderSpec]], config
) -> dict[Api, dict[str, ProviderSpec]]:
logger.warning(
"Specifying external providers via `external_providers_dir` is being deprecated. Please specify `module:` in the provider instead."
)
external_providers_dir = os.path.abspath(os.path.expanduser(config.external_providers_dir))
if not os.path.exists(external_providers_dir):
raise FileNotFoundError(f"External providers directory not found: {external_providers_dir}")
logger.info(f"Loading external providers from {external_providers_dir}")
for api in providable_apis():
api_name = api.name.lower()
# Process both remote and inline providers
for provider_type in ["remote", "inline"]:
api_dir = os.path.join(external_providers_dir, provider_type, api_name)
if not os.path.exists(api_dir):
logger.debug(f"No {provider_type} provider directory found for {api_name}")
continue
# Look for provider spec files in the API directory
for spec_path in glob.glob(os.path.join(api_dir, "*.yaml")):
provider_name = os.path.splitext(os.path.basename(spec_path))[0]
logger.info(f"Loading {provider_type} provider spec from {spec_path}")
try:
with open(spec_path) as f:
spec_data = yaml.safe_load(f)
if provider_type == "remote":
spec = _load_remote_provider_spec(spec_data, api)
provider_type_key = f"remote::{provider_name}"
else:
spec = _load_inline_provider_spec(spec_data, api, provider_name)
provider_type_key = f"inline::{provider_name}"
logger.info(f"Loaded {provider_type} provider spec for {provider_type_key} from {spec_path}")
if provider_type_key in registry[api]:
logger.warning(f"Overriding already registered provider {provider_type_key} for {api.name}")
registry[api][provider_type_key] = spec
logger.info(f"Successfully loaded external provider {provider_type_key}")
except yaml.YAMLError as yaml_err:
logger.error(f"Failed to parse YAML file {spec_path}: {yaml_err}")
raise yaml_err
except Exception as e:
logger.error(f"Failed to load provider spec from {spec_path}: {e}")
raise e
return registry
def get_external_providers_from_module(
registry: dict[Api, dict[str, ProviderSpec]], config, building: bool
) -> dict[Api, dict[str, ProviderSpec]]:
provider_list = None
if isinstance(config, BuildConfig):
provider_list = config.distribution_spec.providers.items()
else:
provider_list = config.providers.items()
if provider_list is None:
logger.warning("Could not get list of providers from config")
return registry
for provider_api, providers in provider_list:
for provider in providers:
if not hasattr(provider, "module") or provider.module is None:
continue
# get provider using module
try:
if not building:
package_name = provider.module.split("==")[0]
module = importlib.import_module(f"{package_name}.provider")
# if config class is wrong you will get an error saying module could not be imported
spec = module.get_provider_spec()
else:
# pass in a partially filled out provider spec to satisfy the registry -- knowing we will be overwriting it later upon build and run
spec = ProviderSpec(
api=Api(provider_api),
provider_type=provider.provider_type,
is_external=True,
module=provider.module,
config_class="",
)
provider_type = provider.provider_type
# in the case we are building we CANNOT import this module of course because it has not been installed.
# return a partially filled out spec that the build script will populate.
registry[Api(provider_api)][provider_type] = spec
except ModuleNotFoundError as exc:
raise ValueError(
"get_provider_spec not found. If specifying an external provider via `module` in the Provider spec, the Provider must have the `provider.get_provider_spec` module available"
) from exc
except Exception as e:
logger.error(f"Failed to load provider spec from module {provider.module}: {e}")
raise e
return registry