mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-10-04 12:07:34 +00:00
lint
This commit is contained in:
parent
a68b44d978
commit
a5940efe4f
1 changed files with 18 additions and 59 deletions
|
@ -10,11 +10,11 @@ import sys
|
|||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
from llama_stack.core.distribution import get_provider_registry
|
||||
|
||||
from pydantic_core import PydanticUndefined
|
||||
from rich.progress import Progress, SpinnerColumn, TextColumn
|
||||
|
||||
from llama_stack.core.distribution import get_provider_registry
|
||||
|
||||
REPO_ROOT = Path(__file__).parent.parent
|
||||
|
||||
|
||||
|
@ -22,9 +22,7 @@ def get_api_docstring(api_name: str) -> str | None:
|
|||
"""Extract docstring from the API protocol class."""
|
||||
try:
|
||||
# Import the API module dynamically
|
||||
api_module = __import__(
|
||||
f"llama_stack.apis.{api_name}", fromlist=[api_name.title()]
|
||||
)
|
||||
api_module = __import__(f"llama_stack.apis.{api_name}", fromlist=[api_name.title()])
|
||||
|
||||
# Get the main protocol class (usually capitalized API name)
|
||||
protocol_class_name = api_name.title()
|
||||
|
@ -72,10 +70,7 @@ def get_config_class_info(config_class_path: str) -> dict[str, Any]:
|
|||
model_config = config_class.model_config
|
||||
if hasattr(model_config, "extra") and model_config.extra == "allow":
|
||||
accepts_extra_config = True
|
||||
elif (
|
||||
isinstance(model_config, dict)
|
||||
and model_config.get("extra") == "allow"
|
||||
):
|
||||
elif isinstance(model_config, dict) and model_config.get("extra") == "allow":
|
||||
accepts_extra_config = True
|
||||
|
||||
fields_info = {}
|
||||
|
@ -84,19 +79,9 @@ def get_config_class_info(config_class_path: str) -> dict[str, Any]:
|
|||
field_type = str(field.annotation) if field.annotation else "Any"
|
||||
|
||||
# this string replace is ridiculous
|
||||
field_type = (
|
||||
field_type.replace("typing.", "")
|
||||
.replace("Optional[", "")
|
||||
.replace("]", "")
|
||||
)
|
||||
field_type = (
|
||||
field_type.replace("Annotated[", "")
|
||||
.replace("FieldInfo(", "")
|
||||
.replace(")", "")
|
||||
)
|
||||
field_type = field_type.replace(
|
||||
"llama_stack.apis.inference.inference.", ""
|
||||
)
|
||||
field_type = field_type.replace("typing.", "").replace("Optional[", "").replace("]", "")
|
||||
field_type = field_type.replace("Annotated[", "").replace("FieldInfo(", "").replace(")", "")
|
||||
field_type = field_type.replace("llama_stack.apis.inference.inference.", "")
|
||||
field_type = field_type.replace("llama_stack.providers.", "")
|
||||
|
||||
default_value = field.default
|
||||
|
@ -106,10 +91,7 @@ def get_config_class_info(config_class_path: str) -> dict[str, Any]:
|
|||
# HACK ALERT:
|
||||
# If the default value contains a path that looks like it came from RUNTIME_BASE_DIR,
|
||||
# replace it with a generic ~/.llama/ path for documentation
|
||||
if (
|
||||
isinstance(default_value, str)
|
||||
and "/.llama/" in default_value
|
||||
):
|
||||
if isinstance(default_value, str) and "/.llama/" in default_value:
|
||||
if ".llama/" in default_value:
|
||||
path_part = default_value.split(".llama/")[-1]
|
||||
default_value = f"~/.llama/{path_part}"
|
||||
|
@ -135,11 +117,7 @@ def get_config_class_info(config_class_path: str) -> dict[str, Any]:
|
|||
lines = source.split("\n")
|
||||
|
||||
for i, line in enumerate(lines):
|
||||
if (
|
||||
"model_config" in line
|
||||
and "ConfigDict" in line
|
||||
and 'extra="allow"' in line
|
||||
):
|
||||
if "model_config" in line and "ConfigDict" in line and 'extra="allow"' in line:
|
||||
comments = []
|
||||
for j in range(i - 1, -1, -1):
|
||||
stripped = lines[j].strip()
|
||||
|
@ -204,9 +182,7 @@ def generate_provider_docs(progress, provider_spec: Any, api_name: str) -> str:
|
|||
# Create sidebar label (clean up provider_type for display)
|
||||
sidebar_label = provider_type.replace("::", " - ").replace("_", " ")
|
||||
if sidebar_label.startswith("inline - "):
|
||||
sidebar_label = sidebar_label[
|
||||
9:
|
||||
].title() # Remove "inline - " prefix and title case
|
||||
sidebar_label = sidebar_label[9:].title() # Remove "inline - " prefix and title case
|
||||
else:
|
||||
sidebar_label = sidebar_label.title()
|
||||
|
||||
|
@ -249,14 +225,10 @@ def generate_provider_docs(progress, provider_spec: Any, api_name: str) -> str:
|
|||
for field_name, field_info in config_info["fields"].items():
|
||||
field_type = field_info["type"].replace("|", "\\|")
|
||||
required = "Yes" if field_info["required"] else "No"
|
||||
default = (
|
||||
str(field_info["default"]) if field_info["default"] is not None else ""
|
||||
)
|
||||
default = str(field_info["default"]) if field_info["default"] is not None else ""
|
||||
description_text = field_info["description"] or ""
|
||||
|
||||
md_lines.append(
|
||||
f"| `{field_name}` | `{field_type}` | {required} | {default} | {description_text} |"
|
||||
)
|
||||
md_lines.append(f"| `{field_name}` | `{field_type}` | {required} | {default} | {description_text} |")
|
||||
|
||||
md_lines.append("")
|
||||
|
||||
|
@ -299,9 +271,7 @@ def generate_provider_docs(progress, provider_spec: Any, api_name: str) -> str:
|
|||
|
||||
sample_config_dict = convert_pydantic_to_dict(sample_config)
|
||||
# Strip trailing newlines from yaml.dump to prevent extra blank lines
|
||||
yaml_output = yaml.dump(
|
||||
sample_config_dict, default_flow_style=False, sort_keys=False
|
||||
).rstrip()
|
||||
yaml_output = yaml.dump(sample_config_dict, default_flow_style=False, sort_keys=False).rstrip()
|
||||
md_lines.append(yaml_output)
|
||||
else:
|
||||
md_lines.append("# No sample configuration available.")
|
||||
|
@ -309,10 +279,7 @@ def generate_provider_docs(progress, provider_spec: Any, api_name: str) -> str:
|
|||
md_lines.append(f"# Error generating sample config: {str(e)}")
|
||||
md_lines.append("```")
|
||||
|
||||
if (
|
||||
hasattr(provider_spec, "deprecation_warning")
|
||||
and provider_spec.deprecation_warning
|
||||
):
|
||||
if hasattr(provider_spec, "deprecation_warning") and provider_spec.deprecation_warning:
|
||||
md_lines.append("## Deprecation Notice")
|
||||
md_lines.append("")
|
||||
md_lines.append(":::warning")
|
||||
|
@ -329,9 +296,7 @@ def generate_provider_docs(progress, provider_spec: Any, api_name: str) -> str:
|
|||
return "\n".join(md_lines) + "\n"
|
||||
|
||||
|
||||
def generate_index_docs(
|
||||
api_name: str, api_docstring: str | None, provider_entries: list
|
||||
) -> str:
|
||||
def generate_index_docs(api_name: str, api_docstring: str | None, provider_entries: list) -> str:
|
||||
"""Generate MDX documentation for the index file."""
|
||||
# Create sidebar label for the API
|
||||
sidebar_label = api_name.replace("_", " ").title()
|
||||
|
@ -359,9 +324,7 @@ def generate_index_docs(
|
|||
md_lines.append(f"{cleaned_docstring}")
|
||||
md_lines.append("")
|
||||
|
||||
md_lines.append(
|
||||
f"This section contains documentation for all available providers for the **{api_name}** API."
|
||||
)
|
||||
md_lines.append(f"This section contains documentation for all available providers for the **{api_name}** API.")
|
||||
md_lines.append("")
|
||||
|
||||
md_lines.append("## Providers")
|
||||
|
@ -409,14 +372,10 @@ def process_provider_registry(progress, change_tracker: ChangedPathTracker) -> N
|
|||
else:
|
||||
display_name = display_name.title()
|
||||
|
||||
provider_entries.append(
|
||||
{"filename": filename, "display_name": display_name}
|
||||
)
|
||||
provider_entries.append({"filename": filename, "display_name": display_name})
|
||||
|
||||
# Generate index file with frontmatter
|
||||
index_content = generate_index_docs(
|
||||
api_name, api_docstring, provider_entries
|
||||
)
|
||||
index_content = generate_index_docs(api_name, api_docstring, provider_entries)
|
||||
index_file = doc_output_dir / "index.mdx"
|
||||
index_file.write_text(index_content)
|
||||
change_tracker.add_paths(index_file)
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue