forked from phoenix/litellm-mirror
(fix) add literals to _get_bearer_token
This commit is contained in:
parent
1b4b5185bb
commit
ce37ead178
1 changed files with 24 additions and 10 deletions
|
@ -2,7 +2,7 @@ import sys, os, platform, time, copy, re, asyncio, inspect
|
||||||
import threading, ast
|
import threading, ast
|
||||||
import shutil, random, traceback, requests
|
import shutil, random, traceback, requests
|
||||||
from datetime import datetime, timedelta, timezone
|
from datetime import datetime, timedelta, timezone
|
||||||
from typing import Optional, List
|
from typing import Optional, List, Literal
|
||||||
import secrets, subprocess
|
import secrets, subprocess
|
||||||
import hashlib, uuid
|
import hashlib, uuid
|
||||||
import warnings
|
import warnings
|
||||||
|
@ -141,6 +141,11 @@ class ProxyException(Exception):
|
||||||
self.code = code
|
self.code = code
|
||||||
|
|
||||||
|
|
||||||
|
# Literals - used by _get_bearer_token
|
||||||
|
NO_API_KEY = "no-api-key"
|
||||||
|
MISSING_BEARER = "missing-bearer"
|
||||||
|
|
||||||
|
|
||||||
@app.exception_handler(ProxyException)
|
@app.exception_handler(ProxyException)
|
||||||
async def openai_exception_handler(request: Request, exc: ProxyException):
|
async def openai_exception_handler(request: Request, exc: ProxyException):
|
||||||
# NOTE: DO NOT MODIFY THIS, its crucial to map to Openai exceptions
|
# NOTE: DO NOT MODIFY THIS, its crucial to map to Openai exceptions
|
||||||
|
@ -233,11 +238,15 @@ def usage_telemetry(
|
||||||
).start()
|
).start()
|
||||||
|
|
||||||
|
|
||||||
def _get_bearer_token(api_key: str):
|
def _get_bearer_token(
|
||||||
|
api_key: str,
|
||||||
|
) -> Union[str, Literal["no-api-key", "missing-bearer"]]:
|
||||||
|
if api_key == "":
|
||||||
|
return NO_API_KEY
|
||||||
if api_key.startswith("Bearer "): # ensure Bearer token passed in
|
if api_key.startswith("Bearer "): # ensure Bearer token passed in
|
||||||
api_key = api_key.replace("Bearer ", "") # extract the token
|
api_key = api_key.replace("Bearer ", "") # extract the token
|
||||||
else:
|
else:
|
||||||
api_key = ""
|
api_key = MISSING_BEARER
|
||||||
return api_key
|
return api_key
|
||||||
|
|
||||||
|
|
||||||
|
@ -258,17 +267,11 @@ async def user_api_key_auth(
|
||||||
passed_in_key = api_key
|
passed_in_key = api_key
|
||||||
api_key = _get_bearer_token(api_key=api_key)
|
api_key = _get_bearer_token(api_key=api_key)
|
||||||
|
|
||||||
### USER-DEFINED AUTH FUNCTION -> This should always be run first if a user has defined it ###
|
### USER-DEFINED AUTH FUNCTION ###
|
||||||
if user_custom_auth is not None:
|
if user_custom_auth is not None:
|
||||||
response = await user_custom_auth(request=request, api_key=api_key)
|
response = await user_custom_auth(request=request, api_key=api_key)
|
||||||
return UserAPIKeyAuth.model_validate(response)
|
return UserAPIKeyAuth.model_validate(response)
|
||||||
|
|
||||||
if api_key == "":
|
|
||||||
# missing 'Bearer ' prefix
|
|
||||||
raise Exception(
|
|
||||||
f"Malformed API Key passed in. Ensure Key has `Bearer ` prefix. Passed in: {passed_in_key}"
|
|
||||||
)
|
|
||||||
|
|
||||||
### LITELLM-DEFINED AUTH FUNCTION ###
|
### LITELLM-DEFINED AUTH FUNCTION ###
|
||||||
if master_key is None:
|
if master_key is None:
|
||||||
if isinstance(api_key, str):
|
if isinstance(api_key, str):
|
||||||
|
@ -276,6 +279,17 @@ async def user_api_key_auth(
|
||||||
else:
|
else:
|
||||||
return UserAPIKeyAuth()
|
return UserAPIKeyAuth()
|
||||||
|
|
||||||
|
if api_key is None:
|
||||||
|
raise Exception("No API Key passed in")
|
||||||
|
if secrets.compare_digest(api_key, MISSING_BEARER):
|
||||||
|
# missing 'Bearer ' prefix
|
||||||
|
raise Exception(
|
||||||
|
f"Malformed API Key passed in. Ensure Key has `Bearer ` prefix. Passed in: {passed_in_key}"
|
||||||
|
)
|
||||||
|
elif secrets.compare_digest(api_key, NO_API_KEY):
|
||||||
|
# no api key passed in
|
||||||
|
raise Exception("No API Key passed in. Passed in: {passed_in_key}")
|
||||||
|
|
||||||
route: str = request.url.path
|
route: str = request.url.path
|
||||||
if route == "/user/auth":
|
if route == "/user/auth":
|
||||||
if general_settings.get("allow_user_auth", False) == True:
|
if general_settings.get("allow_user_auth", False) == True:
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue