From ca40a889876d729405fa5f4e266da2a59d0abcf0 Mon Sep 17 00:00:00 2001 From: Krrish Dholakia Date: Mon, 1 Jan 2024 14:19:46 +0530 Subject: [PATCH] fix(proxy_server.py): check if user email in user db --- litellm/proxy/proxy_server.py | 37 ++++++++++++++++++++++++++++----- litellm/proxy/schema.prisma | 1 + litellm/proxy/utils.py | 39 ++++++++++++++++++++++++++++++++++- 3 files changed, 71 insertions(+), 6 deletions(-) diff --git a/litellm/proxy/proxy_server.py b/litellm/proxy/proxy_server.py index a1eb0a8a7..b524e08e8 100644 --- a/litellm/proxy/proxy_server.py +++ b/litellm/proxy/proxy_server.py @@ -725,6 +725,7 @@ async def generate_key_helper_fn( max_budget: Optional[float] = None, token: Optional[str] = None, user_id: Optional[str] = None, + user_email: Optional[str] = None, max_parallel_requests: Optional[int] = None, metadata: Optional[dict] = {}, ): @@ -780,6 +781,7 @@ async def generate_key_helper_fn( "max_parallel_requests": max_parallel_requests, "metadata": metadata_json, "max_budget": max_budget, + "user_email": user_email, } new_verification_token = await prisma_client.insert_data( data=verification_token_data @@ -1727,15 +1729,40 @@ async def user_auth(request: Request): RESEND_API_EMAIL = "my-sending-email" ``` """ + global prisma_client + data = await request.json() # type: ignore user_email = data["user_email"] - import os - import resend + if user_email is None: + raise HTTPException(status_code=400, detail="User email is none") - ## [TODO]: Check if user exists, if so - use an existing key, if not - create new user -> return new key - response = await generate_key_helper_fn( - **{"duration": "1hr", "models": [], "aliases": {}, "config": {}, "spend": 0} # type: ignore + import os + + try: + import resend + except ImportError: + raise Exception( + "Resend package missing. Run `pip install litellm[extra_proxy]` to add missing dependencies." + ) + + if prisma_client is None: # if no db connected, raise an error + raise Exception("No connected db.") + + ### Check if user email in user table + response = await prisma_client.get_generic_data( + key="user_email", value=user_email, db="users" ) + print(f"response: {response}") + ### if so - generate a 24 hr key with that user id + if response is not None: + user_id = response.user_id + response = await generate_key_helper_fn( + **{"duration": "24hr", "models": [], "aliases": {}, "config": {}, "spend": 0, "user_id": user_id} # type: ignore + ) + else: ### else - create new user + response = await generate_key_helper_fn( + **{"duration": "24hr", "models": [], "aliases": {}, "config": {}, "spend": 0, "user_email": user_email} # type: ignore + ) base_url = os.getenv("LITELLM_HOSTED_UI", "https://dashboard.litellm.ai/") diff --git a/litellm/proxy/schema.prisma b/litellm/proxy/schema.prisma index 5fa0ea008..7ce05f285 100644 --- a/litellm/proxy/schema.prisma +++ b/litellm/proxy/schema.prisma @@ -11,6 +11,7 @@ model LiteLLM_UserTable { user_id String @unique max_budget Float? spend Float @default(0.0) + user_email String? } // required for token gen diff --git a/litellm/proxy/utils.py b/litellm/proxy/utils.py index 589c36ee7..ea73891c4 100644 --- a/litellm/proxy/utils.py +++ b/litellm/proxy/utils.py @@ -182,6 +182,38 @@ class PrismaClient: db_data[k] = json.dumps(v) return db_data + @backoff.on_exception( + backoff.expo, + Exception, # base exception to catch for the backoff + max_tries=3, # maximum number of retries + max_time=10, # maximum total time to retry for + on_backoff=on_backoff, # specifying the function to call on backoff + ) + async def get_generic_data( + self, + key: str, + value: Any, + db: Literal["users", "keys"], + ): + """ + Generic implementation of get data + """ + try: + if db == "users": + response = await self.db.litellm_usertable.find_first( + where={key: value} # type: ignore + ) + elif db == "keys": + response = await self.db.litellm_verificationtoken.find_first( # type: ignore + where={key: value} # type: ignore + ) + return response + except Exception as e: + asyncio.create_task( + self.proxy_logging_obj.failure_handler(original_exception=e) + ) + raise e + @backoff.on_exception( backoff.expo, Exception, # base exception to catch for the backoff @@ -255,6 +287,7 @@ class PrismaClient: db_data = self.jsonify_object(data=data) db_data["token"] = hashed_token max_budget = db_data.pop("max_budget", None) + user_email = db_data.pop("user_email", None) new_verification_token = await self.db.litellm_verificationtoken.upsert( # type: ignore where={ "token": hashed_token, @@ -268,7 +301,11 @@ class PrismaClient: new_user_row = await self.db.litellm_usertable.upsert( where={"user_id": data["user_id"]}, data={ - "create": {"user_id": data["user_id"], "max_budget": max_budget}, + "create": { + "user_id": data["user_id"], + "max_budget": max_budget, + "user_email": user_email, + }, "update": {}, # don't do anything if it already exists }, )