fix(proxy_server.py): handle initializing prisma / db connection just once

This commit is contained in:
Krrish Dholakia 2023-11-18 16:45:12 -08:00
parent 4a364bcbc0
commit 7a669a36d2

View file

@ -1,4 +1,4 @@
import sys, os, platform, time, copy, re import sys, os, platform, time, copy, re, asyncio
import threading, ast import threading, ast
import shutil, random, traceback, requests import shutil, random, traceback, requests
from datetime import datetime, timedelta from datetime import datetime, timedelta
@ -161,7 +161,6 @@ async def user_api_key_auth(request: Request):
return return
print(f"prisma_client: {prisma_client}") print(f"prisma_client: {prisma_client}")
if prisma_client: if prisma_client:
await prisma_client.connect()
valid_token = await prisma_client.litellm_verificationtoken.find_first( valid_token = await prisma_client.litellm_verificationtoken.find_first(
where={ where={
"token": api_key, "token": api_key,
@ -174,7 +173,7 @@ async def user_api_key_auth(request: Request):
else: else:
raise Exception raise Exception
except Exception as e: except Exception as e:
print(e) print(f"An exception occurred - {e}")
raise HTTPException( raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED, status_code=status.HTTP_401_UNAUTHORIZED,
detail={"error": "invalid user key"}, detail={"error": "invalid user key"},
@ -291,6 +290,49 @@ def load_router_config(router: Optional[litellm.Router], config_file_path: str):
return router, model_list, server_settings return router, model_list, server_settings
async def generate_key_helper_fn(duration_str: str):
token = f"sk-{secrets.token_urlsafe(16)}"
def _duration_in_seconds(duration: str):
match = re.match(r"(\d+)([smhd]?)", duration)
if not match:
raise ValueError("Invalid duration format")
value, unit = match.groups()
value = int(value)
if unit == "s":
return value
elif unit == "m":
return value * 60
elif unit == "h":
return value * 3600
elif unit == "d":
return value * 86400
else:
raise ValueError("Unsupported duration unit")
duration = _duration_in_seconds(duration=duration_str)
expires = datetime.utcnow() + timedelta(seconds=duration)
try:
db = prisma_client
# Create a new verification token (you may want to enhance this logic based on your needs)
verification_token_data = {
"token": token,
"expires": expires
}
new_verification_token = await db.litellm_verificationtoken.create(
{**verification_token_data}
)
print(f"new_verification_token: {new_verification_token}")
except Exception as e:
traceback.print_exc()
raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR)
return {"token": new_verification_token.token, "expires": new_verification_token.expires}
async def generate_key_cli_task(duration_str):
task = asyncio.create_task(generate_key_helper_fn(duration_str=duration_str))
await task
def load_config(): def load_config():
#### DEPRECATED #### #### DEPRECATED ####
try: try:
@ -412,7 +454,7 @@ def initialize(
add_function_to_prompt, add_function_to_prompt,
headers, headers,
save, save,
config config,
): ):
global user_model, user_api_base, user_debug, user_max_tokens, user_request_timeout, user_temperature, user_telemetry, user_headers, experimental, llm_model_list, llm_router, server_settings global user_model, user_api_base, user_debug, user_max_tokens, user_request_timeout, user_temperature, user_telemetry, user_headers, experimental, llm_model_list, llm_router, server_settings
generate_feedback_box() generate_feedback_box()
@ -504,13 +546,15 @@ def litellm_completion(*args, **kwargs):
if 'stream' in kwargs and kwargs['stream'] == True: # use generate_responses to stream responses if 'stream' in kwargs and kwargs['stream'] == True: # use generate_responses to stream responses
return StreamingResponse(data_generator(response), media_type='text/event-stream') return StreamingResponse(data_generator(response), media_type='text/event-stream')
return response return response
@app.on_event("startup") @app.on_event("startup")
def startup_event(): async def startup_event():
global prisma_client
import json import json
worker_config = json.loads(os.getenv("WORKER_CONFIG")) worker_config = json.loads(os.getenv("WORKER_CONFIG"))
initialize(**worker_config) initialize(**worker_config)
if prisma_client:
await prisma_client.connect()
@app.on_event("shutdown") @app.on_event("shutdown")
async def shutdown_event(): async def shutdown_event():
@ -625,51 +669,12 @@ async def chat_completion(request: Request, model: Optional[str] = None):
) )
@router.post("/key/generate", dependencies=[Depends(user_api_key_auth)]) @router.post("/key/generate", dependencies=[Depends(user_api_key_auth)])
async def generate_key(request: Request): async def generate_key_fn(request: Request):
data = await request.json() data = await request.json()
token = f"sk-{secrets.token_urlsafe(16)}"
duration_str = data.get("duration", "1h") # Default to 1 hour if duration is not provided duration_str = data.get("duration", "1h") # Default to 1 hour if duration is not provided
response = await generate_key_helper_fn(duration_str=duration_str)
def _duration_in_seconds(duration: str): return {"token": response["token"], "expires": response["expires"]}
match = re.match(r"(\d+)([smhd]?)", duration)
if not match:
raise ValueError("Invalid duration format")
value, unit = match.groups()
value = int(value)
if unit == "s":
return value
elif unit == "m":
return value * 60
elif unit == "h":
return value * 3600
elif unit == "d":
return value * 86400
else:
raise ValueError("Unsupported duration unit")
duration = _duration_in_seconds(duration=duration_str)
expires = datetime.utcnow() + timedelta(seconds=duration)
try:
db = prisma_client
await db.connect()
# Create a new verification token (you may want to enhance this logic based on your needs)
print(dir(db))
verification_token_data = {
"token": token,
"expires": expires
}
new_verification_token = await db.litellm_verificationtoken.create(
{**verification_token_data}
)
print(f"new_verification_token: {new_verification_token}")
except Exception as e:
traceback.print_exc()
raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR)
return {"token": new_verification_token.token, "expires": new_verification_token.expires}
@router.get("/ollama_logs", dependencies=[Depends(user_api_key_auth)]) @router.get("/ollama_logs", dependencies=[Depends(user_api_key_auth)])