(fix) add literals to _get_bearer_token

This commit is contained in:
ishaan-jaff 2024-02-01 14:14:13 -08:00
parent 1b4b5185bb
commit ce37ead178

View file

@ -2,7 +2,7 @@ import sys, os, platform, time, copy, re, asyncio, inspect
import threading, ast
import shutil, random, traceback, requests
from datetime import datetime, timedelta, timezone
from typing import Optional, List
from typing import Optional, List, Literal
import secrets, subprocess
import hashlib, uuid
import warnings
@ -141,6 +141,11 @@ class ProxyException(Exception):
self.code = code
# Literals - used by _get_bearer_token
NO_API_KEY = "no-api-key"
MISSING_BEARER = "missing-bearer"
@app.exception_handler(ProxyException)
async def openai_exception_handler(request: Request, exc: ProxyException):
# NOTE: DO NOT MODIFY THIS, its crucial to map to Openai exceptions
@ -233,11 +238,15 @@ def usage_telemetry(
).start()
def _get_bearer_token(api_key: str):
def _get_bearer_token(
api_key: str,
) -> Union[str, Literal["no-api-key", "missing-bearer"]]:
if api_key == "":
return NO_API_KEY
if api_key.startswith("Bearer "): # ensure Bearer token passed in
api_key = api_key.replace("Bearer ", "") # extract the token
else:
api_key = ""
api_key = MISSING_BEARER
return api_key
@ -258,17 +267,11 @@ async def user_api_key_auth(
passed_in_key = api_key
api_key = _get_bearer_token(api_key=api_key)
### USER-DEFINED AUTH FUNCTION -> This should always be run first if a user has defined it ###
### USER-DEFINED AUTH FUNCTION ###
if user_custom_auth is not None:
response = await user_custom_auth(request=request, api_key=api_key)
return UserAPIKeyAuth.model_validate(response)
if api_key == "":
# missing 'Bearer ' prefix
raise Exception(
f"Malformed API Key passed in. Ensure Key has `Bearer ` prefix. Passed in: {passed_in_key}"
)
### LITELLM-DEFINED AUTH FUNCTION ###
if master_key is None:
if isinstance(api_key, str):
@ -276,6 +279,17 @@ async def user_api_key_auth(
else:
return UserAPIKeyAuth()
if api_key is None:
raise Exception("No API Key passed in")
if secrets.compare_digest(api_key, MISSING_BEARER):
# missing 'Bearer ' prefix
raise Exception(
f"Malformed API Key passed in. Ensure Key has `Bearer ` prefix. Passed in: {passed_in_key}"
)
elif secrets.compare_digest(api_key, NO_API_KEY):
# no api key passed in
raise Exception("No API Key passed in. Passed in: {passed_in_key}")
route: str = request.url.path
if route == "/user/auth":
if general_settings.get("allow_user_auth", False) == True: