feat: re-work distro-codegen

each *.py file in the various templates now has to use `Provider`s rather than the stringified provider_types in the DistributionTemplate. Adjust that, regenerate all templates, docs, etc.

Signed-off-by: Charlie Doern <cdoern@redhat.com>
This commit is contained in:
Charlie Doern 2025-07-06 20:06:27 -04:00
parent dcc6b1eee9
commit 776fabed9e
28 changed files with 809 additions and 328 deletions

View file

@ -5,7 +5,7 @@
# the root directory of this source tree.
from pathlib import Path
from typing import Literal
from typing import Any, Literal
import jinja2
import rich
@ -35,6 +35,51 @@ from llama_stack.providers.utils.sqlstore.sqlstore import SqliteSqlStoreConfig
from llama_stack.providers.utils.sqlstore.sqlstore import get_pip_packages as get_sql_pip_packages
def filter_empty_values(obj: Any) -> Any:
"""Recursively filter out specific empty values from a dictionary or list.
This function removes:
- Empty strings ('') only when they are the 'module' field
- Empty dictionaries ({}) only when they are the 'config' field
- None values (always excluded)
"""
if obj is None:
return None
if isinstance(obj, dict):
filtered = {}
for key, value in obj.items():
# Special handling for specific fields
if key == "module" and isinstance(value, str) and value == "":
# Skip empty module strings
continue
elif key == "config" and isinstance(value, dict) and not value:
# Skip empty config dictionaries
continue
elif key == "container_image" and not value:
# Skip empty container_image names
continue
else:
# For all other fields, recursively filter but preserve empty values
filtered_value = filter_empty_values(value)
# if filtered_value is not None:
filtered[key] = filtered_value
return filtered
elif isinstance(obj, list):
filtered = []
for item in obj:
filtered_item = filter_empty_values(item)
if filtered_item is not None:
filtered.append(filtered_item)
return filtered
else:
# For all other types (including empty strings and dicts that aren't module/config),
# preserve them as-is
return obj
def get_model_registry(
available_models: dict[str, list[ProviderModelEntry]],
) -> tuple[list[ModelInput], bool]:
@ -138,31 +183,26 @@ class RunConfigSettings(BaseModel):
def run_config(
self,
name: str,
providers: dict[str, list[str]],
providers: dict[str, list[Provider]],
container_image: str | None = None,
) -> dict:
provider_registry = get_provider_registry()
provider_configs = {}
for api_str, provider_types in providers.items():
for api_str, provider_objs in providers.items():
if api_providers := self.provider_overrides.get(api_str):
# Convert Provider objects to dicts for YAML serialization
provider_configs[api_str] = [
p.model_dump(exclude_none=True) if isinstance(p, Provider) else p for p in api_providers
]
provider_configs[api_str] = [p.model_dump(exclude_none=True) for p in api_providers]
continue
provider_configs[api_str] = []
for provider_type in provider_types:
provider_id = provider_type.split("::")[-1]
for provider in provider_objs:
api = Api(api_str)
if provider_type not in provider_registry[api]:
raise ValueError(f"Unknown provider type: {provider_type} for API: {api_str}")
if provider.provider_type not in provider_registry[api]:
raise ValueError(f"Unknown provider type: {provider.provider_type} for API: {api_str}")
config_class = provider_registry[api][provider_type].config_class
config_class = provider_registry[api][provider.provider_type].config_class
assert config_class is not None, (
f"No config class for provider type: {provider_type} for API: {api_str}"
f"No config class for provider type: {provider.provider_type} for API: {api_str}"
)
config_class = instantiate_class_type(config_class)
@ -171,14 +211,9 @@ class RunConfigSettings(BaseModel):
else:
config = {}
provider_configs[api_str].append(
Provider(
provider_id=provider_id,
provider_type=provider_type,
config=config,
).model_dump(exclude_none=True)
)
provider.config = config
# Convert Provider object to dict for YAML serialization
provider_configs[api_str].append(provider.model_dump(exclude_none=True))
# Get unique set of APIs from providers
apis = sorted(providers.keys())
@ -222,7 +257,7 @@ class DistributionTemplate(BaseModel):
description: str
distro_type: Literal["self_hosted", "remote_hosted", "ondevice"]
providers: dict[str, list[str]]
providers: dict[str, list[Provider]]
run_configs: dict[str, RunConfigSettings]
template_path: Path | None = None
@ -255,13 +290,28 @@ class DistributionTemplate(BaseModel):
if self.additional_pip_packages:
additional_pip_packages.extend(self.additional_pip_packages)
# Create minimal providers for build config (without runtime configs)
build_providers = {}
for api, providers in self.providers.items():
build_providers[api] = []
for provider in providers:
# Create a minimal provider object with only essential build information
build_provider = Provider(
provider_id=provider.provider_id,
provider_type=provider.provider_type,
config={}, # Empty config for build
module=provider.module,
)
build_providers[api].append(build_provider)
return BuildConfig(
distribution_spec=DistributionSpec(
description=self.description,
container_image=self.container_image,
providers=self.providers,
providers=build_providers,
),
image_type="conda", # default to conda, can be overridden
image_type="conda",
image_name=self.name,
additional_pip_packages=sorted(set(additional_pip_packages)),
)
@ -270,7 +320,7 @@ class DistributionTemplate(BaseModel):
providers_table += "|-----|-------------|\n"
for api, providers in sorted(self.providers.items()):
providers_str = ", ".join(f"`{p}`" for p in providers)
providers_str = ", ".join(f"`{p.provider_type}`" for p in providers)
providers_table += f"| {api} | {providers_str} |\n"
template = self.template_path.read_text()
@ -334,7 +384,7 @@ class DistributionTemplate(BaseModel):
build_config = self.build_config()
with open(yaml_output_dir / "build.yaml", "w") as f:
yaml.safe_dump(
build_config.model_dump(exclude_none=True),
filter_empty_values(build_config.model_dump(exclude_none=True)),
f,
sort_keys=False,
)
@ -343,7 +393,7 @@ class DistributionTemplate(BaseModel):
run_config = settings.run_config(self.name, self.providers, self.container_image)
with open(yaml_output_dir / yaml_pth, "w") as f:
yaml.safe_dump(
{k: v for k, v in run_config.items() if v is not None},
filter_empty_values(run_config),
f,
sort_keys=False,
)