mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-25 18:54:30 +00:00
108 lines
3.6 KiB
Python
108 lines
3.6 KiB
Python
import ast
|
|
import os
|
|
import re
|
|
|
|
|
|
def find_azure_files(base_dir):
|
|
"""
|
|
Find all Python files in the Azure directory.
|
|
"""
|
|
azure_files = []
|
|
for root, _, files in os.walk(base_dir):
|
|
for file in files:
|
|
if file.endswith(".py"):
|
|
azure_files.append(os.path.join(root, file))
|
|
return azure_files
|
|
|
|
|
|
def check_direct_instantiation(file_path):
|
|
"""
|
|
Check if a file directly instantiates AzureOpenAI or AsyncAzureOpenAI
|
|
outside of the BaseAzureLLM class methods.
|
|
"""
|
|
with open(file_path, "r") as file:
|
|
content = file.read()
|
|
|
|
# Parse the file
|
|
tree = ast.parse(content)
|
|
|
|
# Track issues found
|
|
issues = []
|
|
|
|
# Find all class definitions
|
|
for node in ast.walk(tree):
|
|
if isinstance(node, ast.ClassDef):
|
|
class_name = node.name
|
|
|
|
# Skip BaseAzureLLM class since it's allowed to define the client creation methods
|
|
if class_name == "BaseAzureLLM":
|
|
continue
|
|
|
|
# Check method bodies for direct instantiation
|
|
for method in node.body:
|
|
if isinstance(method, ast.FunctionDef) or isinstance(
|
|
method, ast.AsyncFunctionDef
|
|
):
|
|
method_name = method.name
|
|
|
|
# Skip methods that are specifically for client creation
|
|
if method_name in [
|
|
"get_azure_openai_client",
|
|
"initialize_azure_sdk_client",
|
|
]:
|
|
continue
|
|
|
|
# Look for direct instantiation in the method body
|
|
for subnode in ast.walk(method):
|
|
if isinstance(subnode, ast.Call):
|
|
if hasattr(subnode, "func") and hasattr(subnode.func, "id"):
|
|
if subnode.func.id in [
|
|
"AzureOpenAI",
|
|
"AsyncAzureOpenAI",
|
|
]:
|
|
issues.append(
|
|
f"Direct instantiation of {subnode.func.id} in {class_name}.{method_name}"
|
|
)
|
|
elif hasattr(subnode, "func") and hasattr(
|
|
subnode.func, "attr"
|
|
):
|
|
if subnode.func.attr in [
|
|
"AzureOpenAI",
|
|
"AsyncAzureOpenAI",
|
|
]:
|
|
issues.append(
|
|
f"Direct instantiation of {subnode.func.attr} in {class_name}.{method_name}"
|
|
)
|
|
|
|
return issues
|
|
|
|
|
|
def main():
|
|
"""
|
|
Main function to run the test.
|
|
"""
|
|
# local
|
|
base_dir = "../../litellm/llms/azure"
|
|
azure_files = find_azure_files(base_dir)
|
|
print(f"Found {len(azure_files)} Azure Python files to check")
|
|
|
|
all_issues = []
|
|
|
|
for file_path in azure_files:
|
|
issues = check_direct_instantiation(file_path)
|
|
if issues:
|
|
all_issues.extend([f"{file_path}: {issue}" for issue in issues])
|
|
|
|
if all_issues:
|
|
print("Found direct instantiations of AzureOpenAI or AsyncAzureOpenAI:")
|
|
for issue in all_issues:
|
|
print(f" - {issue}")
|
|
raise Exception(
|
|
f"Found {len(all_issues)} direct instantiations of AzureOpenAI or AsyncAzureOpenAI classes. Use get_azure_openai_client instead."
|
|
)
|
|
else:
|
|
print("All Azure modules are correctly using get_azure_openai_client!")
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|