diff --git a/scripts/provider_codegen.py b/scripts/provider_codegen.py index d62d626ad..fcab50925 100755 --- a/scripts/provider_codegen.py +++ b/scripts/provider_codegen.py @@ -5,6 +5,7 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. +import re import subprocess import sys from pathlib import Path @@ -22,7 +23,9 @@ def get_api_docstring(api_name: str) -> str | None: """Extract docstring from the API protocol class.""" try: # Import the API module dynamically - api_module = __import__(f"llama_stack_api.{api_name}", fromlist=[api_name.title()]) + api_module = __import__( + f"llama_stack_api.{api_name}", fromlist=[api_name.title()] + ) # Get the main protocol class (usually capitalized API name) protocol_class_name = api_name.title() @@ -70,7 +73,10 @@ def get_config_class_info(config_class_path: str) -> dict[str, Any]: model_config = config_class.model_config if hasattr(model_config, "extra") and model_config.extra == "allow": accepts_extra_config = True - elif isinstance(model_config, dict) and model_config.get("extra") == "allow": + elif ( + isinstance(model_config, dict) + and model_config.get("extra") == "allow" + ): accepts_extra_config = True fields_info = {} @@ -81,12 +87,25 @@ def get_config_class_info(config_class_path: str) -> dict[str, Any]: field_type = str(field.annotation) if field.annotation else "Any" # this string replace is ridiculous - field_type = field_type.replace("typing.", "").replace("Optional[", "").replace("]", "") - field_type = field_type.replace("Annotated[", "").replace("FieldInfo(", "").replace(")", "") + field_type = ( + field_type.replace("typing.", "") + .replace("Optional[", "") + .replace("]", "") + ) + field_type = ( + field_type.replace("Annotated[", "") + .replace("FieldInfo(", "") + .replace(")", "") + ) field_type = field_type.replace("llama_stack_api.inference.", "") field_type = field_type.replace("llama_stack.providers.", "") field_type = field_type.replace("llama_stack_api.datatypes.", "") + field_type = re.sub(r"Optional\[([^\]]+)\]", r"\1 | None", field_type) + field_type = re.sub(r"Annotated\[([^,]+),.*?\]", r"\1", field_type) + field_type = re.sub(r"FieldInfo\([^)]*\)", "", field_type) + field_type = re.sub(r"", r"\1", field_type) + default_value = field.default if field.default_factory is not None: try: @@ -94,7 +113,10 @@ def get_config_class_info(config_class_path: str) -> dict[str, Any]: # HACK ALERT: # If the default value contains a path that looks like it came from RUNTIME_BASE_DIR, # replace it with a generic ~/.llama/ path for documentation - if isinstance(default_value, str) and "/.llama/" in default_value: + if ( + isinstance(default_value, str) + and "/.llama/" in default_value + ): if ".llama/" in default_value: path_part = default_value.split(".llama/")[-1] default_value = f"~/.llama/{path_part}" @@ -123,7 +145,11 @@ def get_config_class_info(config_class_path: str) -> dict[str, Any]: lines = source.split("\n") for i, line in enumerate(lines): - if "model_config" in line and "ConfigDict" in line and 'extra="allow"' in line: + if ( + "model_config" in line + and "ConfigDict" in line + and 'extra="allow"' in line + ): comments = [] for j in range(i - 1, -1, -1): stripped = lines[j].strip() @@ -188,7 +214,9 @@ def generate_provider_docs(progress, provider_spec: Any, api_name: str) -> str: # Create sidebar label (clean up provider_type for display) sidebar_label = provider_type.replace("::", " - ").replace("_", " ") if sidebar_label.startswith("inline - "): - sidebar_label = sidebar_label[9:].title() # Remove "inline - " prefix and title case + sidebar_label = sidebar_label[ + 9: + ].title() # Remove "inline - " prefix and title case else: sidebar_label = sidebar_label.title() @@ -231,7 +259,9 @@ def generate_provider_docs(progress, provider_spec: Any, api_name: str) -> str: for field_name, field_info in config_info["fields"].items(): field_type = field_info["type"].replace("|", "\\|") required = "Yes" if field_info["required"] else "No" - default = str(field_info["default"]) if field_info["default"] is not None else "" + default = ( + str(field_info["default"]) if field_info["default"] is not None else "" + ) # Handle multiline default values and escape problematic characters for MDX if "\n" in default: @@ -241,7 +271,9 @@ def generate_provider_docs(progress, provider_spec: Any, api_name: str) -> str: for line in lines: if line.strip(): # Escape angle brackets and wrap template tokens in backticks - escaped_line = line.strip().replace("<", "<").replace(">", ">") + escaped_line = ( + line.strip().replace("<", "<").replace(">", ">") + ) if ("{" in escaped_line and "}" in escaped_line) or ( "<|" in escaped_line and "|>" in escaped_line ): @@ -261,13 +293,19 @@ def generate_provider_docs(progress, provider_spec: Any, api_name: str) -> str: default = f"`{escaped_default}`" else: # Apply additional escaping for curly braces - default = escaped_default.replace("{", "{").replace("}", "}") + default = escaped_default.replace("{", "{").replace( + "}", "}" + ) description_text = field_info["description"] or "" # Escape curly braces in description text for MDX compatibility - description_text = description_text.replace("{", "{").replace("}", "}") + description_text = description_text.replace("{", "{").replace( + "}", "}" + ) - md_lines.append(f"| `{field_name}` | `{field_type}` | {required} | {default} | {description_text} |") + md_lines.append( + f"| `{field_name}` | `{field_type}` | {required} | {default} | {description_text} |" + ) md_lines.append("") @@ -310,7 +348,9 @@ def generate_provider_docs(progress, provider_spec: Any, api_name: str) -> str: sample_config_dict = convert_pydantic_to_dict(sample_config) # Strip trailing newlines from yaml.dump to prevent extra blank lines - yaml_output = yaml.dump(sample_config_dict, default_flow_style=False, sort_keys=False).rstrip() + yaml_output = yaml.dump( + sample_config_dict, default_flow_style=False, sort_keys=False + ).rstrip() md_lines.append(yaml_output) else: md_lines.append("# No sample configuration available.") @@ -318,7 +358,10 @@ def generate_provider_docs(progress, provider_spec: Any, api_name: str) -> str: md_lines.append(f"# Error generating sample config: {str(e)}") md_lines.append("```") - if hasattr(provider_spec, "deprecation_warning") and provider_spec.deprecation_warning: + if ( + hasattr(provider_spec, "deprecation_warning") + and provider_spec.deprecation_warning + ): md_lines.append("## Deprecation Notice") md_lines.append("") md_lines.append(":::warning") @@ -335,7 +378,9 @@ def generate_provider_docs(progress, provider_spec: Any, api_name: str) -> str: return "\n".join(md_lines) + "\n" -def generate_index_docs(api_name: str, api_docstring: str | None, provider_entries: list) -> str: +def generate_index_docs( + api_name: str, api_docstring: str | None, provider_entries: list +) -> str: """Generate MDX documentation for the index file.""" # Create sidebar label for the API sidebar_label = api_name.replace("_", " ").title() @@ -363,7 +408,9 @@ def generate_index_docs(api_name: str, api_docstring: str | None, provider_entri md_lines.append(f"{cleaned_docstring}") md_lines.append("") - md_lines.append(f"This section contains documentation for all available providers for the **{api_name}** API.") + md_lines.append( + f"This section contains documentation for all available providers for the **{api_name}** API." + ) return "\n".join(md_lines) + "\n" @@ -401,10 +448,14 @@ def process_provider_registry(progress, change_tracker: ChangedPathTracker) -> N else: display_name = display_name.title() - provider_entries.append({"filename": filename, "display_name": display_name}) + provider_entries.append( + {"filename": filename, "display_name": display_name} + ) # Generate index file with frontmatter - index_content = generate_index_docs(api_name, api_docstring, provider_entries) + index_content = generate_index_docs( + api_name, api_docstring, provider_entries + ) index_file = doc_output_dir / "index.mdx" index_file.write_text(index_content) change_tracker.add_paths(index_file)