provider codegen update

This commit is contained in:
Alexey Rybak 2025-09-22 14:18:31 -07:00
parent 1617b83b3d
commit 08f0024797

View file

@ -10,11 +10,11 @@ import sys
from pathlib import Path from pathlib import Path
from typing import Any from typing import Any
from llama_stack.core.distribution import get_provider_registry
from pydantic_core import PydanticUndefined from pydantic_core import PydanticUndefined
from rich.progress import Progress, SpinnerColumn, TextColumn from rich.progress import Progress, SpinnerColumn, TextColumn
from llama_stack.core.distribution import get_provider_registry
REPO_ROOT = Path(__file__).parent.parent REPO_ROOT = Path(__file__).parent.parent
@ -22,7 +22,9 @@ def get_api_docstring(api_name: str) -> str | None:
"""Extract docstring from the API protocol class.""" """Extract docstring from the API protocol class."""
try: try:
# Import the API module dynamically # Import the API module dynamically
api_module = __import__(f"llama_stack.apis.{api_name}", fromlist=[api_name.title()]) api_module = __import__(
f"llama_stack.apis.{api_name}", fromlist=[api_name.title()]
)
# Get the main protocol class (usually capitalized API name) # Get the main protocol class (usually capitalized API name)
protocol_class_name = api_name.title() protocol_class_name = api_name.title()
@ -70,7 +72,10 @@ def get_config_class_info(config_class_path: str) -> dict[str, Any]:
model_config = config_class.model_config model_config = config_class.model_config
if hasattr(model_config, "extra") and model_config.extra == "allow": if hasattr(model_config, "extra") and model_config.extra == "allow":
accepts_extra_config = True accepts_extra_config = True
elif isinstance(model_config, dict) and model_config.get("extra") == "allow": elif (
isinstance(model_config, dict)
and model_config.get("extra") == "allow"
):
accepts_extra_config = True accepts_extra_config = True
fields_info = {} fields_info = {}
@ -79,9 +84,19 @@ def get_config_class_info(config_class_path: str) -> dict[str, Any]:
field_type = str(field.annotation) if field.annotation else "Any" field_type = str(field.annotation) if field.annotation else "Any"
# this string replace is ridiculous # this string replace is ridiculous
field_type = field_type.replace("typing.", "").replace("Optional[", "").replace("]", "") field_type = (
field_type = field_type.replace("Annotated[", "").replace("FieldInfo(", "").replace(")", "") field_type.replace("typing.", "")
field_type = field_type.replace("llama_stack.apis.inference.inference.", "") .replace("Optional[", "")
.replace("]", "")
)
field_type = (
field_type.replace("Annotated[", "")
.replace("FieldInfo(", "")
.replace(")", "")
)
field_type = field_type.replace(
"llama_stack.apis.inference.inference.", ""
)
field_type = field_type.replace("llama_stack.providers.", "") field_type = field_type.replace("llama_stack.providers.", "")
default_value = field.default default_value = field.default
@ -91,7 +106,10 @@ def get_config_class_info(config_class_path: str) -> dict[str, Any]:
# HACK ALERT: # HACK ALERT:
# If the default value contains a path that looks like it came from RUNTIME_BASE_DIR, # If the default value contains a path that looks like it came from RUNTIME_BASE_DIR,
# replace it with a generic ~/.llama/ path for documentation # replace it with a generic ~/.llama/ path for documentation
if isinstance(default_value, str) and "/.llama/" in default_value: if (
isinstance(default_value, str)
and "/.llama/" in default_value
):
if ".llama/" in default_value: if ".llama/" in default_value:
path_part = default_value.split(".llama/")[-1] path_part = default_value.split(".llama/")[-1]
default_value = f"~/.llama/{path_part}" default_value = f"~/.llama/{path_part}"
@ -117,7 +135,11 @@ def get_config_class_info(config_class_path: str) -> dict[str, Any]:
lines = source.split("\n") lines = source.split("\n")
for i, line in enumerate(lines): for i, line in enumerate(lines):
if "model_config" in line and "ConfigDict" in line and 'extra="allow"' in line: if (
"model_config" in line
and "ConfigDict" in line
and 'extra="allow"' in line
):
comments = [] comments = []
for j in range(i - 1, -1, -1): for j in range(i - 1, -1, -1):
stripped = lines[j].strip() stripped = lines[j].strip()
@ -158,7 +180,7 @@ def get_config_class_info(config_class_path: str) -> dict[str, Any]:
def generate_provider_docs(progress, provider_spec: Any, api_name: str) -> str: def generate_provider_docs(progress, provider_spec: Any, api_name: str) -> str:
"""Generate markdown documentation for a provider.""" """Generate MDX documentation for a provider."""
provider_type = provider_spec.provider_type provider_type = provider_spec.provider_type
config_class = provider_spec.config_class config_class = provider_spec.config_class
@ -166,10 +188,7 @@ def generate_provider_docs(progress, provider_spec: Any, api_name: str) -> str:
if "error" in config_info: if "error" in config_info:
progress.print(config_info["error"]) progress.print(config_info["error"])
md_lines = [] # Extract description for frontmatter
md_lines.append(f"# {provider_type}")
md_lines.append("")
description = "" description = ""
if hasattr(provider_spec, "description") and provider_spec.description: if hasattr(provider_spec, "description") and provider_spec.description:
description = provider_spec.description description = provider_spec.description
@ -182,6 +201,38 @@ def generate_provider_docs(progress, provider_spec: Any, api_name: str) -> str:
elif config_info.get("docstring"): elif config_info.get("docstring"):
description = config_info["docstring"] description = config_info["docstring"]
# Create sidebar label (clean up provider_type for display)
sidebar_label = provider_type.replace("::", " - ").replace("_", " ")
if sidebar_label.startswith("inline - "):
sidebar_label = sidebar_label[
9:
].title() # Remove "inline - " prefix and title case
else:
sidebar_label = sidebar_label.title()
md_lines = []
# Add YAML frontmatter
md_lines.append("---")
if description:
# Handle multi-line descriptions in YAML - keep it simple for single line
if "\n" in description.strip():
md_lines.append("description: |")
for line in description.strip().split("\n"):
md_lines.append(f" {line}")
else:
# For single line descriptions, format properly for YAML
clean_desc = description.strip().replace('"', '\\"')
md_lines.append(f'description: "{clean_desc}"')
md_lines.append(f"sidebar_label: {sidebar_label}")
md_lines.append(f"title: {provider_type}")
md_lines.append("---")
md_lines.append("")
# Add main title
md_lines.append(f"# {provider_type}")
md_lines.append("")
if description: if description:
md_lines.append("## Description") md_lines.append("## Description")
md_lines.append("") md_lines.append("")
@ -197,23 +248,27 @@ def generate_provider_docs(progress, provider_spec: Any, api_name: str) -> str:
for field_name, field_info in config_info["fields"].items(): for field_name, field_info in config_info["fields"].items():
field_type = field_info["type"].replace("|", "\\|") field_type = field_info["type"].replace("|", "\\|")
required = "Yes" if field_info["required"] else "No" required = "Yes" if field_info["required"] else "No"
default = str(field_info["default"]) if field_info["default"] is not None else "" default = (
str(field_info["default"]) if field_info["default"] is not None else ""
)
description = field_info["description"] or "" description = field_info["description"] or ""
md_lines.append(f"| `{field_name}` | `{field_type}` | {required} | {default} | {description} |") md_lines.append(
f"| `{field_name}` | `{field_type}` | {required} | {default} | {description} |"
)
md_lines.append("") md_lines.append("")
if config_info.get("accepts_extra_config"): if config_info.get("accepts_extra_config"):
md_lines.append( md_lines.append(
"```{note}\n This configuration class accepts additional fields beyond those listed above. You can pass any additional configuration options that will be forwarded to the underlying provider.\n ```\n" "\`\`\`{note}\n This configuration class accepts additional fields beyond those listed above. You can pass any additional configuration options that will be forwarded to the underlying provider.\n \`\`\`\n"
) )
md_lines.append("") md_lines.append("")
if config_info.get("sample_config"): if config_info.get("sample_config"):
md_lines.append("## Sample Configuration") md_lines.append("## Sample Configuration")
md_lines.append("") md_lines.append("")
md_lines.append("```yaml") md_lines.append("\`\`\`yaml")
try: try:
sample_config_func = config_info["sample_config"] sample_config_func = config_info["sample_config"]
import inspect import inspect
@ -240,18 +295,27 @@ def generate_provider_docs(progress, provider_spec: Any, api_name: str) -> str:
return obj return obj
sample_config_dict = convert_pydantic_to_dict(sample_config) sample_config_dict = convert_pydantic_to_dict(sample_config)
md_lines.append(yaml.dump(sample_config_dict, default_flow_style=False, sort_keys=False)) md_lines.append(
yaml.dump(
sample_config_dict, default_flow_style=False, sort_keys=False
)
)
else: else:
md_lines.append("# No sample configuration available.") md_lines.append("# No sample configuration available.")
except Exception as e: except Exception as e:
md_lines.append(f"# Error generating sample config: {str(e)}") md_lines.append(f"# Error generating sample config: {str(e)}")
md_lines.append("```") md_lines.append("\`\`\`")
md_lines.append("") md_lines.append("")
if hasattr(provider_spec, "deprecation_warning") and provider_spec.deprecation_warning: if (
hasattr(provider_spec, "deprecation_warning")
and provider_spec.deprecation_warning
):
md_lines.append("## Deprecation Notice") md_lines.append("## Deprecation Notice")
md_lines.append("") md_lines.append("")
md_lines.append(f"```{{warning}}\n{provider_spec.deprecation_warning}\n```") md_lines.append(
f"\`\`\`{{warning}}\n{provider_spec.deprecation_warning}\n\`\`\`"
)
md_lines.append("") md_lines.append("")
if hasattr(provider_spec, "deprecation_error") and provider_spec.deprecation_error: if hasattr(provider_spec, "deprecation_error") and provider_spec.deprecation_error:
@ -262,6 +326,55 @@ def generate_provider_docs(progress, provider_spec: Any, api_name: str) -> str:
return "\n".join(md_lines) + "\n" return "\n".join(md_lines) + "\n"
def generate_index_docs(
api_name: str, api_docstring: str, toctree_entries: list
) -> str:
"""Generate MDX documentation for the index file."""
# Create sidebar label for the API
sidebar_label = api_name.replace("_", " ").title()
md_lines = []
# Add YAML frontmatter for index
md_lines.append("---")
if api_docstring:
clean_desc = api_docstring.strip().replace('"', '\\"')
md_lines.append(f'description: "{clean_desc}"')
md_lines.append(f"sidebar_label: {sidebar_label}")
md_lines.append(f"title: {api_name.title()}")
md_lines.append("---")
md_lines.append("")
# Add main content
md_lines.append(f"# {api_name.title()}")
md_lines.append("")
md_lines.append("## Overview")
md_lines.append("")
if api_docstring:
cleaned_docstring = api_docstring.strip()
md_lines.append(f"{cleaned_docstring}")
md_lines.append("")
md_lines.append(
f"This section contains documentation for all available providers for the **{api_name}** API."
)
md_lines.append("")
md_lines.append("## Providers")
md_lines.append("")
md_lines.append(f"\`\`\`{{toctree}}")
md_lines.append(":maxdepth: 1")
md_lines.append("")
for entry in toctree_entries:
md_lines.append(entry)
md_lines.append("\`\`\`")
md_lines.append("")
return "\n".join(md_lines)
def process_provider_registry(progress, change_tracker: ChangedPathTracker) -> None: def process_provider_registry(progress, change_tracker: ChangedPathTracker) -> None:
"""Process the complete provider registry.""" """Process the complete provider registry."""
progress.print("Processing provider registry") progress.print("Processing provider registry")
@ -272,30 +385,16 @@ def process_provider_registry(progress, change_tracker: ChangedPathTracker) -> N
for api, providers in provider_registry.items(): for api, providers in provider_registry.items():
api_name = api.value api_name = api.value
doc_output_dir = REPO_ROOT / "docs" / "source" / "providers" / api_name doc_output_dir = REPO_ROOT / "docs" / "docs" / "providers" / api_name
doc_output_dir.mkdir(parents=True, exist_ok=True) doc_output_dir.mkdir(parents=True, exist_ok=True)
change_tracker.add_paths(doc_output_dir) change_tracker.add_paths(doc_output_dir)
index_content = []
index_content.append(f"# {api_name.title()}\n")
index_content.append("## Overview\n")
api_docstring = get_api_docstring(api_name) api_docstring = get_api_docstring(api_name)
if api_docstring:
cleaned_docstring = api_docstring.strip()
index_content.append(f"{cleaned_docstring}\n")
index_content.append(
f"This section contains documentation for all available providers for the **{api_name}** API.\n"
)
index_content.append("## Providers\n")
toctree_entries = [] toctree_entries = []
for provider_type, provider in sorted(providers.items()): for provider_type, provider in sorted(providers.items()):
filename = provider_type.replace("::", "_").replace(":", "_") filename = provider_type.replace("::", "_").replace(":", "_")
provider_doc_file = doc_output_dir / f"{filename}.md" provider_doc_file = doc_output_dir / f"{filename}.mdx"
provider_docs = generate_provider_docs(progress, provider, api_name) provider_docs = generate_provider_docs(progress, provider, api_name)
@ -303,10 +402,12 @@ def process_provider_registry(progress, change_tracker: ChangedPathTracker) -> N
change_tracker.add_paths(provider_doc_file) change_tracker.add_paths(provider_doc_file)
toctree_entries.append(f"{filename}") toctree_entries.append(f"{filename}")
index_content.append(f"```{{toctree}}\n:maxdepth: 1\n\n{'\n'.join(toctree_entries)}\n```\n") # Generate index file with frontmatter
index_content = generate_index_docs(
index_file = doc_output_dir / "index.md" api_name, api_docstring, toctree_entries
index_file.write_text("\n".join(index_content)) )
index_file = doc_output_dir / "index.mdx"
index_file.write_text(index_content)
change_tracker.add_paths(index_file) change_tracker.add_paths(index_file)
except Exception as e: except Exception as e: