(fix) add literals to _get_bearer_token

This commit is contained in:
ishaan-jaff 2024-02-01 14:14:13 -08:00
parent 1b4b5185bb
commit ce37ead178

View file

@ -2,7 +2,7 @@ import sys, os, platform, time, copy, re, asyncio, inspect
import threading, ast import threading, ast
import shutil, random, traceback, requests import shutil, random, traceback, requests
from datetime import datetime, timedelta, timezone from datetime import datetime, timedelta, timezone
from typing import Optional, List from typing import Optional, List, Literal
import secrets, subprocess import secrets, subprocess
import hashlib, uuid import hashlib, uuid
import warnings import warnings
@ -141,6 +141,11 @@ class ProxyException(Exception):
self.code = code self.code = code
# Literals - used by _get_bearer_token
NO_API_KEY = "no-api-key"
MISSING_BEARER = "missing-bearer"
@app.exception_handler(ProxyException) @app.exception_handler(ProxyException)
async def openai_exception_handler(request: Request, exc: ProxyException): async def openai_exception_handler(request: Request, exc: ProxyException):
# NOTE: DO NOT MODIFY THIS, its crucial to map to Openai exceptions # NOTE: DO NOT MODIFY THIS, its crucial to map to Openai exceptions
@ -233,11 +238,15 @@ def usage_telemetry(
).start() ).start()
def _get_bearer_token(api_key: str): def _get_bearer_token(
api_key: str,
) -> Union[str, Literal["no-api-key", "missing-bearer"]]:
if api_key == "":
return NO_API_KEY
if api_key.startswith("Bearer "): # ensure Bearer token passed in if api_key.startswith("Bearer "): # ensure Bearer token passed in
api_key = api_key.replace("Bearer ", "") # extract the token api_key = api_key.replace("Bearer ", "") # extract the token
else: else:
api_key = "" api_key = MISSING_BEARER
return api_key return api_key
@ -258,17 +267,11 @@ async def user_api_key_auth(
passed_in_key = api_key passed_in_key = api_key
api_key = _get_bearer_token(api_key=api_key) api_key = _get_bearer_token(api_key=api_key)
### USER-DEFINED AUTH FUNCTION -> This should always be run first if a user has defined it ### ### USER-DEFINED AUTH FUNCTION ###
if user_custom_auth is not None: if user_custom_auth is not None:
response = await user_custom_auth(request=request, api_key=api_key) response = await user_custom_auth(request=request, api_key=api_key)
return UserAPIKeyAuth.model_validate(response) return UserAPIKeyAuth.model_validate(response)
if api_key == "":
# missing 'Bearer ' prefix
raise Exception(
f"Malformed API Key passed in. Ensure Key has `Bearer ` prefix. Passed in: {passed_in_key}"
)
### LITELLM-DEFINED AUTH FUNCTION ### ### LITELLM-DEFINED AUTH FUNCTION ###
if master_key is None: if master_key is None:
if isinstance(api_key, str): if isinstance(api_key, str):
@ -276,6 +279,17 @@ async def user_api_key_auth(
else: else:
return UserAPIKeyAuth() return UserAPIKeyAuth()
if api_key is None:
raise Exception("No API Key passed in")
if secrets.compare_digest(api_key, MISSING_BEARER):
# missing 'Bearer ' prefix
raise Exception(
f"Malformed API Key passed in. Ensure Key has `Bearer ` prefix. Passed in: {passed_in_key}"
)
elif secrets.compare_digest(api_key, NO_API_KEY):
# no api key passed in
raise Exception("No API Key passed in. Passed in: {passed_in_key}")
route: str = request.url.path route: str = request.url.path
if route == "/user/auth": if route == "/user/auth":
if general_settings.get("allow_user_auth", False) == True: if general_settings.get("allow_user_auth", False) == True: