refactor: install external providers from module (#2637)

# What does this PR do?

Today, external providers are installed via the `external_providers_dir`
in the config. This necessitates users to understand the `ProviderSpec`
and set up their directories accordingly. This process splits up the
config for the stack across multiple files, directories, and formats.

Most (if not all) external providers today have a
[get_provider_spec](559cb18fbb/src/ramalama_stack/provider.py (L9))
method that sits unused. Utilizing this method rather than the
providers.d route allows for a much easier installation process for
external providers and limits the amount of extra configuration a
regular user has to do to get their stack off the ground.

To accomplish this and wire it throughout the build process, Introduce
the concept of a `module` for users to specify for an external provider
upon build time. In order to facilitate this, align the build and run
spec to use `Provider` class rather than the stringified provider_type
that build currently uses.

For example, say this is in your build config:

```
- provider_id: ramalama
  provider_type: remote::ramalama
  module: ramalama_stack
```

during build (in the various `build_...` scripts), additionally to
installing any pip dependencies we will also install this module and use
the `get_provider_spec` method to retrieve the ProviderSpec that is
currently specified using `providers.d`.

In production so far, providing instructions for installing external
providers for users has been difficult: they need to install the module
as a pre-req, create the providers.d directory, copy in the provider
spec, and also copy in the necessary build/run yaml files. Accessing an
external provider should be as easy as possible, and pointing to its
installable module aligns more with the rest of our build and dependency
management process.

For now, `external_providers_dir` still exists as an alternate more
declarative method of using external providers.

## Test Plan

added an integration test installing an external provider from module
and more unit test coverage for `get_provider_registry`


( the warning in yellow is expected, the module is installed inside of
the build env, not where we are running the command)
<img width="1119" height="400" alt="Screenshot 2025-07-24 at 11 30
48 AM"
src="https://github.com/user-attachments/assets/1efbaf45-b9e8-451a-bd63-264ed664706d"
/>

<img width="1154" height="618" alt="Screenshot 2025-07-24 at 11 31
14 AM"
src="https://github.com/user-attachments/assets/feb2b3ea-c5dd-418e-9662-9a3bd5dd6bdc"
/>

---------

Signed-off-by: Charlie Doern <cdoern@redhat.com>
This commit is contained in:
Charlie Doern 2025-07-25 09:41:26 -04:00 committed by GitHub
parent 85223ccc4d
commit de6919ecdd
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
44 changed files with 1687 additions and 595 deletions

View file

@ -5,7 +5,7 @@
# the root directory of this source tree.
from pathlib import Path
from typing import Literal
from typing import Any, Literal
import jinja2
import rich
@ -35,6 +35,51 @@ from llama_stack.providers.utils.sqlstore.sqlstore import SqliteSqlStoreConfig
from llama_stack.providers.utils.sqlstore.sqlstore import get_pip_packages as get_sql_pip_packages
def filter_empty_values(obj: Any) -> Any:
"""Recursively filter out specific empty values from a dictionary or list.
This function removes:
- Empty strings ('') only when they are the 'module' field
- Empty dictionaries ({}) only when they are the 'config' field
- None values (always excluded)
"""
if obj is None:
return None
if isinstance(obj, dict):
filtered = {}
for key, value in obj.items():
# Special handling for specific fields
if key == "module" and isinstance(value, str) and value == "":
# Skip empty module strings
continue
elif key == "config" and isinstance(value, dict) and not value:
# Skip empty config dictionaries
continue
elif key == "container_image" and not value:
# Skip empty container_image names
continue
else:
# For all other fields, recursively filter but preserve empty values
filtered_value = filter_empty_values(value)
# if filtered_value is not None:
filtered[key] = filtered_value
return filtered
elif isinstance(obj, list):
filtered = []
for item in obj:
filtered_item = filter_empty_values(item)
if filtered_item is not None:
filtered.append(filtered_item)
return filtered
else:
# For all other types (including empty strings and dicts that aren't module/config),
# preserve them as-is
return obj
def get_model_registry(
available_models: dict[str, list[ProviderModelEntry]],
) -> tuple[list[ModelInput], bool]:
@ -138,31 +183,26 @@ class RunConfigSettings(BaseModel):
def run_config(
self,
name: str,
providers: dict[str, list[str]],
providers: dict[str, list[Provider]],
container_image: str | None = None,
) -> dict:
provider_registry = get_provider_registry()
provider_configs = {}
for api_str, provider_types in providers.items():
for api_str, provider_objs in providers.items():
if api_providers := self.provider_overrides.get(api_str):
# Convert Provider objects to dicts for YAML serialization
provider_configs[api_str] = [
p.model_dump(exclude_none=True) if isinstance(p, Provider) else p for p in api_providers
]
provider_configs[api_str] = [p.model_dump(exclude_none=True) for p in api_providers]
continue
provider_configs[api_str] = []
for provider_type in provider_types:
provider_id = provider_type.split("::")[-1]
for provider in provider_objs:
api = Api(api_str)
if provider_type not in provider_registry[api]:
raise ValueError(f"Unknown provider type: {provider_type} for API: {api_str}")
if provider.provider_type not in provider_registry[api]:
raise ValueError(f"Unknown provider type: {provider.provider_type} for API: {api_str}")
config_class = provider_registry[api][provider_type].config_class
config_class = provider_registry[api][provider.provider_type].config_class
assert config_class is not None, (
f"No config class for provider type: {provider_type} for API: {api_str}"
f"No config class for provider type: {provider.provider_type} for API: {api_str}"
)
config_class = instantiate_class_type(config_class)
@ -171,14 +211,9 @@ class RunConfigSettings(BaseModel):
else:
config = {}
provider_configs[api_str].append(
Provider(
provider_id=provider_id,
provider_type=provider_type,
config=config,
).model_dump(exclude_none=True)
)
provider.config = config
# Convert Provider object to dict for YAML serialization
provider_configs[api_str].append(provider.model_dump(exclude_none=True))
# Get unique set of APIs from providers
apis = sorted(providers.keys())
@ -222,7 +257,7 @@ class DistributionTemplate(BaseModel):
description: str
distro_type: Literal["self_hosted", "remote_hosted", "ondevice"]
providers: dict[str, list[str]]
providers: dict[str, list[Provider]]
run_configs: dict[str, RunConfigSettings]
template_path: Path | None = None
@ -255,13 +290,28 @@ class DistributionTemplate(BaseModel):
if self.additional_pip_packages:
additional_pip_packages.extend(self.additional_pip_packages)
# Create minimal providers for build config (without runtime configs)
build_providers = {}
for api, providers in self.providers.items():
build_providers[api] = []
for provider in providers:
# Create a minimal provider object with only essential build information
build_provider = Provider(
provider_id=provider.provider_id,
provider_type=provider.provider_type,
config={}, # Empty config for build
module=provider.module,
)
build_providers[api].append(build_provider)
return BuildConfig(
distribution_spec=DistributionSpec(
description=self.description,
container_image=self.container_image,
providers=self.providers,
providers=build_providers,
),
image_type="conda", # default to conda, can be overridden
image_type="conda",
image_name=self.name,
additional_pip_packages=sorted(set(additional_pip_packages)),
)
@ -270,7 +320,7 @@ class DistributionTemplate(BaseModel):
providers_table += "|-----|-------------|\n"
for api, providers in sorted(self.providers.items()):
providers_str = ", ".join(f"`{p}`" for p in providers)
providers_str = ", ".join(f"`{p.provider_type}`" for p in providers)
providers_table += f"| {api} | {providers_str} |\n"
template = self.template_path.read_text()
@ -334,7 +384,7 @@ class DistributionTemplate(BaseModel):
build_config = self.build_config()
with open(yaml_output_dir / "build.yaml", "w") as f:
yaml.safe_dump(
build_config.model_dump(exclude_none=True),
filter_empty_values(build_config.model_dump(exclude_none=True)),
f,
sort_keys=False,
)
@ -343,7 +393,7 @@ class DistributionTemplate(BaseModel):
run_config = settings.run_config(self.name, self.providers, self.container_image)
with open(yaml_output_dir / yaml_pth, "w") as f:
yaml.safe_dump(
{k: v for k, v in run_config.items() if v is not None},
filter_empty_values(run_config),
f,
sort_keys=False,
)