llama-stack-mirror/scripts/provider_codegen.py
Sébastien Han c9a49a80e8
docs: auto generated documentation for providers (#2543)
# What does this PR do?

Simple approach to get some provider pages in the docs.

Add or update description fields in the provider configuration class
using Pydantic’s Field, ensuring these descriptions are clear and
complete, as they will be used to auto-generate provider documentation
via ./scripts/distro_codegen.py instead of editing the docs manually.

Signed-off-by: Sébastien Han <seb@redhat.com>
2025-06-30 15:13:20 +02:00

332 lines
13 KiB
Python
Executable file

#!/usr/bin/env python
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import subprocess
import sys
from pathlib import Path
from typing import Any
from rich.progress import Progress, SpinnerColumn, TextColumn
from llama_stack.distribution.distribution import get_provider_registry
REPO_ROOT = Path(__file__).parent.parent
class ChangedPathTracker:
"""Track a list of paths we may have changed."""
def __init__(self):
self._changed_paths = []
def add_paths(self, *paths):
for path in paths:
path = str(path)
if path not in self._changed_paths:
self._changed_paths.append(path)
def changed_paths(self):
return self._changed_paths
def get_config_class_info(config_class_path: str) -> dict[str, Any]:
"""Extract configuration information from a config class."""
try:
module_path, class_name = config_class_path.rsplit(".", 1)
module = __import__(module_path, fromlist=[class_name])
config_class = getattr(module, class_name)
docstring = config_class.__doc__ or ""
accepts_extra_config = False
try:
schema = config_class.model_json_schema()
if schema.get("additionalProperties") is True:
accepts_extra_config = True
except Exception:
if hasattr(config_class, "model_config"):
model_config = config_class.model_config
if hasattr(model_config, "extra") and model_config.extra == "allow":
accepts_extra_config = True
elif isinstance(model_config, dict) and model_config.get("extra") == "allow":
accepts_extra_config = True
fields_info = {}
if hasattr(config_class, "model_fields"):
for field_name, field in config_class.model_fields.items():
field_type = str(field.annotation) if field.annotation else "Any"
field_type = field_type.replace("typing.", "").replace("Optional[", "").replace("]", "")
field_type = field_type.replace("Annotated[", "").replace("FieldInfo(", "").replace(")", "")
field_type = field_type.replace("llama_stack.apis.inference.inference.", "")
field_type = field_type.replace("llama_stack.providers.", "")
default_value = field.default
if field.default_factory is not None:
try:
default_value = field.default_factory()
# HACK ALERT:
# If the default value contains a path that looks like it came from RUNTIME_BASE_DIR,
# replace it with a generic ~/.llama/ path for documentation
if isinstance(default_value, str) and "/.llama/" in default_value:
if ".llama/" in default_value:
path_part = default_value.split(".llama/")[-1]
default_value = f"~/.llama/{path_part}"
except Exception:
default_value = ""
elif field.default is None:
default_value = ""
field_info = {
"type": field_type,
"description": field.description or "",
"default": default_value,
"required": field.default is None and not field.is_required,
}
fields_info[field_name] = field_info
if accepts_extra_config:
config_description = "Additional configuration options that will be forwarded to the underlying provider"
try:
import inspect
source = inspect.getsource(config_class)
lines = source.split("\n")
for i, line in enumerate(lines):
if "model_config" in line and "ConfigDict" in line and 'extra="allow"' in line:
comments = []
for j in range(i - 1, -1, -1):
stripped = lines[j].strip()
if stripped.startswith("#"):
comments.append(stripped[1:].strip())
elif stripped == "":
continue
else:
break
if comments:
config_description = " ".join(reversed(comments))
break
except Exception:
pass
fields_info["config"] = {
"type": "dict",
"description": config_description,
"default": "{}",
"required": False,
}
return {
"docstring": docstring,
"fields": fields_info,
"sample_config": getattr(config_class, "sample_run_config", None),
"accepts_extra_config": accepts_extra_config,
}
except Exception as e:
return {
"error": f"Failed to load config class {config_class_path}: {str(e)}",
"docstring": "",
"fields": {},
"sample_config": None,
"accepts_extra_config": False,
}
def generate_provider_docs(provider_spec: Any, api_name: str) -> str:
"""Generate markdown documentation for a provider."""
provider_type = provider_spec.provider_type
config_class = provider_spec.config_class
config_info = get_config_class_info(config_class)
md_lines = []
md_lines.append(f"# {provider_type}")
md_lines.append("")
description = ""
if hasattr(provider_spec, "description") and provider_spec.description:
description = provider_spec.description
elif (
hasattr(provider_spec, "adapter")
and hasattr(provider_spec.adapter, "description")
and provider_spec.adapter.description
):
description = provider_spec.adapter.description
elif config_info.get("docstring"):
description = config_info["docstring"]
if description:
md_lines.append("## Description")
md_lines.append("")
md_lines.append(description)
md_lines.append("")
if config_info.get("fields"):
md_lines.append("## Configuration")
md_lines.append("")
md_lines.append("| Field | Type | Required | Default | Description |")
md_lines.append("|-------|------|----------|---------|-------------|")
for field_name, field_info in config_info["fields"].items():
field_type = field_info["type"].replace("|", "\\|")
required = "Yes" if field_info["required"] else "No"
default = str(field_info["default"]) if field_info["default"] is not None else ""
description = field_info["description"] or ""
md_lines.append(f"| `{field_name}` | `{field_type}` | {required} | {default} | {description} |")
md_lines.append("")
if config_info.get("accepts_extra_config"):
md_lines.append(
"> **Note**: This configuration class accepts additional fields beyond those listed above. You can pass any additional configuration options that will be forwarded to the underlying provider."
)
md_lines.append("")
if config_info.get("sample_config"):
md_lines.append("## Sample Configuration")
md_lines.append("")
md_lines.append("```yaml")
try:
sample_config_func = config_info["sample_config"]
import inspect
import yaml
if sample_config_func is not None:
sig = inspect.signature(sample_config_func)
if "__distro_dir__" in sig.parameters:
sample_config = sample_config_func(__distro_dir__="~/.llama/dummy")
else:
sample_config = sample_config_func()
def convert_pydantic_to_dict(obj):
if hasattr(obj, "model_dump"):
return obj.model_dump()
elif hasattr(obj, "dict"):
return obj.dict()
elif isinstance(obj, dict):
return {k: convert_pydantic_to_dict(v) for k, v in obj.items()}
elif isinstance(obj, list):
return [convert_pydantic_to_dict(item) for item in obj]
else:
return obj
sample_config_dict = convert_pydantic_to_dict(sample_config)
md_lines.append(yaml.dump(sample_config_dict, default_flow_style=False, sort_keys=False))
else:
md_lines.append("# No sample configuration available.")
except Exception as e:
md_lines.append(f"# Error generating sample config: {str(e)}")
md_lines.append("```")
md_lines.append("")
if hasattr(provider_spec, "deprecation_warning") and provider_spec.deprecation_warning:
md_lines.append("## Deprecation Notice")
md_lines.append("")
md_lines.append(f"⚠️ **Warning**: {provider_spec.deprecation_warning}")
md_lines.append("")
if hasattr(provider_spec, "deprecation_error") and provider_spec.deprecation_error:
md_lines.append("## Deprecation Error")
md_lines.append("")
md_lines.append(f"❌ **Error**: {provider_spec.deprecation_error}")
return "\n".join(md_lines) + "\n"
def process_provider_registry(progress, change_tracker: ChangedPathTracker) -> None:
"""Process the complete provider registry."""
progress.print("Processing provider registry")
try:
provider_registry = get_provider_registry()
for api, providers in provider_registry.items():
api_name = api.value
doc_output_dir = REPO_ROOT / "docs" / "source" / "providers" / api_name
doc_output_dir.mkdir(parents=True, exist_ok=True)
change_tracker.add_paths(doc_output_dir)
index_content = []
index_content.append(f"# {api_name.title()} Providers")
index_content.append("")
index_content.append(
f"This section contains documentation for all available providers for the **{api_name}** API."
)
index_content.append("")
for provider_type, provider in sorted(providers.items()):
provider_doc_file = doc_output_dir / f"{provider_type.replace('::', '_').replace(':', '_')}.md"
provider_docs = generate_provider_docs(provider, api_name)
provider_doc_file.write_text(provider_docs)
change_tracker.add_paths(provider_doc_file)
index_content.append(f"- [{provider_type}]({provider_doc_file.name})")
index_file = doc_output_dir / "index.md"
index_file.write_text("\n".join(index_content))
change_tracker.add_paths(index_file)
except Exception as e:
progress.print(f"[red]Error processing provider registry: {str(e)}")
raise e
def check_for_changes(change_tracker: ChangedPathTracker) -> bool:
"""Check if there are any uncommitted changes, including new files."""
has_changes = False
for path in change_tracker.changed_paths():
result = subprocess.run(
["git", "diff", "--exit-code", path],
cwd=REPO_ROOT,
capture_output=True,
)
if result.returncode != 0:
print(f"Change detected in '{path}'.", file=sys.stderr)
has_changes = True
status_result = subprocess.run(
["git", "status", "--porcelain", path],
cwd=REPO_ROOT,
capture_output=True,
text=True,
)
for line in status_result.stdout.splitlines():
if line.startswith("??"):
print(f"New file detected: '{path}'.", file=sys.stderr)
has_changes = True
return has_changes
def main():
change_tracker = ChangedPathTracker()
with Progress(
SpinnerColumn(),
TextColumn("[progress.description]{task.description}"),
) as progress:
task = progress.add_task("Processing provider registry...", total=1)
process_provider_registry(progress, change_tracker)
progress.update(task, advance=1)
if check_for_changes(change_tracker):
print(
"Provider documentation changes detected. Please commit the changes.",
file=sys.stderr,
)
sys.exit(1)
sys.exit(0)
if __name__ == "__main__":
main()