From 7a669a36d2689c7f7890bc9c93e04ff3c2641299 Mon Sep 17 00:00:00 2001 From: Krrish Dholakia Date: Sat, 18 Nov 2023 16:45:12 -0800 Subject: [PATCH] fix(proxy_server.py): handle initializing prisma / db connection just once --- litellm/proxy/proxy_server.py | 103 ++++++++++++++++++---------------- 1 file changed, 54 insertions(+), 49 deletions(-) diff --git a/litellm/proxy/proxy_server.py b/litellm/proxy/proxy_server.py index bba511514..14ca79fca 100644 --- a/litellm/proxy/proxy_server.py +++ b/litellm/proxy/proxy_server.py @@ -1,4 +1,4 @@ -import sys, os, platform, time, copy, re +import sys, os, platform, time, copy, re, asyncio import threading, ast import shutil, random, traceback, requests from datetime import datetime, timedelta @@ -161,7 +161,6 @@ async def user_api_key_auth(request: Request): return print(f"prisma_client: {prisma_client}") if prisma_client: - await prisma_client.connect() valid_token = await prisma_client.litellm_verificationtoken.find_first( where={ "token": api_key, @@ -174,7 +173,7 @@ async def user_api_key_auth(request: Request): else: raise Exception except Exception as e: - print(e) + print(f"An exception occurred - {e}") raise HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, detail={"error": "invalid user key"}, @@ -291,6 +290,49 @@ def load_router_config(router: Optional[litellm.Router], config_file_path: str): return router, model_list, server_settings +async def generate_key_helper_fn(duration_str: str): + token = f"sk-{secrets.token_urlsafe(16)}" + def _duration_in_seconds(duration: str): + match = re.match(r"(\d+)([smhd]?)", duration) + if not match: + raise ValueError("Invalid duration format") + + value, unit = match.groups() + value = int(value) + + if unit == "s": + return value + elif unit == "m": + return value * 60 + elif unit == "h": + return value * 3600 + elif unit == "d": + return value * 86400 + else: + raise ValueError("Unsupported duration unit") + + duration = _duration_in_seconds(duration=duration_str) + expires = datetime.utcnow() + timedelta(seconds=duration) + try: + db = prisma_client + # Create a new verification token (you may want to enhance this logic based on your needs) + verification_token_data = { + "token": token, + "expires": expires + } + new_verification_token = await db.litellm_verificationtoken.create( + {**verification_token_data} + ) + print(f"new_verification_token: {new_verification_token}") + except Exception as e: + traceback.print_exc() + raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR) + return {"token": new_verification_token.token, "expires": new_verification_token.expires} + +async def generate_key_cli_task(duration_str): + task = asyncio.create_task(generate_key_helper_fn(duration_str=duration_str)) + await task + def load_config(): #### DEPRECATED #### try: @@ -412,7 +454,7 @@ def initialize( add_function_to_prompt, headers, save, - config + config, ): global user_model, user_api_base, user_debug, user_max_tokens, user_request_timeout, user_temperature, user_telemetry, user_headers, experimental, llm_model_list, llm_router, server_settings generate_feedback_box() @@ -504,13 +546,15 @@ def litellm_completion(*args, **kwargs): if 'stream' in kwargs and kwargs['stream'] == True: # use generate_responses to stream responses return StreamingResponse(data_generator(response), media_type='text/event-stream') return response - - + @app.on_event("startup") -def startup_event(): +async def startup_event(): + global prisma_client import json worker_config = json.loads(os.getenv("WORKER_CONFIG")) initialize(**worker_config) + if prisma_client: + await prisma_client.connect() @app.on_event("shutdown") async def shutdown_event(): @@ -625,51 +669,12 @@ async def chat_completion(request: Request, model: Optional[str] = None): ) @router.post("/key/generate", dependencies=[Depends(user_api_key_auth)]) -async def generate_key(request: Request): +async def generate_key_fn(request: Request): data = await request.json() - token = f"sk-{secrets.token_urlsafe(16)}" duration_str = data.get("duration", "1h") # Default to 1 hour if duration is not provided - - def _duration_in_seconds(duration: str): - match = re.match(r"(\d+)([smhd]?)", duration) - if not match: - raise ValueError("Invalid duration format") - - value, unit = match.groups() - value = int(value) - - if unit == "s": - return value - elif unit == "m": - return value * 60 - elif unit == "h": - return value * 3600 - elif unit == "d": - return value * 86400 - else: - raise ValueError("Unsupported duration unit") - - duration = _duration_in_seconds(duration=duration_str) - expires = datetime.utcnow() + timedelta(seconds=duration) - try: - db = prisma_client - await db.connect() - # Create a new verification token (you may want to enhance this logic based on your needs) - print(dir(db)) - verification_token_data = { - "token": token, - "expires": expires - } - new_verification_token = await db.litellm_verificationtoken.create( - {**verification_token_data} - ) - print(f"new_verification_token: {new_verification_token}") - except Exception as e: - traceback.print_exc() - raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR) - - return {"token": new_verification_token.token, "expires": new_verification_token.expires} + response = await generate_key_helper_fn(duration_str=duration_str) + return {"token": response["token"], "expires": response["expires"]} @router.get("/ollama_logs", dependencies=[Depends(user_api_key_auth)])