test: initial test to enforce all functions in user_api_key_auth.py h… (#7797)

* test: initial test to enforce all functions in user_api_key_auth.py have direct testing

* test(test_user_api_key_auth.py): add is_allowed_route unit test

* test(test_user_api_key_auth.py): add more tests

* test(test_user_api_key_auth.py): add complete testing coverage for all functions in `user_api_key_auth.py`

* test(test_db_schema_changes.py): add a unit test to ensure all db schema changes are backwards compatible

gives user an easy rollback path

* test: fix schema compatibility test filepath

* test: fix test
This commit is contained in:
Krish Dholakia 2025-01-15 21:52:45 -08:00 committed by GitHub
parent 6473f9ad02
commit fbdd88d79c
6 changed files with 417 additions and 5 deletions

View file

@ -1,6 +1,5 @@
model_list:
- model_name: anthropic-vertex
- model_name: embedding-small
litellm_params:
model: vertex_ai/claude-3-5-sonnet@20240620
vertex_ai_project: "pathrise-convert-1606954137718"
vertex_ai_location: "europe-west1"
model: openai/text-embedding-3-small

View file

@ -155,6 +155,7 @@ def _is_allowed_route(
"""
- Route b/w ui token check and normal token check
"""
if token_type == "ui" and _is_ui_route(route=route, user_obj=user_obj):
return True
else:

View file

@ -0,0 +1,128 @@
"""
Enforce all functions in user_api_key_auth.py are covered by tests
"""
import ast
import os
def get_function_names_from_file(file_path):
"""
Extracts all function names from a given Python file.
"""
with open(file_path, "r") as file:
tree = ast.parse(file.read())
function_names = []
for node in tree.body:
if isinstance(node, (ast.FunctionDef, ast.AsyncFunctionDef)):
# Top-level functions
function_names.append(node.name)
elif isinstance(node, ast.ClassDef):
# Functions inside classes
for class_node in node.body:
if isinstance(class_node, (ast.FunctionDef, ast.AsyncFunctionDef)):
function_names.append(class_node.name)
return function_names
def get_all_functions_called_in_tests(base_dir):
"""
Returns a set of function names that are called in test functions
inside 'local_testing' and 'proxy_unit_tests' directories,
specifically in files containing the word 'router'.
"""
called_functions = set()
test_dirs = ["local_testing", "proxy_unit_tests"]
for test_dir in test_dirs:
dir_path = os.path.join(base_dir, test_dir)
if not os.path.exists(dir_path):
print(f"Warning: Directory {dir_path} does not exist.")
continue
print("dir_path: ", dir_path)
for root, _, files in os.walk(dir_path):
for file in files:
if file.endswith(".py"):
print("file: ", file)
file_path = os.path.join(root, file)
with open(file_path, "r") as f:
try:
tree = ast.parse(f.read())
except SyntaxError:
print(f"Warning: Syntax error in file {file_path}")
continue
if file == "test_router_validate_fallbacks.py":
print(f"tree: {tree}")
for node in ast.walk(tree):
if isinstance(node, ast.Call) and isinstance(
node.func, ast.Name
):
called_functions.add(node.func.id)
elif isinstance(node, ast.Call) and isinstance(
node.func, ast.Attribute
):
called_functions.add(node.func.attr)
return called_functions
def get_functions_from_router(file_path):
"""
Extracts all functions defined in user_api_key_auth.py.
"""
return get_function_names_from_file(file_path)
ignored_function_names = [
"__init__",
]
def main():
# router_file = [
# "./litellm/user_api_key_auth.py",
# ]
router_file = [
"../../litellm/proxy/auth/user_api_key_auth.py",
]
# router_file = [
# "../../litellm/router.py",
# "../../litellm/router_utils/pattern_match_deployments.py",
# "../../litellm/router_utils/batch_utils.py",
# ] ## LOCAL TESTING
# tests_dir = (
# "./tests/" # Update this path if your tests directory is located elsewhere
# )
tests_dir = "../../tests/" # LOCAL TESTING
router_functions = []
for file in router_file:
router_functions.extend(get_functions_from_router(file))
print("router_functions: ", router_functions)
called_functions_in_tests = get_all_functions_called_in_tests(tests_dir)
untested_functions = [
fn for fn in router_functions if fn not in called_functions_in_tests
]
if untested_functions:
all_untested_functions = []
for func in untested_functions:
if func not in ignored_function_names:
all_untested_functions.append(func)
untested_perc = (len(all_untested_functions)) / len(router_functions)
print("untested_perc: ", untested_perc)
if untested_perc > 0:
print("The following functions in user_api_key_auth.py are not tested:")
raise Exception(
f"{untested_perc * 100:.2f}% of functions in user_api_key_auth.py are not tested: {all_untested_functions}"
)
else:
print("All functions in user_api_key_auth.py are covered by tests.")
if __name__ == "__main__":
main()

View file

@ -444,7 +444,11 @@ async def test_async_vertexai_response():
f"model being tested in async call: {model}, litellm.vertex_language_models: {litellm.vertex_language_models}"
)
if model in VERTEX_MODELS_TO_NOT_TEST or (
"gecko" in model or "32k" in model or "ultra" in model or "002" in model
"gecko" in model
or "32k" in model
or "ultra" in model
or "002" in model
or "gemini-2.0-flash-thinking-exp" == model
):
# our account does not have access to this model
continue

View file

@ -0,0 +1,115 @@
import pytest
import subprocess
import re
from typing import Dict, List, Set
def get_schema_from_branch(branch: str = "main") -> str:
"""Get schema from specified git branch"""
result = subprocess.run(
["git", "show", f"{branch}:schema.prisma"], capture_output=True, text=True
)
return result.stdout
def parse_model_fields(schema: str) -> Dict[str, Dict[str, str]]:
"""Parse Prisma schema into dict of models and their fields"""
models = {}
current_model = None
for line in schema.split("\n"):
line = line.strip()
# Find model definition
if line.startswith("model "):
current_model = line.split(" ")[1]
models[current_model] = {}
continue
# Inside model definition
if current_model and line and not line.startswith("}"):
# Split field definition into name and type
parts = line.split()
if len(parts) >= 2:
field_name = parts[0]
field_type = " ".join(parts[1:])
models[current_model][field_name] = field_type
# End of model definition
if line.startswith("}"):
current_model = None
return models
def check_breaking_changes(
old_schema: Dict[str, Dict[str, str]], new_schema: Dict[str, Dict[str, str]]
) -> List[str]:
"""Check for breaking changes between schemas"""
breaking_changes = []
# Check each model in old schema
for model_name, old_fields in old_schema.items():
if model_name not in new_schema:
breaking_changes.append(f"Breaking: Model {model_name} was removed")
continue
new_fields = new_schema[model_name]
# Check each field in old model
for field_name, old_type in old_fields.items():
if field_name not in new_fields:
breaking_changes.append(
f"Breaking: Field {model_name}.{field_name} was removed"
)
continue
new_type = new_fields[field_name]
# Check for type changes
if old_type != new_type:
# Check specific type changes that are breaking
if "?" in old_type and "?" not in new_type:
breaking_changes.append(
f"Breaking: Field {model_name}.{field_name} changed from optional to required"
)
if not old_type.startswith(new_type.split("?")[0]):
breaking_changes.append(
f"Breaking: Field {model_name}.{field_name} changed type from {old_type} to {new_type}"
)
return breaking_changes
def test_aaaaaschema_compatibility():
"""Test if current schema has breaking changes compared to main"""
import os
print("Current directory:", os.getcwd())
# Get schemas
old_schema = get_schema_from_branch("main")
with open("./schema.prisma", "r") as f:
new_schema = f.read()
# Parse schemas
old_models = parse_model_fields(old_schema)
new_models = parse_model_fields(new_schema)
# Check for breaking changes
breaking_changes = check_breaking_changes(old_models, new_models)
# Fail if breaking changes found
if breaking_changes:
pytest.fail("\n".join(breaking_changes))
# Print informational diff
print("\nNon-breaking changes detected:")
for model_name, new_fields in new_models.items():
if model_name not in old_models:
print(f"Added new model: {model_name}")
continue
for field_name, new_type in new_fields.items():
if field_name not in old_models[model_name]:
print(f"Added new field: {model_name}.{field_name}")

View file

@ -20,6 +20,9 @@ from litellm.proxy.auth.user_api_key_auth import (
UserAPIKeyAuth,
get_api_key_from_custom_header,
)
from fastapi import WebSocket, HTTPException, status
from litellm.proxy._types import LiteLLM_UserTable, LitellmUserRoles
class Request:
@ -629,3 +632,165 @@ async def test_soft_budget_alert():
"budget_alerts",
original_budget_alerts,
)
def test_is_allowed_route():
from litellm.proxy.auth.user_api_key_auth import _is_allowed_route
from litellm.proxy._types import UserAPIKeyAuth
import datetime
request = MagicMock()
args = {
"route": "/embeddings",
"token_type": "api",
"request": request,
"request_data": {"input": ["hello world"], "model": "embedding-small"},
"api_key": "9644159bc181998825c44c788b1526341ed2e825d1b6f562e23173759e14bb86",
"valid_token": UserAPIKeyAuth(
token="9644159bc181998825c44c788b1526341ed2e825d1b6f562e23173759e14bb86",
key_name="sk-...CJjQ",
key_alias=None,
spend=0.0,
max_budget=None,
expires=None,
models=[],
aliases={},
config={},
user_id=None,
team_id=None,
max_parallel_requests=None,
metadata={},
tpm_limit=None,
rpm_limit=None,
budget_duration=None,
budget_reset_at=None,
allowed_cache_controls=[],
permissions={},
model_spend={},
model_max_budget={},
soft_budget_cooldown=False,
blocked=None,
litellm_budget_table=None,
org_id=None,
created_at=MagicMock(),
updated_at=MagicMock(),
team_spend=None,
team_alias=None,
team_tpm_limit=None,
team_rpm_limit=None,
team_max_budget=None,
team_models=[],
team_blocked=False,
soft_budget=None,
team_model_aliases=None,
team_member_spend=None,
team_member=None,
team_metadata=None,
end_user_id=None,
end_user_tpm_limit=None,
end_user_rpm_limit=None,
end_user_max_budget=None,
last_refreshed_at=1736990277.432638,
api_key=None,
user_role=None,
allowed_model_region=None,
parent_otel_span=None,
rpm_limit_per_model=None,
tpm_limit_per_model=None,
user_tpm_limit=None,
user_rpm_limit=None,
),
"user_obj": None,
}
assert _is_allowed_route(**args)
@pytest.mark.parametrize(
"user_obj, expected_result",
[
(None, False), # Case 1: user_obj is None
(
LiteLLM_UserTable(
user_role=LitellmUserRoles.PROXY_ADMIN.value,
user_id="1234",
user_email="test@test.com",
max_budget=None,
spend=0.0,
),
True,
), # Case 2: user_role is PROXY_ADMIN
(
LiteLLM_UserTable(
user_role="OTHER_ROLE",
user_id="1234",
user_email="test@test.com",
max_budget=None,
spend=0.0,
),
False,
), # Case 3: user_role is not PROXY_ADMIN
],
)
def test_is_user_proxy_admin(user_obj, expected_result):
from litellm.proxy.auth.user_api_key_auth import _is_user_proxy_admin
assert _is_user_proxy_admin(user_obj) == expected_result
@pytest.mark.parametrize(
"user_obj, expected_role",
[
(None, None), # Case 1: user_obj is None (should return None)
(
LiteLLM_UserTable(
user_role=LitellmUserRoles.PROXY_ADMIN.value,
user_id="1234",
user_email="test@test.com",
max_budget=None,
spend=0.0,
),
LitellmUserRoles.PROXY_ADMIN,
), # Case 2: user_role is PROXY_ADMIN (should return LitellmUserRoles.PROXY_ADMIN)
(
LiteLLM_UserTable(
user_role="OTHER_ROLE",
user_id="1234",
user_email="test@test.com",
max_budget=None,
spend=0.0,
),
LitellmUserRoles.INTERNAL_USER,
), # Case 3: invalid user_role (should return LitellmUserRoles.INTERNAL_USER)
],
)
def test_get_user_role(user_obj, expected_role):
from litellm.proxy.auth.user_api_key_auth import _get_user_role
assert _get_user_role(user_obj) == expected_role
@pytest.mark.asyncio
async def test_user_api_key_auth_websocket():
from litellm.proxy.auth.user_api_key_auth import user_api_key_auth_websocket
# Prepare a mock WebSocket object
mock_websocket = MagicMock(spec=WebSocket)
mock_websocket.query_params = {"model": "some_model"}
mock_websocket.headers = {"authorization": "Bearer some_api_key"}
# Mock the return value of `user_api_key_auth` when it's called within the `user_api_key_auth_websocket` function
with patch(
"litellm.proxy.auth.user_api_key_auth.user_api_key_auth", autospec=True
) as mock_user_api_key_auth:
# Make the call to the WebSocket function
await user_api_key_auth_websocket(mock_websocket)
# Assert that `user_api_key_auth` was called with the correct parameters
mock_user_api_key_auth.assert_called_once()
assert (
mock_user_api_key_auth.call_args.kwargs["api_key"] == "Bearer some_api_key"
)