mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-07-27 06:28:50 +00:00
# What does this PR do?
Today, external providers are installed via the `external_providers_dir`
in the config. This necessitates users to understand the `ProviderSpec`
and set up their directories accordingly. This process splits up the
config for the stack across multiple files, directories, and formats.
Most (if not all) external providers today have a
[get_provider_spec](559cb18fbb/src/ramalama_stack/provider.py (L9)
)
method that sits unused. Utilizing this method rather than the
providers.d route allows for a much easier installation process for
external providers and limits the amount of extra configuration a
regular user has to do to get their stack off the ground.
To accomplish this and wire it throughout the build process, Introduce
the concept of a `module` for users to specify for an external provider
upon build time. In order to facilitate this, align the build and run
spec to use `Provider` class rather than the stringified provider_type
that build currently uses.
For example, say this is in your build config:
```
- provider_id: ramalama
provider_type: remote::ramalama
module: ramalama_stack
```
during build (in the various `build_...` scripts), additionally to
installing any pip dependencies we will also install this module and use
the `get_provider_spec` method to retrieve the ProviderSpec that is
currently specified using `providers.d`.
In production so far, providing instructions for installing external
providers for users has been difficult: they need to install the module
as a pre-req, create the providers.d directory, copy in the provider
spec, and also copy in the necessary build/run yaml files. Accessing an
external provider should be as easy as possible, and pointing to its
installable module aligns more with the rest of our build and dependency
management process.
For now, `external_providers_dir` still exists as an alternate more
declarative method of using external providers.
## Test Plan
added an integration test installing an external provider from module
and more unit test coverage for `get_provider_registry`
( the warning in yellow is expected, the module is installed inside of
the build env, not where we are running the command)
<img width="1119" height="400" alt="Screenshot 2025-07-24 at 11 30
48 AM"
src="https://github.com/user-attachments/assets/1efbaf45-b9e8-451a-bd63-264ed664706d"
/>
<img width="1154" height="618" alt="Screenshot 2025-07-24 at 11 31
14 AM"
src="https://github.com/user-attachments/assets/feb2b3ea-c5dd-418e-9662-9a3bd5dd6bdc"
/>
---------
Signed-off-by: Charlie Doern <cdoern@redhat.com>
277 lines
11 KiB
Python
277 lines
11 KiB
Python
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
# All rights reserved.
|
|
#
|
|
# This source code is licensed under the terms described in the LICENSE file in
|
|
# the root directory of this source tree.
|
|
|
|
import glob
|
|
import importlib
|
|
import os
|
|
from typing import Any
|
|
|
|
import yaml
|
|
from pydantic import BaseModel
|
|
|
|
from llama_stack.distribution.datatypes import BuildConfig, DistributionSpec
|
|
from llama_stack.distribution.external import load_external_apis
|
|
from llama_stack.log import get_logger
|
|
from llama_stack.providers.datatypes import (
|
|
AdapterSpec,
|
|
Api,
|
|
InlineProviderSpec,
|
|
ProviderSpec,
|
|
remote_provider_spec,
|
|
)
|
|
|
|
logger = get_logger(name=__name__, category="core")
|
|
|
|
|
|
def stack_apis() -> list[Api]:
|
|
return list(Api)
|
|
|
|
|
|
class AutoRoutedApiInfo(BaseModel):
|
|
routing_table_api: Api
|
|
router_api: Api
|
|
|
|
|
|
def builtin_automatically_routed_apis() -> list[AutoRoutedApiInfo]:
|
|
return [
|
|
AutoRoutedApiInfo(
|
|
routing_table_api=Api.models,
|
|
router_api=Api.inference,
|
|
),
|
|
AutoRoutedApiInfo(
|
|
routing_table_api=Api.shields,
|
|
router_api=Api.safety,
|
|
),
|
|
AutoRoutedApiInfo(
|
|
routing_table_api=Api.vector_dbs,
|
|
router_api=Api.vector_io,
|
|
),
|
|
AutoRoutedApiInfo(
|
|
routing_table_api=Api.datasets,
|
|
router_api=Api.datasetio,
|
|
),
|
|
AutoRoutedApiInfo(
|
|
routing_table_api=Api.scoring_functions,
|
|
router_api=Api.scoring,
|
|
),
|
|
AutoRoutedApiInfo(
|
|
routing_table_api=Api.benchmarks,
|
|
router_api=Api.eval,
|
|
),
|
|
AutoRoutedApiInfo(
|
|
routing_table_api=Api.tool_groups,
|
|
router_api=Api.tool_runtime,
|
|
),
|
|
]
|
|
|
|
|
|
def providable_apis() -> list[Api]:
|
|
routing_table_apis = {x.routing_table_api for x in builtin_automatically_routed_apis()}
|
|
return [api for api in Api if api not in routing_table_apis and api != Api.inspect and api != Api.providers]
|
|
|
|
|
|
def _load_remote_provider_spec(spec_data: dict[str, Any], api: Api) -> ProviderSpec:
|
|
adapter = AdapterSpec(**spec_data["adapter"])
|
|
spec = remote_provider_spec(
|
|
api=api,
|
|
adapter=adapter,
|
|
api_dependencies=[Api(dep) for dep in spec_data.get("api_dependencies", [])],
|
|
)
|
|
return spec
|
|
|
|
|
|
def _load_inline_provider_spec(spec_data: dict[str, Any], api: Api, provider_name: str) -> ProviderSpec:
|
|
spec = InlineProviderSpec(
|
|
api=api,
|
|
provider_type=f"inline::{provider_name}",
|
|
pip_packages=spec_data.get("pip_packages", []),
|
|
module=spec_data["module"],
|
|
config_class=spec_data["config_class"],
|
|
api_dependencies=[Api(dep) for dep in spec_data.get("api_dependencies", [])],
|
|
optional_api_dependencies=[Api(dep) for dep in spec_data.get("optional_api_dependencies", [])],
|
|
provider_data_validator=spec_data.get("provider_data_validator"),
|
|
container_image=spec_data.get("container_image"),
|
|
)
|
|
return spec
|
|
|
|
|
|
def get_provider_registry(config=None) -> dict[Api, dict[str, ProviderSpec]]:
|
|
"""Get the provider registry, optionally including external providers.
|
|
|
|
This function loads both built-in providers and external providers from YAML files or from their provided modules.
|
|
External providers are loaded from a directory structure like:
|
|
|
|
providers.d/
|
|
remote/
|
|
inference/
|
|
custom_ollama.yaml
|
|
vllm.yaml
|
|
vector_io/
|
|
qdrant.yaml
|
|
safety/
|
|
llama-guard.yaml
|
|
inline/
|
|
inference/
|
|
custom_ollama.yaml
|
|
vllm.yaml
|
|
vector_io/
|
|
qdrant.yaml
|
|
safety/
|
|
llama-guard.yaml
|
|
|
|
This method is overloaded in that it can be called from a variety of places: during build, during run, during stack construction.
|
|
So when building external providers from a module, there are scenarios where the pip package required to import the module might not be available yet.
|
|
There is special handling for all of the potential cases this method can be called from.
|
|
|
|
Args:
|
|
config: Optional object containing the external providers directory path
|
|
building: Optional bool delineating whether or not this is being called from a build process
|
|
|
|
Returns:
|
|
A dictionary mapping APIs to their available providers
|
|
|
|
Raises:
|
|
FileNotFoundError: If the external providers directory doesn't exist
|
|
ValueError: If any provider spec is invalid
|
|
"""
|
|
|
|
registry: dict[Api, dict[str, ProviderSpec]] = {}
|
|
for api in providable_apis():
|
|
name = api.name.lower()
|
|
logger.debug(f"Importing module {name}")
|
|
try:
|
|
module = importlib.import_module(f"llama_stack.providers.registry.{name}")
|
|
registry[api] = {a.provider_type: a for a in module.available_providers()}
|
|
except ImportError as e:
|
|
logger.warning(f"Failed to import module {name}: {e}")
|
|
|
|
# Refresh providable APIs with external APIs if any
|
|
external_apis = load_external_apis(config)
|
|
for api, api_spec in external_apis.items():
|
|
name = api_spec.name.lower()
|
|
logger.info(f"Importing external API {name} module {api_spec.module}")
|
|
try:
|
|
module = importlib.import_module(api_spec.module)
|
|
registry[api] = {a.provider_type: a for a in module.available_providers()}
|
|
except (ImportError, AttributeError) as e:
|
|
# Populate the registry with an empty dict to avoid breaking the provider registry
|
|
# This assume that the in-tree provider(s) are not available for this API which means
|
|
# that users will need to use external providers for this API.
|
|
registry[api] = {}
|
|
logger.error(
|
|
f"Failed to import external API {name}: {e}. Could not populate the in-tree provider(s) registry for {api.name}. \n"
|
|
"Install the API package to load any in-tree providers for this API."
|
|
)
|
|
|
|
# Check if config has external providers
|
|
if config:
|
|
if hasattr(config, "external_providers_dir") and config.external_providers_dir:
|
|
registry = get_external_providers_from_dir(registry, config)
|
|
# else lets check for modules in each provider
|
|
registry = get_external_providers_from_module(
|
|
registry=registry,
|
|
config=config,
|
|
building=(isinstance(config, BuildConfig) or isinstance(config, DistributionSpec)),
|
|
)
|
|
|
|
return registry
|
|
|
|
|
|
def get_external_providers_from_dir(
|
|
registry: dict[Api, dict[str, ProviderSpec]], config
|
|
) -> dict[Api, dict[str, ProviderSpec]]:
|
|
logger.warning(
|
|
"Specifying external providers via `external_providers_dir` is being deprecated. Please specify `module:` in the provider instead."
|
|
)
|
|
external_providers_dir = os.path.abspath(os.path.expanduser(config.external_providers_dir))
|
|
if not os.path.exists(external_providers_dir):
|
|
raise FileNotFoundError(f"External providers directory not found: {external_providers_dir}")
|
|
logger.info(f"Loading external providers from {external_providers_dir}")
|
|
|
|
for api in providable_apis():
|
|
api_name = api.name.lower()
|
|
|
|
# Process both remote and inline providers
|
|
for provider_type in ["remote", "inline"]:
|
|
api_dir = os.path.join(external_providers_dir, provider_type, api_name)
|
|
if not os.path.exists(api_dir):
|
|
logger.debug(f"No {provider_type} provider directory found for {api_name}")
|
|
continue
|
|
|
|
# Look for provider spec files in the API directory
|
|
for spec_path in glob.glob(os.path.join(api_dir, "*.yaml")):
|
|
provider_name = os.path.splitext(os.path.basename(spec_path))[0]
|
|
logger.info(f"Loading {provider_type} provider spec from {spec_path}")
|
|
|
|
try:
|
|
with open(spec_path) as f:
|
|
spec_data = yaml.safe_load(f)
|
|
|
|
if provider_type == "remote":
|
|
spec = _load_remote_provider_spec(spec_data, api)
|
|
provider_type_key = f"remote::{provider_name}"
|
|
else:
|
|
spec = _load_inline_provider_spec(spec_data, api, provider_name)
|
|
provider_type_key = f"inline::{provider_name}"
|
|
|
|
logger.info(f"Loaded {provider_type} provider spec for {provider_type_key} from {spec_path}")
|
|
if provider_type_key in registry[api]:
|
|
logger.warning(f"Overriding already registered provider {provider_type_key} for {api.name}")
|
|
registry[api][provider_type_key] = spec
|
|
logger.info(f"Successfully loaded external provider {provider_type_key}")
|
|
except yaml.YAMLError as yaml_err:
|
|
logger.error(f"Failed to parse YAML file {spec_path}: {yaml_err}")
|
|
raise yaml_err
|
|
except Exception as e:
|
|
logger.error(f"Failed to load provider spec from {spec_path}: {e}")
|
|
raise e
|
|
|
|
return registry
|
|
|
|
|
|
def get_external_providers_from_module(
|
|
registry: dict[Api, dict[str, ProviderSpec]], config, building: bool
|
|
) -> dict[Api, dict[str, ProviderSpec]]:
|
|
provider_list = None
|
|
if isinstance(config, BuildConfig):
|
|
provider_list = config.distribution_spec.providers.items()
|
|
else:
|
|
provider_list = config.providers.items()
|
|
if provider_list is None:
|
|
logger.warning("Could not get list of providers from config")
|
|
return registry
|
|
for provider_api, providers in provider_list:
|
|
for provider in providers:
|
|
if not hasattr(provider, "module") or provider.module is None:
|
|
continue
|
|
# get provider using module
|
|
try:
|
|
if not building:
|
|
package_name = provider.module.split("==")[0]
|
|
module = importlib.import_module(f"{package_name}.provider")
|
|
# if config class is wrong you will get an error saying module could not be imported
|
|
spec = module.get_provider_spec()
|
|
else:
|
|
# pass in a partially filled out provider spec to satisfy the registry -- knowing we will be overwriting it later upon build and run
|
|
spec = ProviderSpec(
|
|
api=Api(provider_api),
|
|
provider_type=provider.provider_type,
|
|
is_external=True,
|
|
module=provider.module,
|
|
config_class="",
|
|
)
|
|
provider_type = provider.provider_type
|
|
# in the case we are building we CANNOT import this module of course because it has not been installed.
|
|
# return a partially filled out spec that the build script will populate.
|
|
registry[Api(provider_api)][provider_type] = spec
|
|
except ModuleNotFoundError as exc:
|
|
raise ValueError(
|
|
"get_provider_spec not found. If specifying an external provider via `module` in the Provider spec, the Provider must have the `provider.get_provider_spec` module available"
|
|
) from exc
|
|
except Exception as e:
|
|
logger.error(f"Failed to load provider spec from module {provider.module}: {e}")
|
|
raise e
|
|
return registry
|