llama-stack-mirror/scripts/provider_codegen.py
2025-07-23 10:26:13 -04:00

333 lines
13 KiB
Python
Executable file

#!/usr/bin/env python
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import subprocess
import sys
from pathlib import Path
from typing import Any
from rich.progress import Progress, SpinnerColumn, TextColumn
from llama_stack.distribution.distribution import get_provider_registry
REPO_ROOT = Path(__file__).parent.parent
class ChangedPathTracker:
"""Track a list of paths we may have changed."""
def __init__(self):
self._changed_paths = []
def add_paths(self, *paths):
for path in paths:
path = str(path)
if path not in self._changed_paths:
self._changed_paths.append(path)
def changed_paths(self):
return self._changed_paths
def get_config_class_info(config_class_path: str) -> dict[str, Any]:
"""Extract configuration information from a config class."""
try:
module_path, class_name = config_class_path.rsplit(".", 1)
module = __import__(module_path, fromlist=[class_name])
config_class = getattr(module, class_name)
docstring = config_class.__doc__ or ""
accepts_extra_config = False
try:
schema = config_class.model_json_schema()
if schema.get("additionalProperties") is True:
accepts_extra_config = True
except Exception:
if hasattr(config_class, "model_config"):
model_config = config_class.model_config
if hasattr(model_config, "extra") and model_config.extra == "allow":
accepts_extra_config = True
elif isinstance(model_config, dict) and model_config.get("extra") == "allow":
accepts_extra_config = True
fields_info = {}
if hasattr(config_class, "model_fields"):
for field_name, field in config_class.model_fields.items():
field_type = str(field.annotation) if field.annotation else "Any"
field_type = field_type.replace("typing.", "").replace("Optional[", "").replace("]", "")
field_type = field_type.replace("Annotated[", "").replace("FieldInfo(", "").replace(")", "")
field_type = field_type.replace("llama_stack.apis.inference.inference.", "")
field_type = field_type.replace("llama_stack.providers.", "")
default_value = field.default
if field.default_factory is not None:
try:
default_value = field.default_factory()
# HACK ALERT:
# If the default value contains a path that looks like it came from RUNTIME_BASE_DIR,
# replace it with a generic ~/.llama/ path for documentation
if isinstance(default_value, str) and "/.llama/" in default_value:
if ".llama/" in default_value:
path_part = default_value.split(".llama/")[-1]
default_value = f"~/.llama/{path_part}"
except Exception:
default_value = ""
elif field.default is None:
default_value = ""
field_info = {
"type": field_type,
"description": field.description or "",
"default": default_value,
"required": field.default is None and not field.is_required,
}
fields_info[field_name] = field_info
if accepts_extra_config:
config_description = "Additional configuration options that will be forwarded to the underlying provider"
try:
import inspect
source = inspect.getsource(config_class)
lines = source.split("\n")
for i, line in enumerate(lines):
if "model_config" in line and "ConfigDict" in line and 'extra="allow"' in line:
comments = []
for j in range(i - 1, -1, -1):
stripped = lines[j].strip()
if stripped.startswith("#"):
comments.append(stripped[1:].strip())
elif stripped == "":
continue
else:
break
if comments:
config_description = " ".join(reversed(comments))
break
except Exception:
pass
fields_info["config"] = {
"type": "dict",
"description": config_description,
"default": "{}",
"required": False,
}
return {
"docstring": docstring,
"fields": fields_info,
"sample_config": getattr(config_class, "sample_run_config", None),
"accepts_extra_config": accepts_extra_config,
}
except Exception as e:
return {
"error": f"Failed to load config class {config_class_path}: {str(e)}",
"docstring": "",
"fields": {},
"sample_config": None,
"accepts_extra_config": False,
}
def generate_provider_docs(provider_spec: Any, api_name: str) -> str:
"""Generate markdown documentation for a provider."""
provider_type = provider_spec.provider_type
config_class = provider_spec.config_class
config_info = get_config_class_info(config_class)
md_lines = []
md_lines.append("---\norphan: true\n---\n")
md_lines.append(f"# {provider_type}")
md_lines.append("")
description = ""
if hasattr(provider_spec, "description") and provider_spec.description:
description = provider_spec.description
elif (
hasattr(provider_spec, "adapter")
and hasattr(provider_spec.adapter, "description")
and provider_spec.adapter.description
):
description = provider_spec.adapter.description
elif config_info.get("docstring"):
description = config_info["docstring"]
if description:
md_lines.append("## Description")
md_lines.append("")
md_lines.append(description)
md_lines.append("")
if config_info.get("fields"):
md_lines.append("## Configuration")
md_lines.append("")
md_lines.append("| Field | Type | Required | Default | Description |")
md_lines.append("|-------|------|----------|---------|-------------|")
for field_name, field_info in config_info["fields"].items():
field_type = field_info["type"].replace("|", "\\|")
required = "Yes" if field_info["required"] else "No"
default = str(field_info["default"]) if field_info["default"] is not None else ""
description = field_info["description"] or ""
md_lines.append(f"| `{field_name}` | `{field_type}` | {required} | {default} | {description} |")
md_lines.append("")
if config_info.get("accepts_extra_config"):
md_lines.append(
"> **Note**: This configuration class accepts additional fields beyond those listed above. You can pass any additional configuration options that will be forwarded to the underlying provider."
)
md_lines.append("")
if config_info.get("sample_config"):
md_lines.append("## Sample Configuration")
md_lines.append("")
md_lines.append("```yaml")
try:
sample_config_func = config_info["sample_config"]
import inspect
import yaml
if sample_config_func is not None:
sig = inspect.signature(sample_config_func)
if "__distro_dir__" in sig.parameters:
sample_config = sample_config_func(__distro_dir__="~/.llama/dummy")
else:
sample_config = sample_config_func()
def convert_pydantic_to_dict(obj):
if hasattr(obj, "model_dump"):
return obj.model_dump()
elif hasattr(obj, "dict"):
return obj.dict()
elif isinstance(obj, dict):
return {k: convert_pydantic_to_dict(v) for k, v in obj.items()}
elif isinstance(obj, list):
return [convert_pydantic_to_dict(item) for item in obj]
else:
return obj
sample_config_dict = convert_pydantic_to_dict(sample_config)
md_lines.append(yaml.dump(sample_config_dict, default_flow_style=False, sort_keys=False))
else:
md_lines.append("# No sample configuration available.")
except Exception as e:
md_lines.append(f"# Error generating sample config: {str(e)}")
md_lines.append("```")
md_lines.append("")
if hasattr(provider_spec, "deprecation_warning") and provider_spec.deprecation_warning:
md_lines.append("## Deprecation Notice")
md_lines.append("")
md_lines.append(f"⚠️ **Warning**: {provider_spec.deprecation_warning}")
md_lines.append("")
if hasattr(provider_spec, "deprecation_error") and provider_spec.deprecation_error:
md_lines.append("## Deprecation Error")
md_lines.append("")
md_lines.append(f"❌ **Error**: {provider_spec.deprecation_error}")
return "\n".join(md_lines) + "\n"
def process_provider_registry(progress, change_tracker: ChangedPathTracker) -> None:
"""Process the complete provider registry."""
progress.print("Processing provider registry")
try:
provider_registry = get_provider_registry()
for api, providers in provider_registry.items():
api_name = api.value
doc_output_dir = REPO_ROOT / "docs" / "source" / "providers" / api_name
doc_output_dir.mkdir(parents=True, exist_ok=True)
change_tracker.add_paths(doc_output_dir)
index_content = []
index_content.append(f"# {api_name.title()} Providers")
index_content.append("")
index_content.append(
f"This section contains documentation for all available providers for the **{api_name}** API."
)
index_content.append("")
for provider_type, provider in sorted(providers.items()):
provider_doc_file = doc_output_dir / f"{provider_type.replace('::', '_').replace(':', '_')}.md"
provider_docs = generate_provider_docs(provider, api_name)
provider_doc_file.write_text(provider_docs)
change_tracker.add_paths(provider_doc_file)
index_content.append(f"- [{provider_type}]({provider_doc_file.name})")
index_file = doc_output_dir / "index.md"
index_file.write_text("\n".join(index_content))
change_tracker.add_paths(index_file)
except Exception as e:
progress.print(f"[red]Error processing provider registry: {str(e)}")
raise e
def check_for_changes(change_tracker: ChangedPathTracker) -> bool:
"""Check if there are any uncommitted changes, including new files."""
has_changes = False
for path in change_tracker.changed_paths():
result = subprocess.run(
["git", "diff", "--exit-code", path],
cwd=REPO_ROOT,
capture_output=True,
)
if result.returncode != 0:
print(f"Change detected in '{path}'.", file=sys.stderr)
has_changes = True
status_result = subprocess.run(
["git", "status", "--porcelain", path],
cwd=REPO_ROOT,
capture_output=True,
text=True,
)
for line in status_result.stdout.splitlines():
if line.startswith("??"):
print(f"New file detected: '{path}'.", file=sys.stderr)
has_changes = True
return has_changes
def main():
change_tracker = ChangedPathTracker()
with Progress(
SpinnerColumn(),
TextColumn("[progress.description]{task.description}"),
) as progress:
task = progress.add_task("Processing provider registry...", total=1)
process_provider_registry(progress, change_tracker)
progress.update(task, advance=1)
if check_for_changes(change_tracker):
print(
"Provider documentation changes detected. Please commit the changes.",
file=sys.stderr,
)
sys.exit(1)
sys.exit(0)
if __name__ == "__main__":
main()