fix(proxy_server.py): allow user to connect their proxy to a postgres db

This commit is contained in:
Krrish Dholakia 2023-11-18 15:57:31 -08:00
parent 4aa95f9d43
commit 8ae855e008
3 changed files with 92 additions and 49 deletions

View file

@ -2,27 +2,16 @@ model_list:
- model_name: zephyr-alpha
litellm_params: # params for litellm.completion() - https://docs.litellm.ai/docs/completion/input#input---request-body
model: huggingface/HuggingFaceH4/zephyr-7b-alpha
max_tokens: 20
temperature: 0
- model_name: gpt-4-team1
api_base: http://0.0.0.0:8001
- model_name: zephyr-beta
litellm_params:
model: azure/chatgpt-v-2
api_base: https://openai-gpt-4-test-v-1.openai.azure.com/
api_version: "2023-05-15"
azure_ad_token: eyJ0eXAiOiJ
- model_name: gpt-4-team2
litellm_params:
model: azure/gpt-4
api_key: sk-123
api_base: https://openai-gpt-4-test-v-2.openai.azure.com/
- model_name: gpt-4-team3
litellm_params:
model: azure/gpt-4
api_key: sk-123
- model_name: ollama/zephyr
litellm_params:
model: ollama/zephyr
model: huggingface/HuggingFaceH4/zephyr-7b-beta
api_base: https://<my-hosted-endpoint>
litellm_settings:
drop_params: True
success_callback: ["langfuse"] # https://docs.litellm.ai/docs/observability/langfuse_integration
set_verbose: True
general_settings:
master_key: sk-1234 # [OPTIONAL] Only use this if you to require all calls to contain this key (Authorization: Bearer sk-1234)
database_url: "postgresql://<user>:<password>@<host>:<port>/<dbname>" # [OPTIONAL] use for token-based auth to proxy

View file

@ -1,7 +1,9 @@
import sys, os, platform, time, copy
import sys, os, platform, time, copy, re
import threading, ast
import shutil, random, traceback, requests
from datetime import datetime, timedelta
from typing import Optional
import secrets, subprocess
messages: list = []
sys.path.insert(
0, os.path.abspath("../..")
@ -16,7 +18,6 @@ try:
import backoff
import yaml
except ImportError:
import subprocess
import sys
subprocess.check_call(
@ -214,6 +215,14 @@ def save_params_to_config(data: dict):
with open(user_config_path, "wb") as f:
tomli_w.dump(config, f)
def prisma_setup(database_url: Optional[str]):
if database_url:
subprocess.run(['pip', 'install', 'prisma'])
subprocess.run(['python3', '-m', 'pip', 'install', 'prisma'])
subprocess.run(['prisma', 'db', 'push'])
# Now you can import the Prisma Client
from prisma import Client
def load_router_config(router: Optional[litellm.Router], config_file_path: str):
global master_key
@ -230,10 +239,22 @@ def load_router_config(router: Optional[litellm.Router], config_file_path: str):
print(f"Loaded config YAML:\n{json.dumps(config, indent=2)}")
## ENVIRONMENT VARIABLES
environment_variables = config.get('environment_variables', None)
if environment_variables:
for key, value in environment_variables.items():
os.environ[key] = value
## GENERAL SERVER SETTINGS (e.g. master key,..)
general_settings = config.get("general_settings", None)
if general_settings:
### MASTER KEY ###
master_key = general_settings.get("master_key", None)
### CONNECT TO DATABASE ###
database_url = general_settings.get("database_url", None)
prisma_setup(database_url=database_url)
## LITELLM MODULE SETTINGS (e.g. litellm.drop_params=True,..)
litellm_settings = config.get('litellm_settings', None)
@ -250,12 +271,6 @@ def load_router_config(router: Optional[litellm.Router], config_file_path: str):
print(f"\033[32m {model.get('model_name', '')}\033[0m")
print()
## ENVIRONMENT VARIABLES
environment_variables = config.get('environment_variables', None)
if environment_variables:
for key, value in environment_variables.items():
os.environ[key] = value
return router, model_list, server_settings
def load_config():
@ -585,29 +600,54 @@ async def chat_completion(request: Request, model: Optional[str] = None):
detail=error_msg
)
@router.post("/key/generate", dependencies=[Depends(user_api_key_auth)])
async def generate_key(request: Request):
data = await request.json()
@router.post("/router/chat/completions", dependencies=[Depends(user_api_key_auth)])
async def router_completion(request: Request):
token = f"sk-{secrets.token_urlsafe(16)}"
duration_str = data.get("duration", "1h") # Default to 1 hour if duration is not provided
def _duration_in_seconds(duration: str):
match = re.match(r"(\d+)([smhd]?)", duration)
if not match:
raise ValueError("Invalid duration format")
value, unit = match.groups()
value = int(value)
if unit == "s":
return value
elif unit == "m":
return value * 60
elif unit == "h":
return value * 3600
elif unit == "d":
return value * 86400
else:
raise ValueError("Unsupported duration unit")
duration = _duration_in_seconds(duration=duration_str)
expires = datetime.utcnow() + timedelta(seconds=duration)
try:
body = await request.body()
body_str = body.decode()
try:
data = ast.literal_eval(body_str)
except:
data = json.loads(body_str)
return {"data": data}
except Exception as e:
print(f"\033[1;31mAn error occurred: {e}\n\n Debug this by setting `--debug`, e.g. `litellm --model gpt-3.5-turbo --debug`")
error_traceback = traceback.format_exc()
error_msg = f"{str(e)}\n\n{error_traceback}"
try:
status = e.status_code # type: ignore
except:
status = status.HTTP_500_INTERNAL_SERVER_ERROR,
raise HTTPException(
status_code=status,
detail=error_msg
from prisma import Client
db = Client()
await db.connect()
# Create a new verification token (you may want to enhance this logic based on your needs)
print(dir(db))
verification_token_data = {
"token": token,
"expires": expires
}
new_verification_token = await db.litellm_verificationtoken.create(
{**verification_token_data}
)
print(f"new_verification_token: {new_verification_token}")
except Exception as e:
traceback.print_exc()
raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR)
return {"token": new_verification_token.token, "expires": new_verification_token.expires}
@router.get("/ollama_logs", dependencies=[Depends(user_api_key_auth)])
async def retrieve_server_log(request: Request):

View file

@ -0,0 +1,14 @@
datasource client {
provider = "postgresql"
url = env("DATABASE_URL")
}
generator client {
provider = "prisma-client-py"
}
// required for token gen
model LiteLLM_VerificationToken {
token String @unique
expires DateTime
}