mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-10-24 08:47:26 +00:00
We would like to rename the term `template` to `distribution`. To prepare for that, this is a precursor. cc @leseb
277 lines
11 KiB
Python
277 lines
11 KiB
Python
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
# All rights reserved.
|
|
#
|
|
# This source code is licensed under the terms described in the LICENSE file in
|
|
# the root directory of this source tree.
|
|
|
|
import glob
|
|
import importlib
|
|
import os
|
|
from typing import Any
|
|
|
|
import yaml
|
|
from pydantic import BaseModel
|
|
|
|
from llama_stack.core.datatypes import BuildConfig, DistributionSpec
|
|
from llama_stack.core.external import load_external_apis
|
|
from llama_stack.log import get_logger
|
|
from llama_stack.providers.datatypes import (
|
|
AdapterSpec,
|
|
Api,
|
|
InlineProviderSpec,
|
|
ProviderSpec,
|
|
remote_provider_spec,
|
|
)
|
|
|
|
logger = get_logger(name=__name__, category="core")
|
|
|
|
|
|
def stack_apis() -> list[Api]:
|
|
return list(Api)
|
|
|
|
|
|
class AutoRoutedApiInfo(BaseModel):
|
|
routing_table_api: Api
|
|
router_api: Api
|
|
|
|
|
|
def builtin_automatically_routed_apis() -> list[AutoRoutedApiInfo]:
|
|
return [
|
|
AutoRoutedApiInfo(
|
|
routing_table_api=Api.models,
|
|
router_api=Api.inference,
|
|
),
|
|
AutoRoutedApiInfo(
|
|
routing_table_api=Api.shields,
|
|
router_api=Api.safety,
|
|
),
|
|
AutoRoutedApiInfo(
|
|
routing_table_api=Api.vector_dbs,
|
|
router_api=Api.vector_io,
|
|
),
|
|
AutoRoutedApiInfo(
|
|
routing_table_api=Api.datasets,
|
|
router_api=Api.datasetio,
|
|
),
|
|
AutoRoutedApiInfo(
|
|
routing_table_api=Api.scoring_functions,
|
|
router_api=Api.scoring,
|
|
),
|
|
AutoRoutedApiInfo(
|
|
routing_table_api=Api.benchmarks,
|
|
router_api=Api.eval,
|
|
),
|
|
AutoRoutedApiInfo(
|
|
routing_table_api=Api.tool_groups,
|
|
router_api=Api.tool_runtime,
|
|
),
|
|
]
|
|
|
|
|
|
def providable_apis() -> list[Api]:
|
|
routing_table_apis = {x.routing_table_api for x in builtin_automatically_routed_apis()}
|
|
return [api for api in Api if api not in routing_table_apis and api != Api.inspect and api != Api.providers]
|
|
|
|
|
|
def _load_remote_provider_spec(spec_data: dict[str, Any], api: Api) -> ProviderSpec:
|
|
adapter = AdapterSpec(**spec_data["adapter"])
|
|
spec = remote_provider_spec(
|
|
api=api,
|
|
adapter=adapter,
|
|
api_dependencies=[Api(dep) for dep in spec_data.get("api_dependencies", [])],
|
|
)
|
|
return spec
|
|
|
|
|
|
def _load_inline_provider_spec(spec_data: dict[str, Any], api: Api, provider_name: str) -> ProviderSpec:
|
|
spec = InlineProviderSpec(
|
|
api=api,
|
|
provider_type=f"inline::{provider_name}",
|
|
pip_packages=spec_data.get("pip_packages", []),
|
|
module=spec_data["module"],
|
|
config_class=spec_data["config_class"],
|
|
api_dependencies=[Api(dep) for dep in spec_data.get("api_dependencies", [])],
|
|
optional_api_dependencies=[Api(dep) for dep in spec_data.get("optional_api_dependencies", [])],
|
|
provider_data_validator=spec_data.get("provider_data_validator"),
|
|
container_image=spec_data.get("container_image"),
|
|
)
|
|
return spec
|
|
|
|
|
|
def get_provider_registry(config=None) -> dict[Api, dict[str, ProviderSpec]]:
|
|
"""Get the provider registry, optionally including external providers.
|
|
|
|
This function loads both built-in providers and external providers from YAML files or from their provided modules.
|
|
External providers are loaded from a directory structure like:
|
|
|
|
providers.d/
|
|
remote/
|
|
inference/
|
|
custom_ollama.yaml
|
|
vllm.yaml
|
|
vector_io/
|
|
qdrant.yaml
|
|
safety/
|
|
llama-guard.yaml
|
|
inline/
|
|
inference/
|
|
custom_ollama.yaml
|
|
vllm.yaml
|
|
vector_io/
|
|
qdrant.yaml
|
|
safety/
|
|
llama-guard.yaml
|
|
|
|
This method is overloaded in that it can be called from a variety of places: during build, during run, during stack construction.
|
|
So when building external providers from a module, there are scenarios where the pip package required to import the module might not be available yet.
|
|
There is special handling for all of the potential cases this method can be called from.
|
|
|
|
Args:
|
|
config: Optional object containing the external providers directory path
|
|
building: Optional bool delineating whether or not this is being called from a build process
|
|
|
|
Returns:
|
|
A dictionary mapping APIs to their available providers
|
|
|
|
Raises:
|
|
FileNotFoundError: If the external providers directory doesn't exist
|
|
ValueError: If any provider spec is invalid
|
|
"""
|
|
|
|
registry: dict[Api, dict[str, ProviderSpec]] = {}
|
|
for api in providable_apis():
|
|
name = api.name.lower()
|
|
logger.debug(f"Importing module {name}")
|
|
try:
|
|
module = importlib.import_module(f"llama_stack.providers.registry.{name}")
|
|
registry[api] = {a.provider_type: a for a in module.available_providers()}
|
|
except ImportError as e:
|
|
logger.warning(f"Failed to import module {name}: {e}")
|
|
|
|
# Refresh providable APIs with external APIs if any
|
|
external_apis = load_external_apis(config)
|
|
for api, api_spec in external_apis.items():
|
|
name = api_spec.name.lower()
|
|
logger.info(f"Importing external API {name} module {api_spec.module}")
|
|
try:
|
|
module = importlib.import_module(api_spec.module)
|
|
registry[api] = {a.provider_type: a for a in module.available_providers()}
|
|
except (ImportError, AttributeError) as e:
|
|
# Populate the registry with an empty dict to avoid breaking the provider registry
|
|
# This assume that the in-tree provider(s) are not available for this API which means
|
|
# that users will need to use external providers for this API.
|
|
registry[api] = {}
|
|
logger.error(
|
|
f"Failed to import external API {name}: {e}. Could not populate the in-tree provider(s) registry for {api.name}. \n"
|
|
"Install the API package to load any in-tree providers for this API."
|
|
)
|
|
|
|
# Check if config has external providers
|
|
if config:
|
|
if hasattr(config, "external_providers_dir") and config.external_providers_dir:
|
|
registry = get_external_providers_from_dir(registry, config)
|
|
# else lets check for modules in each provider
|
|
registry = get_external_providers_from_module(
|
|
registry=registry,
|
|
config=config,
|
|
building=(isinstance(config, BuildConfig) or isinstance(config, DistributionSpec)),
|
|
)
|
|
|
|
return registry
|
|
|
|
|
|
def get_external_providers_from_dir(
|
|
registry: dict[Api, dict[str, ProviderSpec]], config
|
|
) -> dict[Api, dict[str, ProviderSpec]]:
|
|
logger.warning(
|
|
"Specifying external providers via `external_providers_dir` is being deprecated. Please specify `module:` in the provider instead."
|
|
)
|
|
external_providers_dir = os.path.abspath(os.path.expanduser(config.external_providers_dir))
|
|
if not os.path.exists(external_providers_dir):
|
|
raise FileNotFoundError(f"External providers directory not found: {external_providers_dir}")
|
|
logger.info(f"Loading external providers from {external_providers_dir}")
|
|
|
|
for api in providable_apis():
|
|
api_name = api.name.lower()
|
|
|
|
# Process both remote and inline providers
|
|
for provider_type in ["remote", "inline"]:
|
|
api_dir = os.path.join(external_providers_dir, provider_type, api_name)
|
|
if not os.path.exists(api_dir):
|
|
logger.debug(f"No {provider_type} provider directory found for {api_name}")
|
|
continue
|
|
|
|
# Look for provider spec files in the API directory
|
|
for spec_path in glob.glob(os.path.join(api_dir, "*.yaml")):
|
|
provider_name = os.path.splitext(os.path.basename(spec_path))[0]
|
|
logger.info(f"Loading {provider_type} provider spec from {spec_path}")
|
|
|
|
try:
|
|
with open(spec_path) as f:
|
|
spec_data = yaml.safe_load(f)
|
|
|
|
if provider_type == "remote":
|
|
spec = _load_remote_provider_spec(spec_data, api)
|
|
provider_type_key = f"remote::{provider_name}"
|
|
else:
|
|
spec = _load_inline_provider_spec(spec_data, api, provider_name)
|
|
provider_type_key = f"inline::{provider_name}"
|
|
|
|
logger.info(f"Loaded {provider_type} provider spec for {provider_type_key} from {spec_path}")
|
|
if provider_type_key in registry[api]:
|
|
logger.warning(f"Overriding already registered provider {provider_type_key} for {api.name}")
|
|
registry[api][provider_type_key] = spec
|
|
logger.info(f"Successfully loaded external provider {provider_type_key}")
|
|
except yaml.YAMLError as yaml_err:
|
|
logger.error(f"Failed to parse YAML file {spec_path}: {yaml_err}")
|
|
raise yaml_err
|
|
except Exception as e:
|
|
logger.error(f"Failed to load provider spec from {spec_path}: {e}")
|
|
raise e
|
|
|
|
return registry
|
|
|
|
|
|
def get_external_providers_from_module(
|
|
registry: dict[Api, dict[str, ProviderSpec]], config, building: bool
|
|
) -> dict[Api, dict[str, ProviderSpec]]:
|
|
provider_list = None
|
|
if isinstance(config, BuildConfig):
|
|
provider_list = config.distribution_spec.providers.items()
|
|
else:
|
|
provider_list = config.providers.items()
|
|
if provider_list is None:
|
|
logger.warning("Could not get list of providers from config")
|
|
return registry
|
|
for provider_api, providers in provider_list:
|
|
for provider in providers:
|
|
if not hasattr(provider, "module") or provider.module is None:
|
|
continue
|
|
# get provider using module
|
|
try:
|
|
if not building:
|
|
package_name = provider.module.split("==")[0]
|
|
module = importlib.import_module(f"{package_name}.provider")
|
|
# if config class is wrong you will get an error saying module could not be imported
|
|
spec = module.get_provider_spec()
|
|
else:
|
|
# pass in a partially filled out provider spec to satisfy the registry -- knowing we will be overwriting it later upon build and run
|
|
spec = ProviderSpec(
|
|
api=Api(provider_api),
|
|
provider_type=provider.provider_type,
|
|
is_external=True,
|
|
module=provider.module,
|
|
config_class="",
|
|
)
|
|
provider_type = provider.provider_type
|
|
# in the case we are building we CANNOT import this module of course because it has not been installed.
|
|
# return a partially filled out spec that the build script will populate.
|
|
registry[Api(provider_api)][provider_type] = spec
|
|
except ModuleNotFoundError as exc:
|
|
raise ValueError(
|
|
"get_provider_spec not found. If specifying an external provider via `module` in the Provider spec, the Provider must have the `provider.get_provider_spec` module available"
|
|
) from exc
|
|
except Exception as e:
|
|
logger.error(f"Failed to load provider spec from module {provider.module}: {e}")
|
|
raise e
|
|
return registry
|