Merge pull request #1576 from BerriAI/litellm_map_openai_auth_errors

[Feat] Make Proxy Auth Exceptions OpenAI compatible
This commit is contained in:
Ishaan Jaff 2024-01-23 18:32:00 -08:00 committed by GitHub
commit d00548ee5d
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
5 changed files with 194 additions and 77 deletions

View file

@ -485,11 +485,19 @@ async def user_api_key_auth(
# verbose_proxy_logger.debug(f"An exception occurred - {traceback.format_exc()}")
traceback.print_exc()
if isinstance(e, HTTPException):
raise ProxyException(
message=getattr(e, "detail", f"Authentication Error({str(e)})"),
type="auth_error",
param=getattr(e, "param", "None"),
code=getattr(e, "status_code", status.HTTP_401_UNAUTHORIZED),
)
elif isinstance(e, ProxyException):
raise e
else:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail=f"Authentication Error, {str(e)}",
raise ProxyException(
message="Authentication Error, " + str(e),
type="auth_error",
param=getattr(e, "param", "None"),
code=status.HTTP_401_UNAUTHORIZED,
)
@ -2216,6 +2224,7 @@ async def generate_key_fn(
- expires: (datetime) Datetime object for when key expires.
- user_id: (str) Unique user id - used for tracking spend across multiple keys for same user id.
"""
try:
global user_custom_key_generate
verbose_proxy_logger.debug("entered /key/generate")
@ -2224,7 +2233,9 @@ async def generate_key_fn(
decision = result.get("decision", True)
message = result.get("message", "Authentication Failed - Custom Auth Rule")
if not decision:
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail=message)
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN, detail=message
)
data_json = data.json() # type: ignore
@ -2232,6 +2243,7 @@ async def generate_key_fn(
if "max_budget" in data_json:
data_json["key_max_budget"] = data_json.pop("max_budget", None)
if "budget_duration" in data_json:
data_json["key_budget_duration"] = data_json.pop("budget_duration", None)
@ -2239,6 +2251,23 @@ async def generate_key_fn(
return GenerateKeyResponse(
key=response["token"], expires=response["expires"], user_id=response["user_id"]
)
except Exception as e:
if isinstance(e, HTTPException):
raise ProxyException(
message=getattr(e, "detail", f"Authentication Error({str(e)})"),
type="auth_error",
param=getattr(e, "param", "None"),
code=getattr(e, "status_code", status.HTTP_400_BAD_REQUEST),
)
elif isinstance(e, ProxyException):
raise e
raise ProxyException(
message="Authentication Error, " + str(e),
type="auth_error",
param=getattr(e, "param", "None"),
code=status.HTTP_400_BAD_REQUEST,
)
@router.post(
@ -2263,9 +2292,20 @@ async def update_key_fn(request: Request, data: UpdateKeyRequest):
return {"key": key, **non_default_values}
# update based on remaining passed in values
except Exception as e:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail={"error": str(e)},
if isinstance(e, HTTPException):
raise ProxyException(
message=getattr(e, "detail", f"Authentication Error({str(e)})"),
type="auth_error",
param=getattr(e, "param", "None"),
code=getattr(e, "status_code", status.HTTP_400_BAD_REQUEST),
)
elif isinstance(e, ProxyException):
raise e
raise ProxyException(
message="Authentication Error, " + str(e),
type="auth_error",
param=getattr(e, "param", "None"),
code=status.HTTP_400_BAD_REQUEST,
)
@ -2296,9 +2336,20 @@ async def delete_key_fn(data: DeleteKeyRequest):
assert len(keys) == number_deleted_keys
return {"deleted_keys": keys}
except Exception as e:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail={"error": str(e)},
if isinstance(e, HTTPException):
raise ProxyException(
message=getattr(e, "detail", f"Authentication Error({str(e)})"),
type="auth_error",
param=getattr(e, "param", "None"),
code=getattr(e, "status_code", status.HTTP_400_BAD_REQUEST),
)
elif isinstance(e, ProxyException):
raise e
raise ProxyException(
message="Authentication Error, " + str(e),
type="auth_error",
param=getattr(e, "param", "None"),
code=status.HTTP_400_BAD_REQUEST,
)
@ -2324,9 +2375,20 @@ async def info_key_fn(
key_info.pop("token")
return {"key": key, "info": key_info}
except Exception as e:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail={"error": str(e)},
if isinstance(e, HTTPException):
raise ProxyException(
message=getattr(e, "detail", f"Authentication Error({str(e)})"),
type="auth_error",
param=getattr(e, "param", "None"),
code=getattr(e, "status_code", status.HTTP_400_BAD_REQUEST),
)
elif isinstance(e, ProxyException):
raise e
raise ProxyException(
message="Authentication Error, " + str(e),
type="auth_error",
param=getattr(e, "param", "None"),
code=status.HTTP_400_BAD_REQUEST,
)
@ -2555,9 +2617,20 @@ async def user_info(
key.pop("token", None)
return {"user_id": user_id, "user_info": user_info, "keys": keys}
except Exception as e:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail={"error": str(e)},
if isinstance(e, HTTPException):
raise ProxyException(
message=getattr(e, "detail", f"Authentication Error({str(e)})"),
type="auth_error",
param=getattr(e, "param", "None"),
code=getattr(e, "status_code", status.HTTP_400_BAD_REQUEST),
)
elif isinstance(e, ProxyException):
raise e
raise ProxyException(
message="Authentication Error, " + str(e),
type="auth_error",
param=getattr(e, "param", "None"),
code=status.HTTP_400_BAD_REQUEST,
)
@ -2610,10 +2683,19 @@ async def add_new_model(model_params: ModelParams):
except Exception as e:
traceback.print_exc()
if isinstance(e, HTTPException):
raise ProxyException(
message=getattr(e, "detail", f"Authentication Error({str(e)})"),
type="auth_error",
param=getattr(e, "param", "None"),
code=getattr(e, "status_code", status.HTTP_400_BAD_REQUEST),
)
elif isinstance(e, ProxyException):
raise e
else:
raise HTTPException(
status_code=500, detail=f"Internal Server Error: {str(e)}"
raise ProxyException(
message="Authentication Error, " + str(e),
type="auth_error",
param=getattr(e, "param", "None"),
code=status.HTTP_400_BAD_REQUEST,
)
@ -2703,11 +2785,22 @@ async def delete_model(model_info: ModelInfoDelete):
config = await proxy_config.save_config(new_config=config)
return {"message": "Model deleted successfully"}
except HTTPException as e:
# Re-raise the HTTP exceptions to be handled by FastAPI
raise
except Exception as e:
raise HTTPException(status_code=500, detail=f"Internal Server Error: {str(e)}")
if isinstance(e, HTTPException):
raise ProxyException(
message=getattr(e, "detail", f"Authentication Error({str(e)})"),
type="auth_error",
param=getattr(e, "param", "None"),
code=getattr(e, "status_code", status.HTTP_400_BAD_REQUEST),
)
elif isinstance(e, ProxyException):
raise e
raise ProxyException(
message="Authentication Error, " + str(e),
type="auth_error",
param=getattr(e, "param", "None"),
code=status.HTTP_400_BAD_REQUEST,
)
#### EXPERIMENTAL QUEUING ####
@ -2852,9 +2945,20 @@ async def async_queue_request(
await proxy_logging_obj.post_call_failure_hook(
user_api_key_dict=user_api_key_dict, original_exception=e
)
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail={"error": str(e)},
if isinstance(e, HTTPException):
raise ProxyException(
message=getattr(e, "detail", f"Authentication Error({str(e)})"),
type="auth_error",
param=getattr(e, "param", "None"),
code=getattr(e, "status_code", status.HTTP_400_BAD_REQUEST),
)
elif isinstance(e, ProxyException):
raise e
raise ProxyException(
message="Authentication Error, " + str(e),
type="auth_error",
param=getattr(e, "param", "None"),
code=status.HTTP_400_BAD_REQUEST,
)
@ -2924,11 +3028,23 @@ async def update_config(config_info: ConfigYAML):
message="This is a test", level="Low"
)
return {"message": "Config updated successfully"}
except HTTPException as e:
raise e
except Exception as e:
traceback.print_exc()
raise HTTPException(status_code=500, detail=f"An error occurred - {str(e)}")
if isinstance(e, HTTPException):
raise ProxyException(
message=getattr(e, "detail", f"Authentication Error({str(e)})"),
type="auth_error",
param=getattr(e, "param", "None"),
code=getattr(e, "status_code", status.HTTP_400_BAD_REQUEST),
)
elif isinstance(e, ProxyException):
raise e
raise ProxyException(
message="Authentication Error, " + str(e),
type="auth_error",
param=getattr(e, "param", "None"),
code=status.HTTP_400_BAD_REQUEST,
)
@router.get(

View file

@ -13,4 +13,4 @@ async def user_api_key_auth(request: Request, api_key: str) -> UserAPIKeyAuth:
return UserAPIKeyAuth(api_key=api_key)
raise Exception
except:
raise Exception
raise Exception("Failed custom auth")

View file

@ -109,8 +109,8 @@ def test_call_with_invalid_key(custom_db_client):
asyncio.run(test())
except Exception as e:
print("Got Exception", e)
print(e.detail)
assert "Authentication Error" in e.detail
print(e.message)
assert "Authentication Error" in e.message
pass
@ -143,7 +143,7 @@ def test_call_with_invalid_model(custom_db_client):
asyncio.run(test())
except Exception as e:
assert (
e.detail
e.message
== "Authentication Error, API Key not allowed to access model. This token can only access models=['mistral']. Tried to access gemini-pro-vision"
)
pass
@ -248,7 +248,7 @@ def test_call_with_user_over_budget(custom_db_client):
asyncio.run(test())
except Exception as e:
error_detail = e.detail
error_detail = e.message
assert "Authentication Error, ExceededBudget:" in error_detail
print(vars(e))
@ -321,7 +321,7 @@ def test_call_with_user_over_budget_stream(custom_db_client):
asyncio.run(test())
except Exception as e:
error_detail = e.detail
error_detail = e.message
assert "Authentication Error, ExceededBudget:" in error_detail
print(vars(e))
@ -392,7 +392,7 @@ def test_call_with_user_key_budget(custom_db_client):
asyncio.run(test())
except Exception as e:
error_detail = e.detail
error_detail = e.message
assert "Authentication Error, ExceededTokenBudget:" in error_detail
print(vars(e))
@ -465,6 +465,6 @@ def test_call_with_key_over_budget_stream(custom_db_client):
asyncio.run(test())
except Exception as e:
error_detail = e.detail
error_detail = e.message
assert "Authentication Error, ExceededTokenBudget:" in error_detail
print(vars(e))

View file

@ -136,8 +136,8 @@ def test_call_with_invalid_key(prisma_client):
asyncio.run(test())
except Exception as e:
print("Got Exception", e)
print(e.detail)
assert "Authentication Error" in e.detail
print(e.message)
assert "Authentication Error" in e.message
pass
@ -171,7 +171,7 @@ def test_call_with_invalid_model(prisma_client):
asyncio.run(test())
except Exception as e:
assert (
e.detail
e.message
== "Authentication Error, API Key not allowed to access model. This token can only access models=['mistral']. Tried to access gemini-pro-vision"
)
pass
@ -274,7 +274,7 @@ def test_call_with_user_over_budget(prisma_client):
asyncio.run(test())
except Exception as e:
error_detail = e.detail
error_detail = e.message
assert "Authentication Error, ExceededBudget:" in error_detail
print(vars(e))
@ -350,7 +350,7 @@ def test_call_with_user_over_budget_stream(prisma_client):
asyncio.run(test())
except Exception as e:
error_detail = e.detail
error_detail = e.message
assert "Authentication Error, ExceededBudget:" in error_detail
print(vars(e))
@ -414,8 +414,8 @@ def test_generate_and_call_with_expired_key(prisma_client):
asyncio.run(test())
except Exception as e:
print("Got Exception", e)
print(e.detail)
assert "Authentication Error" in e.detail
print(e.message)
assert "Authentication Error" in e.message
pass
@ -486,8 +486,8 @@ def test_delete_key_auth(prisma_client):
asyncio.run(test())
except Exception as e:
print("Got Exception", e)
print(e.detail)
assert "Authentication Error" in e.detail
print(e.message)
assert "Authentication Error" in e.message
pass
@ -599,7 +599,7 @@ def test_generate_and_update_key(prisma_client):
asyncio.run(test())
except Exception as e:
print("Got Exception", e)
print(e.detail)
print(e.message)
pytest.fail(f"An exception occurred - {str(e)}")
@ -665,11 +665,11 @@ def test_key_generate_with_custom_auth(prisma_client):
except Exception as e:
# this should fail
print("Got Exception", e)
print(e.detail)
print(e.message)
print("First request failed!. This is expected")
assert (
"This violates LiteLLM Proxy Rules. No team id provided."
in e.detail
in e.message
)
request_2 = GenerateKeyRequest(
@ -683,7 +683,7 @@ def test_key_generate_with_custom_auth(prisma_client):
asyncio.run(test())
except Exception as e:
print("Got Exception", e)
print(e.detail)
print(e.message)
pytest.fail(f"An exception occurred - {str(e)}")
@ -752,7 +752,7 @@ def test_call_with_key_over_budget(prisma_client):
asyncio.run(test())
except Exception as e:
error_detail = e.detail
error_detail = e.message
assert "Authentication Error, ExceededTokenBudget:" in error_detail
print(vars(e))
@ -827,6 +827,6 @@ def test_call_with_key_over_budget_stream(prisma_client):
pytest.fail(f"This should have failed!. They key crossed it's budget")
except Exception as e:
error_detail = e.detail
error_detail = e.message
assert "Authentication Error, ExceededTokenBudget:" in error_detail
print(vars(e))

View file

@ -58,9 +58,10 @@ def test_custom_auth(client):
headers = {"Authorization": f"Bearer {token}"}
response = client.post("/chat/completions", json=test_data, headers=headers)
print(f"response: {response.text}")
assert response.status_code == 401
result = response.json()
print(f"Received response: {result}")
pytest.fail("LiteLLM Proxy test failed. This request should have been rejected")
except Exception as e:
pytest.fail("LiteLLM Proxy test failed. Exception", e)
print(vars(e))
print("got an exception")
assert e.code == 401
assert e.message == "Authentication Error, Failed custom auth"
pass