mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-07-27 06:28:50 +00:00
# What does this PR do?
Today, external providers are installed via the `external_providers_dir`
in the config. This necessitates users to understand the `ProviderSpec`
and set up their directories accordingly. This process splits up the
config for the stack across multiple files, directories, and formats.
Most (if not all) external providers today have a
[get_provider_spec](559cb18fbb/src/ramalama_stack/provider.py (L9)
)
method that sits unused. Utilizing this method rather than the
providers.d route allows for a much easier installation process for
external providers and limits the amount of extra configuration a
regular user has to do to get their stack off the ground.
To accomplish this and wire it throughout the build process, Introduce
the concept of a `module` for users to specify for an external provider
upon build time. In order to facilitate this, align the build and run
spec to use `Provider` class rather than the stringified provider_type
that build currently uses.
For example, say this is in your build config:
```
- provider_id: ramalama
provider_type: remote::ramalama
module: ramalama_stack
```
during build (in the various `build_...` scripts), additionally to
installing any pip dependencies we will also install this module and use
the `get_provider_spec` method to retrieve the ProviderSpec that is
currently specified using `providers.d`.
In production so far, providing instructions for installing external
providers for users has been difficult: they need to install the module
as a pre-req, create the providers.d directory, copy in the provider
spec, and also copy in the necessary build/run yaml files. Accessing an
external provider should be as easy as possible, and pointing to its
installable module aligns more with the rest of our build and dependency
management process.
For now, `external_providers_dir` still exists as an alternate more
declarative method of using external providers.
## Test Plan
added an integration test installing an external provider from module
and more unit test coverage for `get_provider_registry`
( the warning in yellow is expected, the module is installed inside of
the build env, not where we are running the command)
<img width="1119" height="400" alt="Screenshot 2025-07-24 at 11 30
48 AM"
src="https://github.com/user-attachments/assets/1efbaf45-b9e8-451a-bd63-264ed664706d"
/>
<img width="1154" height="618" alt="Screenshot 2025-07-24 at 11 31
14 AM"
src="https://github.com/user-attachments/assets/feb2b3ea-c5dd-418e-9662-9a3bd5dd6bdc"
/>
---------
Signed-off-by: Charlie Doern <cdoern@redhat.com>
180 lines
6.5 KiB
Python
180 lines
6.5 KiB
Python
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
# All rights reserved.
|
|
#
|
|
# This source code is licensed under the terms described in the LICENSE file in
|
|
# the root directory of this source tree.
|
|
import logging
|
|
import textwrap
|
|
from typing import Any
|
|
|
|
from llama_stack.distribution.datatypes import (
|
|
LLAMA_STACK_RUN_CONFIG_VERSION,
|
|
DistributionSpec,
|
|
Provider,
|
|
StackRunConfig,
|
|
)
|
|
from llama_stack.distribution.distribution import (
|
|
builtin_automatically_routed_apis,
|
|
get_provider_registry,
|
|
)
|
|
from llama_stack.distribution.stack import cast_image_name_to_string, replace_env_vars
|
|
from llama_stack.distribution.utils.config_dirs import EXTERNAL_PROVIDERS_DIR
|
|
from llama_stack.distribution.utils.dynamic import instantiate_class_type
|
|
from llama_stack.distribution.utils.prompt_for_config import prompt_for_config
|
|
from llama_stack.providers.datatypes import Api, ProviderSpec
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
def configure_single_provider(registry: dict[str, ProviderSpec], provider: Provider) -> Provider:
|
|
provider_spec = registry[provider.provider_type]
|
|
config_type = instantiate_class_type(provider_spec.config_class)
|
|
try:
|
|
if provider.config:
|
|
existing = config_type(**provider.config)
|
|
else:
|
|
existing = None
|
|
except Exception:
|
|
existing = None
|
|
|
|
cfg = prompt_for_config(config_type, existing)
|
|
return Provider(
|
|
provider_id=provider.provider_id,
|
|
provider_type=provider.provider_type,
|
|
config=cfg.model_dump(),
|
|
)
|
|
|
|
|
|
def configure_api_providers(config: StackRunConfig, build_spec: DistributionSpec) -> StackRunConfig:
|
|
is_nux = len(config.providers) == 0
|
|
|
|
if is_nux:
|
|
logger.info(
|
|
textwrap.dedent(
|
|
"""
|
|
Llama Stack is composed of several APIs working together. For each API served by the Stack,
|
|
we need to configure the providers (implementations) you want to use for these APIs.
|
|
"""
|
|
)
|
|
)
|
|
|
|
provider_registry = get_provider_registry()
|
|
builtin_apis = [a.routing_table_api for a in builtin_automatically_routed_apis()]
|
|
|
|
if config.apis:
|
|
apis_to_serve = config.apis
|
|
else:
|
|
apis_to_serve = [a.value for a in Api if a not in (Api.telemetry, Api.inspect, Api.providers)]
|
|
|
|
for api_str in apis_to_serve:
|
|
api = Api(api_str)
|
|
if api in builtin_apis:
|
|
continue
|
|
if api not in provider_registry:
|
|
raise ValueError(f"Unknown API `{api_str}`")
|
|
|
|
existing_providers = config.providers.get(api_str, [])
|
|
if existing_providers:
|
|
logger.info(f"Re-configuring existing providers for API `{api_str}`...")
|
|
updated_providers = []
|
|
for p in existing_providers:
|
|
logger.info(f"> Configuring provider `({p.provider_type})`")
|
|
updated_providers.append(configure_single_provider(provider_registry[api], p))
|
|
logger.info("")
|
|
else:
|
|
# we are newly configuring this API
|
|
plist = build_spec.providers.get(api_str, [])
|
|
plist = plist if isinstance(plist, list) else [plist]
|
|
|
|
if not plist:
|
|
raise ValueError(f"No provider configured for API {api_str}?")
|
|
|
|
logger.info(f"Configuring API `{api_str}`...")
|
|
updated_providers = []
|
|
for i, provider in enumerate(plist):
|
|
if i >= 1:
|
|
others = ", ".join(p.provider_type for p in plist[i:])
|
|
logger.info(
|
|
f"Not configuring other providers ({others}) interactively. Please edit the resulting YAML directly.\n"
|
|
)
|
|
break
|
|
|
|
logger.info(f"> Configuring provider `({provider.provider_type})`")
|
|
updated_providers.append(
|
|
configure_single_provider(
|
|
provider_registry[api],
|
|
Provider(
|
|
provider_id=(f"{provider.provider_id}-{i:02d}" if len(plist) > 1 else provider.provider_id),
|
|
provider_type=provider.provider_type,
|
|
config={},
|
|
),
|
|
)
|
|
)
|
|
logger.info("")
|
|
|
|
config.providers[api_str] = updated_providers
|
|
|
|
return config
|
|
|
|
|
|
def upgrade_from_routing_table(
|
|
config_dict: dict[str, Any],
|
|
) -> dict[str, Any]:
|
|
def get_providers(entries):
|
|
return [
|
|
Provider(
|
|
provider_id=(f"{entry['provider_type']}-{i:02d}" if len(entries) > 1 else entry["provider_type"]),
|
|
provider_type=entry["provider_type"],
|
|
config=entry["config"],
|
|
)
|
|
for i, entry in enumerate(entries)
|
|
]
|
|
|
|
providers_by_api = {}
|
|
|
|
routing_table = config_dict.get("routing_table", {})
|
|
for api_str, entries in routing_table.items():
|
|
providers = get_providers(entries)
|
|
providers_by_api[api_str] = providers
|
|
|
|
provider_map = config_dict.get("api_providers", config_dict.get("provider_map", {}))
|
|
if provider_map:
|
|
for api_str, provider in provider_map.items():
|
|
if isinstance(provider, dict) and "provider_type" in provider:
|
|
providers_by_api[api_str] = [
|
|
Provider(
|
|
provider_id=f"{provider['provider_type']}",
|
|
provider_type=provider["provider_type"],
|
|
config=provider["config"],
|
|
)
|
|
]
|
|
|
|
config_dict["providers"] = providers_by_api
|
|
|
|
config_dict.pop("routing_table", None)
|
|
config_dict.pop("api_providers", None)
|
|
config_dict.pop("provider_map", None)
|
|
|
|
config_dict["apis"] = config_dict["apis_to_serve"]
|
|
config_dict.pop("apis_to_serve", None)
|
|
|
|
return config_dict
|
|
|
|
|
|
def parse_and_maybe_upgrade_config(config_dict: dict[str, Any]) -> StackRunConfig:
|
|
version = config_dict.get("version", None)
|
|
if version == LLAMA_STACK_RUN_CONFIG_VERSION:
|
|
processed_config_dict = replace_env_vars(config_dict)
|
|
return StackRunConfig(**cast_image_name_to_string(processed_config_dict))
|
|
|
|
if "routing_table" in config_dict:
|
|
logger.info("Upgrading config...")
|
|
config_dict = upgrade_from_routing_table(config_dict)
|
|
|
|
config_dict["version"] = LLAMA_STACK_RUN_CONFIG_VERSION
|
|
|
|
if not config_dict.get("external_providers_dir", None):
|
|
config_dict["external_providers_dir"] = EXTERNAL_PROVIDERS_DIR
|
|
|
|
processed_config_dict = replace_env_vars(config_dict)
|
|
return StackRunConfig(**cast_image_name_to_string(processed_config_dict))
|