forked from phoenix/litellm-mirror
fix(proxy_server.py): allow user to connect their proxy to a postgres db
This commit is contained in:
parent
4aa95f9d43
commit
8ae855e008
3 changed files with 92 additions and 49 deletions
|
@ -2,27 +2,16 @@ model_list:
|
||||||
- model_name: zephyr-alpha
|
- model_name: zephyr-alpha
|
||||||
litellm_params: # params for litellm.completion() - https://docs.litellm.ai/docs/completion/input#input---request-body
|
litellm_params: # params for litellm.completion() - https://docs.litellm.ai/docs/completion/input#input---request-body
|
||||||
model: huggingface/HuggingFaceH4/zephyr-7b-alpha
|
model: huggingface/HuggingFaceH4/zephyr-7b-alpha
|
||||||
max_tokens: 20
|
api_base: http://0.0.0.0:8001
|
||||||
temperature: 0
|
- model_name: zephyr-beta
|
||||||
- model_name: gpt-4-team1
|
|
||||||
litellm_params:
|
litellm_params:
|
||||||
model: azure/chatgpt-v-2
|
model: huggingface/HuggingFaceH4/zephyr-7b-beta
|
||||||
api_base: https://openai-gpt-4-test-v-1.openai.azure.com/
|
api_base: https://<my-hosted-endpoint>
|
||||||
api_version: "2023-05-15"
|
|
||||||
azure_ad_token: eyJ0eXAiOiJ
|
|
||||||
- model_name: gpt-4-team2
|
|
||||||
litellm_params:
|
|
||||||
model: azure/gpt-4
|
|
||||||
api_key: sk-123
|
|
||||||
api_base: https://openai-gpt-4-test-v-2.openai.azure.com/
|
|
||||||
- model_name: gpt-4-team3
|
|
||||||
litellm_params:
|
|
||||||
model: azure/gpt-4
|
|
||||||
api_key: sk-123
|
|
||||||
- model_name: ollama/zephyr
|
|
||||||
litellm_params:
|
|
||||||
model: ollama/zephyr
|
|
||||||
|
|
||||||
litellm_settings:
|
litellm_settings:
|
||||||
drop_params: True
|
drop_params: True
|
||||||
success_callback: ["langfuse"] # https://docs.litellm.ai/docs/observability/langfuse_integration
|
set_verbose: True
|
||||||
|
|
||||||
|
general_settings:
|
||||||
|
master_key: sk-1234 # [OPTIONAL] Only use this if you to require all calls to contain this key (Authorization: Bearer sk-1234)
|
||||||
|
database_url: "postgresql://<user>:<password>@<host>:<port>/<dbname>" # [OPTIONAL] use for token-based auth to proxy
|
|
@ -1,7 +1,9 @@
|
||||||
import sys, os, platform, time, copy
|
import sys, os, platform, time, copy, re
|
||||||
import threading, ast
|
import threading, ast
|
||||||
import shutil, random, traceback, requests
|
import shutil, random, traceback, requests
|
||||||
|
from datetime import datetime, timedelta
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
import secrets, subprocess
|
||||||
messages: list = []
|
messages: list = []
|
||||||
sys.path.insert(
|
sys.path.insert(
|
||||||
0, os.path.abspath("../..")
|
0, os.path.abspath("../..")
|
||||||
|
@ -16,7 +18,6 @@ try:
|
||||||
import backoff
|
import backoff
|
||||||
import yaml
|
import yaml
|
||||||
except ImportError:
|
except ImportError:
|
||||||
import subprocess
|
|
||||||
import sys
|
import sys
|
||||||
|
|
||||||
subprocess.check_call(
|
subprocess.check_call(
|
||||||
|
@ -214,6 +215,14 @@ def save_params_to_config(data: dict):
|
||||||
with open(user_config_path, "wb") as f:
|
with open(user_config_path, "wb") as f:
|
||||||
tomli_w.dump(config, f)
|
tomli_w.dump(config, f)
|
||||||
|
|
||||||
|
def prisma_setup(database_url: Optional[str]):
|
||||||
|
if database_url:
|
||||||
|
subprocess.run(['pip', 'install', 'prisma'])
|
||||||
|
subprocess.run(['python3', '-m', 'pip', 'install', 'prisma'])
|
||||||
|
subprocess.run(['prisma', 'db', 'push'])
|
||||||
|
# Now you can import the Prisma Client
|
||||||
|
from prisma import Client
|
||||||
|
|
||||||
|
|
||||||
def load_router_config(router: Optional[litellm.Router], config_file_path: str):
|
def load_router_config(router: Optional[litellm.Router], config_file_path: str):
|
||||||
global master_key
|
global master_key
|
||||||
|
@ -230,10 +239,22 @@ def load_router_config(router: Optional[litellm.Router], config_file_path: str):
|
||||||
|
|
||||||
print(f"Loaded config YAML:\n{json.dumps(config, indent=2)}")
|
print(f"Loaded config YAML:\n{json.dumps(config, indent=2)}")
|
||||||
|
|
||||||
|
## ENVIRONMENT VARIABLES
|
||||||
|
environment_variables = config.get('environment_variables', None)
|
||||||
|
if environment_variables:
|
||||||
|
for key, value in environment_variables.items():
|
||||||
|
os.environ[key] = value
|
||||||
|
|
||||||
|
|
||||||
## GENERAL SERVER SETTINGS (e.g. master key,..)
|
## GENERAL SERVER SETTINGS (e.g. master key,..)
|
||||||
general_settings = config.get("general_settings", None)
|
general_settings = config.get("general_settings", None)
|
||||||
if general_settings:
|
if general_settings:
|
||||||
|
### MASTER KEY ###
|
||||||
master_key = general_settings.get("master_key", None)
|
master_key = general_settings.get("master_key", None)
|
||||||
|
### CONNECT TO DATABASE ###
|
||||||
|
database_url = general_settings.get("database_url", None)
|
||||||
|
prisma_setup(database_url=database_url)
|
||||||
|
|
||||||
|
|
||||||
## LITELLM MODULE SETTINGS (e.g. litellm.drop_params=True,..)
|
## LITELLM MODULE SETTINGS (e.g. litellm.drop_params=True,..)
|
||||||
litellm_settings = config.get('litellm_settings', None)
|
litellm_settings = config.get('litellm_settings', None)
|
||||||
|
@ -250,12 +271,6 @@ def load_router_config(router: Optional[litellm.Router], config_file_path: str):
|
||||||
print(f"\033[32m {model.get('model_name', '')}\033[0m")
|
print(f"\033[32m {model.get('model_name', '')}\033[0m")
|
||||||
print()
|
print()
|
||||||
|
|
||||||
## ENVIRONMENT VARIABLES
|
|
||||||
environment_variables = config.get('environment_variables', None)
|
|
||||||
if environment_variables:
|
|
||||||
for key, value in environment_variables.items():
|
|
||||||
os.environ[key] = value
|
|
||||||
|
|
||||||
return router, model_list, server_settings
|
return router, model_list, server_settings
|
||||||
|
|
||||||
def load_config():
|
def load_config():
|
||||||
|
@ -585,29 +600,54 @@ async def chat_completion(request: Request, model: Optional[str] = None):
|
||||||
detail=error_msg
|
detail=error_msg
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@router.post("/key/generate", dependencies=[Depends(user_api_key_auth)])
|
||||||
|
async def generate_key(request: Request):
|
||||||
|
data = await request.json()
|
||||||
|
|
||||||
@router.post("/router/chat/completions", dependencies=[Depends(user_api_key_auth)])
|
token = f"sk-{secrets.token_urlsafe(16)}"
|
||||||
async def router_completion(request: Request):
|
duration_str = data.get("duration", "1h") # Default to 1 hour if duration is not provided
|
||||||
try:
|
|
||||||
body = await request.body()
|
def _duration_in_seconds(duration: str):
|
||||||
body_str = body.decode()
|
match = re.match(r"(\d+)([smhd]?)", duration)
|
||||||
try:
|
if not match:
|
||||||
data = ast.literal_eval(body_str)
|
raise ValueError("Invalid duration format")
|
||||||
except:
|
|
||||||
data = json.loads(body_str)
|
value, unit = match.groups()
|
||||||
return {"data": data}
|
value = int(value)
|
||||||
except Exception as e:
|
|
||||||
print(f"\033[1;31mAn error occurred: {e}\n\n Debug this by setting `--debug`, e.g. `litellm --model gpt-3.5-turbo --debug`")
|
if unit == "s":
|
||||||
error_traceback = traceback.format_exc()
|
return value
|
||||||
error_msg = f"{str(e)}\n\n{error_traceback}"
|
elif unit == "m":
|
||||||
try:
|
return value * 60
|
||||||
status = e.status_code # type: ignore
|
elif unit == "h":
|
||||||
except:
|
return value * 3600
|
||||||
status = status.HTTP_500_INTERNAL_SERVER_ERROR,
|
elif unit == "d":
|
||||||
raise HTTPException(
|
return value * 86400
|
||||||
status_code=status,
|
else:
|
||||||
detail=error_msg
|
raise ValueError("Unsupported duration unit")
|
||||||
|
|
||||||
|
duration = _duration_in_seconds(duration=duration_str)
|
||||||
|
expires = datetime.utcnow() + timedelta(seconds=duration)
|
||||||
|
try:
|
||||||
|
from prisma import Client
|
||||||
|
db = Client()
|
||||||
|
await db.connect()
|
||||||
|
# Create a new verification token (you may want to enhance this logic based on your needs)
|
||||||
|
print(dir(db))
|
||||||
|
verification_token_data = {
|
||||||
|
"token": token,
|
||||||
|
"expires": expires
|
||||||
|
}
|
||||||
|
new_verification_token = await db.litellm_verificationtoken.create(
|
||||||
|
{**verification_token_data}
|
||||||
)
|
)
|
||||||
|
print(f"new_verification_token: {new_verification_token}")
|
||||||
|
except Exception as e:
|
||||||
|
traceback.print_exc()
|
||||||
|
raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR)
|
||||||
|
|
||||||
|
return {"token": new_verification_token.token, "expires": new_verification_token.expires}
|
||||||
|
|
||||||
|
|
||||||
@router.get("/ollama_logs", dependencies=[Depends(user_api_key_auth)])
|
@router.get("/ollama_logs", dependencies=[Depends(user_api_key_auth)])
|
||||||
async def retrieve_server_log(request: Request):
|
async def retrieve_server_log(request: Request):
|
||||||
|
|
14
litellm/proxy/schema.prisma
Normal file
14
litellm/proxy/schema.prisma
Normal file
|
@ -0,0 +1,14 @@
|
||||||
|
datasource client {
|
||||||
|
provider = "postgresql"
|
||||||
|
url = env("DATABASE_URL")
|
||||||
|
}
|
||||||
|
|
||||||
|
generator client {
|
||||||
|
provider = "prisma-client-py"
|
||||||
|
}
|
||||||
|
|
||||||
|
// required for token gen
|
||||||
|
model LiteLLM_VerificationToken {
|
||||||
|
token String @unique
|
||||||
|
expires DateTime
|
||||||
|
}
|
Loading…
Add table
Add a link
Reference in a new issue