forked from phoenix/litellm-mirror
fix(proxy_server.py): allow user to connect their proxy to a postgres db
This commit is contained in:
parent
4aa95f9d43
commit
8ae855e008
3 changed files with 92 additions and 49 deletions
|
@ -1,7 +1,9 @@
|
|||
import sys, os, platform, time, copy
|
||||
import sys, os, platform, time, copy, re
|
||||
import threading, ast
|
||||
import shutil, random, traceback, requests
|
||||
from datetime import datetime, timedelta
|
||||
from typing import Optional
|
||||
import secrets, subprocess
|
||||
messages: list = []
|
||||
sys.path.insert(
|
||||
0, os.path.abspath("../..")
|
||||
|
@ -16,7 +18,6 @@ try:
|
|||
import backoff
|
||||
import yaml
|
||||
except ImportError:
|
||||
import subprocess
|
||||
import sys
|
||||
|
||||
subprocess.check_call(
|
||||
|
@ -214,6 +215,14 @@ def save_params_to_config(data: dict):
|
|||
with open(user_config_path, "wb") as f:
|
||||
tomli_w.dump(config, f)
|
||||
|
||||
def prisma_setup(database_url: Optional[str]):
|
||||
if database_url:
|
||||
subprocess.run(['pip', 'install', 'prisma'])
|
||||
subprocess.run(['python3', '-m', 'pip', 'install', 'prisma'])
|
||||
subprocess.run(['prisma', 'db', 'push'])
|
||||
# Now you can import the Prisma Client
|
||||
from prisma import Client
|
||||
|
||||
|
||||
def load_router_config(router: Optional[litellm.Router], config_file_path: str):
|
||||
global master_key
|
||||
|
@ -230,10 +239,22 @@ def load_router_config(router: Optional[litellm.Router], config_file_path: str):
|
|||
|
||||
print(f"Loaded config YAML:\n{json.dumps(config, indent=2)}")
|
||||
|
||||
## ENVIRONMENT VARIABLES
|
||||
environment_variables = config.get('environment_variables', None)
|
||||
if environment_variables:
|
||||
for key, value in environment_variables.items():
|
||||
os.environ[key] = value
|
||||
|
||||
|
||||
## GENERAL SERVER SETTINGS (e.g. master key,..)
|
||||
general_settings = config.get("general_settings", None)
|
||||
if general_settings:
|
||||
### MASTER KEY ###
|
||||
master_key = general_settings.get("master_key", None)
|
||||
### CONNECT TO DATABASE ###
|
||||
database_url = general_settings.get("database_url", None)
|
||||
prisma_setup(database_url=database_url)
|
||||
|
||||
|
||||
## LITELLM MODULE SETTINGS (e.g. litellm.drop_params=True,..)
|
||||
litellm_settings = config.get('litellm_settings', None)
|
||||
|
@ -250,12 +271,6 @@ def load_router_config(router: Optional[litellm.Router], config_file_path: str):
|
|||
print(f"\033[32m {model.get('model_name', '')}\033[0m")
|
||||
print()
|
||||
|
||||
## ENVIRONMENT VARIABLES
|
||||
environment_variables = config.get('environment_variables', None)
|
||||
if environment_variables:
|
||||
for key, value in environment_variables.items():
|
||||
os.environ[key] = value
|
||||
|
||||
return router, model_list, server_settings
|
||||
|
||||
def load_config():
|
||||
|
@ -585,29 +600,54 @@ async def chat_completion(request: Request, model: Optional[str] = None):
|
|||
detail=error_msg
|
||||
)
|
||||
|
||||
@router.post("/key/generate", dependencies=[Depends(user_api_key_auth)])
|
||||
async def generate_key(request: Request):
|
||||
data = await request.json()
|
||||
|
||||
@router.post("/router/chat/completions", dependencies=[Depends(user_api_key_auth)])
|
||||
async def router_completion(request: Request):
|
||||
try:
|
||||
body = await request.body()
|
||||
body_str = body.decode()
|
||||
try:
|
||||
data = ast.literal_eval(body_str)
|
||||
except:
|
||||
data = json.loads(body_str)
|
||||
return {"data": data}
|
||||
except Exception as e:
|
||||
print(f"\033[1;31mAn error occurred: {e}\n\n Debug this by setting `--debug`, e.g. `litellm --model gpt-3.5-turbo --debug`")
|
||||
error_traceback = traceback.format_exc()
|
||||
error_msg = f"{str(e)}\n\n{error_traceback}"
|
||||
try:
|
||||
status = e.status_code # type: ignore
|
||||
except:
|
||||
status = status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
raise HTTPException(
|
||||
status_code=status,
|
||||
detail=error_msg
|
||||
token = f"sk-{secrets.token_urlsafe(16)}"
|
||||
duration_str = data.get("duration", "1h") # Default to 1 hour if duration is not provided
|
||||
|
||||
def _duration_in_seconds(duration: str):
|
||||
match = re.match(r"(\d+)([smhd]?)", duration)
|
||||
if not match:
|
||||
raise ValueError("Invalid duration format")
|
||||
|
||||
value, unit = match.groups()
|
||||
value = int(value)
|
||||
|
||||
if unit == "s":
|
||||
return value
|
||||
elif unit == "m":
|
||||
return value * 60
|
||||
elif unit == "h":
|
||||
return value * 3600
|
||||
elif unit == "d":
|
||||
return value * 86400
|
||||
else:
|
||||
raise ValueError("Unsupported duration unit")
|
||||
|
||||
duration = _duration_in_seconds(duration=duration_str)
|
||||
expires = datetime.utcnow() + timedelta(seconds=duration)
|
||||
try:
|
||||
from prisma import Client
|
||||
db = Client()
|
||||
await db.connect()
|
||||
# Create a new verification token (you may want to enhance this logic based on your needs)
|
||||
print(dir(db))
|
||||
verification_token_data = {
|
||||
"token": token,
|
||||
"expires": expires
|
||||
}
|
||||
new_verification_token = await db.litellm_verificationtoken.create(
|
||||
{**verification_token_data}
|
||||
)
|
||||
print(f"new_verification_token: {new_verification_token}")
|
||||
except Exception as e:
|
||||
traceback.print_exc()
|
||||
raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR)
|
||||
|
||||
return {"token": new_verification_token.token, "expires": new_verification_token.expires}
|
||||
|
||||
|
||||
@router.get("/ollama_logs", dependencies=[Depends(user_api_key_auth)])
|
||||
async def retrieve_server_log(request: Request):
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue