mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-26 19:24:27 +00:00
test: initial test to enforce all functions in user_api_key_auth.py h… (#7797)
* test: initial test to enforce all functions in user_api_key_auth.py have direct testing * test(test_user_api_key_auth.py): add is_allowed_route unit test * test(test_user_api_key_auth.py): add more tests * test(test_user_api_key_auth.py): add complete testing coverage for all functions in `user_api_key_auth.py` * test(test_db_schema_changes.py): add a unit test to ensure all db schema changes are backwards compatible gives user an easy rollback path * test: fix schema compatibility test filepath * test: fix test
This commit is contained in:
parent
6473f9ad02
commit
fbdd88d79c
6 changed files with 417 additions and 5 deletions
|
@ -1,6 +1,5 @@
|
||||||
model_list:
|
model_list:
|
||||||
- model_name: anthropic-vertex
|
- model_name: embedding-small
|
||||||
litellm_params:
|
litellm_params:
|
||||||
model: vertex_ai/claude-3-5-sonnet@20240620
|
model: openai/text-embedding-3-small
|
||||||
vertex_ai_project: "pathrise-convert-1606954137718"
|
|
||||||
vertex_ai_location: "europe-west1"
|
|
|
@ -155,6 +155,7 @@ def _is_allowed_route(
|
||||||
"""
|
"""
|
||||||
- Route b/w ui token check and normal token check
|
- Route b/w ui token check and normal token check
|
||||||
"""
|
"""
|
||||||
|
|
||||||
if token_type == "ui" and _is_ui_route(route=route, user_obj=user_obj):
|
if token_type == "ui" and _is_ui_route(route=route, user_obj=user_obj):
|
||||||
return True
|
return True
|
||||||
else:
|
else:
|
||||||
|
|
128
tests/code_coverage_tests/user_api_key_auth_code_coverage.py
Normal file
128
tests/code_coverage_tests/user_api_key_auth_code_coverage.py
Normal file
|
@ -0,0 +1,128 @@
|
||||||
|
"""
|
||||||
|
Enforce all functions in user_api_key_auth.py are covered by tests
|
||||||
|
"""
|
||||||
|
|
||||||
|
import ast
|
||||||
|
import os
|
||||||
|
|
||||||
|
|
||||||
|
def get_function_names_from_file(file_path):
|
||||||
|
"""
|
||||||
|
Extracts all function names from a given Python file.
|
||||||
|
"""
|
||||||
|
with open(file_path, "r") as file:
|
||||||
|
tree = ast.parse(file.read())
|
||||||
|
|
||||||
|
function_names = []
|
||||||
|
|
||||||
|
for node in tree.body:
|
||||||
|
if isinstance(node, (ast.FunctionDef, ast.AsyncFunctionDef)):
|
||||||
|
# Top-level functions
|
||||||
|
function_names.append(node.name)
|
||||||
|
elif isinstance(node, ast.ClassDef):
|
||||||
|
# Functions inside classes
|
||||||
|
for class_node in node.body:
|
||||||
|
if isinstance(class_node, (ast.FunctionDef, ast.AsyncFunctionDef)):
|
||||||
|
function_names.append(class_node.name)
|
||||||
|
|
||||||
|
return function_names
|
||||||
|
|
||||||
|
|
||||||
|
def get_all_functions_called_in_tests(base_dir):
|
||||||
|
"""
|
||||||
|
Returns a set of function names that are called in test functions
|
||||||
|
inside 'local_testing' and 'proxy_unit_tests' directories,
|
||||||
|
specifically in files containing the word 'router'.
|
||||||
|
"""
|
||||||
|
called_functions = set()
|
||||||
|
test_dirs = ["local_testing", "proxy_unit_tests"]
|
||||||
|
|
||||||
|
for test_dir in test_dirs:
|
||||||
|
dir_path = os.path.join(base_dir, test_dir)
|
||||||
|
if not os.path.exists(dir_path):
|
||||||
|
print(f"Warning: Directory {dir_path} does not exist.")
|
||||||
|
continue
|
||||||
|
|
||||||
|
print("dir_path: ", dir_path)
|
||||||
|
for root, _, files in os.walk(dir_path):
|
||||||
|
for file in files:
|
||||||
|
if file.endswith(".py"):
|
||||||
|
print("file: ", file)
|
||||||
|
file_path = os.path.join(root, file)
|
||||||
|
with open(file_path, "r") as f:
|
||||||
|
try:
|
||||||
|
tree = ast.parse(f.read())
|
||||||
|
except SyntaxError:
|
||||||
|
print(f"Warning: Syntax error in file {file_path}")
|
||||||
|
continue
|
||||||
|
if file == "test_router_validate_fallbacks.py":
|
||||||
|
print(f"tree: {tree}")
|
||||||
|
for node in ast.walk(tree):
|
||||||
|
if isinstance(node, ast.Call) and isinstance(
|
||||||
|
node.func, ast.Name
|
||||||
|
):
|
||||||
|
called_functions.add(node.func.id)
|
||||||
|
elif isinstance(node, ast.Call) and isinstance(
|
||||||
|
node.func, ast.Attribute
|
||||||
|
):
|
||||||
|
called_functions.add(node.func.attr)
|
||||||
|
|
||||||
|
return called_functions
|
||||||
|
|
||||||
|
|
||||||
|
def get_functions_from_router(file_path):
|
||||||
|
"""
|
||||||
|
Extracts all functions defined in user_api_key_auth.py.
|
||||||
|
"""
|
||||||
|
return get_function_names_from_file(file_path)
|
||||||
|
|
||||||
|
|
||||||
|
ignored_function_names = [
|
||||||
|
"__init__",
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
# router_file = [
|
||||||
|
# "./litellm/user_api_key_auth.py",
|
||||||
|
# ]
|
||||||
|
router_file = [
|
||||||
|
"../../litellm/proxy/auth/user_api_key_auth.py",
|
||||||
|
]
|
||||||
|
# router_file = [
|
||||||
|
# "../../litellm/router.py",
|
||||||
|
# "../../litellm/router_utils/pattern_match_deployments.py",
|
||||||
|
# "../../litellm/router_utils/batch_utils.py",
|
||||||
|
# ] ## LOCAL TESTING
|
||||||
|
# tests_dir = (
|
||||||
|
# "./tests/" # Update this path if your tests directory is located elsewhere
|
||||||
|
# )
|
||||||
|
tests_dir = "../../tests/" # LOCAL TESTING
|
||||||
|
|
||||||
|
router_functions = []
|
||||||
|
for file in router_file:
|
||||||
|
router_functions.extend(get_functions_from_router(file))
|
||||||
|
print("router_functions: ", router_functions)
|
||||||
|
called_functions_in_tests = get_all_functions_called_in_tests(tests_dir)
|
||||||
|
untested_functions = [
|
||||||
|
fn for fn in router_functions if fn not in called_functions_in_tests
|
||||||
|
]
|
||||||
|
|
||||||
|
if untested_functions:
|
||||||
|
all_untested_functions = []
|
||||||
|
for func in untested_functions:
|
||||||
|
if func not in ignored_function_names:
|
||||||
|
all_untested_functions.append(func)
|
||||||
|
untested_perc = (len(all_untested_functions)) / len(router_functions)
|
||||||
|
print("untested_perc: ", untested_perc)
|
||||||
|
if untested_perc > 0:
|
||||||
|
print("The following functions in user_api_key_auth.py are not tested:")
|
||||||
|
raise Exception(
|
||||||
|
f"{untested_perc * 100:.2f}% of functions in user_api_key_auth.py are not tested: {all_untested_functions}"
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
print("All functions in user_api_key_auth.py are covered by tests.")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
|
@ -444,7 +444,11 @@ async def test_async_vertexai_response():
|
||||||
f"model being tested in async call: {model}, litellm.vertex_language_models: {litellm.vertex_language_models}"
|
f"model being tested in async call: {model}, litellm.vertex_language_models: {litellm.vertex_language_models}"
|
||||||
)
|
)
|
||||||
if model in VERTEX_MODELS_TO_NOT_TEST or (
|
if model in VERTEX_MODELS_TO_NOT_TEST or (
|
||||||
"gecko" in model or "32k" in model or "ultra" in model or "002" in model
|
"gecko" in model
|
||||||
|
or "32k" in model
|
||||||
|
or "ultra" in model
|
||||||
|
or "002" in model
|
||||||
|
or "gemini-2.0-flash-thinking-exp" == model
|
||||||
):
|
):
|
||||||
# our account does not have access to this model
|
# our account does not have access to this model
|
||||||
continue
|
continue
|
||||||
|
|
115
tests/proxy_unit_tests/test_db_schema_changes.py
Normal file
115
tests/proxy_unit_tests/test_db_schema_changes.py
Normal file
|
@ -0,0 +1,115 @@
|
||||||
|
import pytest
|
||||||
|
import subprocess
|
||||||
|
import re
|
||||||
|
from typing import Dict, List, Set
|
||||||
|
|
||||||
|
|
||||||
|
def get_schema_from_branch(branch: str = "main") -> str:
|
||||||
|
"""Get schema from specified git branch"""
|
||||||
|
result = subprocess.run(
|
||||||
|
["git", "show", f"{branch}:schema.prisma"], capture_output=True, text=True
|
||||||
|
)
|
||||||
|
return result.stdout
|
||||||
|
|
||||||
|
|
||||||
|
def parse_model_fields(schema: str) -> Dict[str, Dict[str, str]]:
|
||||||
|
"""Parse Prisma schema into dict of models and their fields"""
|
||||||
|
models = {}
|
||||||
|
current_model = None
|
||||||
|
|
||||||
|
for line in schema.split("\n"):
|
||||||
|
line = line.strip()
|
||||||
|
|
||||||
|
# Find model definition
|
||||||
|
if line.startswith("model "):
|
||||||
|
current_model = line.split(" ")[1]
|
||||||
|
models[current_model] = {}
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Inside model definition
|
||||||
|
if current_model and line and not line.startswith("}"):
|
||||||
|
# Split field definition into name and type
|
||||||
|
parts = line.split()
|
||||||
|
if len(parts) >= 2:
|
||||||
|
field_name = parts[0]
|
||||||
|
field_type = " ".join(parts[1:])
|
||||||
|
models[current_model][field_name] = field_type
|
||||||
|
|
||||||
|
# End of model definition
|
||||||
|
if line.startswith("}"):
|
||||||
|
current_model = None
|
||||||
|
|
||||||
|
return models
|
||||||
|
|
||||||
|
|
||||||
|
def check_breaking_changes(
|
||||||
|
old_schema: Dict[str, Dict[str, str]], new_schema: Dict[str, Dict[str, str]]
|
||||||
|
) -> List[str]:
|
||||||
|
"""Check for breaking changes between schemas"""
|
||||||
|
breaking_changes = []
|
||||||
|
|
||||||
|
# Check each model in old schema
|
||||||
|
for model_name, old_fields in old_schema.items():
|
||||||
|
if model_name not in new_schema:
|
||||||
|
breaking_changes.append(f"Breaking: Model {model_name} was removed")
|
||||||
|
continue
|
||||||
|
|
||||||
|
new_fields = new_schema[model_name]
|
||||||
|
|
||||||
|
# Check each field in old model
|
||||||
|
for field_name, old_type in old_fields.items():
|
||||||
|
if field_name not in new_fields:
|
||||||
|
breaking_changes.append(
|
||||||
|
f"Breaking: Field {model_name}.{field_name} was removed"
|
||||||
|
)
|
||||||
|
continue
|
||||||
|
|
||||||
|
new_type = new_fields[field_name]
|
||||||
|
|
||||||
|
# Check for type changes
|
||||||
|
if old_type != new_type:
|
||||||
|
# Check specific type changes that are breaking
|
||||||
|
if "?" in old_type and "?" not in new_type:
|
||||||
|
breaking_changes.append(
|
||||||
|
f"Breaking: Field {model_name}.{field_name} changed from optional to required"
|
||||||
|
)
|
||||||
|
if not old_type.startswith(new_type.split("?")[0]):
|
||||||
|
breaking_changes.append(
|
||||||
|
f"Breaking: Field {model_name}.{field_name} changed type from {old_type} to {new_type}"
|
||||||
|
)
|
||||||
|
|
||||||
|
return breaking_changes
|
||||||
|
|
||||||
|
|
||||||
|
def test_aaaaaschema_compatibility():
|
||||||
|
"""Test if current schema has breaking changes compared to main"""
|
||||||
|
import os
|
||||||
|
|
||||||
|
print("Current directory:", os.getcwd())
|
||||||
|
|
||||||
|
# Get schemas
|
||||||
|
old_schema = get_schema_from_branch("main")
|
||||||
|
with open("./schema.prisma", "r") as f:
|
||||||
|
new_schema = f.read()
|
||||||
|
|
||||||
|
# Parse schemas
|
||||||
|
old_models = parse_model_fields(old_schema)
|
||||||
|
new_models = parse_model_fields(new_schema)
|
||||||
|
|
||||||
|
# Check for breaking changes
|
||||||
|
breaking_changes = check_breaking_changes(old_models, new_models)
|
||||||
|
|
||||||
|
# Fail if breaking changes found
|
||||||
|
if breaking_changes:
|
||||||
|
pytest.fail("\n".join(breaking_changes))
|
||||||
|
|
||||||
|
# Print informational diff
|
||||||
|
print("\nNon-breaking changes detected:")
|
||||||
|
for model_name, new_fields in new_models.items():
|
||||||
|
if model_name not in old_models:
|
||||||
|
print(f"Added new model: {model_name}")
|
||||||
|
continue
|
||||||
|
|
||||||
|
for field_name, new_type in new_fields.items():
|
||||||
|
if field_name not in old_models[model_name]:
|
||||||
|
print(f"Added new field: {model_name}.{field_name}")
|
|
@ -20,6 +20,9 @@ from litellm.proxy.auth.user_api_key_auth import (
|
||||||
UserAPIKeyAuth,
|
UserAPIKeyAuth,
|
||||||
get_api_key_from_custom_header,
|
get_api_key_from_custom_header,
|
||||||
)
|
)
|
||||||
|
from fastapi import WebSocket, HTTPException, status
|
||||||
|
|
||||||
|
from litellm.proxy._types import LiteLLM_UserTable, LitellmUserRoles
|
||||||
|
|
||||||
|
|
||||||
class Request:
|
class Request:
|
||||||
|
@ -629,3 +632,165 @@ async def test_soft_budget_alert():
|
||||||
"budget_alerts",
|
"budget_alerts",
|
||||||
original_budget_alerts,
|
original_budget_alerts,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_is_allowed_route():
|
||||||
|
from litellm.proxy.auth.user_api_key_auth import _is_allowed_route
|
||||||
|
from litellm.proxy._types import UserAPIKeyAuth
|
||||||
|
import datetime
|
||||||
|
|
||||||
|
request = MagicMock()
|
||||||
|
|
||||||
|
args = {
|
||||||
|
"route": "/embeddings",
|
||||||
|
"token_type": "api",
|
||||||
|
"request": request,
|
||||||
|
"request_data": {"input": ["hello world"], "model": "embedding-small"},
|
||||||
|
"api_key": "9644159bc181998825c44c788b1526341ed2e825d1b6f562e23173759e14bb86",
|
||||||
|
"valid_token": UserAPIKeyAuth(
|
||||||
|
token="9644159bc181998825c44c788b1526341ed2e825d1b6f562e23173759e14bb86",
|
||||||
|
key_name="sk-...CJjQ",
|
||||||
|
key_alias=None,
|
||||||
|
spend=0.0,
|
||||||
|
max_budget=None,
|
||||||
|
expires=None,
|
||||||
|
models=[],
|
||||||
|
aliases={},
|
||||||
|
config={},
|
||||||
|
user_id=None,
|
||||||
|
team_id=None,
|
||||||
|
max_parallel_requests=None,
|
||||||
|
metadata={},
|
||||||
|
tpm_limit=None,
|
||||||
|
rpm_limit=None,
|
||||||
|
budget_duration=None,
|
||||||
|
budget_reset_at=None,
|
||||||
|
allowed_cache_controls=[],
|
||||||
|
permissions={},
|
||||||
|
model_spend={},
|
||||||
|
model_max_budget={},
|
||||||
|
soft_budget_cooldown=False,
|
||||||
|
blocked=None,
|
||||||
|
litellm_budget_table=None,
|
||||||
|
org_id=None,
|
||||||
|
created_at=MagicMock(),
|
||||||
|
updated_at=MagicMock(),
|
||||||
|
team_spend=None,
|
||||||
|
team_alias=None,
|
||||||
|
team_tpm_limit=None,
|
||||||
|
team_rpm_limit=None,
|
||||||
|
team_max_budget=None,
|
||||||
|
team_models=[],
|
||||||
|
team_blocked=False,
|
||||||
|
soft_budget=None,
|
||||||
|
team_model_aliases=None,
|
||||||
|
team_member_spend=None,
|
||||||
|
team_member=None,
|
||||||
|
team_metadata=None,
|
||||||
|
end_user_id=None,
|
||||||
|
end_user_tpm_limit=None,
|
||||||
|
end_user_rpm_limit=None,
|
||||||
|
end_user_max_budget=None,
|
||||||
|
last_refreshed_at=1736990277.432638,
|
||||||
|
api_key=None,
|
||||||
|
user_role=None,
|
||||||
|
allowed_model_region=None,
|
||||||
|
parent_otel_span=None,
|
||||||
|
rpm_limit_per_model=None,
|
||||||
|
tpm_limit_per_model=None,
|
||||||
|
user_tpm_limit=None,
|
||||||
|
user_rpm_limit=None,
|
||||||
|
),
|
||||||
|
"user_obj": None,
|
||||||
|
}
|
||||||
|
|
||||||
|
assert _is_allowed_route(**args)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"user_obj, expected_result",
|
||||||
|
[
|
||||||
|
(None, False), # Case 1: user_obj is None
|
||||||
|
(
|
||||||
|
LiteLLM_UserTable(
|
||||||
|
user_role=LitellmUserRoles.PROXY_ADMIN.value,
|
||||||
|
user_id="1234",
|
||||||
|
user_email="test@test.com",
|
||||||
|
max_budget=None,
|
||||||
|
spend=0.0,
|
||||||
|
),
|
||||||
|
True,
|
||||||
|
), # Case 2: user_role is PROXY_ADMIN
|
||||||
|
(
|
||||||
|
LiteLLM_UserTable(
|
||||||
|
user_role="OTHER_ROLE",
|
||||||
|
user_id="1234",
|
||||||
|
user_email="test@test.com",
|
||||||
|
max_budget=None,
|
||||||
|
spend=0.0,
|
||||||
|
),
|
||||||
|
False,
|
||||||
|
), # Case 3: user_role is not PROXY_ADMIN
|
||||||
|
],
|
||||||
|
)
|
||||||
|
def test_is_user_proxy_admin(user_obj, expected_result):
|
||||||
|
from litellm.proxy.auth.user_api_key_auth import _is_user_proxy_admin
|
||||||
|
|
||||||
|
assert _is_user_proxy_admin(user_obj) == expected_result
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"user_obj, expected_role",
|
||||||
|
[
|
||||||
|
(None, None), # Case 1: user_obj is None (should return None)
|
||||||
|
(
|
||||||
|
LiteLLM_UserTable(
|
||||||
|
user_role=LitellmUserRoles.PROXY_ADMIN.value,
|
||||||
|
user_id="1234",
|
||||||
|
user_email="test@test.com",
|
||||||
|
max_budget=None,
|
||||||
|
spend=0.0,
|
||||||
|
),
|
||||||
|
LitellmUserRoles.PROXY_ADMIN,
|
||||||
|
), # Case 2: user_role is PROXY_ADMIN (should return LitellmUserRoles.PROXY_ADMIN)
|
||||||
|
(
|
||||||
|
LiteLLM_UserTable(
|
||||||
|
user_role="OTHER_ROLE",
|
||||||
|
user_id="1234",
|
||||||
|
user_email="test@test.com",
|
||||||
|
max_budget=None,
|
||||||
|
spend=0.0,
|
||||||
|
),
|
||||||
|
LitellmUserRoles.INTERNAL_USER,
|
||||||
|
), # Case 3: invalid user_role (should return LitellmUserRoles.INTERNAL_USER)
|
||||||
|
],
|
||||||
|
)
|
||||||
|
def test_get_user_role(user_obj, expected_role):
|
||||||
|
from litellm.proxy.auth.user_api_key_auth import _get_user_role
|
||||||
|
|
||||||
|
assert _get_user_role(user_obj) == expected_role
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_user_api_key_auth_websocket():
|
||||||
|
from litellm.proxy.auth.user_api_key_auth import user_api_key_auth_websocket
|
||||||
|
|
||||||
|
# Prepare a mock WebSocket object
|
||||||
|
mock_websocket = MagicMock(spec=WebSocket)
|
||||||
|
mock_websocket.query_params = {"model": "some_model"}
|
||||||
|
mock_websocket.headers = {"authorization": "Bearer some_api_key"}
|
||||||
|
|
||||||
|
# Mock the return value of `user_api_key_auth` when it's called within the `user_api_key_auth_websocket` function
|
||||||
|
with patch(
|
||||||
|
"litellm.proxy.auth.user_api_key_auth.user_api_key_auth", autospec=True
|
||||||
|
) as mock_user_api_key_auth:
|
||||||
|
|
||||||
|
# Make the call to the WebSocket function
|
||||||
|
await user_api_key_auth_websocket(mock_websocket)
|
||||||
|
|
||||||
|
# Assert that `user_api_key_auth` was called with the correct parameters
|
||||||
|
mock_user_api_key_auth.assert_called_once()
|
||||||
|
|
||||||
|
assert (
|
||||||
|
mock_user_api_key_auth.call_args.kwargs["api_key"] == "Bearer some_api_key"
|
||||||
|
)
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue