fix(proxy_server.py): handle initializing prisma / db connection just once

This commit is contained in:
Krrish Dholakia 2023-11-18 16:45:12 -08:00
parent 4a364bcbc0
commit 7a669a36d2

View file

@ -1,4 +1,4 @@
import sys, os, platform, time, copy, re
import sys, os, platform, time, copy, re, asyncio
import threading, ast
import shutil, random, traceback, requests
from datetime import datetime, timedelta
@ -161,7 +161,6 @@ async def user_api_key_auth(request: Request):
return
print(f"prisma_client: {prisma_client}")
if prisma_client:
await prisma_client.connect()
valid_token = await prisma_client.litellm_verificationtoken.find_first(
where={
"token": api_key,
@ -174,7 +173,7 @@ async def user_api_key_auth(request: Request):
else:
raise Exception
except Exception as e:
print(e)
print(f"An exception occurred - {e}")
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail={"error": "invalid user key"},
@ -291,6 +290,49 @@ def load_router_config(router: Optional[litellm.Router], config_file_path: str):
return router, model_list, server_settings
async def generate_key_helper_fn(duration_str: str):
token = f"sk-{secrets.token_urlsafe(16)}"
def _duration_in_seconds(duration: str):
match = re.match(r"(\d+)([smhd]?)", duration)
if not match:
raise ValueError("Invalid duration format")
value, unit = match.groups()
value = int(value)
if unit == "s":
return value
elif unit == "m":
return value * 60
elif unit == "h":
return value * 3600
elif unit == "d":
return value * 86400
else:
raise ValueError("Unsupported duration unit")
duration = _duration_in_seconds(duration=duration_str)
expires = datetime.utcnow() + timedelta(seconds=duration)
try:
db = prisma_client
# Create a new verification token (you may want to enhance this logic based on your needs)
verification_token_data = {
"token": token,
"expires": expires
}
new_verification_token = await db.litellm_verificationtoken.create(
{**verification_token_data}
)
print(f"new_verification_token: {new_verification_token}")
except Exception as e:
traceback.print_exc()
raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR)
return {"token": new_verification_token.token, "expires": new_verification_token.expires}
async def generate_key_cli_task(duration_str):
task = asyncio.create_task(generate_key_helper_fn(duration_str=duration_str))
await task
def load_config():
#### DEPRECATED ####
try:
@ -412,7 +454,7 @@ def initialize(
add_function_to_prompt,
headers,
save,
config
config,
):
global user_model, user_api_base, user_debug, user_max_tokens, user_request_timeout, user_temperature, user_telemetry, user_headers, experimental, llm_model_list, llm_router, server_settings
generate_feedback_box()
@ -504,13 +546,15 @@ def litellm_completion(*args, **kwargs):
if 'stream' in kwargs and kwargs['stream'] == True: # use generate_responses to stream responses
return StreamingResponse(data_generator(response), media_type='text/event-stream')
return response
@app.on_event("startup")
def startup_event():
async def startup_event():
global prisma_client
import json
worker_config = json.loads(os.getenv("WORKER_CONFIG"))
initialize(**worker_config)
if prisma_client:
await prisma_client.connect()
@app.on_event("shutdown")
async def shutdown_event():
@ -625,51 +669,12 @@ async def chat_completion(request: Request, model: Optional[str] = None):
)
@router.post("/key/generate", dependencies=[Depends(user_api_key_auth)])
async def generate_key(request: Request):
async def generate_key_fn(request: Request):
data = await request.json()
token = f"sk-{secrets.token_urlsafe(16)}"
duration_str = data.get("duration", "1h") # Default to 1 hour if duration is not provided
def _duration_in_seconds(duration: str):
match = re.match(r"(\d+)([smhd]?)", duration)
if not match:
raise ValueError("Invalid duration format")
value, unit = match.groups()
value = int(value)
if unit == "s":
return value
elif unit == "m":
return value * 60
elif unit == "h":
return value * 3600
elif unit == "d":
return value * 86400
else:
raise ValueError("Unsupported duration unit")
duration = _duration_in_seconds(duration=duration_str)
expires = datetime.utcnow() + timedelta(seconds=duration)
try:
db = prisma_client
await db.connect()
# Create a new verification token (you may want to enhance this logic based on your needs)
print(dir(db))
verification_token_data = {
"token": token,
"expires": expires
}
new_verification_token = await db.litellm_verificationtoken.create(
{**verification_token_data}
)
print(f"new_verification_token: {new_verification_token}")
except Exception as e:
traceback.print_exc()
raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR)
return {"token": new_verification_token.token, "expires": new_verification_token.expires}
response = await generate_key_helper_fn(duration_str=duration_str)
return {"token": response["token"], "expires": response["expires"]}
@router.get("/ollama_logs", dependencies=[Depends(user_api_key_auth)])