fix(proxy_server.py): allow user to connect their proxy to a postgres db

This commit is contained in:
Krrish Dholakia 2023-11-18 15:57:31 -08:00
parent 4aa95f9d43
commit 8ae855e008
3 changed files with 92 additions and 49 deletions

View file

@ -2,27 +2,16 @@ model_list:
- model_name: zephyr-alpha - model_name: zephyr-alpha
litellm_params: # params for litellm.completion() - https://docs.litellm.ai/docs/completion/input#input---request-body litellm_params: # params for litellm.completion() - https://docs.litellm.ai/docs/completion/input#input---request-body
model: huggingface/HuggingFaceH4/zephyr-7b-alpha model: huggingface/HuggingFaceH4/zephyr-7b-alpha
max_tokens: 20 api_base: http://0.0.0.0:8001
temperature: 0 - model_name: zephyr-beta
- model_name: gpt-4-team1
litellm_params: litellm_params:
model: azure/chatgpt-v-2 model: huggingface/HuggingFaceH4/zephyr-7b-beta
api_base: https://openai-gpt-4-test-v-1.openai.azure.com/ api_base: https://<my-hosted-endpoint>
api_version: "2023-05-15"
azure_ad_token: eyJ0eXAiOiJ
- model_name: gpt-4-team2
litellm_params:
model: azure/gpt-4
api_key: sk-123
api_base: https://openai-gpt-4-test-v-2.openai.azure.com/
- model_name: gpt-4-team3
litellm_params:
model: azure/gpt-4
api_key: sk-123
- model_name: ollama/zephyr
litellm_params:
model: ollama/zephyr
litellm_settings: litellm_settings:
drop_params: True drop_params: True
success_callback: ["langfuse"] # https://docs.litellm.ai/docs/observability/langfuse_integration set_verbose: True
general_settings:
master_key: sk-1234 # [OPTIONAL] Only use this if you to require all calls to contain this key (Authorization: Bearer sk-1234)
database_url: "postgresql://<user>:<password>@<host>:<port>/<dbname>" # [OPTIONAL] use for token-based auth to proxy

View file

@ -1,7 +1,9 @@
import sys, os, platform, time, copy import sys, os, platform, time, copy, re
import threading, ast import threading, ast
import shutil, random, traceback, requests import shutil, random, traceback, requests
from datetime import datetime, timedelta
from typing import Optional from typing import Optional
import secrets, subprocess
messages: list = [] messages: list = []
sys.path.insert( sys.path.insert(
0, os.path.abspath("../..") 0, os.path.abspath("../..")
@ -16,7 +18,6 @@ try:
import backoff import backoff
import yaml import yaml
except ImportError: except ImportError:
import subprocess
import sys import sys
subprocess.check_call( subprocess.check_call(
@ -214,6 +215,14 @@ def save_params_to_config(data: dict):
with open(user_config_path, "wb") as f: with open(user_config_path, "wb") as f:
tomli_w.dump(config, f) tomli_w.dump(config, f)
def prisma_setup(database_url: Optional[str]):
if database_url:
subprocess.run(['pip', 'install', 'prisma'])
subprocess.run(['python3', '-m', 'pip', 'install', 'prisma'])
subprocess.run(['prisma', 'db', 'push'])
# Now you can import the Prisma Client
from prisma import Client
def load_router_config(router: Optional[litellm.Router], config_file_path: str): def load_router_config(router: Optional[litellm.Router], config_file_path: str):
global master_key global master_key
@ -230,10 +239,22 @@ def load_router_config(router: Optional[litellm.Router], config_file_path: str):
print(f"Loaded config YAML:\n{json.dumps(config, indent=2)}") print(f"Loaded config YAML:\n{json.dumps(config, indent=2)}")
## ENVIRONMENT VARIABLES
environment_variables = config.get('environment_variables', None)
if environment_variables:
for key, value in environment_variables.items():
os.environ[key] = value
## GENERAL SERVER SETTINGS (e.g. master key,..) ## GENERAL SERVER SETTINGS (e.g. master key,..)
general_settings = config.get("general_settings", None) general_settings = config.get("general_settings", None)
if general_settings: if general_settings:
### MASTER KEY ###
master_key = general_settings.get("master_key", None) master_key = general_settings.get("master_key", None)
### CONNECT TO DATABASE ###
database_url = general_settings.get("database_url", None)
prisma_setup(database_url=database_url)
## LITELLM MODULE SETTINGS (e.g. litellm.drop_params=True,..) ## LITELLM MODULE SETTINGS (e.g. litellm.drop_params=True,..)
litellm_settings = config.get('litellm_settings', None) litellm_settings = config.get('litellm_settings', None)
@ -250,12 +271,6 @@ def load_router_config(router: Optional[litellm.Router], config_file_path: str):
print(f"\033[32m {model.get('model_name', '')}\033[0m") print(f"\033[32m {model.get('model_name', '')}\033[0m")
print() print()
## ENVIRONMENT VARIABLES
environment_variables = config.get('environment_variables', None)
if environment_variables:
for key, value in environment_variables.items():
os.environ[key] = value
return router, model_list, server_settings return router, model_list, server_settings
def load_config(): def load_config():
@ -585,29 +600,54 @@ async def chat_completion(request: Request, model: Optional[str] = None):
detail=error_msg detail=error_msg
) )
@router.post("/key/generate", dependencies=[Depends(user_api_key_auth)])
async def generate_key(request: Request):
data = await request.json()
@router.post("/router/chat/completions", dependencies=[Depends(user_api_key_auth)]) token = f"sk-{secrets.token_urlsafe(16)}"
async def router_completion(request: Request): duration_str = data.get("duration", "1h") # Default to 1 hour if duration is not provided
try:
body = await request.body() def _duration_in_seconds(duration: str):
body_str = body.decode() match = re.match(r"(\d+)([smhd]?)", duration)
try: if not match:
data = ast.literal_eval(body_str) raise ValueError("Invalid duration format")
except:
data = json.loads(body_str) value, unit = match.groups()
return {"data": data} value = int(value)
except Exception as e:
print(f"\033[1;31mAn error occurred: {e}\n\n Debug this by setting `--debug`, e.g. `litellm --model gpt-3.5-turbo --debug`") if unit == "s":
error_traceback = traceback.format_exc() return value
error_msg = f"{str(e)}\n\n{error_traceback}" elif unit == "m":
try: return value * 60
status = e.status_code # type: ignore elif unit == "h":
except: return value * 3600
status = status.HTTP_500_INTERNAL_SERVER_ERROR, elif unit == "d":
raise HTTPException( return value * 86400
status_code=status, else:
detail=error_msg raise ValueError("Unsupported duration unit")
duration = _duration_in_seconds(duration=duration_str)
expires = datetime.utcnow() + timedelta(seconds=duration)
try:
from prisma import Client
db = Client()
await db.connect()
# Create a new verification token (you may want to enhance this logic based on your needs)
print(dir(db))
verification_token_data = {
"token": token,
"expires": expires
}
new_verification_token = await db.litellm_verificationtoken.create(
{**verification_token_data}
) )
print(f"new_verification_token: {new_verification_token}")
except Exception as e:
traceback.print_exc()
raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR)
return {"token": new_verification_token.token, "expires": new_verification_token.expires}
@router.get("/ollama_logs", dependencies=[Depends(user_api_key_auth)]) @router.get("/ollama_logs", dependencies=[Depends(user_api_key_auth)])
async def retrieve_server_log(request: Request): async def retrieve_server_log(request: Request):

View file

@ -0,0 +1,14 @@
datasource client {
provider = "postgresql"
url = env("DATABASE_URL")
}
generator client {
provider = "prisma-client-py"
}
// required for token gen
model LiteLLM_VerificationToken {
token String @unique
expires DateTime
}