mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-07-27 06:28:50 +00:00
333 lines
13 KiB
Python
Executable file
333 lines
13 KiB
Python
Executable file
#!/usr/bin/env python
|
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
# All rights reserved.
|
|
#
|
|
# This source code is licensed under the terms described in the LICENSE file in
|
|
# the root directory of this source tree.
|
|
|
|
import subprocess
|
|
import sys
|
|
from pathlib import Path
|
|
from typing import Any
|
|
|
|
from rich.progress import Progress, SpinnerColumn, TextColumn
|
|
|
|
from llama_stack.distribution.distribution import get_provider_registry
|
|
|
|
REPO_ROOT = Path(__file__).parent.parent
|
|
|
|
|
|
class ChangedPathTracker:
|
|
"""Track a list of paths we may have changed."""
|
|
|
|
def __init__(self):
|
|
self._changed_paths = []
|
|
|
|
def add_paths(self, *paths):
|
|
for path in paths:
|
|
path = str(path)
|
|
if path not in self._changed_paths:
|
|
self._changed_paths.append(path)
|
|
|
|
def changed_paths(self):
|
|
return self._changed_paths
|
|
|
|
|
|
def get_config_class_info(config_class_path: str) -> dict[str, Any]:
|
|
"""Extract configuration information from a config class."""
|
|
try:
|
|
module_path, class_name = config_class_path.rsplit(".", 1)
|
|
module = __import__(module_path, fromlist=[class_name])
|
|
config_class = getattr(module, class_name)
|
|
|
|
docstring = config_class.__doc__ or ""
|
|
|
|
accepts_extra_config = False
|
|
try:
|
|
schema = config_class.model_json_schema()
|
|
if schema.get("additionalProperties") is True:
|
|
accepts_extra_config = True
|
|
except Exception:
|
|
if hasattr(config_class, "model_config"):
|
|
model_config = config_class.model_config
|
|
if hasattr(model_config, "extra") and model_config.extra == "allow":
|
|
accepts_extra_config = True
|
|
elif isinstance(model_config, dict) and model_config.get("extra") == "allow":
|
|
accepts_extra_config = True
|
|
|
|
fields_info = {}
|
|
if hasattr(config_class, "model_fields"):
|
|
for field_name, field in config_class.model_fields.items():
|
|
field_type = str(field.annotation) if field.annotation else "Any"
|
|
field_type = field_type.replace("typing.", "").replace("Optional[", "").replace("]", "")
|
|
field_type = field_type.replace("Annotated[", "").replace("FieldInfo(", "").replace(")", "")
|
|
field_type = field_type.replace("llama_stack.apis.inference.inference.", "")
|
|
field_type = field_type.replace("llama_stack.providers.", "")
|
|
|
|
default_value = field.default
|
|
if field.default_factory is not None:
|
|
try:
|
|
default_value = field.default_factory()
|
|
# HACK ALERT:
|
|
# If the default value contains a path that looks like it came from RUNTIME_BASE_DIR,
|
|
# replace it with a generic ~/.llama/ path for documentation
|
|
if isinstance(default_value, str) and "/.llama/" in default_value:
|
|
if ".llama/" in default_value:
|
|
path_part = default_value.split(".llama/")[-1]
|
|
default_value = f"~/.llama/{path_part}"
|
|
except Exception:
|
|
default_value = ""
|
|
elif field.default is None:
|
|
default_value = ""
|
|
|
|
field_info = {
|
|
"type": field_type,
|
|
"description": field.description or "",
|
|
"default": default_value,
|
|
"required": field.default is None and not field.is_required,
|
|
}
|
|
fields_info[field_name] = field_info
|
|
|
|
if accepts_extra_config:
|
|
config_description = "Additional configuration options that will be forwarded to the underlying provider"
|
|
try:
|
|
import inspect
|
|
|
|
source = inspect.getsource(config_class)
|
|
lines = source.split("\n")
|
|
|
|
for i, line in enumerate(lines):
|
|
if "model_config" in line and "ConfigDict" in line and 'extra="allow"' in line:
|
|
comments = []
|
|
for j in range(i - 1, -1, -1):
|
|
stripped = lines[j].strip()
|
|
if stripped.startswith("#"):
|
|
comments.append(stripped[1:].strip())
|
|
elif stripped == "":
|
|
continue
|
|
else:
|
|
break
|
|
|
|
if comments:
|
|
config_description = " ".join(reversed(comments))
|
|
break
|
|
except Exception:
|
|
pass
|
|
|
|
fields_info["config"] = {
|
|
"type": "dict",
|
|
"description": config_description,
|
|
"default": "{}",
|
|
"required": False,
|
|
}
|
|
|
|
return {
|
|
"docstring": docstring,
|
|
"fields": fields_info,
|
|
"sample_config": getattr(config_class, "sample_run_config", None),
|
|
"accepts_extra_config": accepts_extra_config,
|
|
}
|
|
except Exception as e:
|
|
return {
|
|
"error": f"Failed to load config class {config_class_path}: {str(e)}",
|
|
"docstring": "",
|
|
"fields": {},
|
|
"sample_config": None,
|
|
"accepts_extra_config": False,
|
|
}
|
|
|
|
|
|
def generate_provider_docs(provider_spec: Any, api_name: str) -> str:
|
|
"""Generate markdown documentation for a provider."""
|
|
provider_type = provider_spec.provider_type
|
|
config_class = provider_spec.config_class
|
|
|
|
config_info = get_config_class_info(config_class)
|
|
|
|
md_lines = []
|
|
md_lines.append("---\norphan: true\n---\n")
|
|
md_lines.append(f"# {provider_type}")
|
|
md_lines.append("")
|
|
|
|
description = ""
|
|
if hasattr(provider_spec, "description") and provider_spec.description:
|
|
description = provider_spec.description
|
|
elif (
|
|
hasattr(provider_spec, "adapter")
|
|
and hasattr(provider_spec.adapter, "description")
|
|
and provider_spec.adapter.description
|
|
):
|
|
description = provider_spec.adapter.description
|
|
elif config_info.get("docstring"):
|
|
description = config_info["docstring"]
|
|
|
|
if description:
|
|
md_lines.append("## Description")
|
|
md_lines.append("")
|
|
md_lines.append(description)
|
|
md_lines.append("")
|
|
|
|
if config_info.get("fields"):
|
|
md_lines.append("## Configuration")
|
|
md_lines.append("")
|
|
md_lines.append("| Field | Type | Required | Default | Description |")
|
|
md_lines.append("|-------|------|----------|---------|-------------|")
|
|
|
|
for field_name, field_info in config_info["fields"].items():
|
|
field_type = field_info["type"].replace("|", "\\|")
|
|
required = "Yes" if field_info["required"] else "No"
|
|
default = str(field_info["default"]) if field_info["default"] is not None else ""
|
|
description = field_info["description"] or ""
|
|
|
|
md_lines.append(f"| `{field_name}` | `{field_type}` | {required} | {default} | {description} |")
|
|
|
|
md_lines.append("")
|
|
|
|
if config_info.get("accepts_extra_config"):
|
|
md_lines.append(
|
|
"> **Note**: This configuration class accepts additional fields beyond those listed above. You can pass any additional configuration options that will be forwarded to the underlying provider."
|
|
)
|
|
md_lines.append("")
|
|
|
|
if config_info.get("sample_config"):
|
|
md_lines.append("## Sample Configuration")
|
|
md_lines.append("")
|
|
md_lines.append("```yaml")
|
|
try:
|
|
sample_config_func = config_info["sample_config"]
|
|
import inspect
|
|
|
|
import yaml
|
|
|
|
if sample_config_func is not None:
|
|
sig = inspect.signature(sample_config_func)
|
|
if "__distro_dir__" in sig.parameters:
|
|
sample_config = sample_config_func(__distro_dir__="~/.llama/dummy")
|
|
else:
|
|
sample_config = sample_config_func()
|
|
|
|
def convert_pydantic_to_dict(obj):
|
|
if hasattr(obj, "model_dump"):
|
|
return obj.model_dump()
|
|
elif hasattr(obj, "dict"):
|
|
return obj.dict()
|
|
elif isinstance(obj, dict):
|
|
return {k: convert_pydantic_to_dict(v) for k, v in obj.items()}
|
|
elif isinstance(obj, list):
|
|
return [convert_pydantic_to_dict(item) for item in obj]
|
|
else:
|
|
return obj
|
|
|
|
sample_config_dict = convert_pydantic_to_dict(sample_config)
|
|
md_lines.append(yaml.dump(sample_config_dict, default_flow_style=False, sort_keys=False))
|
|
else:
|
|
md_lines.append("# No sample configuration available.")
|
|
except Exception as e:
|
|
md_lines.append(f"# Error generating sample config: {str(e)}")
|
|
md_lines.append("```")
|
|
md_lines.append("")
|
|
|
|
if hasattr(provider_spec, "deprecation_warning") and provider_spec.deprecation_warning:
|
|
md_lines.append("## Deprecation Notice")
|
|
md_lines.append("")
|
|
md_lines.append(f"⚠️ **Warning**: {provider_spec.deprecation_warning}")
|
|
md_lines.append("")
|
|
|
|
if hasattr(provider_spec, "deprecation_error") and provider_spec.deprecation_error:
|
|
md_lines.append("## Deprecation Error")
|
|
md_lines.append("")
|
|
md_lines.append(f"❌ **Error**: {provider_spec.deprecation_error}")
|
|
|
|
return "\n".join(md_lines) + "\n"
|
|
|
|
|
|
def process_provider_registry(progress, change_tracker: ChangedPathTracker) -> None:
|
|
"""Process the complete provider registry."""
|
|
progress.print("Processing provider registry")
|
|
|
|
try:
|
|
provider_registry = get_provider_registry()
|
|
|
|
for api, providers in provider_registry.items():
|
|
api_name = api.value
|
|
|
|
doc_output_dir = REPO_ROOT / "docs" / "source" / "providers" / api_name
|
|
doc_output_dir.mkdir(parents=True, exist_ok=True)
|
|
change_tracker.add_paths(doc_output_dir)
|
|
|
|
index_content = []
|
|
index_content.append(f"# {api_name.title()} Providers")
|
|
index_content.append("")
|
|
index_content.append(
|
|
f"This section contains documentation for all available providers for the **{api_name}** API."
|
|
)
|
|
index_content.append("")
|
|
|
|
for provider_type, provider in sorted(providers.items()):
|
|
provider_doc_file = doc_output_dir / f"{provider_type.replace('::', '_').replace(':', '_')}.md"
|
|
|
|
provider_docs = generate_provider_docs(provider, api_name)
|
|
|
|
provider_doc_file.write_text(provider_docs)
|
|
change_tracker.add_paths(provider_doc_file)
|
|
|
|
index_content.append(f"- [{provider_type}]({provider_doc_file.name})")
|
|
|
|
index_file = doc_output_dir / "index.md"
|
|
index_file.write_text("\n".join(index_content))
|
|
change_tracker.add_paths(index_file)
|
|
|
|
except Exception as e:
|
|
progress.print(f"[red]Error processing provider registry: {str(e)}")
|
|
raise e
|
|
|
|
|
|
def check_for_changes(change_tracker: ChangedPathTracker) -> bool:
|
|
"""Check if there are any uncommitted changes, including new files."""
|
|
has_changes = False
|
|
for path in change_tracker.changed_paths():
|
|
result = subprocess.run(
|
|
["git", "diff", "--exit-code", path],
|
|
cwd=REPO_ROOT,
|
|
capture_output=True,
|
|
)
|
|
if result.returncode != 0:
|
|
print(f"Change detected in '{path}'.", file=sys.stderr)
|
|
has_changes = True
|
|
status_result = subprocess.run(
|
|
["git", "status", "--porcelain", path],
|
|
cwd=REPO_ROOT,
|
|
capture_output=True,
|
|
text=True,
|
|
)
|
|
for line in status_result.stdout.splitlines():
|
|
if line.startswith("??"):
|
|
print(f"New file detected: '{path}'.", file=sys.stderr)
|
|
has_changes = True
|
|
return has_changes
|
|
|
|
|
|
def main():
|
|
change_tracker = ChangedPathTracker()
|
|
|
|
with Progress(
|
|
SpinnerColumn(),
|
|
TextColumn("[progress.description]{task.description}"),
|
|
) as progress:
|
|
task = progress.add_task("Processing provider registry...", total=1)
|
|
|
|
process_provider_registry(progress, change_tracker)
|
|
progress.update(task, advance=1)
|
|
|
|
if check_for_changes(change_tracker):
|
|
print(
|
|
"Provider documentation changes detected. Please commit the changes.",
|
|
file=sys.stderr,
|
|
)
|
|
sys.exit(1)
|
|
|
|
sys.exit(0)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|