diff --git a/litellm/proxy/proxy_server.py b/litellm/proxy/proxy_server.py index e51df68da..bba511514 100644 --- a/litellm/proxy/proxy_server.py +++ b/litellm/proxy/proxy_server.py @@ -134,6 +134,7 @@ server_settings: dict = {} log_file = "api_log.json" worker_config = None master_key = None +prisma_client = None #### HELPER FUNCTIONS #### def print_verbose(print_statement): global user_debug @@ -151,21 +152,36 @@ def usage_telemetry( ).start() async def user_api_key_auth(request: Request): - global master_key + global master_key, prisma_client if master_key is None: return try: api_key = await oauth2_scheme(request=request) if api_key == master_key: return - except: - pass + print(f"prisma_client: {prisma_client}") + if prisma_client: + await prisma_client.connect() + valid_token = await prisma_client.litellm_verificationtoken.find_first( + where={ + "token": api_key, + "expires": {"gte": datetime.utcnow()} # Check if the token is not expired + } + ) + print(f"valid_token: {valid_token}") + if valid_token: + return + else: + raise Exception + except Exception as e: + print(e) raise HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, detail={"error": "invalid user key"}, ) def add_keys_to_config(key, value): + #### DEPRECATED #### - this uses the older .toml config approach, which has been deprecated for config.yaml # Check if file exists if os.path.exists(user_config_path): # Load existing file @@ -184,6 +200,7 @@ def add_keys_to_config(key, value): def save_params_to_config(data: dict): + #### DEPRECATED #### - this uses the older .toml config approach, which has been deprecated for config.yaml # Check if file exists if os.path.exists(user_config_path): # Load existing file @@ -215,13 +232,14 @@ def save_params_to_config(data: dict): with open(user_config_path, "wb") as f: tomli_w.dump(config, f) -def prisma_setup(database_url: Optional[str]): - if database_url: - subprocess.run(['pip', 'install', 'prisma']) - subprocess.run(['python3', '-m', 'pip', 'install', 'prisma']) - subprocess.run(['prisma', 'db', 'push']) - # Now you can import the Prisma Client - from prisma import Client +def prisma_setup(): + global prisma_client + subprocess.run(['pip', 'install', 'prisma']) + subprocess.run(['python3', '-m', 'pip', 'install', 'prisma']) + subprocess.run(['prisma', 'db', 'push']) + # Now you can import the Prisma Client + from prisma import Client + prisma_client = Client() def load_router_config(router: Optional[litellm.Router], config_file_path: str): @@ -244,6 +262,9 @@ def load_router_config(router: Optional[litellm.Router], config_file_path: str): if environment_variables: for key, value in environment_variables.items(): os.environ[key] = value + ### CONNECT TO DATABASE ### + if key == "DATABASE_URL": + prisma_setup() ## GENERAL SERVER SETTINGS (e.g. master key,..) @@ -251,9 +272,6 @@ def load_router_config(router: Optional[litellm.Router], config_file_path: str): if general_settings: ### MASTER KEY ### master_key = general_settings.get("master_key", None) - ### CONNECT TO DATABASE ### - database_url = general_settings.get("database_url", None) - prisma_setup(database_url=database_url) ## LITELLM MODULE SETTINGS (e.g. litellm.drop_params=True,..) @@ -493,7 +511,13 @@ def startup_event(): import json worker_config = json.loads(os.getenv("WORKER_CONFIG")) initialize(**worker_config) - # print(f"\033[32mWorker Initialized\033[0m\n") + +@app.on_event("shutdown") +async def shutdown_event(): + global prisma_client + if prisma_client: + print("Disconnecting from Prisma") + await prisma_client.disconnect() #### API ENDPOINTS #### @router.get("/v1/models", dependencies=[Depends(user_api_key_auth)]) @@ -629,8 +653,7 @@ async def generate_key(request: Request): duration = _duration_in_seconds(duration=duration_str) expires = datetime.utcnow() + timedelta(seconds=duration) try: - from prisma import Client - db = Client() + db = prisma_client await db.connect() # Create a new verification token (you may want to enhance this logic based on your needs) print(dir(db))