diff --git a/litellm/proxy/auth/user_api_key_auth.py b/litellm/proxy/auth/user_api_key_auth.py index 9f9a88d38a..e9f2fc6d0c 100644 --- a/litellm/proxy/auth/user_api_key_auth.py +++ b/litellm/proxy/auth/user_api_key_auth.py @@ -1373,15 +1373,27 @@ def _get_user_role( def get_api_key_from_custom_header( request: Request, custom_litellm_key_header_name: str -): +) -> str: + """ + Get API key from custom header + + Args: + request (Request): Request object + custom_litellm_key_header_name (str): Custom header name + + Returns: + Optional[str]: API key + """ + api_key: str = "" # use this as the virtual key passed to litellm proxy custom_litellm_key_header_name = custom_litellm_key_header_name.lower() + _headers = {k.lower(): v for k, v in request.headers.items()} verbose_proxy_logger.debug( "searching for custom_litellm_key_header_name= %s, in headers=%s", custom_litellm_key_header_name, - request.headers, + _headers, ) - custom_api_key = request.headers.get(custom_litellm_key_header_name) + custom_api_key = _headers.get(custom_litellm_key_header_name) if custom_api_key: api_key = _get_bearer_token(api_key=custom_api_key) verbose_proxy_logger.debug( @@ -1390,7 +1402,7 @@ def get_api_key_from_custom_header( ) ) else: - raise ValueError( + verbose_proxy_logger.exception( f"No LiteLLM Virtual Key pass. Please set header={custom_litellm_key_header_name}: Bearer " ) return api_key diff --git a/tests/local_testing/test_user_api_key_auth.py b/tests/local_testing/test_user_api_key_auth.py deleted file mode 100644 index 167809da16..0000000000 --- a/tests/local_testing/test_user_api_key_auth.py +++ /dev/null @@ -1,432 +0,0 @@ -# What is this? -## Unit tests for user_api_key_auth helper functions - -import os -import sys - -sys.path.insert( - 0, os.path.abspath("../..") -) # Adds the parent directory to the system path -from typing import Dict, List, Optional -from unittest.mock import MagicMock, patch, AsyncMock - -import pytest -from starlette.datastructures import URL - -import litellm -from litellm.proxy.auth.user_api_key_auth import user_api_key_auth - - -class Request: - def __init__(self, client_ip: Optional[str] = None, headers: Optional[dict] = None): - self.client = MagicMock() - self.client.host = client_ip - self.headers: Dict[str, str] = {} - - -@pytest.mark.parametrize( - "allowed_ips, client_ip, expected_result", - [ - (None, "127.0.0.1", True), # No IP restrictions, should be allowed - (["127.0.0.1"], "127.0.0.1", True), # IP in allowed list - (["192.168.1.1"], "127.0.0.1", False), # IP not in allowed list - ([], "127.0.0.1", False), # Empty allowed list, no IP should be allowed - (["192.168.1.1", "10.0.0.1"], "10.0.0.1", True), # IP in allowed list - ( - ["192.168.1.1"], - None, - False, - ), # Request with no client IP should not be allowed - ], -) -def test_check_valid_ip( - allowed_ips: Optional[List[str]], client_ip: Optional[str], expected_result: bool -): - from litellm.proxy.auth.auth_utils import _check_valid_ip - - request = Request(client_ip) - - assert _check_valid_ip(allowed_ips, request)[0] == expected_result # type: ignore - - -# test x-forwarder for is used when user has opted in - - -@pytest.mark.parametrize( - "allowed_ips, client_ip, expected_result", - [ - (None, "127.0.0.1", True), # No IP restrictions, should be allowed - (["127.0.0.1"], "127.0.0.1", True), # IP in allowed list - (["192.168.1.1"], "127.0.0.1", False), # IP not in allowed list - ([], "127.0.0.1", False), # Empty allowed list, no IP should be allowed - (["192.168.1.1", "10.0.0.1"], "10.0.0.1", True), # IP in allowed list - ( - ["192.168.1.1"], - None, - False, - ), # Request with no client IP should not be allowed - ], -) -def test_check_valid_ip_sent_with_x_forwarded_for( - allowed_ips: Optional[List[str]], client_ip: Optional[str], expected_result: bool -): - from litellm.proxy.auth.auth_utils import _check_valid_ip - - request = Request(client_ip, headers={"X-Forwarded-For": client_ip}) - - assert _check_valid_ip(allowed_ips, request, use_x_forwarded_for=True)[0] == expected_result # type: ignore - - -@pytest.mark.asyncio -async def test_check_blocked_team(): - """ - cached valid_token obj has team_blocked = true - - cached team obj has team_blocked = false - - assert team is not blocked - """ - import asyncio - import time - - from fastapi import Request - from starlette.datastructures import URL - - from litellm.proxy._types import ( - LiteLLM_TeamTable, - LiteLLM_TeamTableCachedObj, - UserAPIKeyAuth, - ) - from litellm.proxy.auth.user_api_key_auth import user_api_key_auth - from litellm.proxy.proxy_server import hash_token, user_api_key_cache - - _team_id = "1234" - user_key = "sk-12345678" - - valid_token = UserAPIKeyAuth( - team_id=_team_id, - team_blocked=True, - token=hash_token(user_key), - last_refreshed_at=time.time(), - ) - await asyncio.sleep(1) - team_obj = LiteLLM_TeamTableCachedObj( - team_id=_team_id, blocked=False, last_refreshed_at=time.time() - ) - hashed_token = hash_token(user_key) - print(f"STORING TOKEN UNDER KEY={hashed_token}") - user_api_key_cache.set_cache(key=hashed_token, value=valid_token) - user_api_key_cache.set_cache(key="team_id:{}".format(_team_id), value=team_obj) - - setattr(litellm.proxy.proxy_server, "user_api_key_cache", user_api_key_cache) - setattr(litellm.proxy.proxy_server, "master_key", "sk-1234") - setattr(litellm.proxy.proxy_server, "prisma_client", "hello-world") - - request = Request(scope={"type": "http"}) - request._url = URL(url="/chat/completions") - - await user_api_key_auth(request=request, api_key="Bearer " + user_key) - - -@pytest.mark.parametrize( - "user_role, expected_role", - [ - ("app_user", "internal_user"), - ("internal_user", "internal_user"), - ("proxy_admin_viewer", "proxy_admin_viewer"), - ], -) -def test_returned_user_api_key_auth(user_role, expected_role): - from litellm.proxy._types import LiteLLM_UserTable, LitellmUserRoles - from litellm.proxy.auth.user_api_key_auth import _return_user_api_key_auth_obj - from datetime import datetime - - new_obj = _return_user_api_key_auth_obj( - user_obj=LiteLLM_UserTable( - user_role=user_role, user_id="", max_budget=None, user_email="" - ), - api_key="hello-world", - parent_otel_span=None, - valid_token_dict={}, - route="/chat/completion", - start_time=datetime.now(), - ) - - assert new_obj.user_role == expected_role - - -@pytest.mark.parametrize("key_ownership", ["user_key", "team_key"]) -@pytest.mark.asyncio -async def test_aaauser_personal_budgets(key_ownership): - """ - Set a personal budget on a user - - - have it only apply when key belongs to user -> raises BudgetExceededError - - if key belongs to team, have key respect team budget -> allows call to go through - """ - import asyncio - import time - - from fastapi import Request - from starlette.datastructures import URL - import litellm - - from litellm.proxy._types import LiteLLM_UserTable, UserAPIKeyAuth - from litellm.proxy.auth.user_api_key_auth import user_api_key_auth - from litellm.proxy.proxy_server import hash_token, user_api_key_cache - - _user_id = "1234" - user_key = "sk-12345678" - - if key_ownership == "user_key": - valid_token = UserAPIKeyAuth( - token=hash_token(user_key), - last_refreshed_at=time.time(), - user_id=_user_id, - spend=20, - ) - elif key_ownership == "team_key": - valid_token = UserAPIKeyAuth( - token=hash_token(user_key), - last_refreshed_at=time.time(), - user_id=_user_id, - team_id="my-special-team", - team_max_budget=100, - spend=20, - ) - - user_obj = LiteLLM_UserTable( - user_id=_user_id, spend=11, max_budget=10, user_email="" - ) - user_api_key_cache.set_cache(key=hash_token(user_key), value=valid_token) - user_api_key_cache.set_cache(key="{}".format(_user_id), value=user_obj) - - setattr(litellm.proxy.proxy_server, "user_api_key_cache", user_api_key_cache) - setattr(litellm.proxy.proxy_server, "master_key", "sk-1234") - setattr(litellm.proxy.proxy_server, "prisma_client", "hello-world") - - request = Request(scope={"type": "http"}) - request._url = URL(url="/chat/completions") - - test_user_cache = getattr(litellm.proxy.proxy_server, "user_api_key_cache") - - assert test_user_cache.get_cache(key=hash_token(user_key)) == valid_token - - try: - await user_api_key_auth(request=request, api_key="Bearer " + user_key) - - if key_ownership == "user_key": - pytest.fail("Expected this call to fail. User is over limit.") - except Exception: - if key_ownership == "team_key": - pytest.fail("Expected this call to work. Key is below team budget.") - - -@pytest.mark.asyncio -@pytest.mark.parametrize("prohibited_param", ["api_base", "base_url"]) -async def test_user_api_key_auth_fails_with_prohibited_params(prohibited_param): - """ - Relevant issue: https://huntr.com/bounties/4001e1a2-7b7a-4776-a3ae-e6692ec3d997 - """ - import json - - from fastapi import Request - - # Setup - user_key = "sk-1234" - - setattr(litellm.proxy.proxy_server, "master_key", "sk-1234") - - # Create request with prohibited parameter in body - request = Request(scope={"type": "http"}) - request._url = URL(url="/chat/completions") - - async def return_body(): - body = {prohibited_param: "https://custom-api.com"} - return bytes(json.dumps(body), "utf-8") - - request.body = return_body - try: - response = await user_api_key_auth( - request=request, api_key="Bearer " + user_key - ) - except Exception as e: - print("error str=", str(e)) - error_message = str(e.message) - print("error message=", error_message) - assert "is not allowed in request body" in error_message - - -@pytest.mark.asyncio() -@pytest.mark.parametrize( - "route, should_raise_error", - [ - ("/embeddings", False), - ("/chat/completions", True), - ("/completions", True), - ("/models", True), - ("/v1/embeddings", True), - ], -) -async def test_auth_with_allowed_routes(route, should_raise_error): - # Setup - user_key = "sk-1234" - - general_settings = {"allowed_routes": ["/embeddings"]} - from fastapi import Request - - from litellm.proxy import proxy_server - - initial_general_settings = getattr(proxy_server, "general_settings") - - setattr(proxy_server, "master_key", "sk-1234") - setattr(proxy_server, "general_settings", general_settings) - - request = Request(scope={"type": "http"}) - request._url = URL(url=route) - - if should_raise_error: - try: - await user_api_key_auth(request=request, api_key="Bearer " + user_key) - pytest.fail("Expected this call to fail. User is over limit.") - except Exception as e: - print("error str=", str(e.message)) - error_str = str(e.message) - assert "Route" in error_str and "not allowed" in error_str - pass - else: - await user_api_key_auth(request=request, api_key="Bearer " + user_key) - - setattr(proxy_server, "general_settings", initial_general_settings) - - -@pytest.mark.parametrize( - "route, user_role, expected_result", - [ - # Proxy Admin checks - ("/global/spend/logs", "proxy_admin", True), - ("/key/delete", "proxy_admin", False), - ("/key/generate", "proxy_admin", False), - ("/key/regenerate", "proxy_admin", False), - # Internal User checks - allowed routes - ("/global/spend/logs", "internal_user", True), - ("/key/delete", "internal_user", False), - ("/key/generate", "internal_user", False), - ("/key/82akk800000000jjsk/regenerate", "internal_user", False), - # Internal User Viewer - ("/key/generate", "internal_user_viewer", False), - # Internal User checks - disallowed routes - ("/organization/member_add", "internal_user", False), - ], -) -def test_is_ui_route_allowed(route, user_role, expected_result): - from litellm.proxy.auth.user_api_key_auth import _is_ui_route - from litellm.proxy._types import LiteLLM_UserTable - - user_obj = LiteLLM_UserTable( - user_id="3b803c0e-666e-4e99-bd5c-6e534c07e297", - max_budget=None, - spend=0.0, - model_max_budget={}, - model_spend={}, - user_email="my-test-email@1234.com", - models=[], - tpm_limit=None, - rpm_limit=None, - user_role=user_role, - organization_memberships=[], - ) - - received_args: dict = { - "route": route, - "user_obj": user_obj, - } - try: - assert _is_ui_route(**received_args) == expected_result - except Exception as e: - # If expected result is False, we expect an error - if expected_result is False: - pass - else: - raise e - - -@pytest.mark.parametrize( - "route, user_role, expected_result", - [ - ("/key/generate", "internal_user_viewer", False), - ], -) -def test_is_api_route_allowed(route, user_role, expected_result): - from litellm.proxy.auth.user_api_key_auth import _is_api_route_allowed - from litellm.proxy._types import LiteLLM_UserTable - - user_obj = LiteLLM_UserTable( - user_id="3b803c0e-666e-4e99-bd5c-6e534c07e297", - max_budget=None, - spend=0.0, - model_max_budget={}, - model_spend={}, - user_email="my-test-email@1234.com", - models=[], - tpm_limit=None, - rpm_limit=None, - user_role=user_role, - organization_memberships=[], - ) - - received_args: dict = { - "route": route, - "user_obj": user_obj, - } - try: - assert _is_api_route_allowed(**received_args) == expected_result - except Exception as e: - # If expected result is False, we expect an error - if expected_result is False: - pass - else: - raise e - - -from litellm.proxy._types import LitellmUserRoles - - -@pytest.mark.parametrize( - "user_role, auth_user_id, requested_user_id, expected_result", - [ - (LitellmUserRoles.PROXY_ADMIN, "1234", None, True), - (LitellmUserRoles.PROXY_ADMIN_VIEW_ONLY, None, "1234", True), - (LitellmUserRoles.TEAM, "1234", None, False), - (LitellmUserRoles.TEAM, None, None, False), - (LitellmUserRoles.TEAM, "1234", "1234", True), - ], -) -def test_allowed_route_inside_route( - user_role, auth_user_id, requested_user_id, expected_result -): - from litellm.proxy.auth.auth_checks import allowed_route_check_inside_route - from litellm.proxy._types import UserAPIKeyAuth, LitellmUserRoles - - assert ( - allowed_route_check_inside_route( - user_api_key_dict=UserAPIKeyAuth(user_role=user_role, user_id=auth_user_id), - requested_user_id=requested_user_id, - ) - == expected_result - ) - - -def test_read_request_body(): - from litellm.proxy.common_utils.http_parsing_utils import _read_request_body - from fastapi import Request - - payload = "()" * 1000000 - request = Request(scope={"type": "http"}) - - async def return_body(): - return payload - - request.body = return_body - result = _read_request_body(request) - assert result is not None diff --git a/tests/proxy_unit_tests/test_key_generate_prisma.py b/tests/proxy_unit_tests/test_key_generate_prisma.py index 2c8ba5b2ab..73a9c4bd58 100644 --- a/tests/proxy_unit_tests/test_key_generate_prisma.py +++ b/tests/proxy_unit_tests/test_key_generate_prisma.py @@ -2811,8 +2811,7 @@ async def test_custom_api_key_header_name(prisma_client): except Exception as e: print("failed with error", e) assert ( - "No LiteLLM Virtual Key pass. Please set header=x-litellm-key: Bearer " - in e.message + "Malformed API Key passed in. Ensure Key has `Bearer ` prefix" in e.message ) pass diff --git a/tests/proxy_unit_tests/test_user_api_key_auth.py b/tests/proxy_unit_tests/test_user_api_key_auth.py index 9cd52e591c..c547885c38 100644 --- a/tests/proxy_unit_tests/test_user_api_key_auth.py +++ b/tests/proxy_unit_tests/test_user_api_key_auth.py @@ -12,9 +12,14 @@ from unittest.mock import MagicMock, patch, AsyncMock import pytest from starlette.datastructures import URL - +from litellm._logging import verbose_proxy_logger +import logging import litellm -from litellm.proxy.auth.user_api_key_auth import user_api_key_auth, UserAPIKeyAuth +from litellm.proxy.auth.user_api_key_auth import ( + user_api_key_auth, + UserAPIKeyAuth, + get_api_key_from_custom_header, +) class Request: @@ -418,3 +423,76 @@ async def test_auth_not_connected_to_db(): print("got valid token", valid_token) assert valid_token.key_name == "failed-to-connect-to-db" assert valid_token.token == "failed-to-connect-to-db" + + +@pytest.mark.parametrize( + "headers, custom_header_name, expected_api_key", + [ + # Test with valid Bearer token + ({"x-custom-api-key": "Bearer sk-12345678"}, "x-custom-api-key", "sk-12345678"), + # Test with raw token (no Bearer prefix) + ({"x-custom-api-key": "Bearer sk-12345678"}, "x-custom-api-key", "sk-12345678"), + # Test with empty header value + ({"x-custom-api-key": ""}, "x-custom-api-key", ""), + # Test with missing header + ({}, "X-Custom-API-Key", ""), + # Test with different header casing + ({"X-CUSTOM-API-KEY": "Bearer sk-12345678"}, "X-Custom-API-Key", "sk-12345678"), + ], +) +def test_get_api_key_from_custom_header(headers, custom_header_name, expected_api_key): + verbose_proxy_logger.setLevel(logging.DEBUG) + + # Mock the Request object + request = MagicMock(spec=Request) + request.headers = headers + + # Call the function and verify it doesn't raise an exception + + api_key = get_api_key_from_custom_header( + request=request, custom_litellm_key_header_name=custom_header_name + ) + assert api_key == expected_api_key + + +from litellm.proxy._types import LitellmUserRoles + + +@pytest.mark.parametrize( + "user_role, auth_user_id, requested_user_id, expected_result", + [ + (LitellmUserRoles.PROXY_ADMIN, "1234", None, True), + (LitellmUserRoles.PROXY_ADMIN_VIEW_ONLY, None, "1234", True), + (LitellmUserRoles.TEAM, "1234", None, False), + (LitellmUserRoles.TEAM, None, None, False), + (LitellmUserRoles.TEAM, "1234", "1234", True), + ], +) +def test_allowed_route_inside_route( + user_role, auth_user_id, requested_user_id, expected_result +): + from litellm.proxy.auth.auth_checks import allowed_route_check_inside_route + from litellm.proxy._types import UserAPIKeyAuth, LitellmUserRoles + + assert ( + allowed_route_check_inside_route( + user_api_key_dict=UserAPIKeyAuth(user_role=user_role, user_id=auth_user_id), + requested_user_id=requested_user_id, + ) + == expected_result + ) + + +def test_read_request_body(): + from litellm.proxy.common_utils.http_parsing_utils import _read_request_body + from fastapi import Request + + payload = "()" * 1000000 + request = Request(scope={"type": "http"}) + + async def return_body(): + return payload + + request.body = return_body + result = _read_request_body(request) + assert result is not None