mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-07-04 13:15:24 +00:00
# What does this PR do? Simple approach to get some provider pages in the docs. Add or update description fields in the provider configuration class using Pydantic’s Field, ensuring these descriptions are clear and complete, as they will be used to auto-generate provider documentation via ./scripts/distro_codegen.py instead of editing the docs manually. Signed-off-by: Sébastien Han <seb@redhat.com>
332 lines
13 KiB
Python
Executable file
332 lines
13 KiB
Python
Executable file
#!/usr/bin/env python
|
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
# All rights reserved.
|
|
#
|
|
# This source code is licensed under the terms described in the LICENSE file in
|
|
# the root directory of this source tree.
|
|
|
|
import subprocess
|
|
import sys
|
|
from pathlib import Path
|
|
from typing import Any
|
|
|
|
from rich.progress import Progress, SpinnerColumn, TextColumn
|
|
|
|
from llama_stack.distribution.distribution import get_provider_registry
|
|
|
|
REPO_ROOT = Path(__file__).parent.parent
|
|
|
|
|
|
class ChangedPathTracker:
|
|
"""Track a list of paths we may have changed."""
|
|
|
|
def __init__(self):
|
|
self._changed_paths = []
|
|
|
|
def add_paths(self, *paths):
|
|
for path in paths:
|
|
path = str(path)
|
|
if path not in self._changed_paths:
|
|
self._changed_paths.append(path)
|
|
|
|
def changed_paths(self):
|
|
return self._changed_paths
|
|
|
|
|
|
def get_config_class_info(config_class_path: str) -> dict[str, Any]:
|
|
"""Extract configuration information from a config class."""
|
|
try:
|
|
module_path, class_name = config_class_path.rsplit(".", 1)
|
|
module = __import__(module_path, fromlist=[class_name])
|
|
config_class = getattr(module, class_name)
|
|
|
|
docstring = config_class.__doc__ or ""
|
|
|
|
accepts_extra_config = False
|
|
try:
|
|
schema = config_class.model_json_schema()
|
|
if schema.get("additionalProperties") is True:
|
|
accepts_extra_config = True
|
|
except Exception:
|
|
if hasattr(config_class, "model_config"):
|
|
model_config = config_class.model_config
|
|
if hasattr(model_config, "extra") and model_config.extra == "allow":
|
|
accepts_extra_config = True
|
|
elif isinstance(model_config, dict) and model_config.get("extra") == "allow":
|
|
accepts_extra_config = True
|
|
|
|
fields_info = {}
|
|
if hasattr(config_class, "model_fields"):
|
|
for field_name, field in config_class.model_fields.items():
|
|
field_type = str(field.annotation) if field.annotation else "Any"
|
|
field_type = field_type.replace("typing.", "").replace("Optional[", "").replace("]", "")
|
|
field_type = field_type.replace("Annotated[", "").replace("FieldInfo(", "").replace(")", "")
|
|
field_type = field_type.replace("llama_stack.apis.inference.inference.", "")
|
|
field_type = field_type.replace("llama_stack.providers.", "")
|
|
|
|
default_value = field.default
|
|
if field.default_factory is not None:
|
|
try:
|
|
default_value = field.default_factory()
|
|
# HACK ALERT:
|
|
# If the default value contains a path that looks like it came from RUNTIME_BASE_DIR,
|
|
# replace it with a generic ~/.llama/ path for documentation
|
|
if isinstance(default_value, str) and "/.llama/" in default_value:
|
|
if ".llama/" in default_value:
|
|
path_part = default_value.split(".llama/")[-1]
|
|
default_value = f"~/.llama/{path_part}"
|
|
except Exception:
|
|
default_value = ""
|
|
elif field.default is None:
|
|
default_value = ""
|
|
|
|
field_info = {
|
|
"type": field_type,
|
|
"description": field.description or "",
|
|
"default": default_value,
|
|
"required": field.default is None and not field.is_required,
|
|
}
|
|
fields_info[field_name] = field_info
|
|
|
|
if accepts_extra_config:
|
|
config_description = "Additional configuration options that will be forwarded to the underlying provider"
|
|
try:
|
|
import inspect
|
|
|
|
source = inspect.getsource(config_class)
|
|
lines = source.split("\n")
|
|
|
|
for i, line in enumerate(lines):
|
|
if "model_config" in line and "ConfigDict" in line and 'extra="allow"' in line:
|
|
comments = []
|
|
for j in range(i - 1, -1, -1):
|
|
stripped = lines[j].strip()
|
|
if stripped.startswith("#"):
|
|
comments.append(stripped[1:].strip())
|
|
elif stripped == "":
|
|
continue
|
|
else:
|
|
break
|
|
|
|
if comments:
|
|
config_description = " ".join(reversed(comments))
|
|
break
|
|
except Exception:
|
|
pass
|
|
|
|
fields_info["config"] = {
|
|
"type": "dict",
|
|
"description": config_description,
|
|
"default": "{}",
|
|
"required": False,
|
|
}
|
|
|
|
return {
|
|
"docstring": docstring,
|
|
"fields": fields_info,
|
|
"sample_config": getattr(config_class, "sample_run_config", None),
|
|
"accepts_extra_config": accepts_extra_config,
|
|
}
|
|
except Exception as e:
|
|
return {
|
|
"error": f"Failed to load config class {config_class_path}: {str(e)}",
|
|
"docstring": "",
|
|
"fields": {},
|
|
"sample_config": None,
|
|
"accepts_extra_config": False,
|
|
}
|
|
|
|
|
|
def generate_provider_docs(provider_spec: Any, api_name: str) -> str:
|
|
"""Generate markdown documentation for a provider."""
|
|
provider_type = provider_spec.provider_type
|
|
config_class = provider_spec.config_class
|
|
|
|
config_info = get_config_class_info(config_class)
|
|
|
|
md_lines = []
|
|
md_lines.append(f"# {provider_type}")
|
|
md_lines.append("")
|
|
|
|
description = ""
|
|
if hasattr(provider_spec, "description") and provider_spec.description:
|
|
description = provider_spec.description
|
|
elif (
|
|
hasattr(provider_spec, "adapter")
|
|
and hasattr(provider_spec.adapter, "description")
|
|
and provider_spec.adapter.description
|
|
):
|
|
description = provider_spec.adapter.description
|
|
elif config_info.get("docstring"):
|
|
description = config_info["docstring"]
|
|
|
|
if description:
|
|
md_lines.append("## Description")
|
|
md_lines.append("")
|
|
md_lines.append(description)
|
|
md_lines.append("")
|
|
|
|
if config_info.get("fields"):
|
|
md_lines.append("## Configuration")
|
|
md_lines.append("")
|
|
md_lines.append("| Field | Type | Required | Default | Description |")
|
|
md_lines.append("|-------|------|----------|---------|-------------|")
|
|
|
|
for field_name, field_info in config_info["fields"].items():
|
|
field_type = field_info["type"].replace("|", "\\|")
|
|
required = "Yes" if field_info["required"] else "No"
|
|
default = str(field_info["default"]) if field_info["default"] is not None else ""
|
|
description = field_info["description"] or ""
|
|
|
|
md_lines.append(f"| `{field_name}` | `{field_type}` | {required} | {default} | {description} |")
|
|
|
|
md_lines.append("")
|
|
|
|
if config_info.get("accepts_extra_config"):
|
|
md_lines.append(
|
|
"> **Note**: This configuration class accepts additional fields beyond those listed above. You can pass any additional configuration options that will be forwarded to the underlying provider."
|
|
)
|
|
md_lines.append("")
|
|
|
|
if config_info.get("sample_config"):
|
|
md_lines.append("## Sample Configuration")
|
|
md_lines.append("")
|
|
md_lines.append("```yaml")
|
|
try:
|
|
sample_config_func = config_info["sample_config"]
|
|
import inspect
|
|
|
|
import yaml
|
|
|
|
if sample_config_func is not None:
|
|
sig = inspect.signature(sample_config_func)
|
|
if "__distro_dir__" in sig.parameters:
|
|
sample_config = sample_config_func(__distro_dir__="~/.llama/dummy")
|
|
else:
|
|
sample_config = sample_config_func()
|
|
|
|
def convert_pydantic_to_dict(obj):
|
|
if hasattr(obj, "model_dump"):
|
|
return obj.model_dump()
|
|
elif hasattr(obj, "dict"):
|
|
return obj.dict()
|
|
elif isinstance(obj, dict):
|
|
return {k: convert_pydantic_to_dict(v) for k, v in obj.items()}
|
|
elif isinstance(obj, list):
|
|
return [convert_pydantic_to_dict(item) for item in obj]
|
|
else:
|
|
return obj
|
|
|
|
sample_config_dict = convert_pydantic_to_dict(sample_config)
|
|
md_lines.append(yaml.dump(sample_config_dict, default_flow_style=False, sort_keys=False))
|
|
else:
|
|
md_lines.append("# No sample configuration available.")
|
|
except Exception as e:
|
|
md_lines.append(f"# Error generating sample config: {str(e)}")
|
|
md_lines.append("```")
|
|
md_lines.append("")
|
|
|
|
if hasattr(provider_spec, "deprecation_warning") and provider_spec.deprecation_warning:
|
|
md_lines.append("## Deprecation Notice")
|
|
md_lines.append("")
|
|
md_lines.append(f"⚠️ **Warning**: {provider_spec.deprecation_warning}")
|
|
md_lines.append("")
|
|
|
|
if hasattr(provider_spec, "deprecation_error") and provider_spec.deprecation_error:
|
|
md_lines.append("## Deprecation Error")
|
|
md_lines.append("")
|
|
md_lines.append(f"❌ **Error**: {provider_spec.deprecation_error}")
|
|
|
|
return "\n".join(md_lines) + "\n"
|
|
|
|
|
|
def process_provider_registry(progress, change_tracker: ChangedPathTracker) -> None:
|
|
"""Process the complete provider registry."""
|
|
progress.print("Processing provider registry")
|
|
|
|
try:
|
|
provider_registry = get_provider_registry()
|
|
|
|
for api, providers in provider_registry.items():
|
|
api_name = api.value
|
|
|
|
doc_output_dir = REPO_ROOT / "docs" / "source" / "providers" / api_name
|
|
doc_output_dir.mkdir(parents=True, exist_ok=True)
|
|
change_tracker.add_paths(doc_output_dir)
|
|
|
|
index_content = []
|
|
index_content.append(f"# {api_name.title()} Providers")
|
|
index_content.append("")
|
|
index_content.append(
|
|
f"This section contains documentation for all available providers for the **{api_name}** API."
|
|
)
|
|
index_content.append("")
|
|
|
|
for provider_type, provider in sorted(providers.items()):
|
|
provider_doc_file = doc_output_dir / f"{provider_type.replace('::', '_').replace(':', '_')}.md"
|
|
|
|
provider_docs = generate_provider_docs(provider, api_name)
|
|
|
|
provider_doc_file.write_text(provider_docs)
|
|
change_tracker.add_paths(provider_doc_file)
|
|
|
|
index_content.append(f"- [{provider_type}]({provider_doc_file.name})")
|
|
|
|
index_file = doc_output_dir / "index.md"
|
|
index_file.write_text("\n".join(index_content))
|
|
change_tracker.add_paths(index_file)
|
|
|
|
except Exception as e:
|
|
progress.print(f"[red]Error processing provider registry: {str(e)}")
|
|
raise e
|
|
|
|
|
|
def check_for_changes(change_tracker: ChangedPathTracker) -> bool:
|
|
"""Check if there are any uncommitted changes, including new files."""
|
|
has_changes = False
|
|
for path in change_tracker.changed_paths():
|
|
result = subprocess.run(
|
|
["git", "diff", "--exit-code", path],
|
|
cwd=REPO_ROOT,
|
|
capture_output=True,
|
|
)
|
|
if result.returncode != 0:
|
|
print(f"Change detected in '{path}'.", file=sys.stderr)
|
|
has_changes = True
|
|
status_result = subprocess.run(
|
|
["git", "status", "--porcelain", path],
|
|
cwd=REPO_ROOT,
|
|
capture_output=True,
|
|
text=True,
|
|
)
|
|
for line in status_result.stdout.splitlines():
|
|
if line.startswith("??"):
|
|
print(f"New file detected: '{path}'.", file=sys.stderr)
|
|
has_changes = True
|
|
return has_changes
|
|
|
|
|
|
def main():
|
|
change_tracker = ChangedPathTracker()
|
|
|
|
with Progress(
|
|
SpinnerColumn(),
|
|
TextColumn("[progress.description]{task.description}"),
|
|
) as progress:
|
|
task = progress.add_task("Processing provider registry...", total=1)
|
|
|
|
process_provider_registry(progress, change_tracker)
|
|
progress.update(task, advance=1)
|
|
|
|
if check_for_changes(change_tracker):
|
|
print(
|
|
"Provider documentation changes detected. Please commit the changes.",
|
|
file=sys.stderr,
|
|
)
|
|
sys.exit(1)
|
|
|
|
sys.exit(0)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|