forked from phoenix/litellm-mirror
Merge pull request #1576 from BerriAI/litellm_map_openai_auth_errors
[Feat] Make Proxy Auth Exceptions OpenAI compatible
This commit is contained in:
commit
d00548ee5d
5 changed files with 194 additions and 77 deletions
|
@ -485,12 +485,20 @@ async def user_api_key_auth(
|
|||
# verbose_proxy_logger.debug(f"An exception occurred - {traceback.format_exc()}")
|
||||
traceback.print_exc()
|
||||
if isinstance(e, HTTPException):
|
||||
raise e
|
||||
else:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail=f"Authentication Error, {str(e)}",
|
||||
raise ProxyException(
|
||||
message=getattr(e, "detail", f"Authentication Error({str(e)})"),
|
||||
type="auth_error",
|
||||
param=getattr(e, "param", "None"),
|
||||
code=getattr(e, "status_code", status.HTTP_401_UNAUTHORIZED),
|
||||
)
|
||||
elif isinstance(e, ProxyException):
|
||||
raise e
|
||||
raise ProxyException(
|
||||
message="Authentication Error, " + str(e),
|
||||
type="auth_error",
|
||||
param=getattr(e, "param", "None"),
|
||||
code=status.HTTP_401_UNAUTHORIZED,
|
||||
)
|
||||
|
||||
|
||||
def prisma_setup(database_url: Optional[str]):
|
||||
|
@ -2216,30 +2224,51 @@ async def generate_key_fn(
|
|||
- expires: (datetime) Datetime object for when key expires.
|
||||
- user_id: (str) Unique user id - used for tracking spend across multiple keys for same user id.
|
||||
"""
|
||||
global user_custom_key_generate
|
||||
verbose_proxy_logger.debug("entered /key/generate")
|
||||
try:
|
||||
global user_custom_key_generate
|
||||
verbose_proxy_logger.debug("entered /key/generate")
|
||||
|
||||
if user_custom_key_generate is not None:
|
||||
result = await user_custom_key_generate(data)
|
||||
decision = result.get("decision", True)
|
||||
message = result.get("message", "Authentication Failed - Custom Auth Rule")
|
||||
if not decision:
|
||||
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail=message)
|
||||
if user_custom_key_generate is not None:
|
||||
result = await user_custom_key_generate(data)
|
||||
decision = result.get("decision", True)
|
||||
message = result.get("message", "Authentication Failed - Custom Auth Rule")
|
||||
if not decision:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN, detail=message
|
||||
)
|
||||
|
||||
data_json = data.json() # type: ignore
|
||||
data_json = data.json() # type: ignore
|
||||
|
||||
# if we get max_budget passed to /key/generate, then use it as key_max_budget. Since generate_key_helper_fn is used to make new users
|
||||
if "max_budget" in data_json:
|
||||
data_json["key_max_budget"] = data_json.pop("max_budget", None)
|
||||
# if we get max_budget passed to /key/generate, then use it as key_max_budget. Since generate_key_helper_fn is used to make new users
|
||||
if "max_budget" in data_json:
|
||||
data_json["key_max_budget"] = data_json.pop("max_budget", None)
|
||||
|
||||
if "budget_duration" in data_json:
|
||||
data_json["key_budget_duration"] = data_json.pop("budget_duration", None)
|
||||
|
||||
response = await generate_key_helper_fn(**data_json)
|
||||
return GenerateKeyResponse(
|
||||
key=response["token"], expires=response["expires"], user_id=response["user_id"]
|
||||
)
|
||||
if "budget_duration" in data_json:
|
||||
data_json["key_budget_duration"] = data_json.pop("budget_duration", None)
|
||||
|
||||
response = await generate_key_helper_fn(**data_json)
|
||||
return GenerateKeyResponse(
|
||||
key=response["token"], expires=response["expires"], user_id=response["user_id"]
|
||||
)
|
||||
except Exception as e:
|
||||
if isinstance(e, HTTPException):
|
||||
raise ProxyException(
|
||||
message=getattr(e, "detail", f"Authentication Error({str(e)})"),
|
||||
type="auth_error",
|
||||
param=getattr(e, "param", "None"),
|
||||
code=getattr(e, "status_code", status.HTTP_400_BAD_REQUEST),
|
||||
)
|
||||
elif isinstance(e, ProxyException):
|
||||
raise e
|
||||
raise ProxyException(
|
||||
message="Authentication Error, " + str(e),
|
||||
type="auth_error",
|
||||
param=getattr(e, "param", "None"),
|
||||
code=status.HTTP_400_BAD_REQUEST,
|
||||
)
|
||||
|
||||
|
||||
|
||||
@router.post(
|
||||
"/key/update", tags=["key management"], dependencies=[Depends(user_api_key_auth)]
|
||||
|
@ -2263,9 +2292,20 @@ async def update_key_fn(request: Request, data: UpdateKeyRequest):
|
|||
return {"key": key, **non_default_values}
|
||||
# update based on remaining passed in values
|
||||
except Exception as e:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail={"error": str(e)},
|
||||
if isinstance(e, HTTPException):
|
||||
raise ProxyException(
|
||||
message=getattr(e, "detail", f"Authentication Error({str(e)})"),
|
||||
type="auth_error",
|
||||
param=getattr(e, "param", "None"),
|
||||
code=getattr(e, "status_code", status.HTTP_400_BAD_REQUEST),
|
||||
)
|
||||
elif isinstance(e, ProxyException):
|
||||
raise e
|
||||
raise ProxyException(
|
||||
message="Authentication Error, " + str(e),
|
||||
type="auth_error",
|
||||
param=getattr(e, "param", "None"),
|
||||
code=status.HTTP_400_BAD_REQUEST,
|
||||
)
|
||||
|
||||
|
||||
|
@ -2296,9 +2336,20 @@ async def delete_key_fn(data: DeleteKeyRequest):
|
|||
assert len(keys) == number_deleted_keys
|
||||
return {"deleted_keys": keys}
|
||||
except Exception as e:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail={"error": str(e)},
|
||||
if isinstance(e, HTTPException):
|
||||
raise ProxyException(
|
||||
message=getattr(e, "detail", f"Authentication Error({str(e)})"),
|
||||
type="auth_error",
|
||||
param=getattr(e, "param", "None"),
|
||||
code=getattr(e, "status_code", status.HTTP_400_BAD_REQUEST),
|
||||
)
|
||||
elif isinstance(e, ProxyException):
|
||||
raise e
|
||||
raise ProxyException(
|
||||
message="Authentication Error, " + str(e),
|
||||
type="auth_error",
|
||||
param=getattr(e, "param", "None"),
|
||||
code=status.HTTP_400_BAD_REQUEST,
|
||||
)
|
||||
|
||||
|
||||
|
@ -2324,9 +2375,20 @@ async def info_key_fn(
|
|||
key_info.pop("token")
|
||||
return {"key": key, "info": key_info}
|
||||
except Exception as e:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail={"error": str(e)},
|
||||
if isinstance(e, HTTPException):
|
||||
raise ProxyException(
|
||||
message=getattr(e, "detail", f"Authentication Error({str(e)})"),
|
||||
type="auth_error",
|
||||
param=getattr(e, "param", "None"),
|
||||
code=getattr(e, "status_code", status.HTTP_400_BAD_REQUEST),
|
||||
)
|
||||
elif isinstance(e, ProxyException):
|
||||
raise e
|
||||
raise ProxyException(
|
||||
message="Authentication Error, " + str(e),
|
||||
type="auth_error",
|
||||
param=getattr(e, "param", "None"),
|
||||
code=status.HTTP_400_BAD_REQUEST,
|
||||
)
|
||||
|
||||
|
||||
|
@ -2555,9 +2617,20 @@ async def user_info(
|
|||
key.pop("token", None)
|
||||
return {"user_id": user_id, "user_info": user_info, "keys": keys}
|
||||
except Exception as e:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail={"error": str(e)},
|
||||
if isinstance(e, HTTPException):
|
||||
raise ProxyException(
|
||||
message=getattr(e, "detail", f"Authentication Error({str(e)})"),
|
||||
type="auth_error",
|
||||
param=getattr(e, "param", "None"),
|
||||
code=getattr(e, "status_code", status.HTTP_400_BAD_REQUEST),
|
||||
)
|
||||
elif isinstance(e, ProxyException):
|
||||
raise e
|
||||
raise ProxyException(
|
||||
message="Authentication Error, " + str(e),
|
||||
type="auth_error",
|
||||
param=getattr(e, "param", "None"),
|
||||
code=status.HTTP_400_BAD_REQUEST,
|
||||
)
|
||||
|
||||
|
||||
|
@ -2610,11 +2683,20 @@ async def add_new_model(model_params: ModelParams):
|
|||
except Exception as e:
|
||||
traceback.print_exc()
|
||||
if isinstance(e, HTTPException):
|
||||
raise e
|
||||
else:
|
||||
raise HTTPException(
|
||||
status_code=500, detail=f"Internal Server Error: {str(e)}"
|
||||
raise ProxyException(
|
||||
message=getattr(e, "detail", f"Authentication Error({str(e)})"),
|
||||
type="auth_error",
|
||||
param=getattr(e, "param", "None"),
|
||||
code=getattr(e, "status_code", status.HTTP_400_BAD_REQUEST),
|
||||
)
|
||||
elif isinstance(e, ProxyException):
|
||||
raise e
|
||||
raise ProxyException(
|
||||
message="Authentication Error, " + str(e),
|
||||
type="auth_error",
|
||||
param=getattr(e, "param", "None"),
|
||||
code=status.HTTP_400_BAD_REQUEST,
|
||||
)
|
||||
|
||||
|
||||
#### [BETA] - This is a beta endpoint, format might change based on user feedback https://github.com/BerriAI/litellm/issues/933. If you need a stable endpoint use /model/info
|
||||
|
@ -2703,11 +2785,22 @@ async def delete_model(model_info: ModelInfoDelete):
|
|||
config = await proxy_config.save_config(new_config=config)
|
||||
return {"message": "Model deleted successfully"}
|
||||
|
||||
except HTTPException as e:
|
||||
# Re-raise the HTTP exceptions to be handled by FastAPI
|
||||
raise
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=f"Internal Server Error: {str(e)}")
|
||||
if isinstance(e, HTTPException):
|
||||
raise ProxyException(
|
||||
message=getattr(e, "detail", f"Authentication Error({str(e)})"),
|
||||
type="auth_error",
|
||||
param=getattr(e, "param", "None"),
|
||||
code=getattr(e, "status_code", status.HTTP_400_BAD_REQUEST),
|
||||
)
|
||||
elif isinstance(e, ProxyException):
|
||||
raise e
|
||||
raise ProxyException(
|
||||
message="Authentication Error, " + str(e),
|
||||
type="auth_error",
|
||||
param=getattr(e, "param", "None"),
|
||||
code=status.HTTP_400_BAD_REQUEST,
|
||||
)
|
||||
|
||||
|
||||
#### EXPERIMENTAL QUEUING ####
|
||||
|
@ -2852,9 +2945,20 @@ async def async_queue_request(
|
|||
await proxy_logging_obj.post_call_failure_hook(
|
||||
user_api_key_dict=user_api_key_dict, original_exception=e
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail={"error": str(e)},
|
||||
if isinstance(e, HTTPException):
|
||||
raise ProxyException(
|
||||
message=getattr(e, "detail", f"Authentication Error({str(e)})"),
|
||||
type="auth_error",
|
||||
param=getattr(e, "param", "None"),
|
||||
code=getattr(e, "status_code", status.HTTP_400_BAD_REQUEST),
|
||||
)
|
||||
elif isinstance(e, ProxyException):
|
||||
raise e
|
||||
raise ProxyException(
|
||||
message="Authentication Error, " + str(e),
|
||||
type="auth_error",
|
||||
param=getattr(e, "param", "None"),
|
||||
code=status.HTTP_400_BAD_REQUEST,
|
||||
)
|
||||
|
||||
|
||||
|
@ -2924,11 +3028,23 @@ async def update_config(config_info: ConfigYAML):
|
|||
message="This is a test", level="Low"
|
||||
)
|
||||
return {"message": "Config updated successfully"}
|
||||
except HTTPException as e:
|
||||
raise e
|
||||
except Exception as e:
|
||||
traceback.print_exc()
|
||||
raise HTTPException(status_code=500, detail=f"An error occurred - {str(e)}")
|
||||
if isinstance(e, HTTPException):
|
||||
raise ProxyException(
|
||||
message=getattr(e, "detail", f"Authentication Error({str(e)})"),
|
||||
type="auth_error",
|
||||
param=getattr(e, "param", "None"),
|
||||
code=getattr(e, "status_code", status.HTTP_400_BAD_REQUEST),
|
||||
)
|
||||
elif isinstance(e, ProxyException):
|
||||
raise e
|
||||
raise ProxyException(
|
||||
message="Authentication Error, " + str(e),
|
||||
type="auth_error",
|
||||
param=getattr(e, "param", "None"),
|
||||
code=status.HTTP_400_BAD_REQUEST,
|
||||
)
|
||||
|
||||
|
||||
@router.get(
|
||||
|
|
|
@ -13,4 +13,4 @@ async def user_api_key_auth(request: Request, api_key: str) -> UserAPIKeyAuth:
|
|||
return UserAPIKeyAuth(api_key=api_key)
|
||||
raise Exception
|
||||
except:
|
||||
raise Exception
|
||||
raise Exception("Failed custom auth")
|
||||
|
|
|
@ -109,8 +109,8 @@ def test_call_with_invalid_key(custom_db_client):
|
|||
asyncio.run(test())
|
||||
except Exception as e:
|
||||
print("Got Exception", e)
|
||||
print(e.detail)
|
||||
assert "Authentication Error" in e.detail
|
||||
print(e.message)
|
||||
assert "Authentication Error" in e.message
|
||||
pass
|
||||
|
||||
|
||||
|
@ -143,7 +143,7 @@ def test_call_with_invalid_model(custom_db_client):
|
|||
asyncio.run(test())
|
||||
except Exception as e:
|
||||
assert (
|
||||
e.detail
|
||||
e.message
|
||||
== "Authentication Error, API Key not allowed to access model. This token can only access models=['mistral']. Tried to access gemini-pro-vision"
|
||||
)
|
||||
pass
|
||||
|
@ -248,7 +248,7 @@ def test_call_with_user_over_budget(custom_db_client):
|
|||
|
||||
asyncio.run(test())
|
||||
except Exception as e:
|
||||
error_detail = e.detail
|
||||
error_detail = e.message
|
||||
assert "Authentication Error, ExceededBudget:" in error_detail
|
||||
print(vars(e))
|
||||
|
||||
|
@ -321,7 +321,7 @@ def test_call_with_user_over_budget_stream(custom_db_client):
|
|||
|
||||
asyncio.run(test())
|
||||
except Exception as e:
|
||||
error_detail = e.detail
|
||||
error_detail = e.message
|
||||
assert "Authentication Error, ExceededBudget:" in error_detail
|
||||
print(vars(e))
|
||||
|
||||
|
@ -392,7 +392,7 @@ def test_call_with_user_key_budget(custom_db_client):
|
|||
|
||||
asyncio.run(test())
|
||||
except Exception as e:
|
||||
error_detail = e.detail
|
||||
error_detail = e.message
|
||||
assert "Authentication Error, ExceededTokenBudget:" in error_detail
|
||||
print(vars(e))
|
||||
|
||||
|
@ -465,6 +465,6 @@ def test_call_with_key_over_budget_stream(custom_db_client):
|
|||
|
||||
asyncio.run(test())
|
||||
except Exception as e:
|
||||
error_detail = e.detail
|
||||
error_detail = e.message
|
||||
assert "Authentication Error, ExceededTokenBudget:" in error_detail
|
||||
print(vars(e))
|
||||
|
|
|
@ -136,8 +136,8 @@ def test_call_with_invalid_key(prisma_client):
|
|||
asyncio.run(test())
|
||||
except Exception as e:
|
||||
print("Got Exception", e)
|
||||
print(e.detail)
|
||||
assert "Authentication Error" in e.detail
|
||||
print(e.message)
|
||||
assert "Authentication Error" in e.message
|
||||
pass
|
||||
|
||||
|
||||
|
@ -171,7 +171,7 @@ def test_call_with_invalid_model(prisma_client):
|
|||
asyncio.run(test())
|
||||
except Exception as e:
|
||||
assert (
|
||||
e.detail
|
||||
e.message
|
||||
== "Authentication Error, API Key not allowed to access model. This token can only access models=['mistral']. Tried to access gemini-pro-vision"
|
||||
)
|
||||
pass
|
||||
|
@ -274,7 +274,7 @@ def test_call_with_user_over_budget(prisma_client):
|
|||
|
||||
asyncio.run(test())
|
||||
except Exception as e:
|
||||
error_detail = e.detail
|
||||
error_detail = e.message
|
||||
assert "Authentication Error, ExceededBudget:" in error_detail
|
||||
print(vars(e))
|
||||
|
||||
|
@ -350,7 +350,7 @@ def test_call_with_user_over_budget_stream(prisma_client):
|
|||
|
||||
asyncio.run(test())
|
||||
except Exception as e:
|
||||
error_detail = e.detail
|
||||
error_detail = e.message
|
||||
assert "Authentication Error, ExceededBudget:" in error_detail
|
||||
print(vars(e))
|
||||
|
||||
|
@ -414,8 +414,8 @@ def test_generate_and_call_with_expired_key(prisma_client):
|
|||
asyncio.run(test())
|
||||
except Exception as e:
|
||||
print("Got Exception", e)
|
||||
print(e.detail)
|
||||
assert "Authentication Error" in e.detail
|
||||
print(e.message)
|
||||
assert "Authentication Error" in e.message
|
||||
pass
|
||||
|
||||
|
||||
|
@ -486,8 +486,8 @@ def test_delete_key_auth(prisma_client):
|
|||
asyncio.run(test())
|
||||
except Exception as e:
|
||||
print("Got Exception", e)
|
||||
print(e.detail)
|
||||
assert "Authentication Error" in e.detail
|
||||
print(e.message)
|
||||
assert "Authentication Error" in e.message
|
||||
pass
|
||||
|
||||
|
||||
|
@ -599,7 +599,7 @@ def test_generate_and_update_key(prisma_client):
|
|||
asyncio.run(test())
|
||||
except Exception as e:
|
||||
print("Got Exception", e)
|
||||
print(e.detail)
|
||||
print(e.message)
|
||||
pytest.fail(f"An exception occurred - {str(e)}")
|
||||
|
||||
|
||||
|
@ -665,11 +665,11 @@ def test_key_generate_with_custom_auth(prisma_client):
|
|||
except Exception as e:
|
||||
# this should fail
|
||||
print("Got Exception", e)
|
||||
print(e.detail)
|
||||
print(e.message)
|
||||
print("First request failed!. This is expected")
|
||||
assert (
|
||||
"This violates LiteLLM Proxy Rules. No team id provided."
|
||||
in e.detail
|
||||
in e.message
|
||||
)
|
||||
|
||||
request_2 = GenerateKeyRequest(
|
||||
|
@ -683,7 +683,7 @@ def test_key_generate_with_custom_auth(prisma_client):
|
|||
asyncio.run(test())
|
||||
except Exception as e:
|
||||
print("Got Exception", e)
|
||||
print(e.detail)
|
||||
print(e.message)
|
||||
pytest.fail(f"An exception occurred - {str(e)}")
|
||||
|
||||
|
||||
|
@ -752,7 +752,7 @@ def test_call_with_key_over_budget(prisma_client):
|
|||
|
||||
asyncio.run(test())
|
||||
except Exception as e:
|
||||
error_detail = e.detail
|
||||
error_detail = e.message
|
||||
assert "Authentication Error, ExceededTokenBudget:" in error_detail
|
||||
print(vars(e))
|
||||
|
||||
|
@ -827,6 +827,6 @@ def test_call_with_key_over_budget_stream(prisma_client):
|
|||
pytest.fail(f"This should have failed!. They key crossed it's budget")
|
||||
|
||||
except Exception as e:
|
||||
error_detail = e.detail
|
||||
error_detail = e.message
|
||||
assert "Authentication Error, ExceededTokenBudget:" in error_detail
|
||||
print(vars(e))
|
||||
|
|
|
@ -58,9 +58,10 @@ def test_custom_auth(client):
|
|||
|
||||
headers = {"Authorization": f"Bearer {token}"}
|
||||
response = client.post("/chat/completions", json=test_data, headers=headers)
|
||||
print(f"response: {response.text}")
|
||||
assert response.status_code == 401
|
||||
result = response.json()
|
||||
print(f"Received response: {result}")
|
||||
pytest.fail("LiteLLM Proxy test failed. This request should have been rejected")
|
||||
except Exception as e:
|
||||
pytest.fail("LiteLLM Proxy test failed. Exception", e)
|
||||
print(vars(e))
|
||||
print("got an exception")
|
||||
assert e.code == 401
|
||||
assert e.message == "Authentication Error, Failed custom auth"
|
||||
pass
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue