Merge pull request #1748 from BerriAI/litellm_custom_auth_fixes

[Fix] user_custom_auth fixes when user passed bad api_keys
This commit is contained in:
Ishaan Jaff 2024-02-01 15:42:39 -08:00 committed by GitHub
commit b6a709fd8d
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
4 changed files with 157 additions and 4 deletions

View file

@ -233,9 +233,13 @@ def usage_telemetry(
).start()
def _get_bearer_token(api_key: str):
assert api_key.startswith("Bearer ") # ensure Bearer token passed in
api_key = api_key.replace("Bearer ", "") # extract the token
def _get_bearer_token(
api_key: str,
):
if api_key.startswith("Bearer "): # ensure Bearer token passed in
api_key = api_key.replace("Bearer ", "") # extract the token
else:
api_key = ""
return api_key
@ -253,11 +257,14 @@ async def user_api_key_auth(
global master_key, prisma_client, llm_model_list, user_custom_auth, custom_db_client
try:
if isinstance(api_key, str):
passed_in_key = api_key
api_key = _get_bearer_token(api_key=api_key)
### USER-DEFINED AUTH FUNCTION ###
if user_custom_auth is not None:
response = await user_custom_auth(request=request, api_key=api_key)
return UserAPIKeyAuth.model_validate(response)
### LITELLM-DEFINED AUTH FUNCTION ###
if master_key is None:
if isinstance(api_key, str):
@ -265,6 +272,14 @@ async def user_api_key_auth(
else:
return UserAPIKeyAuth()
if api_key is None:
raise Exception("No API Key passed in. api_key is None")
if secrets.compare_digest(api_key, ""):
# missing 'Bearer ' prefix
raise Exception(
f"Malformed API Key passed in. Ensure Key has `Bearer ` prefix. Passed in: {passed_in_key}"
)
route: str = request.url.path
if route == "/user/auth":
if general_settings.get("allow_user_auth", False) == True: