litellm-mirror/tests/documentation_tests/test_api_docs.py
Krish Dholakia 516c2a6a70
Litellm remove circular imports (#7232)
* fix(utils.py): initial commit to remove circular imports - moves llmproviders to utils.py

* fix(router.py): fix 'litellm.EmbeddingResponse' import from router.py

'

* refactor: fix litellm.ModelResponse import on pass through endpoints

* refactor(litellm_logging.py): fix circular import for custom callbacks literal

* fix(factory.py): fix circular imports inside prompt factory

* fix(cost_calculator.py): fix circular import for 'litellm.Usage'

* fix(proxy_server.py): fix potential circular import with `litellm.Router'

* fix(proxy/utils.py): fix potential circular import in `litellm.Router`

* fix: remove circular imports in 'auth_checks' and 'guardrails/'

* fix(prompt_injection_detection.py): fix router impor t

* fix(vertex_passthrough_logging_handler.py): fix potential circular imports in vertex pass through

* fix(anthropic_pass_through_logging_handler.py): fix potential circular imports

* fix(slack_alerting.py-+-ollama_chat.py): fix modelresponse import

* fix(base.py): fix potential circular import

* fix(handler.py): fix potential circular ref in codestral + cohere handler's

* fix(azure.py): fix potential circular imports

* fix(gpt_transformation.py): fix modelresponse import

* fix(litellm_logging.py): add logging base class - simplify typing

makes it easy for other files to type check the logging obj without introducing circular imports

* fix(azure_ai/embed): fix potential circular import on handler.py

* fix(databricks/): fix potential circular imports in databricks/

* fix(vertex_ai/): fix potential circular imports on vertex ai embeddings

* fix(vertex_ai/image_gen): fix import

* fix(watsonx-+-bedrock): cleanup imports

* refactor(anthropic-pass-through-+-petals): cleanup imports

* refactor(huggingface/): cleanup imports

* fix(ollama-+-clarifai): cleanup circular imports

* fix(openai_like/): fix impor t

* fix(openai_like/): fix embedding handler

cleanup imports

* refactor(openai.py): cleanup imports

* fix(sagemaker/transformation.py): fix import

* ci(config.yml): add circular import test to ci/cd
2024-12-14 16:28:34 -08:00

206 lines
6.5 KiB
Python
Raw Blame History

This file contains invisible Unicode characters

This file contains invisible Unicode characters that are indistinguishable to humans but may be processed differently by a computer. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import ast
from typing import List, Dict, Set, Optional
import os
from dataclasses import dataclass
import argparse
import re
import sys
sys.path.insert(
0, os.path.abspath("../..")
) # Adds the parent directory to the system path
import litellm
@dataclass
class FunctionInfo:
"""Store function information."""
name: str
docstring: Optional[str]
parameters: Set[str]
file_path: str
line_number: int
class FastAPIDocVisitor(ast.NodeVisitor):
"""AST visitor to find FastAPI endpoint functions."""
def __init__(self, target_functions: Set[str]):
self.target_functions = target_functions
self.functions: Dict[str, FunctionInfo] = {}
self.current_file = ""
def visit_FunctionDef(self, node: ast.FunctionDef | ast.AsyncFunctionDef) -> None:
"""Visit function definitions (both async and sync) and collect info if they match target functions."""
if node.name in self.target_functions:
# Extract docstring
docstring = ast.get_docstring(node)
# Extract parameters
parameters = set()
for arg in node.args.args:
if arg.annotation is not None:
# Get the parameter type from annotation
if isinstance(arg.annotation, ast.Name):
parameters.add((arg.arg, arg.annotation.id))
elif isinstance(arg.annotation, ast.Subscript):
if isinstance(arg.annotation.value, ast.Name):
parameters.add((arg.arg, arg.annotation.value.id))
self.functions[node.name] = FunctionInfo(
name=node.name,
docstring=docstring,
parameters=parameters,
file_path=self.current_file,
line_number=node.lineno,
)
# Also need to add this to handle async functions
def visit_AsyncFunctionDef(self, node: ast.AsyncFunctionDef) -> None:
"""Handle async functions by delegating to the regular function visitor."""
return self.visit_FunctionDef(node)
def find_functions_in_file(
file_path: str, target_functions: Set[str]
) -> Dict[str, FunctionInfo]:
"""Find target functions in a Python file using AST."""
try:
with open(file_path, "r", encoding="utf-8") as f:
content = f.read()
visitor = FastAPIDocVisitor(target_functions)
visitor.current_file = file_path
tree = ast.parse(content)
visitor.visit(tree)
return visitor.functions
except Exception as e:
print(f"Error parsing {file_path}: {str(e)}")
return {}
def extract_docstring_params(docstring: Optional[str]) -> Set[str]:
"""Extract parameter names from docstring."""
if not docstring:
return set()
params = set()
# Match parameters in format:
# - parameter_name: description
# or
# parameter_name: description
param_pattern = r"-?\s*([a-zA-Z_][a-zA-Z0-9_]*)\s*(?:\([^)]*\))?\s*:"
for match in re.finditer(param_pattern, docstring):
params.add(match.group(1))
return params
def analyze_function(func_info: FunctionInfo) -> Dict:
"""Analyze function documentation and return validation results."""
docstring_params = extract_docstring_params(func_info.docstring)
print(f"func_info.parameters: {func_info.parameters}")
pydantic_params = set()
for name, type_name in func_info.parameters:
if type_name.endswith("Request") or type_name.endswith("Response"):
pydantic_model = getattr(litellm.proxy._types, type_name, None)
if pydantic_model is not None:
for param in pydantic_model.model_fields.keys():
pydantic_params.add(param)
print(f"pydantic_params: {pydantic_params}")
missing_params = pydantic_params - docstring_params
return {
"function": func_info.name,
"file_path": func_info.file_path,
"line_number": func_info.line_number,
"has_docstring": bool(func_info.docstring),
"pydantic_params": list(pydantic_params),
"documented_params": list(docstring_params),
"missing_params": list(missing_params),
"is_valid": len(missing_params) == 0,
}
def print_validation_results(results: Dict) -> None:
"""Print validation results in a readable format."""
print(f"\nChecking function: {results['function']}")
print(f"File: {results['file_path']}:{results['line_number']}")
print("-" * 50)
if not results["has_docstring"]:
print("❌ No docstring found!")
return
if not results["pydantic_params"]:
print(" No Pydantic input models found.")
return
if results["is_valid"]:
print("✅ All Pydantic parameters are documented!")
else:
print("❌ Missing documentation for parameters:")
for param in sorted(results["missing_params"]):
print(f" - {param}")
def main():
function_names = [
"new_end_user",
"end_user_info",
"update_end_user",
"delete_end_user",
"generate_key_fn",
"info_key_fn",
"update_key_fn",
"delete_key_fn",
"new_user",
"new_team",
"team_info",
"update_team",
"delete_team",
"new_organization",
"update_organization",
"delete_organization",
"list_organization",
"user_update",
]
# directory = "../../litellm/proxy/management_endpoints" # LOCAL
directory = "./litellm/proxy/management_endpoints"
# Convert function names to set for faster lookup
target_functions = set(function_names)
found_functions: Dict[str, FunctionInfo] = {}
# Walk through directory
for root, _, files in os.walk(directory):
for file in files:
if file.endswith(".py"):
file_path = os.path.join(root, file)
found = find_functions_in_file(file_path, target_functions)
found_functions.update(found)
# Analyze and output results
for func_name in function_names:
if func_name in found_functions:
result = analyze_function(found_functions[func_name])
if not result["is_valid"]:
raise Exception(print_validation_results(result))
# results.append(result)
# print_validation_results(result)
# # Exit with error code if any validation failed
# if any(not r["is_valid"] for r in results):
# exit(1)
if __name__ == "__main__":
main()