forked from phoenix/litellm-mirror
fix(proxy_server.py): check if user email in user db
This commit is contained in:
parent
7623c1a846
commit
ca40a88987
3 changed files with 71 additions and 6 deletions
|
@ -725,6 +725,7 @@ async def generate_key_helper_fn(
|
||||||
max_budget: Optional[float] = None,
|
max_budget: Optional[float] = None,
|
||||||
token: Optional[str] = None,
|
token: Optional[str] = None,
|
||||||
user_id: Optional[str] = None,
|
user_id: Optional[str] = None,
|
||||||
|
user_email: Optional[str] = None,
|
||||||
max_parallel_requests: Optional[int] = None,
|
max_parallel_requests: Optional[int] = None,
|
||||||
metadata: Optional[dict] = {},
|
metadata: Optional[dict] = {},
|
||||||
):
|
):
|
||||||
|
@ -780,6 +781,7 @@ async def generate_key_helper_fn(
|
||||||
"max_parallel_requests": max_parallel_requests,
|
"max_parallel_requests": max_parallel_requests,
|
||||||
"metadata": metadata_json,
|
"metadata": metadata_json,
|
||||||
"max_budget": max_budget,
|
"max_budget": max_budget,
|
||||||
|
"user_email": user_email,
|
||||||
}
|
}
|
||||||
new_verification_token = await prisma_client.insert_data(
|
new_verification_token = await prisma_client.insert_data(
|
||||||
data=verification_token_data
|
data=verification_token_data
|
||||||
|
@ -1727,15 +1729,40 @@ async def user_auth(request: Request):
|
||||||
RESEND_API_EMAIL = "my-sending-email"
|
RESEND_API_EMAIL = "my-sending-email"
|
||||||
```
|
```
|
||||||
"""
|
"""
|
||||||
|
global prisma_client
|
||||||
|
|
||||||
data = await request.json() # type: ignore
|
data = await request.json() # type: ignore
|
||||||
user_email = data["user_email"]
|
user_email = data["user_email"]
|
||||||
import os
|
if user_email is None:
|
||||||
import resend
|
raise HTTPException(status_code=400, detail="User email is none")
|
||||||
|
|
||||||
## [TODO]: Check if user exists, if so - use an existing key, if not - create new user -> return new key
|
import os
|
||||||
response = await generate_key_helper_fn(
|
|
||||||
**{"duration": "1hr", "models": [], "aliases": {}, "config": {}, "spend": 0} # type: ignore
|
try:
|
||||||
|
import resend
|
||||||
|
except ImportError:
|
||||||
|
raise Exception(
|
||||||
|
"Resend package missing. Run `pip install litellm[extra_proxy]` to add missing dependencies."
|
||||||
|
)
|
||||||
|
|
||||||
|
if prisma_client is None: # if no db connected, raise an error
|
||||||
|
raise Exception("No connected db.")
|
||||||
|
|
||||||
|
### Check if user email in user table
|
||||||
|
response = await prisma_client.get_generic_data(
|
||||||
|
key="user_email", value=user_email, db="users"
|
||||||
)
|
)
|
||||||
|
print(f"response: {response}")
|
||||||
|
### if so - generate a 24 hr key with that user id
|
||||||
|
if response is not None:
|
||||||
|
user_id = response.user_id
|
||||||
|
response = await generate_key_helper_fn(
|
||||||
|
**{"duration": "24hr", "models": [], "aliases": {}, "config": {}, "spend": 0, "user_id": user_id} # type: ignore
|
||||||
|
)
|
||||||
|
else: ### else - create new user
|
||||||
|
response = await generate_key_helper_fn(
|
||||||
|
**{"duration": "24hr", "models": [], "aliases": {}, "config": {}, "spend": 0, "user_email": user_email} # type: ignore
|
||||||
|
)
|
||||||
|
|
||||||
base_url = os.getenv("LITELLM_HOSTED_UI", "https://dashboard.litellm.ai/")
|
base_url = os.getenv("LITELLM_HOSTED_UI", "https://dashboard.litellm.ai/")
|
||||||
|
|
||||||
|
|
|
@ -11,6 +11,7 @@ model LiteLLM_UserTable {
|
||||||
user_id String @unique
|
user_id String @unique
|
||||||
max_budget Float?
|
max_budget Float?
|
||||||
spend Float @default(0.0)
|
spend Float @default(0.0)
|
||||||
|
user_email String?
|
||||||
}
|
}
|
||||||
|
|
||||||
// required for token gen
|
// required for token gen
|
||||||
|
|
|
@ -182,6 +182,38 @@ class PrismaClient:
|
||||||
db_data[k] = json.dumps(v)
|
db_data[k] = json.dumps(v)
|
||||||
return db_data
|
return db_data
|
||||||
|
|
||||||
|
@backoff.on_exception(
|
||||||
|
backoff.expo,
|
||||||
|
Exception, # base exception to catch for the backoff
|
||||||
|
max_tries=3, # maximum number of retries
|
||||||
|
max_time=10, # maximum total time to retry for
|
||||||
|
on_backoff=on_backoff, # specifying the function to call on backoff
|
||||||
|
)
|
||||||
|
async def get_generic_data(
|
||||||
|
self,
|
||||||
|
key: str,
|
||||||
|
value: Any,
|
||||||
|
db: Literal["users", "keys"],
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Generic implementation of get data
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
if db == "users":
|
||||||
|
response = await self.db.litellm_usertable.find_first(
|
||||||
|
where={key: value} # type: ignore
|
||||||
|
)
|
||||||
|
elif db == "keys":
|
||||||
|
response = await self.db.litellm_verificationtoken.find_first( # type: ignore
|
||||||
|
where={key: value} # type: ignore
|
||||||
|
)
|
||||||
|
return response
|
||||||
|
except Exception as e:
|
||||||
|
asyncio.create_task(
|
||||||
|
self.proxy_logging_obj.failure_handler(original_exception=e)
|
||||||
|
)
|
||||||
|
raise e
|
||||||
|
|
||||||
@backoff.on_exception(
|
@backoff.on_exception(
|
||||||
backoff.expo,
|
backoff.expo,
|
||||||
Exception, # base exception to catch for the backoff
|
Exception, # base exception to catch for the backoff
|
||||||
|
@ -255,6 +287,7 @@ class PrismaClient:
|
||||||
db_data = self.jsonify_object(data=data)
|
db_data = self.jsonify_object(data=data)
|
||||||
db_data["token"] = hashed_token
|
db_data["token"] = hashed_token
|
||||||
max_budget = db_data.pop("max_budget", None)
|
max_budget = db_data.pop("max_budget", None)
|
||||||
|
user_email = db_data.pop("user_email", None)
|
||||||
new_verification_token = await self.db.litellm_verificationtoken.upsert( # type: ignore
|
new_verification_token = await self.db.litellm_verificationtoken.upsert( # type: ignore
|
||||||
where={
|
where={
|
||||||
"token": hashed_token,
|
"token": hashed_token,
|
||||||
|
@ -268,7 +301,11 @@ class PrismaClient:
|
||||||
new_user_row = await self.db.litellm_usertable.upsert(
|
new_user_row = await self.db.litellm_usertable.upsert(
|
||||||
where={"user_id": data["user_id"]},
|
where={"user_id": data["user_id"]},
|
||||||
data={
|
data={
|
||||||
"create": {"user_id": data["user_id"], "max_budget": max_budget},
|
"create": {
|
||||||
|
"user_id": data["user_id"],
|
||||||
|
"max_budget": max_budget,
|
||||||
|
"user_email": user_email,
|
||||||
|
},
|
||||||
"update": {}, # don't do anything if it already exists
|
"update": {}, # don't do anything if it already exists
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue