forked from phoenix/litellm-mirror
* feat(customer_endpoints.py): support passing budget duration via `/customer/new` endpoint Closes https://github.com/BerriAI/litellm/issues/5651 * docs: add missing params to swagger + api documentation test * docs: add documentation for all key endpoints documents all params on swagger * docs(internal_user_endpoints.py): document all /user/new params Ensures all params are documented * docs(team_endpoints.py): add missing documentation for team endpoints Ensures 100% param documentation on swagger * docs(organization_endpoints.py): document all org params Adds documentation for all params in org endpoint * docs(customer_endpoints.py): add coverage for all params on /customer endpoints ensures all /customer/* params are documented * ci(config.yml): add endpoint doc testing to ci/cd * fix: fix internal_user_endpoints.py * fix(internal_user_endpoints.py): support 'duration' param * fix(partner_models/main.py): fix anthropic re-raise exception on vertex * fix: fix pydantic obj
206 lines
6.5 KiB
Python
206 lines
6.5 KiB
Python
import ast
|
||
from typing import List, Dict, Set, Optional
|
||
import os
|
||
from dataclasses import dataclass
|
||
import argparse
|
||
import re
|
||
import sys
|
||
|
||
sys.path.insert(
|
||
0, os.path.abspath("../..")
|
||
) # Adds the parent directory to the system path
|
||
import litellm
|
||
|
||
|
||
@dataclass
|
||
class FunctionInfo:
|
||
"""Store function information."""
|
||
|
||
name: str
|
||
docstring: Optional[str]
|
||
parameters: Set[str]
|
||
file_path: str
|
||
line_number: int
|
||
|
||
|
||
class FastAPIDocVisitor(ast.NodeVisitor):
|
||
"""AST visitor to find FastAPI endpoint functions."""
|
||
|
||
def __init__(self, target_functions: Set[str]):
|
||
self.target_functions = target_functions
|
||
self.functions: Dict[str, FunctionInfo] = {}
|
||
self.current_file = ""
|
||
|
||
def visit_FunctionDef(self, node: ast.FunctionDef | ast.AsyncFunctionDef) -> None:
|
||
"""Visit function definitions (both async and sync) and collect info if they match target functions."""
|
||
if node.name in self.target_functions:
|
||
# Extract docstring
|
||
docstring = ast.get_docstring(node)
|
||
|
||
# Extract parameters
|
||
parameters = set()
|
||
for arg in node.args.args:
|
||
if arg.annotation is not None:
|
||
# Get the parameter type from annotation
|
||
if isinstance(arg.annotation, ast.Name):
|
||
parameters.add((arg.arg, arg.annotation.id))
|
||
elif isinstance(arg.annotation, ast.Subscript):
|
||
if isinstance(arg.annotation.value, ast.Name):
|
||
parameters.add((arg.arg, arg.annotation.value.id))
|
||
|
||
self.functions[node.name] = FunctionInfo(
|
||
name=node.name,
|
||
docstring=docstring,
|
||
parameters=parameters,
|
||
file_path=self.current_file,
|
||
line_number=node.lineno,
|
||
)
|
||
|
||
# Also need to add this to handle async functions
|
||
def visit_AsyncFunctionDef(self, node: ast.AsyncFunctionDef) -> None:
|
||
"""Handle async functions by delegating to the regular function visitor."""
|
||
return self.visit_FunctionDef(node)
|
||
|
||
|
||
def find_functions_in_file(
|
||
file_path: str, target_functions: Set[str]
|
||
) -> Dict[str, FunctionInfo]:
|
||
"""Find target functions in a Python file using AST."""
|
||
try:
|
||
with open(file_path, "r", encoding="utf-8") as f:
|
||
content = f.read()
|
||
|
||
visitor = FastAPIDocVisitor(target_functions)
|
||
visitor.current_file = file_path
|
||
tree = ast.parse(content)
|
||
visitor.visit(tree)
|
||
return visitor.functions
|
||
|
||
except Exception as e:
|
||
print(f"Error parsing {file_path}: {str(e)}")
|
||
return {}
|
||
|
||
|
||
def extract_docstring_params(docstring: Optional[str]) -> Set[str]:
|
||
"""Extract parameter names from docstring."""
|
||
if not docstring:
|
||
return set()
|
||
|
||
params = set()
|
||
# Match parameters in format:
|
||
# - parameter_name: description
|
||
# or
|
||
# parameter_name: description
|
||
param_pattern = r"-?\s*([a-zA-Z_][a-zA-Z0-9_]*)\s*(?:\([^)]*\))?\s*:"
|
||
|
||
for match in re.finditer(param_pattern, docstring):
|
||
params.add(match.group(1))
|
||
|
||
return params
|
||
|
||
|
||
def analyze_function(func_info: FunctionInfo) -> Dict:
|
||
"""Analyze function documentation and return validation results."""
|
||
|
||
docstring_params = extract_docstring_params(func_info.docstring)
|
||
|
||
print(f"func_info.parameters: {func_info.parameters}")
|
||
pydantic_params = set()
|
||
|
||
for name, type_name in func_info.parameters:
|
||
if type_name.endswith("Request") or type_name.endswith("Response"):
|
||
pydantic_model = getattr(litellm.proxy._types, type_name, None)
|
||
if pydantic_model is not None:
|
||
for param in pydantic_model.model_fields.keys():
|
||
pydantic_params.add(param)
|
||
|
||
print(f"pydantic_params: {pydantic_params}")
|
||
|
||
missing_params = pydantic_params - docstring_params
|
||
|
||
return {
|
||
"function": func_info.name,
|
||
"file_path": func_info.file_path,
|
||
"line_number": func_info.line_number,
|
||
"has_docstring": bool(func_info.docstring),
|
||
"pydantic_params": list(pydantic_params),
|
||
"documented_params": list(docstring_params),
|
||
"missing_params": list(missing_params),
|
||
"is_valid": len(missing_params) == 0,
|
||
}
|
||
|
||
|
||
def print_validation_results(results: Dict) -> None:
|
||
"""Print validation results in a readable format."""
|
||
print(f"\nChecking function: {results['function']}")
|
||
print(f"File: {results['file_path']}:{results['line_number']}")
|
||
print("-" * 50)
|
||
|
||
if not results["has_docstring"]:
|
||
print("❌ No docstring found!")
|
||
return
|
||
|
||
if not results["pydantic_params"]:
|
||
print("ℹ️ No Pydantic input models found.")
|
||
return
|
||
|
||
if results["is_valid"]:
|
||
print("✅ All Pydantic parameters are documented!")
|
||
else:
|
||
print("❌ Missing documentation for parameters:")
|
||
for param in sorted(results["missing_params"]):
|
||
print(f" - {param}")
|
||
|
||
|
||
def main():
|
||
function_names = [
|
||
"new_end_user",
|
||
"end_user_info",
|
||
"update_end_user",
|
||
"delete_end_user",
|
||
"generate_key_fn",
|
||
"info_key_fn",
|
||
"update_key_fn",
|
||
"delete_key_fn",
|
||
"new_user",
|
||
"new_team",
|
||
"team_info",
|
||
"update_team",
|
||
"delete_team",
|
||
"new_organization",
|
||
"update_organization",
|
||
"delete_organization",
|
||
"list_organization",
|
||
"user_update",
|
||
]
|
||
directory = "../../litellm/proxy/management_endpoints" # LOCAL
|
||
# directory = "./litellm/proxy/management_endpoints"
|
||
|
||
# Convert function names to set for faster lookup
|
||
target_functions = set(function_names)
|
||
found_functions: Dict[str, FunctionInfo] = {}
|
||
|
||
# Walk through directory
|
||
for root, _, files in os.walk(directory):
|
||
for file in files:
|
||
if file.endswith(".py"):
|
||
file_path = os.path.join(root, file)
|
||
found = find_functions_in_file(file_path, target_functions)
|
||
found_functions.update(found)
|
||
|
||
# Analyze and output results
|
||
for func_name in function_names:
|
||
if func_name in found_functions:
|
||
result = analyze_function(found_functions[func_name])
|
||
if not result["is_valid"]:
|
||
raise Exception(print_validation_results(result))
|
||
# results.append(result)
|
||
# print_validation_results(result)
|
||
|
||
# # Exit with error code if any validation failed
|
||
# if any(not r["is_valid"] for r in results):
|
||
# exit(1)
|
||
|
||
|
||
if __name__ == "__main__":
|
||
main()
|