forked from phoenix/litellm-mirror
(fix) add literals to _get_bearer_token
This commit is contained in:
parent
1b4b5185bb
commit
ce37ead178
1 changed files with 24 additions and 10 deletions
|
@ -2,7 +2,7 @@ import sys, os, platform, time, copy, re, asyncio, inspect
|
|||
import threading, ast
|
||||
import shutil, random, traceback, requests
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from typing import Optional, List
|
||||
from typing import Optional, List, Literal
|
||||
import secrets, subprocess
|
||||
import hashlib, uuid
|
||||
import warnings
|
||||
|
@ -141,6 +141,11 @@ class ProxyException(Exception):
|
|||
self.code = code
|
||||
|
||||
|
||||
# Literals - used by _get_bearer_token
|
||||
NO_API_KEY = "no-api-key"
|
||||
MISSING_BEARER = "missing-bearer"
|
||||
|
||||
|
||||
@app.exception_handler(ProxyException)
|
||||
async def openai_exception_handler(request: Request, exc: ProxyException):
|
||||
# NOTE: DO NOT MODIFY THIS, its crucial to map to Openai exceptions
|
||||
|
@ -233,11 +238,15 @@ def usage_telemetry(
|
|||
).start()
|
||||
|
||||
|
||||
def _get_bearer_token(api_key: str):
|
||||
def _get_bearer_token(
|
||||
api_key: str,
|
||||
) -> Union[str, Literal["no-api-key", "missing-bearer"]]:
|
||||
if api_key == "":
|
||||
return NO_API_KEY
|
||||
if api_key.startswith("Bearer "): # ensure Bearer token passed in
|
||||
api_key = api_key.replace("Bearer ", "") # extract the token
|
||||
else:
|
||||
api_key = ""
|
||||
api_key = MISSING_BEARER
|
||||
return api_key
|
||||
|
||||
|
||||
|
@ -258,17 +267,11 @@ async def user_api_key_auth(
|
|||
passed_in_key = api_key
|
||||
api_key = _get_bearer_token(api_key=api_key)
|
||||
|
||||
### USER-DEFINED AUTH FUNCTION -> This should always be run first if a user has defined it ###
|
||||
### USER-DEFINED AUTH FUNCTION ###
|
||||
if user_custom_auth is not None:
|
||||
response = await user_custom_auth(request=request, api_key=api_key)
|
||||
return UserAPIKeyAuth.model_validate(response)
|
||||
|
||||
if api_key == "":
|
||||
# missing 'Bearer ' prefix
|
||||
raise Exception(
|
||||
f"Malformed API Key passed in. Ensure Key has `Bearer ` prefix. Passed in: {passed_in_key}"
|
||||
)
|
||||
|
||||
### LITELLM-DEFINED AUTH FUNCTION ###
|
||||
if master_key is None:
|
||||
if isinstance(api_key, str):
|
||||
|
@ -276,6 +279,17 @@ async def user_api_key_auth(
|
|||
else:
|
||||
return UserAPIKeyAuth()
|
||||
|
||||
if api_key is None:
|
||||
raise Exception("No API Key passed in")
|
||||
if secrets.compare_digest(api_key, MISSING_BEARER):
|
||||
# missing 'Bearer ' prefix
|
||||
raise Exception(
|
||||
f"Malformed API Key passed in. Ensure Key has `Bearer ` prefix. Passed in: {passed_in_key}"
|
||||
)
|
||||
elif secrets.compare_digest(api_key, NO_API_KEY):
|
||||
# no api key passed in
|
||||
raise Exception("No API Key passed in. Passed in: {passed_in_key}")
|
||||
|
||||
route: str = request.url.path
|
||||
if route == "/user/auth":
|
||||
if general_settings.get("allow_user_auth", False) == True:
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue