diff --git a/litellm/proxy/auth/user_api_key_auth.py b/litellm/proxy/auth/user_api_key_auth.py index ae6d4774be..e9c56d7950 100644 --- a/litellm/proxy/auth/user_api_key_auth.py +++ b/litellm/proxy/auth/user_api_key_auth.py @@ -899,12 +899,10 @@ async def user_api_key_auth( # noqa: PLR0915 # the validation will occur when checking the team has access to this model pass else: - try: - data = await request.json() - except json.JSONDecodeError: - data = {} # Provide a default value, such as an empty dictionary - model = data.get("model", None) - fallback_models: Optional[List[str]] = data.get("fallbacks", None) + model = request_data.get("model", None) + fallback_models: Optional[List[str]] = request_data.get( + "fallbacks", None + ) if model is not None: await can_key_call_model( diff --git a/litellm/proxy/common_utils/http_parsing_utils.py b/litellm/proxy/common_utils/http_parsing_utils.py index deb259895d..36056d316d 100644 --- a/litellm/proxy/common_utils/http_parsing_utils.py +++ b/litellm/proxy/common_utils/http_parsing_utils.py @@ -21,19 +21,23 @@ async def _read_request_body(request: Optional[Request]) -> Dict: try: if request is None: return {} + _request_headers: dict = _safe_get_request_headers(request=request) + content_type = _request_headers.get("content-type", "") + if "form" in content_type: + return dict(await request.form()) + else: + # Read the request body + body = await request.body() - # Read the request body - body = await request.body() + # Return empty dict if body is empty or None + if not body: + return {} - # Return empty dict if body is empty or None - if not body: - return {} + # Decode the body to a string + body_str = body.decode() - # Decode the body to a string - body_str = body.decode() - - # Attempt JSON parsing (safe for untrusted input) - return json.loads(body_str) + # Attempt JSON parsing (safe for untrusted input) + return json.loads(body_str) except json.JSONDecodeError: # Log detailed information for debugging @@ -48,6 +52,21 @@ async def _read_request_body(request: Optional[Request]) -> Dict: return {} +def _safe_get_request_headers(request: Optional[Request]) -> dict: + """ + [Non-Blocking] Safely get the request headers + """ + try: + if request is None: + return {} + return dict(request.headers) + except Exception as e: + verbose_proxy_logger.exception( + "Unexpected error reading request headers - {}".format(e) + ) + return {} + + def check_file_size_under_limit( request_data: dict, file: UploadFile, diff --git a/tests/proxy_unit_tests/test_user_api_key_auth.py b/tests/proxy_unit_tests/test_user_api_key_auth.py index c547885c38..a5bc06db62 100644 --- a/tests/proxy_unit_tests/test_user_api_key_auth.py +++ b/tests/proxy_unit_tests/test_user_api_key_auth.py @@ -496,3 +496,57 @@ def test_read_request_body(): request.body = return_body result = _read_request_body(request) assert result is not None + + +@pytest.mark.asyncio +async def test_auth_with_form_data_and_model(): + """ + Test user_api_key_auth when: + 1. Request has form data instead of JSON body + 2. Virtual key has a model set + """ + from fastapi import Request + from starlette.datastructures import URL, FormData + from litellm.proxy.proxy_server import ( + hash_token, + user_api_key_cache, + user_api_key_auth, + ) + + # Setup + user_key = "sk-12345678" + + # Create a virtual key with a specific model + valid_token = UserAPIKeyAuth( + token=hash_token(user_key), + models=["gpt-4"], + ) + + # Store the virtual key in cache + user_api_key_cache.set_cache(key=hash_token(user_key), value=valid_token) + + setattr(litellm.proxy.proxy_server, "user_api_key_cache", user_api_key_cache) + setattr(litellm.proxy.proxy_server, "master_key", "sk-1234") + setattr(litellm.proxy.proxy_server, "prisma_client", "hello-world") + + # Create request with form data + request = Request( + scope={ + "type": "http", + "method": "POST", + "headers": [(b"content-type", b"application/x-www-form-urlencoded")], + } + ) + request._url = URL(url="/chat/completions") + + # Mock form data + form_data = FormData([("key1", "value1"), ("key2", "value2")]) + + async def return_form_data(): + return form_data + + request.form = return_form_data + + # Test user_api_key_auth with form data request + response = await user_api_key_auth(request=request, api_key="Bearer " + user_key) + assert response.models == ["gpt-4"], "Model from virtual key should be preserved"