From fbdd88d79c9d7e2855918f15c9b5d947afd3008c Mon Sep 17 00:00:00 2001 From: Krish Dholakia Date: Wed, 15 Jan 2025 21:52:45 -0800 Subject: [PATCH] =?UTF-8?q?test:=20initial=20test=20to=20enforce=20all=20f?= =?UTF-8?q?unctions=20in=20user=5Fapi=5Fkey=5Fauth.py=20h=E2=80=A6=20(#779?= =?UTF-8?q?7)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * test: initial test to enforce all functions in user_api_key_auth.py have direct testing * test(test_user_api_key_auth.py): add is_allowed_route unit test * test(test_user_api_key_auth.py): add more tests * test(test_user_api_key_auth.py): add complete testing coverage for all functions in `user_api_key_auth.py` * test(test_db_schema_changes.py): add a unit test to ensure all db schema changes are backwards compatible gives user an easy rollback path * test: fix schema compatibility test filepath * test: fix test --- litellm/proxy/_new_secret_config.yaml | 7 +- litellm/proxy/auth/user_api_key_auth.py | 1 + .../user_api_key_auth_code_coverage.py | 128 ++++++++++++++ .../test_amazing_vertex_completion.py | 6 +- .../test_db_schema_changes.py | 115 ++++++++++++ .../test_user_api_key_auth.py | 165 ++++++++++++++++++ 6 files changed, 417 insertions(+), 5 deletions(-) create mode 100644 tests/code_coverage_tests/user_api_key_auth_code_coverage.py create mode 100644 tests/proxy_unit_tests/test_db_schema_changes.py diff --git a/litellm/proxy/_new_secret_config.yaml b/litellm/proxy/_new_secret_config.yaml index 498e199261..5ea6dd6a6c 100644 --- a/litellm/proxy/_new_secret_config.yaml +++ b/litellm/proxy/_new_secret_config.yaml @@ -1,6 +1,5 @@ model_list: - - model_name: anthropic-vertex + - model_name: embedding-small litellm_params: - model: vertex_ai/claude-3-5-sonnet@20240620 - vertex_ai_project: "pathrise-convert-1606954137718" - vertex_ai_location: "europe-west1" \ No newline at end of file + model: openai/text-embedding-3-small + \ No newline at end of file diff --git a/litellm/proxy/auth/user_api_key_auth.py b/litellm/proxy/auth/user_api_key_auth.py index d2191946dc..b3c0b17ab4 100644 --- a/litellm/proxy/auth/user_api_key_auth.py +++ b/litellm/proxy/auth/user_api_key_auth.py @@ -155,6 +155,7 @@ def _is_allowed_route( """ - Route b/w ui token check and normal token check """ + if token_type == "ui" and _is_ui_route(route=route, user_obj=user_obj): return True else: diff --git a/tests/code_coverage_tests/user_api_key_auth_code_coverage.py b/tests/code_coverage_tests/user_api_key_auth_code_coverage.py new file mode 100644 index 0000000000..a9c2f8ef15 --- /dev/null +++ b/tests/code_coverage_tests/user_api_key_auth_code_coverage.py @@ -0,0 +1,128 @@ +""" +Enforce all functions in user_api_key_auth.py are covered by tests +""" + +import ast +import os + + +def get_function_names_from_file(file_path): + """ + Extracts all function names from a given Python file. + """ + with open(file_path, "r") as file: + tree = ast.parse(file.read()) + + function_names = [] + + for node in tree.body: + if isinstance(node, (ast.FunctionDef, ast.AsyncFunctionDef)): + # Top-level functions + function_names.append(node.name) + elif isinstance(node, ast.ClassDef): + # Functions inside classes + for class_node in node.body: + if isinstance(class_node, (ast.FunctionDef, ast.AsyncFunctionDef)): + function_names.append(class_node.name) + + return function_names + + +def get_all_functions_called_in_tests(base_dir): + """ + Returns a set of function names that are called in test functions + inside 'local_testing' and 'proxy_unit_tests' directories, + specifically in files containing the word 'router'. + """ + called_functions = set() + test_dirs = ["local_testing", "proxy_unit_tests"] + + for test_dir in test_dirs: + dir_path = os.path.join(base_dir, test_dir) + if not os.path.exists(dir_path): + print(f"Warning: Directory {dir_path} does not exist.") + continue + + print("dir_path: ", dir_path) + for root, _, files in os.walk(dir_path): + for file in files: + if file.endswith(".py"): + print("file: ", file) + file_path = os.path.join(root, file) + with open(file_path, "r") as f: + try: + tree = ast.parse(f.read()) + except SyntaxError: + print(f"Warning: Syntax error in file {file_path}") + continue + if file == "test_router_validate_fallbacks.py": + print(f"tree: {tree}") + for node in ast.walk(tree): + if isinstance(node, ast.Call) and isinstance( + node.func, ast.Name + ): + called_functions.add(node.func.id) + elif isinstance(node, ast.Call) and isinstance( + node.func, ast.Attribute + ): + called_functions.add(node.func.attr) + + return called_functions + + +def get_functions_from_router(file_path): + """ + Extracts all functions defined in user_api_key_auth.py. + """ + return get_function_names_from_file(file_path) + + +ignored_function_names = [ + "__init__", +] + + +def main(): + # router_file = [ + # "./litellm/user_api_key_auth.py", + # ] + router_file = [ + "../../litellm/proxy/auth/user_api_key_auth.py", + ] + # router_file = [ + # "../../litellm/router.py", + # "../../litellm/router_utils/pattern_match_deployments.py", + # "../../litellm/router_utils/batch_utils.py", + # ] ## LOCAL TESTING + # tests_dir = ( + # "./tests/" # Update this path if your tests directory is located elsewhere + # ) + tests_dir = "../../tests/" # LOCAL TESTING + + router_functions = [] + for file in router_file: + router_functions.extend(get_functions_from_router(file)) + print("router_functions: ", router_functions) + called_functions_in_tests = get_all_functions_called_in_tests(tests_dir) + untested_functions = [ + fn for fn in router_functions if fn not in called_functions_in_tests + ] + + if untested_functions: + all_untested_functions = [] + for func in untested_functions: + if func not in ignored_function_names: + all_untested_functions.append(func) + untested_perc = (len(all_untested_functions)) / len(router_functions) + print("untested_perc: ", untested_perc) + if untested_perc > 0: + print("The following functions in user_api_key_auth.py are not tested:") + raise Exception( + f"{untested_perc * 100:.2f}% of functions in user_api_key_auth.py are not tested: {all_untested_functions}" + ) + else: + print("All functions in user_api_key_auth.py are covered by tests.") + + +if __name__ == "__main__": + main() diff --git a/tests/local_testing/test_amazing_vertex_completion.py b/tests/local_testing/test_amazing_vertex_completion.py index 08afb09a18..2c12900f30 100644 --- a/tests/local_testing/test_amazing_vertex_completion.py +++ b/tests/local_testing/test_amazing_vertex_completion.py @@ -444,7 +444,11 @@ async def test_async_vertexai_response(): f"model being tested in async call: {model}, litellm.vertex_language_models: {litellm.vertex_language_models}" ) if model in VERTEX_MODELS_TO_NOT_TEST or ( - "gecko" in model or "32k" in model or "ultra" in model or "002" in model + "gecko" in model + or "32k" in model + or "ultra" in model + or "002" in model + or "gemini-2.0-flash-thinking-exp" == model ): # our account does not have access to this model continue diff --git a/tests/proxy_unit_tests/test_db_schema_changes.py b/tests/proxy_unit_tests/test_db_schema_changes.py new file mode 100644 index 0000000000..c7685763a2 --- /dev/null +++ b/tests/proxy_unit_tests/test_db_schema_changes.py @@ -0,0 +1,115 @@ +import pytest +import subprocess +import re +from typing import Dict, List, Set + + +def get_schema_from_branch(branch: str = "main") -> str: + """Get schema from specified git branch""" + result = subprocess.run( + ["git", "show", f"{branch}:schema.prisma"], capture_output=True, text=True + ) + return result.stdout + + +def parse_model_fields(schema: str) -> Dict[str, Dict[str, str]]: + """Parse Prisma schema into dict of models and their fields""" + models = {} + current_model = None + + for line in schema.split("\n"): + line = line.strip() + + # Find model definition + if line.startswith("model "): + current_model = line.split(" ")[1] + models[current_model] = {} + continue + + # Inside model definition + if current_model and line and not line.startswith("}"): + # Split field definition into name and type + parts = line.split() + if len(parts) >= 2: + field_name = parts[0] + field_type = " ".join(parts[1:]) + models[current_model][field_name] = field_type + + # End of model definition + if line.startswith("}"): + current_model = None + + return models + + +def check_breaking_changes( + old_schema: Dict[str, Dict[str, str]], new_schema: Dict[str, Dict[str, str]] +) -> List[str]: + """Check for breaking changes between schemas""" + breaking_changes = [] + + # Check each model in old schema + for model_name, old_fields in old_schema.items(): + if model_name not in new_schema: + breaking_changes.append(f"Breaking: Model {model_name} was removed") + continue + + new_fields = new_schema[model_name] + + # Check each field in old model + for field_name, old_type in old_fields.items(): + if field_name not in new_fields: + breaking_changes.append( + f"Breaking: Field {model_name}.{field_name} was removed" + ) + continue + + new_type = new_fields[field_name] + + # Check for type changes + if old_type != new_type: + # Check specific type changes that are breaking + if "?" in old_type and "?" not in new_type: + breaking_changes.append( + f"Breaking: Field {model_name}.{field_name} changed from optional to required" + ) + if not old_type.startswith(new_type.split("?")[0]): + breaking_changes.append( + f"Breaking: Field {model_name}.{field_name} changed type from {old_type} to {new_type}" + ) + + return breaking_changes + + +def test_aaaaaschema_compatibility(): + """Test if current schema has breaking changes compared to main""" + import os + + print("Current directory:", os.getcwd()) + + # Get schemas + old_schema = get_schema_from_branch("main") + with open("./schema.prisma", "r") as f: + new_schema = f.read() + + # Parse schemas + old_models = parse_model_fields(old_schema) + new_models = parse_model_fields(new_schema) + + # Check for breaking changes + breaking_changes = check_breaking_changes(old_models, new_models) + + # Fail if breaking changes found + if breaking_changes: + pytest.fail("\n".join(breaking_changes)) + + # Print informational diff + print("\nNon-breaking changes detected:") + for model_name, new_fields in new_models.items(): + if model_name not in old_models: + print(f"Added new model: {model_name}") + continue + + for field_name, new_type in new_fields.items(): + if field_name not in old_models[model_name]: + print(f"Added new field: {model_name}.{field_name}") diff --git a/tests/proxy_unit_tests/test_user_api_key_auth.py b/tests/proxy_unit_tests/test_user_api_key_auth.py index 2a77f27c8e..d1a25dec16 100644 --- a/tests/proxy_unit_tests/test_user_api_key_auth.py +++ b/tests/proxy_unit_tests/test_user_api_key_auth.py @@ -20,6 +20,9 @@ from litellm.proxy.auth.user_api_key_auth import ( UserAPIKeyAuth, get_api_key_from_custom_header, ) +from fastapi import WebSocket, HTTPException, status + +from litellm.proxy._types import LiteLLM_UserTable, LitellmUserRoles class Request: @@ -629,3 +632,165 @@ async def test_soft_budget_alert(): "budget_alerts", original_budget_alerts, ) + + +def test_is_allowed_route(): + from litellm.proxy.auth.user_api_key_auth import _is_allowed_route + from litellm.proxy._types import UserAPIKeyAuth + import datetime + + request = MagicMock() + + args = { + "route": "/embeddings", + "token_type": "api", + "request": request, + "request_data": {"input": ["hello world"], "model": "embedding-small"}, + "api_key": "9644159bc181998825c44c788b1526341ed2e825d1b6f562e23173759e14bb86", + "valid_token": UserAPIKeyAuth( + token="9644159bc181998825c44c788b1526341ed2e825d1b6f562e23173759e14bb86", + key_name="sk-...CJjQ", + key_alias=None, + spend=0.0, + max_budget=None, + expires=None, + models=[], + aliases={}, + config={}, + user_id=None, + team_id=None, + max_parallel_requests=None, + metadata={}, + tpm_limit=None, + rpm_limit=None, + budget_duration=None, + budget_reset_at=None, + allowed_cache_controls=[], + permissions={}, + model_spend={}, + model_max_budget={}, + soft_budget_cooldown=False, + blocked=None, + litellm_budget_table=None, + org_id=None, + created_at=MagicMock(), + updated_at=MagicMock(), + team_spend=None, + team_alias=None, + team_tpm_limit=None, + team_rpm_limit=None, + team_max_budget=None, + team_models=[], + team_blocked=False, + soft_budget=None, + team_model_aliases=None, + team_member_spend=None, + team_member=None, + team_metadata=None, + end_user_id=None, + end_user_tpm_limit=None, + end_user_rpm_limit=None, + end_user_max_budget=None, + last_refreshed_at=1736990277.432638, + api_key=None, + user_role=None, + allowed_model_region=None, + parent_otel_span=None, + rpm_limit_per_model=None, + tpm_limit_per_model=None, + user_tpm_limit=None, + user_rpm_limit=None, + ), + "user_obj": None, + } + + assert _is_allowed_route(**args) + + +@pytest.mark.parametrize( + "user_obj, expected_result", + [ + (None, False), # Case 1: user_obj is None + ( + LiteLLM_UserTable( + user_role=LitellmUserRoles.PROXY_ADMIN.value, + user_id="1234", + user_email="test@test.com", + max_budget=None, + spend=0.0, + ), + True, + ), # Case 2: user_role is PROXY_ADMIN + ( + LiteLLM_UserTable( + user_role="OTHER_ROLE", + user_id="1234", + user_email="test@test.com", + max_budget=None, + spend=0.0, + ), + False, + ), # Case 3: user_role is not PROXY_ADMIN + ], +) +def test_is_user_proxy_admin(user_obj, expected_result): + from litellm.proxy.auth.user_api_key_auth import _is_user_proxy_admin + + assert _is_user_proxy_admin(user_obj) == expected_result + + +@pytest.mark.parametrize( + "user_obj, expected_role", + [ + (None, None), # Case 1: user_obj is None (should return None) + ( + LiteLLM_UserTable( + user_role=LitellmUserRoles.PROXY_ADMIN.value, + user_id="1234", + user_email="test@test.com", + max_budget=None, + spend=0.0, + ), + LitellmUserRoles.PROXY_ADMIN, + ), # Case 2: user_role is PROXY_ADMIN (should return LitellmUserRoles.PROXY_ADMIN) + ( + LiteLLM_UserTable( + user_role="OTHER_ROLE", + user_id="1234", + user_email="test@test.com", + max_budget=None, + spend=0.0, + ), + LitellmUserRoles.INTERNAL_USER, + ), # Case 3: invalid user_role (should return LitellmUserRoles.INTERNAL_USER) + ], +) +def test_get_user_role(user_obj, expected_role): + from litellm.proxy.auth.user_api_key_auth import _get_user_role + + assert _get_user_role(user_obj) == expected_role + + +@pytest.mark.asyncio +async def test_user_api_key_auth_websocket(): + from litellm.proxy.auth.user_api_key_auth import user_api_key_auth_websocket + + # Prepare a mock WebSocket object + mock_websocket = MagicMock(spec=WebSocket) + mock_websocket.query_params = {"model": "some_model"} + mock_websocket.headers = {"authorization": "Bearer some_api_key"} + + # Mock the return value of `user_api_key_auth` when it's called within the `user_api_key_auth_websocket` function + with patch( + "litellm.proxy.auth.user_api_key_auth.user_api_key_auth", autospec=True + ) as mock_user_api_key_auth: + + # Make the call to the WebSocket function + await user_api_key_auth_websocket(mock_websocket) + + # Assert that `user_api_key_auth` was called with the correct parameters + mock_user_api_key_auth.assert_called_once() + + assert ( + mock_user_api_key_auth.call_args.kwargs["api_key"] == "Bearer some_api_key" + )