Use regex for Optional, Annotated, FieldInfo and class cleanup

Signed-off-by: thepetk <thepetk@gmail.com>
This commit is contained in:
thepetk 2025-11-02 19:05:11 +00:00
parent 5ea1be69fe
commit a1ff4984e6

View file

@ -5,6 +5,7 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import re
import subprocess
import sys
from pathlib import Path
@ -22,7 +23,9 @@ def get_api_docstring(api_name: str) -> str | None:
"""Extract docstring from the API protocol class."""
try:
# Import the API module dynamically
api_module = __import__(f"llama_stack_api.{api_name}", fromlist=[api_name.title()])
api_module = __import__(
f"llama_stack_api.{api_name}", fromlist=[api_name.title()]
)
# Get the main protocol class (usually capitalized API name)
protocol_class_name = api_name.title()
@ -70,7 +73,10 @@ def get_config_class_info(config_class_path: str) -> dict[str, Any]:
model_config = config_class.model_config
if hasattr(model_config, "extra") and model_config.extra == "allow":
accepts_extra_config = True
elif isinstance(model_config, dict) and model_config.get("extra") == "allow":
elif (
isinstance(model_config, dict)
and model_config.get("extra") == "allow"
):
accepts_extra_config = True
fields_info = {}
@ -81,12 +87,25 @@ def get_config_class_info(config_class_path: str) -> dict[str, Any]:
field_type = str(field.annotation) if field.annotation else "Any"
# this string replace is ridiculous
field_type = field_type.replace("typing.", "").replace("Optional[", "").replace("]", "")
field_type = field_type.replace("Annotated[", "").replace("FieldInfo(", "").replace(")", "")
field_type = (
field_type.replace("typing.", "")
.replace("Optional[", "")
.replace("]", "")
)
field_type = (
field_type.replace("Annotated[", "")
.replace("FieldInfo(", "")
.replace(")", "")
)
field_type = field_type.replace("llama_stack_api.inference.", "")
field_type = field_type.replace("llama_stack.providers.", "")
field_type = field_type.replace("llama_stack_api.datatypes.", "")
field_type = re.sub(r"Optional\[([^\]]+)\]", r"\1 | None", field_type)
field_type = re.sub(r"Annotated\[([^,]+),.*?\]", r"\1", field_type)
field_type = re.sub(r"FieldInfo\([^)]*\)", "", field_type)
field_type = re.sub(r"<class '([^']+)'>", r"\1", field_type)
default_value = field.default
if field.default_factory is not None:
try:
@ -94,7 +113,10 @@ def get_config_class_info(config_class_path: str) -> dict[str, Any]:
# HACK ALERT:
# If the default value contains a path that looks like it came from RUNTIME_BASE_DIR,
# replace it with a generic ~/.llama/ path for documentation
if isinstance(default_value, str) and "/.llama/" in default_value:
if (
isinstance(default_value, str)
and "/.llama/" in default_value
):
if ".llama/" in default_value:
path_part = default_value.split(".llama/")[-1]
default_value = f"~/.llama/{path_part}"
@ -123,7 +145,11 @@ def get_config_class_info(config_class_path: str) -> dict[str, Any]:
lines = source.split("\n")
for i, line in enumerate(lines):
if "model_config" in line and "ConfigDict" in line and 'extra="allow"' in line:
if (
"model_config" in line
and "ConfigDict" in line
and 'extra="allow"' in line
):
comments = []
for j in range(i - 1, -1, -1):
stripped = lines[j].strip()
@ -188,7 +214,9 @@ def generate_provider_docs(progress, provider_spec: Any, api_name: str) -> str:
# Create sidebar label (clean up provider_type for display)
sidebar_label = provider_type.replace("::", " - ").replace("_", " ")
if sidebar_label.startswith("inline - "):
sidebar_label = sidebar_label[9:].title() # Remove "inline - " prefix and title case
sidebar_label = sidebar_label[
9:
].title() # Remove "inline - " prefix and title case
else:
sidebar_label = sidebar_label.title()
@ -231,7 +259,9 @@ def generate_provider_docs(progress, provider_spec: Any, api_name: str) -> str:
for field_name, field_info in config_info["fields"].items():
field_type = field_info["type"].replace("|", "\\|")
required = "Yes" if field_info["required"] else "No"
default = str(field_info["default"]) if field_info["default"] is not None else ""
default = (
str(field_info["default"]) if field_info["default"] is not None else ""
)
# Handle multiline default values and escape problematic characters for MDX
if "\n" in default:
@ -241,7 +271,9 @@ def generate_provider_docs(progress, provider_spec: Any, api_name: str) -> str:
for line in lines:
if line.strip():
# Escape angle brackets and wrap template tokens in backticks
escaped_line = line.strip().replace("<", "&lt;").replace(">", "&gt;")
escaped_line = (
line.strip().replace("<", "&lt;").replace(">", "&gt;")
)
if ("{" in escaped_line and "}" in escaped_line) or (
"&lt;|" in escaped_line and "|&gt;" in escaped_line
):
@ -261,13 +293,19 @@ def generate_provider_docs(progress, provider_spec: Any, api_name: str) -> str:
default = f"`{escaped_default}`"
else:
# Apply additional escaping for curly braces
default = escaped_default.replace("{", "&#123;").replace("}", "&#125;")
default = escaped_default.replace("{", "&#123;").replace(
"}", "&#125;"
)
description_text = field_info["description"] or ""
# Escape curly braces in description text for MDX compatibility
description_text = description_text.replace("{", "&#123;").replace("}", "&#125;")
description_text = description_text.replace("{", "&#123;").replace(
"}", "&#125;"
)
md_lines.append(f"| `{field_name}` | `{field_type}` | {required} | {default} | {description_text} |")
md_lines.append(
f"| `{field_name}` | `{field_type}` | {required} | {default} | {description_text} |"
)
md_lines.append("")
@ -310,7 +348,9 @@ def generate_provider_docs(progress, provider_spec: Any, api_name: str) -> str:
sample_config_dict = convert_pydantic_to_dict(sample_config)
# Strip trailing newlines from yaml.dump to prevent extra blank lines
yaml_output = yaml.dump(sample_config_dict, default_flow_style=False, sort_keys=False).rstrip()
yaml_output = yaml.dump(
sample_config_dict, default_flow_style=False, sort_keys=False
).rstrip()
md_lines.append(yaml_output)
else:
md_lines.append("# No sample configuration available.")
@ -318,7 +358,10 @@ def generate_provider_docs(progress, provider_spec: Any, api_name: str) -> str:
md_lines.append(f"# Error generating sample config: {str(e)}")
md_lines.append("```")
if hasattr(provider_spec, "deprecation_warning") and provider_spec.deprecation_warning:
if (
hasattr(provider_spec, "deprecation_warning")
and provider_spec.deprecation_warning
):
md_lines.append("## Deprecation Notice")
md_lines.append("")
md_lines.append(":::warning")
@ -335,7 +378,9 @@ def generate_provider_docs(progress, provider_spec: Any, api_name: str) -> str:
return "\n".join(md_lines) + "\n"
def generate_index_docs(api_name: str, api_docstring: str | None, provider_entries: list) -> str:
def generate_index_docs(
api_name: str, api_docstring: str | None, provider_entries: list
) -> str:
"""Generate MDX documentation for the index file."""
# Create sidebar label for the API
sidebar_label = api_name.replace("_", " ").title()
@ -363,7 +408,9 @@ def generate_index_docs(api_name: str, api_docstring: str | None, provider_entri
md_lines.append(f"{cleaned_docstring}")
md_lines.append("")
md_lines.append(f"This section contains documentation for all available providers for the **{api_name}** API.")
md_lines.append(
f"This section contains documentation for all available providers for the **{api_name}** API."
)
return "\n".join(md_lines) + "\n"
@ -401,10 +448,14 @@ def process_provider_registry(progress, change_tracker: ChangedPathTracker) -> N
else:
display_name = display_name.title()
provider_entries.append({"filename": filename, "display_name": display_name})
provider_entries.append(
{"filename": filename, "display_name": display_name}
)
# Generate index file with frontmatter
index_content = generate_index_docs(api_name, api_docstring, provider_entries)
index_content = generate_index_docs(
api_name, api_docstring, provider_entries
)
index_file = doc_output_dir / "index.mdx"
index_file.write_text(index_content)
change_tracker.add_paths(index_file)