From 37dd1114f2e86699cad606131d86b616ffb2133e Mon Sep 17 00:00:00 2001 From: ishaan-jaff Date: Tue, 23 Jan 2024 17:30:01 -0800 Subject: [PATCH 1/4] (test) update prisma test --- litellm/tests/test_key_generate_prisma.py | 30 +++++++++++------------ 1 file changed, 15 insertions(+), 15 deletions(-) diff --git a/litellm/tests/test_key_generate_prisma.py b/litellm/tests/test_key_generate_prisma.py index f50df4570..c62d41cca 100644 --- a/litellm/tests/test_key_generate_prisma.py +++ b/litellm/tests/test_key_generate_prisma.py @@ -136,8 +136,8 @@ def test_call_with_invalid_key(prisma_client): asyncio.run(test()) except Exception as e: print("Got Exception", e) - print(e.detail) - assert "Authentication Error" in e.detail + print(e.message) + assert "Authentication Error" in e.message pass @@ -171,7 +171,7 @@ def test_call_with_invalid_model(prisma_client): asyncio.run(test()) except Exception as e: assert ( - e.detail + e.message == "Authentication Error, API Key not allowed to access model. This token can only access models=['mistral']. Tried to access gemini-pro-vision" ) pass @@ -274,7 +274,7 @@ def test_call_with_user_over_budget(prisma_client): asyncio.run(test()) except Exception as e: - error_detail = e.detail + error_detail = e.message assert "Authentication Error, ExceededBudget:" in error_detail print(vars(e)) @@ -350,7 +350,7 @@ def test_call_with_user_over_budget_stream(prisma_client): asyncio.run(test()) except Exception as e: - error_detail = e.detail + error_detail = e.message assert "Authentication Error, ExceededBudget:" in error_detail print(vars(e)) @@ -414,8 +414,8 @@ def test_generate_and_call_with_expired_key(prisma_client): asyncio.run(test()) except Exception as e: print("Got Exception", e) - print(e.detail) - assert "Authentication Error" in e.detail + print(e.message) + assert "Authentication Error" in e.message pass @@ -486,8 +486,8 @@ def test_delete_key_auth(prisma_client): asyncio.run(test()) except Exception as e: print("Got Exception", e) - print(e.detail) - assert "Authentication Error" in e.detail + print(e.message) + assert "Authentication Error" in e.message pass @@ -599,7 +599,7 @@ def test_generate_and_update_key(prisma_client): asyncio.run(test()) except Exception as e: print("Got Exception", e) - print(e.detail) + print(e.message) pytest.fail(f"An exception occurred - {str(e)}") @@ -665,11 +665,11 @@ def test_key_generate_with_custom_auth(prisma_client): except Exception as e: # this should fail print("Got Exception", e) - print(e.detail) + print(e.message) print("First request failed!. This is expected") assert ( "This violates LiteLLM Proxy Rules. No team id provided." - in e.detail + in e.message ) request_2 = GenerateKeyRequest( @@ -683,7 +683,7 @@ def test_key_generate_with_custom_auth(prisma_client): asyncio.run(test()) except Exception as e: print("Got Exception", e) - print(e.detail) + print(e.message) pytest.fail(f"An exception occurred - {str(e)}") @@ -752,7 +752,7 @@ def test_call_with_key_over_budget(prisma_client): asyncio.run(test()) except Exception as e: - error_detail = e.detail + error_detail = e.message assert "Authentication Error, ExceededTokenBudget:" in error_detail print(vars(e)) @@ -827,6 +827,6 @@ def test_call_with_key_over_budget_stream(prisma_client): pytest.fail(f"This should have failed!. They key crossed it's budget") except Exception as e: - error_detail = e.detail + error_detail = e.message assert "Authentication Error, ExceededTokenBudget:" in error_detail print(vars(e)) From feb367bbb974a3ac4797125102a45bafb509e099 Mon Sep 17 00:00:00 2001 From: ishaan-jaff Date: Tue, 23 Jan 2024 17:32:47 -0800 Subject: [PATCH 2/4] (feat) all endpoints raise OpenAI compatible exceptions --- litellm/proxy/proxy_server.py | 210 ++++++++++++++++++++++++++-------- 1 file changed, 163 insertions(+), 47 deletions(-) diff --git a/litellm/proxy/proxy_server.py b/litellm/proxy/proxy_server.py index a1790f49c..9500b663c 100644 --- a/litellm/proxy/proxy_server.py +++ b/litellm/proxy/proxy_server.py @@ -483,12 +483,20 @@ async def user_api_key_auth( # verbose_proxy_logger.debug(f"An exception occurred - {traceback.format_exc()}") traceback.print_exc() if isinstance(e, HTTPException): - raise e - else: - raise HTTPException( - status_code=status.HTTP_401_UNAUTHORIZED, - detail=f"Authentication Error, {str(e)}", + raise ProxyException( + message=getattr(e, "detail", f"Authentication Error({str(e)})"), + type="auth_error", + param=getattr(e, "param", "None"), + code=getattr(e, "status_code", status.HTTP_401_UNAUTHORIZED), ) + elif isinstance(e, ProxyException): + raise e + raise ProxyException( + message="Authentication Error, " + str(e), + type="auth_error", + param=getattr(e, "param", "None"), + code=status.HTTP_401_UNAUTHORIZED, + ) def prisma_setup(database_url: Optional[str]): @@ -2194,26 +2202,47 @@ async def generate_key_fn( - expires: (datetime) Datetime object for when key expires. - user_id: (str) Unique user id - used for tracking spend across multiple keys for same user id. """ - global user_custom_key_generate - verbose_proxy_logger.debug("entered /key/generate") + try: + global user_custom_key_generate + verbose_proxy_logger.debug("entered /key/generate") - if user_custom_key_generate is not None: - result = await user_custom_key_generate(data) - decision = result.get("decision", True) - message = result.get("message", "Authentication Failed - Custom Auth Rule") - if not decision: - raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail=message) + if user_custom_key_generate is not None: + result = await user_custom_key_generate(data) + decision = result.get("decision", True) + message = result.get("message", "Authentication Failed - Custom Auth Rule") + if not decision: + raise HTTPException( + status_code=status.HTTP_403_FORBIDDEN, detail=message + ) - data_json = data.json() # type: ignore + data_json = data.json() # type: ignore - # if we get max_budget passed to /key/generate, then use it as key_max_budget. Since generate_key_helper_fn is used to make new users - if "max_budget" in data_json: - data_json["key_max_budget"] = data_json.pop("max_budget", None) + # if we get max_budget passed to /key/generate, then use it as key_max_budget. Since generate_key_helper_fn is used to make new users + if "max_budget" in data_json: + data_json["key_max_budget"] = data_json.pop("max_budget", None) - response = await generate_key_helper_fn(**data_json) - return GenerateKeyResponse( - key=response["token"], expires=response["expires"], user_id=response["user_id"] - ) + response = await generate_key_helper_fn(**data_json) + return GenerateKeyResponse( + key=response["token"], + expires=response["expires"], + user_id=response["user_id"], + ) + except Exception as e: + if isinstance(e, HTTPException): + raise ProxyException( + message=getattr(e, "detail", f"Authentication Error({str(e)})"), + type="auth_error", + param=getattr(e, "param", "None"), + code=getattr(e, "status_code", status.HTTP_400_BAD_REQUEST), + ) + elif isinstance(e, ProxyException): + raise e + raise ProxyException( + message="Authentication Error, " + str(e), + type="auth_error", + param=getattr(e, "param", "None"), + code=status.HTTP_400_BAD_REQUEST, + ) @router.post( @@ -2238,9 +2267,20 @@ async def update_key_fn(request: Request, data: UpdateKeyRequest): return {"key": key, **non_default_values} # update based on remaining passed in values except Exception as e: - raise HTTPException( - status_code=status.HTTP_400_BAD_REQUEST, - detail={"error": str(e)}, + if isinstance(e, HTTPException): + raise ProxyException( + message=getattr(e, "detail", f"Authentication Error({str(e)})"), + type="auth_error", + param=getattr(e, "param", "None"), + code=getattr(e, "status_code", status.HTTP_400_BAD_REQUEST), + ) + elif isinstance(e, ProxyException): + raise e + raise ProxyException( + message="Authentication Error, " + str(e), + type="auth_error", + param=getattr(e, "param", "None"), + code=status.HTTP_400_BAD_REQUEST, ) @@ -2271,9 +2311,20 @@ async def delete_key_fn(data: DeleteKeyRequest): assert len(keys) == number_deleted_keys return {"deleted_keys": keys} except Exception as e: - raise HTTPException( - status_code=status.HTTP_400_BAD_REQUEST, - detail={"error": str(e)}, + if isinstance(e, HTTPException): + raise ProxyException( + message=getattr(e, "detail", f"Authentication Error({str(e)})"), + type="auth_error", + param=getattr(e, "param", "None"), + code=getattr(e, "status_code", status.HTTP_400_BAD_REQUEST), + ) + elif isinstance(e, ProxyException): + raise e + raise ProxyException( + message="Authentication Error, " + str(e), + type="auth_error", + param=getattr(e, "param", "None"), + code=status.HTTP_400_BAD_REQUEST, ) @@ -2299,9 +2350,20 @@ async def info_key_fn( key_info.pop("token") return {"key": key, "info": key_info} except Exception as e: - raise HTTPException( - status_code=status.HTTP_400_BAD_REQUEST, - detail={"error": str(e)}, + if isinstance(e, HTTPException): + raise ProxyException( + message=getattr(e, "detail", f"Authentication Error({str(e)})"), + type="auth_error", + param=getattr(e, "param", "None"), + code=getattr(e, "status_code", status.HTTP_400_BAD_REQUEST), + ) + elif isinstance(e, ProxyException): + raise e + raise ProxyException( + message="Authentication Error, " + str(e), + type="auth_error", + param=getattr(e, "param", "None"), + code=status.HTTP_400_BAD_REQUEST, ) @@ -2442,9 +2504,20 @@ async def user_info( key.pop("token", None) return {"user_id": user_id, "user_info": user_info, "keys": keys} except Exception as e: - raise HTTPException( - status_code=status.HTTP_400_BAD_REQUEST, - detail={"error": str(e)}, + if isinstance(e, HTTPException): + raise ProxyException( + message=getattr(e, "detail", f"Authentication Error({str(e)})"), + type="auth_error", + param=getattr(e, "param", "None"), + code=getattr(e, "status_code", status.HTTP_400_BAD_REQUEST), + ) + elif isinstance(e, ProxyException): + raise e + raise ProxyException( + message="Authentication Error, " + str(e), + type="auth_error", + param=getattr(e, "param", "None"), + code=status.HTTP_400_BAD_REQUEST, ) @@ -2497,11 +2570,20 @@ async def add_new_model(model_params: ModelParams): except Exception as e: traceback.print_exc() if isinstance(e, HTTPException): - raise e - else: - raise HTTPException( - status_code=500, detail=f"Internal Server Error: {str(e)}" + raise ProxyException( + message=getattr(e, "detail", f"Authentication Error({str(e)})"), + type="auth_error", + param=getattr(e, "param", "None"), + code=getattr(e, "status_code", status.HTTP_400_BAD_REQUEST), ) + elif isinstance(e, ProxyException): + raise e + raise ProxyException( + message="Authentication Error, " + str(e), + type="auth_error", + param=getattr(e, "param", "None"), + code=status.HTTP_400_BAD_REQUEST, + ) #### [BETA] - This is a beta endpoint, format might change based on user feedback https://github.com/BerriAI/litellm/issues/933. If you need a stable endpoint use /model/info @@ -2590,11 +2672,22 @@ async def delete_model(model_info: ModelInfoDelete): config = await proxy_config.save_config(new_config=config) return {"message": "Model deleted successfully"} - except HTTPException as e: - # Re-raise the HTTP exceptions to be handled by FastAPI - raise except Exception as e: - raise HTTPException(status_code=500, detail=f"Internal Server Error: {str(e)}") + if isinstance(e, HTTPException): + raise ProxyException( + message=getattr(e, "detail", f"Authentication Error({str(e)})"), + type="auth_error", + param=getattr(e, "param", "None"), + code=getattr(e, "status_code", status.HTTP_400_BAD_REQUEST), + ) + elif isinstance(e, ProxyException): + raise e + raise ProxyException( + message="Authentication Error, " + str(e), + type="auth_error", + param=getattr(e, "param", "None"), + code=status.HTTP_400_BAD_REQUEST, + ) #### EXPERIMENTAL QUEUING #### @@ -2739,9 +2832,20 @@ async def async_queue_request( await proxy_logging_obj.post_call_failure_hook( user_api_key_dict=user_api_key_dict, original_exception=e ) - raise HTTPException( - status_code=status.HTTP_400_BAD_REQUEST, - detail={"error": str(e)}, + if isinstance(e, HTTPException): + raise ProxyException( + message=getattr(e, "detail", f"Authentication Error({str(e)})"), + type="auth_error", + param=getattr(e, "param", "None"), + code=getattr(e, "status_code", status.HTTP_400_BAD_REQUEST), + ) + elif isinstance(e, ProxyException): + raise e + raise ProxyException( + message="Authentication Error, " + str(e), + type="auth_error", + param=getattr(e, "param", "None"), + code=status.HTTP_400_BAD_REQUEST, ) @@ -2811,11 +2915,23 @@ async def update_config(config_info: ConfigYAML): message="This is a test", level="Low" ) return {"message": "Config updated successfully"} - except HTTPException as e: - raise e except Exception as e: traceback.print_exc() - raise HTTPException(status_code=500, detail=f"An error occurred - {str(e)}") + if isinstance(e, HTTPException): + raise ProxyException( + message=getattr(e, "detail", f"Authentication Error({str(e)})"), + type="auth_error", + param=getattr(e, "param", "None"), + code=getattr(e, "status_code", status.HTTP_400_BAD_REQUEST), + ) + elif isinstance(e, ProxyException): + raise e + raise ProxyException( + message="Authentication Error, " + str(e), + type="auth_error", + param=getattr(e, "param", "None"), + code=status.HTTP_400_BAD_REQUEST, + ) @router.get( From d2675a1ff3ea9dc70cb9dbcb02257a48148e3106 Mon Sep 17 00:00:00 2001 From: ishaan-jaff Date: Tue, 23 Jan 2024 17:34:18 -0800 Subject: [PATCH 3/4] (test) switch dynamodb test to use new exceptions --- litellm/tests/test_key_generate_dynamodb.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/litellm/tests/test_key_generate_dynamodb.py b/litellm/tests/test_key_generate_dynamodb.py index 8b3077909..be55595fa 100644 --- a/litellm/tests/test_key_generate_dynamodb.py +++ b/litellm/tests/test_key_generate_dynamodb.py @@ -109,8 +109,8 @@ def test_call_with_invalid_key(custom_db_client): asyncio.run(test()) except Exception as e: print("Got Exception", e) - print(e.detail) - assert "Authentication Error" in e.detail + print(e.message) + assert "Authentication Error" in e.message pass @@ -143,7 +143,7 @@ def test_call_with_invalid_model(custom_db_client): asyncio.run(test()) except Exception as e: assert ( - e.detail + e.message == "Authentication Error, API Key not allowed to access model. This token can only access models=['mistral']. Tried to access gemini-pro-vision" ) pass @@ -248,7 +248,7 @@ def test_call_with_user_over_budget(custom_db_client): asyncio.run(test()) except Exception as e: - error_detail = e.detail + error_detail = e.message assert "Authentication Error, ExceededBudget:" in error_detail print(vars(e)) @@ -321,7 +321,7 @@ def test_call_with_user_over_budget_stream(custom_db_client): asyncio.run(test()) except Exception as e: - error_detail = e.detail + error_detail = e.message assert "Authentication Error, ExceededBudget:" in error_detail print(vars(e)) @@ -392,7 +392,7 @@ def test_call_with_user_key_budget(custom_db_client): asyncio.run(test()) except Exception as e: - error_detail = e.detail + error_detail = e.message assert "Authentication Error, ExceededTokenBudget:" in error_detail print(vars(e)) @@ -465,6 +465,6 @@ def test_call_with_key_over_budget_stream(custom_db_client): asyncio.run(test()) except Exception as e: - error_detail = e.detail + error_detail = e.message assert "Authentication Error, ExceededTokenBudget:" in error_detail print(vars(e)) From b3ce0ac728a7cd28d43b518120dbf84f2e7ccd4b Mon Sep 17 00:00:00 2001 From: ishaan-jaff Date: Tue, 23 Jan 2024 17:36:13 -0800 Subject: [PATCH 4/4] (test) proxy_custom_auth use new exceptions --- litellm/tests/test_configs/custom_auth.py | 2 +- litellm/tests/test_proxy_custom_auth.py | 11 ++++++----- 2 files changed, 7 insertions(+), 6 deletions(-) diff --git a/litellm/tests/test_configs/custom_auth.py b/litellm/tests/test_configs/custom_auth.py index f3825038e..e4747ee53 100644 --- a/litellm/tests/test_configs/custom_auth.py +++ b/litellm/tests/test_configs/custom_auth.py @@ -13,4 +13,4 @@ async def user_api_key_auth(request: Request, api_key: str) -> UserAPIKeyAuth: return UserAPIKeyAuth(api_key=api_key) raise Exception except: - raise Exception + raise Exception("Failed custom auth") diff --git a/litellm/tests/test_proxy_custom_auth.py b/litellm/tests/test_proxy_custom_auth.py index ceb3d1c93..b6b833e17 100644 --- a/litellm/tests/test_proxy_custom_auth.py +++ b/litellm/tests/test_proxy_custom_auth.py @@ -58,9 +58,10 @@ def test_custom_auth(client): headers = {"Authorization": f"Bearer {token}"} response = client.post("/chat/completions", json=test_data, headers=headers) - print(f"response: {response.text}") - assert response.status_code == 401 - result = response.json() - print(f"Received response: {result}") + pytest.fail("LiteLLM Proxy test failed. This request should have been rejected") except Exception as e: - pytest.fail("LiteLLM Proxy test failed. Exception", e) + print(vars(e)) + print("got an exception") + assert e.code == 401 + assert e.message == "Authentication Error, Failed custom auth" + pass