diff --git a/scripts/provider_codegen.py b/scripts/provider_codegen.py index 17efa2138..2b2944d68 100755 --- a/scripts/provider_codegen.py +++ b/scripts/provider_codegen.py @@ -10,11 +10,11 @@ import sys from pathlib import Path from typing import Any +from llama_stack.core.distribution import get_provider_registry + from pydantic_core import PydanticUndefined from rich.progress import Progress, SpinnerColumn, TextColumn -from llama_stack.core.distribution import get_provider_registry - REPO_ROOT = Path(__file__).parent.parent @@ -22,7 +22,9 @@ def get_api_docstring(api_name: str) -> str | None: """Extract docstring from the API protocol class.""" try: # Import the API module dynamically - api_module = __import__(f"llama_stack.apis.{api_name}", fromlist=[api_name.title()]) + api_module = __import__( + f"llama_stack.apis.{api_name}", fromlist=[api_name.title()] + ) # Get the main protocol class (usually capitalized API name) protocol_class_name = api_name.title() @@ -70,7 +72,10 @@ def get_config_class_info(config_class_path: str) -> dict[str, Any]: model_config = config_class.model_config if hasattr(model_config, "extra") and model_config.extra == "allow": accepts_extra_config = True - elif isinstance(model_config, dict) and model_config.get("extra") == "allow": + elif ( + isinstance(model_config, dict) + and model_config.get("extra") == "allow" + ): accepts_extra_config = True fields_info = {} @@ -79,9 +84,19 @@ def get_config_class_info(config_class_path: str) -> dict[str, Any]: field_type = str(field.annotation) if field.annotation else "Any" # this string replace is ridiculous - field_type = field_type.replace("typing.", "").replace("Optional[", "").replace("]", "") - field_type = field_type.replace("Annotated[", "").replace("FieldInfo(", "").replace(")", "") - field_type = field_type.replace("llama_stack.apis.inference.inference.", "") + field_type = ( + field_type.replace("typing.", "") + .replace("Optional[", "") + .replace("]", "") + ) + field_type = ( + field_type.replace("Annotated[", "") + .replace("FieldInfo(", "") + .replace(")", "") + ) + field_type = field_type.replace( + "llama_stack.apis.inference.inference.", "" + ) field_type = field_type.replace("llama_stack.providers.", "") default_value = field.default @@ -91,7 +106,10 @@ def get_config_class_info(config_class_path: str) -> dict[str, Any]: # HACK ALERT: # If the default value contains a path that looks like it came from RUNTIME_BASE_DIR, # replace it with a generic ~/.llama/ path for documentation - if isinstance(default_value, str) and "/.llama/" in default_value: + if ( + isinstance(default_value, str) + and "/.llama/" in default_value + ): if ".llama/" in default_value: path_part = default_value.split(".llama/")[-1] default_value = f"~/.llama/{path_part}" @@ -117,7 +135,11 @@ def get_config_class_info(config_class_path: str) -> dict[str, Any]: lines = source.split("\n") for i, line in enumerate(lines): - if "model_config" in line and "ConfigDict" in line and 'extra="allow"' in line: + if ( + "model_config" in line + and "ConfigDict" in line + and 'extra="allow"' in line + ): comments = [] for j in range(i - 1, -1, -1): stripped = lines[j].strip() @@ -158,7 +180,7 @@ def get_config_class_info(config_class_path: str) -> dict[str, Any]: def generate_provider_docs(progress, provider_spec: Any, api_name: str) -> str: - """Generate markdown documentation for a provider.""" + """Generate MDX documentation for a provider.""" provider_type = provider_spec.provider_type config_class = provider_spec.config_class @@ -166,10 +188,7 @@ def generate_provider_docs(progress, provider_spec: Any, api_name: str) -> str: if "error" in config_info: progress.print(config_info["error"]) - md_lines = [] - md_lines.append(f"# {provider_type}") - md_lines.append("") - + # Extract description for frontmatter description = "" if hasattr(provider_spec, "description") and provider_spec.description: description = provider_spec.description @@ -182,6 +201,38 @@ def generate_provider_docs(progress, provider_spec: Any, api_name: str) -> str: elif config_info.get("docstring"): description = config_info["docstring"] + # Create sidebar label (clean up provider_type for display) + sidebar_label = provider_type.replace("::", " - ").replace("_", " ") + if sidebar_label.startswith("inline - "): + sidebar_label = sidebar_label[ + 9: + ].title() # Remove "inline - " prefix and title case + else: + sidebar_label = sidebar_label.title() + + md_lines = [] + + # Add YAML frontmatter + md_lines.append("---") + if description: + # Handle multi-line descriptions in YAML - keep it simple for single line + if "\n" in description.strip(): + md_lines.append("description: |") + for line in description.strip().split("\n"): + md_lines.append(f" {line}") + else: + # For single line descriptions, format properly for YAML + clean_desc = description.strip().replace('"', '\\"') + md_lines.append(f'description: "{clean_desc}"') + md_lines.append(f"sidebar_label: {sidebar_label}") + md_lines.append(f"title: {provider_type}") + md_lines.append("---") + md_lines.append("") + + # Add main title + md_lines.append(f"# {provider_type}") + md_lines.append("") + if description: md_lines.append("## Description") md_lines.append("") @@ -197,23 +248,27 @@ def generate_provider_docs(progress, provider_spec: Any, api_name: str) -> str: for field_name, field_info in config_info["fields"].items(): field_type = field_info["type"].replace("|", "\\|") required = "Yes" if field_info["required"] else "No" - default = str(field_info["default"]) if field_info["default"] is not None else "" + default = ( + str(field_info["default"]) if field_info["default"] is not None else "" + ) description = field_info["description"] or "" - md_lines.append(f"| `{field_name}` | `{field_type}` | {required} | {default} | {description} |") + md_lines.append( + f"| `{field_name}` | `{field_type}` | {required} | {default} | {description} |" + ) md_lines.append("") if config_info.get("accepts_extra_config"): md_lines.append( - "```{note}\n This configuration class accepts additional fields beyond those listed above. You can pass any additional configuration options that will be forwarded to the underlying provider.\n ```\n" + "\`\`\`{note}\n This configuration class accepts additional fields beyond those listed above. You can pass any additional configuration options that will be forwarded to the underlying provider.\n \`\`\`\n" ) md_lines.append("") if config_info.get("sample_config"): md_lines.append("## Sample Configuration") md_lines.append("") - md_lines.append("```yaml") + md_lines.append("\`\`\`yaml") try: sample_config_func = config_info["sample_config"] import inspect @@ -240,18 +295,27 @@ def generate_provider_docs(progress, provider_spec: Any, api_name: str) -> str: return obj sample_config_dict = convert_pydantic_to_dict(sample_config) - md_lines.append(yaml.dump(sample_config_dict, default_flow_style=False, sort_keys=False)) + md_lines.append( + yaml.dump( + sample_config_dict, default_flow_style=False, sort_keys=False + ) + ) else: md_lines.append("# No sample configuration available.") except Exception as e: md_lines.append(f"# Error generating sample config: {str(e)}") - md_lines.append("```") + md_lines.append("\`\`\`") md_lines.append("") - if hasattr(provider_spec, "deprecation_warning") and provider_spec.deprecation_warning: + if ( + hasattr(provider_spec, "deprecation_warning") + and provider_spec.deprecation_warning + ): md_lines.append("## Deprecation Notice") md_lines.append("") - md_lines.append(f"```{{warning}}\n{provider_spec.deprecation_warning}\n```") + md_lines.append( + f"\`\`\`{{warning}}\n{provider_spec.deprecation_warning}\n\`\`\`" + ) md_lines.append("") if hasattr(provider_spec, "deprecation_error") and provider_spec.deprecation_error: @@ -262,6 +326,55 @@ def generate_provider_docs(progress, provider_spec: Any, api_name: str) -> str: return "\n".join(md_lines) + "\n" +def generate_index_docs( + api_name: str, api_docstring: str, toctree_entries: list +) -> str: + """Generate MDX documentation for the index file.""" + # Create sidebar label for the API + sidebar_label = api_name.replace("_", " ").title() + + md_lines = [] + + # Add YAML frontmatter for index + md_lines.append("---") + if api_docstring: + clean_desc = api_docstring.strip().replace('"', '\\"') + md_lines.append(f'description: "{clean_desc}"') + md_lines.append(f"sidebar_label: {sidebar_label}") + md_lines.append(f"title: {api_name.title()}") + md_lines.append("---") + md_lines.append("") + + # Add main content + md_lines.append(f"# {api_name.title()}") + md_lines.append("") + md_lines.append("## Overview") + md_lines.append("") + + if api_docstring: + cleaned_docstring = api_docstring.strip() + md_lines.append(f"{cleaned_docstring}") + md_lines.append("") + + md_lines.append( + f"This section contains documentation for all available providers for the **{api_name}** API." + ) + md_lines.append("") + + md_lines.append("## Providers") + md_lines.append("") + + md_lines.append(f"\`\`\`{{toctree}}") + md_lines.append(":maxdepth: 1") + md_lines.append("") + for entry in toctree_entries: + md_lines.append(entry) + md_lines.append("\`\`\`") + md_lines.append("") + + return "\n".join(md_lines) + + def process_provider_registry(progress, change_tracker: ChangedPathTracker) -> None: """Process the complete provider registry.""" progress.print("Processing provider registry") @@ -272,30 +385,16 @@ def process_provider_registry(progress, change_tracker: ChangedPathTracker) -> N for api, providers in provider_registry.items(): api_name = api.value - doc_output_dir = REPO_ROOT / "docs" / "source" / "providers" / api_name + doc_output_dir = REPO_ROOT / "docs" / "docs" / "providers" / api_name doc_output_dir.mkdir(parents=True, exist_ok=True) change_tracker.add_paths(doc_output_dir) - index_content = [] - index_content.append(f"# {api_name.title()}\n") - index_content.append("## Overview\n") - api_docstring = get_api_docstring(api_name) - if api_docstring: - cleaned_docstring = api_docstring.strip() - index_content.append(f"{cleaned_docstring}\n") - - index_content.append( - f"This section contains documentation for all available providers for the **{api_name}** API.\n" - ) - - index_content.append("## Providers\n") - toctree_entries = [] for provider_type, provider in sorted(providers.items()): filename = provider_type.replace("::", "_").replace(":", "_") - provider_doc_file = doc_output_dir / f"{filename}.md" + provider_doc_file = doc_output_dir / f"{filename}.mdx" provider_docs = generate_provider_docs(progress, provider, api_name) @@ -303,10 +402,12 @@ def process_provider_registry(progress, change_tracker: ChangedPathTracker) -> N change_tracker.add_paths(provider_doc_file) toctree_entries.append(f"{filename}") - index_content.append(f"```{{toctree}}\n:maxdepth: 1\n\n{'\n'.join(toctree_entries)}\n```\n") - - index_file = doc_output_dir / "index.md" - index_file.write_text("\n".join(index_content)) + # Generate index file with frontmatter + index_content = generate_index_docs( + api_name, api_docstring, toctree_entries + ) + index_file = doc_output_dir / "index.mdx" + index_file.write_text(index_content) change_tracker.add_paths(index_file) except Exception as e: