feat(proxy_server.py): enable token based authentication for server endpoints

This commit is contained in:
Krrish Dholakia 2023-11-18 16:13:53 -08:00
parent 8ae855e008
commit c02794d3ff

View file

@ -134,6 +134,7 @@ server_settings: dict = {}
log_file = "api_log.json" log_file = "api_log.json"
worker_config = None worker_config = None
master_key = None master_key = None
prisma_client = None
#### HELPER FUNCTIONS #### #### HELPER FUNCTIONS ####
def print_verbose(print_statement): def print_verbose(print_statement):
global user_debug global user_debug
@ -151,21 +152,36 @@ def usage_telemetry(
).start() ).start()
async def user_api_key_auth(request: Request): async def user_api_key_auth(request: Request):
global master_key global master_key, prisma_client
if master_key is None: if master_key is None:
return return
try: try:
api_key = await oauth2_scheme(request=request) api_key = await oauth2_scheme(request=request)
if api_key == master_key: if api_key == master_key:
return return
except: print(f"prisma_client: {prisma_client}")
pass if prisma_client:
await prisma_client.connect()
valid_token = await prisma_client.litellm_verificationtoken.find_first(
where={
"token": api_key,
"expires": {"gte": datetime.utcnow()} # Check if the token is not expired
}
)
print(f"valid_token: {valid_token}")
if valid_token:
return
else:
raise Exception
except Exception as e:
print(e)
raise HTTPException( raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED, status_code=status.HTTP_401_UNAUTHORIZED,
detail={"error": "invalid user key"}, detail={"error": "invalid user key"},
) )
def add_keys_to_config(key, value): def add_keys_to_config(key, value):
#### DEPRECATED #### - this uses the older .toml config approach, which has been deprecated for config.yaml
# Check if file exists # Check if file exists
if os.path.exists(user_config_path): if os.path.exists(user_config_path):
# Load existing file # Load existing file
@ -184,6 +200,7 @@ def add_keys_to_config(key, value):
def save_params_to_config(data: dict): def save_params_to_config(data: dict):
#### DEPRECATED #### - this uses the older .toml config approach, which has been deprecated for config.yaml
# Check if file exists # Check if file exists
if os.path.exists(user_config_path): if os.path.exists(user_config_path):
# Load existing file # Load existing file
@ -215,13 +232,14 @@ def save_params_to_config(data: dict):
with open(user_config_path, "wb") as f: with open(user_config_path, "wb") as f:
tomli_w.dump(config, f) tomli_w.dump(config, f)
def prisma_setup(database_url: Optional[str]): def prisma_setup():
if database_url: global prisma_client
subprocess.run(['pip', 'install', 'prisma']) subprocess.run(['pip', 'install', 'prisma'])
subprocess.run(['python3', '-m', 'pip', 'install', 'prisma']) subprocess.run(['python3', '-m', 'pip', 'install', 'prisma'])
subprocess.run(['prisma', 'db', 'push']) subprocess.run(['prisma', 'db', 'push'])
# Now you can import the Prisma Client # Now you can import the Prisma Client
from prisma import Client from prisma import Client
prisma_client = Client()
def load_router_config(router: Optional[litellm.Router], config_file_path: str): def load_router_config(router: Optional[litellm.Router], config_file_path: str):
@ -244,6 +262,9 @@ def load_router_config(router: Optional[litellm.Router], config_file_path: str):
if environment_variables: if environment_variables:
for key, value in environment_variables.items(): for key, value in environment_variables.items():
os.environ[key] = value os.environ[key] = value
### CONNECT TO DATABASE ###
if key == "DATABASE_URL":
prisma_setup()
## GENERAL SERVER SETTINGS (e.g. master key,..) ## GENERAL SERVER SETTINGS (e.g. master key,..)
@ -251,9 +272,6 @@ def load_router_config(router: Optional[litellm.Router], config_file_path: str):
if general_settings: if general_settings:
### MASTER KEY ### ### MASTER KEY ###
master_key = general_settings.get("master_key", None) master_key = general_settings.get("master_key", None)
### CONNECT TO DATABASE ###
database_url = general_settings.get("database_url", None)
prisma_setup(database_url=database_url)
## LITELLM MODULE SETTINGS (e.g. litellm.drop_params=True,..) ## LITELLM MODULE SETTINGS (e.g. litellm.drop_params=True,..)
@ -493,7 +511,13 @@ def startup_event():
import json import json
worker_config = json.loads(os.getenv("WORKER_CONFIG")) worker_config = json.loads(os.getenv("WORKER_CONFIG"))
initialize(**worker_config) initialize(**worker_config)
# print(f"\033[32mWorker Initialized\033[0m\n")
@app.on_event("shutdown")
async def shutdown_event():
global prisma_client
if prisma_client:
print("Disconnecting from Prisma")
await prisma_client.disconnect()
#### API ENDPOINTS #### #### API ENDPOINTS ####
@router.get("/v1/models", dependencies=[Depends(user_api_key_auth)]) @router.get("/v1/models", dependencies=[Depends(user_api_key_auth)])
@ -629,8 +653,7 @@ async def generate_key(request: Request):
duration = _duration_in_seconds(duration=duration_str) duration = _duration_in_seconds(duration=duration_str)
expires = datetime.utcnow() + timedelta(seconds=duration) expires = datetime.utcnow() + timedelta(seconds=duration)
try: try:
from prisma import Client db = prisma_client
db = Client()
await db.connect() await db.connect()
# Create a new verification token (you may want to enhance this logic based on your needs) # Create a new verification token (you may want to enhance this logic based on your needs)
print(dir(db)) print(dir(db))