mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-27 03:34:10 +00:00
fix - correctly re-use azure openai client
This commit is contained in:
parent
d991b3c398
commit
8b54873e9f
2 changed files with 203 additions and 34 deletions
108
tests/code_coverage_tests/azure_client_usage_test.py
Normal file
108
tests/code_coverage_tests/azure_client_usage_test.py
Normal file
|
@ -0,0 +1,108 @@
|
|||
import ast
|
||||
import os
|
||||
import re
|
||||
|
||||
|
||||
def find_azure_files(base_dir):
|
||||
"""
|
||||
Find all Python files in the Azure directory.
|
||||
"""
|
||||
azure_files = []
|
||||
for root, _, files in os.walk(base_dir):
|
||||
for file in files:
|
||||
if file.endswith(".py"):
|
||||
azure_files.append(os.path.join(root, file))
|
||||
return azure_files
|
||||
|
||||
|
||||
def check_direct_instantiation(file_path):
|
||||
"""
|
||||
Check if a file directly instantiates AzureOpenAI or AsyncAzureOpenAI
|
||||
outside of the BaseAzureLLM class methods.
|
||||
"""
|
||||
with open(file_path, "r") as file:
|
||||
content = file.read()
|
||||
|
||||
# Parse the file
|
||||
tree = ast.parse(content)
|
||||
|
||||
# Track issues found
|
||||
issues = []
|
||||
|
||||
# Find all class definitions
|
||||
for node in ast.walk(tree):
|
||||
if isinstance(node, ast.ClassDef):
|
||||
class_name = node.name
|
||||
|
||||
# Skip BaseAzureLLM class since it's allowed to define the client creation methods
|
||||
if class_name == "BaseAzureLLM":
|
||||
continue
|
||||
|
||||
# Check method bodies for direct instantiation
|
||||
for method in node.body:
|
||||
if isinstance(method, ast.FunctionDef) or isinstance(
|
||||
method, ast.AsyncFunctionDef
|
||||
):
|
||||
method_name = method.name
|
||||
|
||||
# Skip methods that are specifically for client creation
|
||||
if method_name in [
|
||||
"get_azure_openai_client",
|
||||
"initialize_azure_sdk_client",
|
||||
]:
|
||||
continue
|
||||
|
||||
# Look for direct instantiation in the method body
|
||||
for subnode in ast.walk(method):
|
||||
if isinstance(subnode, ast.Call):
|
||||
if hasattr(subnode, "func") and hasattr(subnode.func, "id"):
|
||||
if subnode.func.id in [
|
||||
"AzureOpenAI",
|
||||
"AsyncAzureOpenAI",
|
||||
]:
|
||||
issues.append(
|
||||
f"Direct instantiation of {subnode.func.id} in {class_name}.{method_name}"
|
||||
)
|
||||
elif hasattr(subnode, "func") and hasattr(
|
||||
subnode.func, "attr"
|
||||
):
|
||||
if subnode.func.attr in [
|
||||
"AzureOpenAI",
|
||||
"AsyncAzureOpenAI",
|
||||
]:
|
||||
issues.append(
|
||||
f"Direct instantiation of {subnode.func.attr} in {class_name}.{method_name}"
|
||||
)
|
||||
|
||||
return issues
|
||||
|
||||
|
||||
def main():
|
||||
"""
|
||||
Main function to run the test.
|
||||
"""
|
||||
# local
|
||||
base_dir = "../../litellm/llms/azure"
|
||||
azure_files = find_azure_files(base_dir)
|
||||
print(f"Found {len(azure_files)} Azure Python files to check")
|
||||
|
||||
all_issues = []
|
||||
|
||||
for file_path in azure_files:
|
||||
issues = check_direct_instantiation(file_path)
|
||||
if issues:
|
||||
all_issues.extend([f"{file_path}: {issue}" for issue in issues])
|
||||
|
||||
if all_issues:
|
||||
print("Found direct instantiations of AzureOpenAI or AsyncAzureOpenAI:")
|
||||
for issue in all_issues:
|
||||
print(f" - {issue}")
|
||||
raise Exception(
|
||||
f"Found {len(all_issues)} direct instantiations of AzureOpenAI or AsyncAzureOpenAI classes. Use get_azure_openai_client instead."
|
||||
)
|
||||
else:
|
||||
print("All Azure modules are correctly using get_azure_openai_client!")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
Loading…
Add table
Add a link
Reference in a new issue