fix(proxy_server.py): allow user to connect their proxy to a postgres db

This commit is contained in:
Krrish Dholakia 2023-11-18 15:57:31 -08:00
parent c9445db22f
commit 229e5ea083
3 changed files with 92 additions and 49 deletions

View file

@ -1,7 +1,9 @@
import sys, os, platform, time, copy
import sys, os, platform, time, copy, re
import threading, ast
import shutil, random, traceback, requests
from datetime import datetime, timedelta
from typing import Optional
import secrets, subprocess
messages: list = []
sys.path.insert(
0, os.path.abspath("../..")
@ -16,7 +18,6 @@ try:
import backoff
import yaml
except ImportError:
import subprocess
import sys
subprocess.check_call(
@ -214,6 +215,14 @@ def save_params_to_config(data: dict):
with open(user_config_path, "wb") as f:
tomli_w.dump(config, f)
def prisma_setup(database_url: Optional[str]):
if database_url:
subprocess.run(['pip', 'install', 'prisma'])
subprocess.run(['python3', '-m', 'pip', 'install', 'prisma'])
subprocess.run(['prisma', 'db', 'push'])
# Now you can import the Prisma Client
from prisma import Client
def load_router_config(router: Optional[litellm.Router], config_file_path: str):
global master_key
@ -230,10 +239,22 @@ def load_router_config(router: Optional[litellm.Router], config_file_path: str):
print(f"Loaded config YAML:\n{json.dumps(config, indent=2)}")
## ENVIRONMENT VARIABLES
environment_variables = config.get('environment_variables', None)
if environment_variables:
for key, value in environment_variables.items():
os.environ[key] = value
## GENERAL SERVER SETTINGS (e.g. master key,..)
general_settings = config.get("general_settings", None)
if general_settings:
### MASTER KEY ###
master_key = general_settings.get("master_key", None)
### CONNECT TO DATABASE ###
database_url = general_settings.get("database_url", None)
prisma_setup(database_url=database_url)
## LITELLM MODULE SETTINGS (e.g. litellm.drop_params=True,..)
litellm_settings = config.get('litellm_settings', None)
@ -250,12 +271,6 @@ def load_router_config(router: Optional[litellm.Router], config_file_path: str):
print(f"\033[32m {model.get('model_name', '')}\033[0m")
print()
## ENVIRONMENT VARIABLES
environment_variables = config.get('environment_variables', None)
if environment_variables:
for key, value in environment_variables.items():
os.environ[key] = value
return router, model_list, server_settings
def load_config():
@ -585,29 +600,54 @@ async def chat_completion(request: Request, model: Optional[str] = None):
detail=error_msg
)
@router.post("/key/generate", dependencies=[Depends(user_api_key_auth)])
async def generate_key(request: Request):
data = await request.json()
@router.post("/router/chat/completions", dependencies=[Depends(user_api_key_auth)])
async def router_completion(request: Request):
try:
body = await request.body()
body_str = body.decode()
try:
data = ast.literal_eval(body_str)
except:
data = json.loads(body_str)
return {"data": data}
except Exception as e:
print(f"\033[1;31mAn error occurred: {e}\n\n Debug this by setting `--debug`, e.g. `litellm --model gpt-3.5-turbo --debug`")
error_traceback = traceback.format_exc()
error_msg = f"{str(e)}\n\n{error_traceback}"
try:
status = e.status_code # type: ignore
except:
status = status.HTTP_500_INTERNAL_SERVER_ERROR,
raise HTTPException(
status_code=status,
detail=error_msg
token = f"sk-{secrets.token_urlsafe(16)}"
duration_str = data.get("duration", "1h") # Default to 1 hour if duration is not provided
def _duration_in_seconds(duration: str):
match = re.match(r"(\d+)([smhd]?)", duration)
if not match:
raise ValueError("Invalid duration format")
value, unit = match.groups()
value = int(value)
if unit == "s":
return value
elif unit == "m":
return value * 60
elif unit == "h":
return value * 3600
elif unit == "d":
return value * 86400
else:
raise ValueError("Unsupported duration unit")
duration = _duration_in_seconds(duration=duration_str)
expires = datetime.utcnow() + timedelta(seconds=duration)
try:
from prisma import Client
db = Client()
await db.connect()
# Create a new verification token (you may want to enhance this logic based on your needs)
print(dir(db))
verification_token_data = {
"token": token,
"expires": expires
}
new_verification_token = await db.litellm_verificationtoken.create(
{**verification_token_data}
)
print(f"new_verification_token: {new_verification_token}")
except Exception as e:
traceback.print_exc()
raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR)
return {"token": new_verification_token.token, "expires": new_verification_token.expires}
@router.get("/ollama_logs", dependencies=[Depends(user_api_key_auth)])
async def retrieve_server_log(request: Request):