litellm-mirror/tests/code_coverage_tests/azure_client_usage_test.py
2025-03-18 09:51:28 -07:00

108 lines
3.6 KiB
Python

import ast
import os
import re
def find_azure_files(base_dir):
"""
Find all Python files in the Azure directory.
"""
azure_files = []
for root, _, files in os.walk(base_dir):
for file in files:
if file.endswith(".py"):
azure_files.append(os.path.join(root, file))
return azure_files
def check_direct_instantiation(file_path):
"""
Check if a file directly instantiates AzureOpenAI or AsyncAzureOpenAI
outside of the BaseAzureLLM class methods.
"""
with open(file_path, "r") as file:
content = file.read()
# Parse the file
tree = ast.parse(content)
# Track issues found
issues = []
# Find all class definitions
for node in ast.walk(tree):
if isinstance(node, ast.ClassDef):
class_name = node.name
# Skip BaseAzureLLM class since it's allowed to define the client creation methods
if class_name == "BaseAzureLLM":
continue
# Check method bodies for direct instantiation
for method in node.body:
if isinstance(method, ast.FunctionDef) or isinstance(
method, ast.AsyncFunctionDef
):
method_name = method.name
# Skip methods that are specifically for client creation
if method_name in [
"get_azure_openai_client",
"initialize_azure_sdk_client",
]:
continue
# Look for direct instantiation in the method body
for subnode in ast.walk(method):
if isinstance(subnode, ast.Call):
if hasattr(subnode, "func") and hasattr(subnode.func, "id"):
if subnode.func.id in [
"AzureOpenAI",
"AsyncAzureOpenAI",
]:
issues.append(
f"Direct instantiation of {subnode.func.id} in {class_name}.{method_name}"
)
elif hasattr(subnode, "func") and hasattr(
subnode.func, "attr"
):
if subnode.func.attr in [
"AzureOpenAI",
"AsyncAzureOpenAI",
]:
issues.append(
f"Direct instantiation of {subnode.func.attr} in {class_name}.{method_name}"
)
return issues
def main():
"""
Main function to run the test.
"""
# local
base_dir = "../../litellm/llms/azure"
azure_files = find_azure_files(base_dir)
print(f"Found {len(azure_files)} Azure Python files to check")
all_issues = []
for file_path in azure_files:
issues = check_direct_instantiation(file_path)
if issues:
all_issues.extend([f"{file_path}: {issue}" for issue in issues])
if all_issues:
print("Found direct instantiations of AzureOpenAI or AsyncAzureOpenAI:")
for issue in all_issues:
print(f" - {issue}")
raise Exception(
f"Found {len(all_issues)} direct instantiations of AzureOpenAI or AsyncAzureOpenAI classes. Use get_azure_openai_client instead."
)
else:
print("All Azure modules are correctly using get_azure_openai_client!")
if __name__ == "__main__":
main()