mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-26 19:24:27 +00:00
test: initial test to enforce all functions in user_api_key_auth.py h… (#7797)
* test: initial test to enforce all functions in user_api_key_auth.py have direct testing * test(test_user_api_key_auth.py): add is_allowed_route unit test * test(test_user_api_key_auth.py): add more tests * test(test_user_api_key_auth.py): add complete testing coverage for all functions in `user_api_key_auth.py` * test(test_db_schema_changes.py): add a unit test to ensure all db schema changes are backwards compatible gives user an easy rollback path * test: fix schema compatibility test filepath * test: fix test
This commit is contained in:
parent
6473f9ad02
commit
fbdd88d79c
6 changed files with 417 additions and 5 deletions
|
@ -1,6 +1,5 @@
|
|||
model_list:
|
||||
- model_name: anthropic-vertex
|
||||
- model_name: embedding-small
|
||||
litellm_params:
|
||||
model: vertex_ai/claude-3-5-sonnet@20240620
|
||||
vertex_ai_project: "pathrise-convert-1606954137718"
|
||||
vertex_ai_location: "europe-west1"
|
||||
model: openai/text-embedding-3-small
|
||||
|
|
@ -155,6 +155,7 @@ def _is_allowed_route(
|
|||
"""
|
||||
- Route b/w ui token check and normal token check
|
||||
"""
|
||||
|
||||
if token_type == "ui" and _is_ui_route(route=route, user_obj=user_obj):
|
||||
return True
|
||||
else:
|
||||
|
|
128
tests/code_coverage_tests/user_api_key_auth_code_coverage.py
Normal file
128
tests/code_coverage_tests/user_api_key_auth_code_coverage.py
Normal file
|
@ -0,0 +1,128 @@
|
|||
"""
|
||||
Enforce all functions in user_api_key_auth.py are covered by tests
|
||||
"""
|
||||
|
||||
import ast
|
||||
import os
|
||||
|
||||
|
||||
def get_function_names_from_file(file_path):
|
||||
"""
|
||||
Extracts all function names from a given Python file.
|
||||
"""
|
||||
with open(file_path, "r") as file:
|
||||
tree = ast.parse(file.read())
|
||||
|
||||
function_names = []
|
||||
|
||||
for node in tree.body:
|
||||
if isinstance(node, (ast.FunctionDef, ast.AsyncFunctionDef)):
|
||||
# Top-level functions
|
||||
function_names.append(node.name)
|
||||
elif isinstance(node, ast.ClassDef):
|
||||
# Functions inside classes
|
||||
for class_node in node.body:
|
||||
if isinstance(class_node, (ast.FunctionDef, ast.AsyncFunctionDef)):
|
||||
function_names.append(class_node.name)
|
||||
|
||||
return function_names
|
||||
|
||||
|
||||
def get_all_functions_called_in_tests(base_dir):
|
||||
"""
|
||||
Returns a set of function names that are called in test functions
|
||||
inside 'local_testing' and 'proxy_unit_tests' directories,
|
||||
specifically in files containing the word 'router'.
|
||||
"""
|
||||
called_functions = set()
|
||||
test_dirs = ["local_testing", "proxy_unit_tests"]
|
||||
|
||||
for test_dir in test_dirs:
|
||||
dir_path = os.path.join(base_dir, test_dir)
|
||||
if not os.path.exists(dir_path):
|
||||
print(f"Warning: Directory {dir_path} does not exist.")
|
||||
continue
|
||||
|
||||
print("dir_path: ", dir_path)
|
||||
for root, _, files in os.walk(dir_path):
|
||||
for file in files:
|
||||
if file.endswith(".py"):
|
||||
print("file: ", file)
|
||||
file_path = os.path.join(root, file)
|
||||
with open(file_path, "r") as f:
|
||||
try:
|
||||
tree = ast.parse(f.read())
|
||||
except SyntaxError:
|
||||
print(f"Warning: Syntax error in file {file_path}")
|
||||
continue
|
||||
if file == "test_router_validate_fallbacks.py":
|
||||
print(f"tree: {tree}")
|
||||
for node in ast.walk(tree):
|
||||
if isinstance(node, ast.Call) and isinstance(
|
||||
node.func, ast.Name
|
||||
):
|
||||
called_functions.add(node.func.id)
|
||||
elif isinstance(node, ast.Call) and isinstance(
|
||||
node.func, ast.Attribute
|
||||
):
|
||||
called_functions.add(node.func.attr)
|
||||
|
||||
return called_functions
|
||||
|
||||
|
||||
def get_functions_from_router(file_path):
|
||||
"""
|
||||
Extracts all functions defined in user_api_key_auth.py.
|
||||
"""
|
||||
return get_function_names_from_file(file_path)
|
||||
|
||||
|
||||
ignored_function_names = [
|
||||
"__init__",
|
||||
]
|
||||
|
||||
|
||||
def main():
|
||||
# router_file = [
|
||||
# "./litellm/user_api_key_auth.py",
|
||||
# ]
|
||||
router_file = [
|
||||
"../../litellm/proxy/auth/user_api_key_auth.py",
|
||||
]
|
||||
# router_file = [
|
||||
# "../../litellm/router.py",
|
||||
# "../../litellm/router_utils/pattern_match_deployments.py",
|
||||
# "../../litellm/router_utils/batch_utils.py",
|
||||
# ] ## LOCAL TESTING
|
||||
# tests_dir = (
|
||||
# "./tests/" # Update this path if your tests directory is located elsewhere
|
||||
# )
|
||||
tests_dir = "../../tests/" # LOCAL TESTING
|
||||
|
||||
router_functions = []
|
||||
for file in router_file:
|
||||
router_functions.extend(get_functions_from_router(file))
|
||||
print("router_functions: ", router_functions)
|
||||
called_functions_in_tests = get_all_functions_called_in_tests(tests_dir)
|
||||
untested_functions = [
|
||||
fn for fn in router_functions if fn not in called_functions_in_tests
|
||||
]
|
||||
|
||||
if untested_functions:
|
||||
all_untested_functions = []
|
||||
for func in untested_functions:
|
||||
if func not in ignored_function_names:
|
||||
all_untested_functions.append(func)
|
||||
untested_perc = (len(all_untested_functions)) / len(router_functions)
|
||||
print("untested_perc: ", untested_perc)
|
||||
if untested_perc > 0:
|
||||
print("The following functions in user_api_key_auth.py are not tested:")
|
||||
raise Exception(
|
||||
f"{untested_perc * 100:.2f}% of functions in user_api_key_auth.py are not tested: {all_untested_functions}"
|
||||
)
|
||||
else:
|
||||
print("All functions in user_api_key_auth.py are covered by tests.")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
|
@ -444,7 +444,11 @@ async def test_async_vertexai_response():
|
|||
f"model being tested in async call: {model}, litellm.vertex_language_models: {litellm.vertex_language_models}"
|
||||
)
|
||||
if model in VERTEX_MODELS_TO_NOT_TEST or (
|
||||
"gecko" in model or "32k" in model or "ultra" in model or "002" in model
|
||||
"gecko" in model
|
||||
or "32k" in model
|
||||
or "ultra" in model
|
||||
or "002" in model
|
||||
or "gemini-2.0-flash-thinking-exp" == model
|
||||
):
|
||||
# our account does not have access to this model
|
||||
continue
|
||||
|
|
115
tests/proxy_unit_tests/test_db_schema_changes.py
Normal file
115
tests/proxy_unit_tests/test_db_schema_changes.py
Normal file
|
@ -0,0 +1,115 @@
|
|||
import pytest
|
||||
import subprocess
|
||||
import re
|
||||
from typing import Dict, List, Set
|
||||
|
||||
|
||||
def get_schema_from_branch(branch: str = "main") -> str:
|
||||
"""Get schema from specified git branch"""
|
||||
result = subprocess.run(
|
||||
["git", "show", f"{branch}:schema.prisma"], capture_output=True, text=True
|
||||
)
|
||||
return result.stdout
|
||||
|
||||
|
||||
def parse_model_fields(schema: str) -> Dict[str, Dict[str, str]]:
|
||||
"""Parse Prisma schema into dict of models and their fields"""
|
||||
models = {}
|
||||
current_model = None
|
||||
|
||||
for line in schema.split("\n"):
|
||||
line = line.strip()
|
||||
|
||||
# Find model definition
|
||||
if line.startswith("model "):
|
||||
current_model = line.split(" ")[1]
|
||||
models[current_model] = {}
|
||||
continue
|
||||
|
||||
# Inside model definition
|
||||
if current_model and line and not line.startswith("}"):
|
||||
# Split field definition into name and type
|
||||
parts = line.split()
|
||||
if len(parts) >= 2:
|
||||
field_name = parts[0]
|
||||
field_type = " ".join(parts[1:])
|
||||
models[current_model][field_name] = field_type
|
||||
|
||||
# End of model definition
|
||||
if line.startswith("}"):
|
||||
current_model = None
|
||||
|
||||
return models
|
||||
|
||||
|
||||
def check_breaking_changes(
|
||||
old_schema: Dict[str, Dict[str, str]], new_schema: Dict[str, Dict[str, str]]
|
||||
) -> List[str]:
|
||||
"""Check for breaking changes between schemas"""
|
||||
breaking_changes = []
|
||||
|
||||
# Check each model in old schema
|
||||
for model_name, old_fields in old_schema.items():
|
||||
if model_name not in new_schema:
|
||||
breaking_changes.append(f"Breaking: Model {model_name} was removed")
|
||||
continue
|
||||
|
||||
new_fields = new_schema[model_name]
|
||||
|
||||
# Check each field in old model
|
||||
for field_name, old_type in old_fields.items():
|
||||
if field_name not in new_fields:
|
||||
breaking_changes.append(
|
||||
f"Breaking: Field {model_name}.{field_name} was removed"
|
||||
)
|
||||
continue
|
||||
|
||||
new_type = new_fields[field_name]
|
||||
|
||||
# Check for type changes
|
||||
if old_type != new_type:
|
||||
# Check specific type changes that are breaking
|
||||
if "?" in old_type and "?" not in new_type:
|
||||
breaking_changes.append(
|
||||
f"Breaking: Field {model_name}.{field_name} changed from optional to required"
|
||||
)
|
||||
if not old_type.startswith(new_type.split("?")[0]):
|
||||
breaking_changes.append(
|
||||
f"Breaking: Field {model_name}.{field_name} changed type from {old_type} to {new_type}"
|
||||
)
|
||||
|
||||
return breaking_changes
|
||||
|
||||
|
||||
def test_aaaaaschema_compatibility():
|
||||
"""Test if current schema has breaking changes compared to main"""
|
||||
import os
|
||||
|
||||
print("Current directory:", os.getcwd())
|
||||
|
||||
# Get schemas
|
||||
old_schema = get_schema_from_branch("main")
|
||||
with open("./schema.prisma", "r") as f:
|
||||
new_schema = f.read()
|
||||
|
||||
# Parse schemas
|
||||
old_models = parse_model_fields(old_schema)
|
||||
new_models = parse_model_fields(new_schema)
|
||||
|
||||
# Check for breaking changes
|
||||
breaking_changes = check_breaking_changes(old_models, new_models)
|
||||
|
||||
# Fail if breaking changes found
|
||||
if breaking_changes:
|
||||
pytest.fail("\n".join(breaking_changes))
|
||||
|
||||
# Print informational diff
|
||||
print("\nNon-breaking changes detected:")
|
||||
for model_name, new_fields in new_models.items():
|
||||
if model_name not in old_models:
|
||||
print(f"Added new model: {model_name}")
|
||||
continue
|
||||
|
||||
for field_name, new_type in new_fields.items():
|
||||
if field_name not in old_models[model_name]:
|
||||
print(f"Added new field: {model_name}.{field_name}")
|
|
@ -20,6 +20,9 @@ from litellm.proxy.auth.user_api_key_auth import (
|
|||
UserAPIKeyAuth,
|
||||
get_api_key_from_custom_header,
|
||||
)
|
||||
from fastapi import WebSocket, HTTPException, status
|
||||
|
||||
from litellm.proxy._types import LiteLLM_UserTable, LitellmUserRoles
|
||||
|
||||
|
||||
class Request:
|
||||
|
@ -629,3 +632,165 @@ async def test_soft_budget_alert():
|
|||
"budget_alerts",
|
||||
original_budget_alerts,
|
||||
)
|
||||
|
||||
|
||||
def test_is_allowed_route():
|
||||
from litellm.proxy.auth.user_api_key_auth import _is_allowed_route
|
||||
from litellm.proxy._types import UserAPIKeyAuth
|
||||
import datetime
|
||||
|
||||
request = MagicMock()
|
||||
|
||||
args = {
|
||||
"route": "/embeddings",
|
||||
"token_type": "api",
|
||||
"request": request,
|
||||
"request_data": {"input": ["hello world"], "model": "embedding-small"},
|
||||
"api_key": "9644159bc181998825c44c788b1526341ed2e825d1b6f562e23173759e14bb86",
|
||||
"valid_token": UserAPIKeyAuth(
|
||||
token="9644159bc181998825c44c788b1526341ed2e825d1b6f562e23173759e14bb86",
|
||||
key_name="sk-...CJjQ",
|
||||
key_alias=None,
|
||||
spend=0.0,
|
||||
max_budget=None,
|
||||
expires=None,
|
||||
models=[],
|
||||
aliases={},
|
||||
config={},
|
||||
user_id=None,
|
||||
team_id=None,
|
||||
max_parallel_requests=None,
|
||||
metadata={},
|
||||
tpm_limit=None,
|
||||
rpm_limit=None,
|
||||
budget_duration=None,
|
||||
budget_reset_at=None,
|
||||
allowed_cache_controls=[],
|
||||
permissions={},
|
||||
model_spend={},
|
||||
model_max_budget={},
|
||||
soft_budget_cooldown=False,
|
||||
blocked=None,
|
||||
litellm_budget_table=None,
|
||||
org_id=None,
|
||||
created_at=MagicMock(),
|
||||
updated_at=MagicMock(),
|
||||
team_spend=None,
|
||||
team_alias=None,
|
||||
team_tpm_limit=None,
|
||||
team_rpm_limit=None,
|
||||
team_max_budget=None,
|
||||
team_models=[],
|
||||
team_blocked=False,
|
||||
soft_budget=None,
|
||||
team_model_aliases=None,
|
||||
team_member_spend=None,
|
||||
team_member=None,
|
||||
team_metadata=None,
|
||||
end_user_id=None,
|
||||
end_user_tpm_limit=None,
|
||||
end_user_rpm_limit=None,
|
||||
end_user_max_budget=None,
|
||||
last_refreshed_at=1736990277.432638,
|
||||
api_key=None,
|
||||
user_role=None,
|
||||
allowed_model_region=None,
|
||||
parent_otel_span=None,
|
||||
rpm_limit_per_model=None,
|
||||
tpm_limit_per_model=None,
|
||||
user_tpm_limit=None,
|
||||
user_rpm_limit=None,
|
||||
),
|
||||
"user_obj": None,
|
||||
}
|
||||
|
||||
assert _is_allowed_route(**args)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"user_obj, expected_result",
|
||||
[
|
||||
(None, False), # Case 1: user_obj is None
|
||||
(
|
||||
LiteLLM_UserTable(
|
||||
user_role=LitellmUserRoles.PROXY_ADMIN.value,
|
||||
user_id="1234",
|
||||
user_email="test@test.com",
|
||||
max_budget=None,
|
||||
spend=0.0,
|
||||
),
|
||||
True,
|
||||
), # Case 2: user_role is PROXY_ADMIN
|
||||
(
|
||||
LiteLLM_UserTable(
|
||||
user_role="OTHER_ROLE",
|
||||
user_id="1234",
|
||||
user_email="test@test.com",
|
||||
max_budget=None,
|
||||
spend=0.0,
|
||||
),
|
||||
False,
|
||||
), # Case 3: user_role is not PROXY_ADMIN
|
||||
],
|
||||
)
|
||||
def test_is_user_proxy_admin(user_obj, expected_result):
|
||||
from litellm.proxy.auth.user_api_key_auth import _is_user_proxy_admin
|
||||
|
||||
assert _is_user_proxy_admin(user_obj) == expected_result
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"user_obj, expected_role",
|
||||
[
|
||||
(None, None), # Case 1: user_obj is None (should return None)
|
||||
(
|
||||
LiteLLM_UserTable(
|
||||
user_role=LitellmUserRoles.PROXY_ADMIN.value,
|
||||
user_id="1234",
|
||||
user_email="test@test.com",
|
||||
max_budget=None,
|
||||
spend=0.0,
|
||||
),
|
||||
LitellmUserRoles.PROXY_ADMIN,
|
||||
), # Case 2: user_role is PROXY_ADMIN (should return LitellmUserRoles.PROXY_ADMIN)
|
||||
(
|
||||
LiteLLM_UserTable(
|
||||
user_role="OTHER_ROLE",
|
||||
user_id="1234",
|
||||
user_email="test@test.com",
|
||||
max_budget=None,
|
||||
spend=0.0,
|
||||
),
|
||||
LitellmUserRoles.INTERNAL_USER,
|
||||
), # Case 3: invalid user_role (should return LitellmUserRoles.INTERNAL_USER)
|
||||
],
|
||||
)
|
||||
def test_get_user_role(user_obj, expected_role):
|
||||
from litellm.proxy.auth.user_api_key_auth import _get_user_role
|
||||
|
||||
assert _get_user_role(user_obj) == expected_role
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_user_api_key_auth_websocket():
|
||||
from litellm.proxy.auth.user_api_key_auth import user_api_key_auth_websocket
|
||||
|
||||
# Prepare a mock WebSocket object
|
||||
mock_websocket = MagicMock(spec=WebSocket)
|
||||
mock_websocket.query_params = {"model": "some_model"}
|
||||
mock_websocket.headers = {"authorization": "Bearer some_api_key"}
|
||||
|
||||
# Mock the return value of `user_api_key_auth` when it's called within the `user_api_key_auth_websocket` function
|
||||
with patch(
|
||||
"litellm.proxy.auth.user_api_key_auth.user_api_key_auth", autospec=True
|
||||
) as mock_user_api_key_auth:
|
||||
|
||||
# Make the call to the WebSocket function
|
||||
await user_api_key_auth_websocket(mock_websocket)
|
||||
|
||||
# Assert that `user_api_key_auth` was called with the correct parameters
|
||||
mock_user_api_key_auth.assert_called_once()
|
||||
|
||||
assert (
|
||||
mock_user_api_key_auth.call_args.kwargs["api_key"] == "Bearer some_api_key"
|
||||
)
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue