mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-25 10:44:24 +00:00
fix(proxy_server.py): manage budget at user-level not key-level
https://github.com/BerriAI/litellm/issues/1220
This commit is contained in:
parent
979575a2a6
commit
89ee9fe400
5 changed files with 220 additions and 79 deletions
|
@ -165,26 +165,35 @@ class PrismaClient:
|
|||
max_time=10, # maximum total time to retry for
|
||||
on_backoff=on_backoff, # specifying the function to call on backoff
|
||||
)
|
||||
async def get_data(self, token: str, expires: Optional[Any]=None):
|
||||
async def get_data(self, token: Optional[str]=None, expires: Optional[Any]=None, user_id: Optional[str]=None):
|
||||
try:
|
||||
# check if plain text or hash
|
||||
hashed_token = token
|
||||
if token.startswith("sk-"):
|
||||
hashed_token = self.hash_token(token=token)
|
||||
if expires:
|
||||
response = await self.db.litellm_verificationtoken.find_first(
|
||||
response = None
|
||||
if token is not None:
|
||||
# check if plain text or hash
|
||||
hashed_token = token
|
||||
if token.startswith("sk-"):
|
||||
hashed_token = self.hash_token(token=token)
|
||||
if expires:
|
||||
response = await self.db.litellm_verificationtoken.find_first(
|
||||
where={
|
||||
"token": hashed_token,
|
||||
"expires": {"gte": expires} # Check if the token is not expired
|
||||
}
|
||||
)
|
||||
else:
|
||||
response = await self.db.litellm_verificationtoken.find_unique(
|
||||
where={
|
||||
"token": hashed_token,
|
||||
"expires": {"gte": expires} # Check if the token is not expired
|
||||
"token": hashed_token
|
||||
}
|
||||
)
|
||||
else:
|
||||
response = await self.db.litellm_verificationtoken.find_unique(
|
||||
where={
|
||||
"token": hashed_token
|
||||
}
|
||||
)
|
||||
return response
|
||||
return response
|
||||
elif user_id is not None:
|
||||
response = await self.db.litellm_usertable.find_first( # type: ignore
|
||||
where={
|
||||
"user_id": user_id,
|
||||
}
|
||||
)
|
||||
return response
|
||||
except Exception as e:
|
||||
asyncio.create_task(self.proxy_logging_obj.failure_handler(original_exception=e))
|
||||
raise e
|
||||
|
@ -206,6 +215,7 @@ class PrismaClient:
|
|||
hashed_token = self.hash_token(token=token)
|
||||
db_data = self.jsonify_object(data=data)
|
||||
db_data["token"] = hashed_token
|
||||
max_budget = db_data.pop("max_budget", None)
|
||||
new_verification_token = await self.db.litellm_verificationtoken.upsert( # type: ignore
|
||||
where={
|
||||
'token': hashed_token,
|
||||
|
@ -215,6 +225,16 @@ class PrismaClient:
|
|||
"update": {} # don't do anything if it already exists
|
||||
}
|
||||
)
|
||||
|
||||
new_user_row = await self.db.litellm_usertable.upsert(
|
||||
where={
|
||||
'user_id': data['user_id']
|
||||
},
|
||||
data={
|
||||
"create": {"user_id": data['user_id'], "max_budget": max_budget},
|
||||
"update": {} # don't do anything if it already exists
|
||||
}
|
||||
)
|
||||
return new_verification_token
|
||||
except Exception as e:
|
||||
asyncio.create_task(self.proxy_logging_obj.failure_handler(original_exception=e))
|
||||
|
@ -228,26 +248,37 @@ class PrismaClient:
|
|||
max_time=10, # maximum total time to retry for
|
||||
on_backoff=on_backoff, # specifying the function to call on backoff
|
||||
)
|
||||
async def update_data(self, token: str, data: dict):
|
||||
async def update_data(self, token: Optional[str]=None, data: dict={}, user_id: Optional[str]=None):
|
||||
"""
|
||||
Update existing data
|
||||
"""
|
||||
try:
|
||||
print_verbose(f"token: {token}")
|
||||
# check if plain text or hash
|
||||
if token.startswith("sk-"):
|
||||
token = self.hash_token(token=token)
|
||||
|
||||
db_data = self.jsonify_object(data=data)
|
||||
db_data["token"] = token
|
||||
response = await self.db.litellm_verificationtoken.update(
|
||||
where={
|
||||
"token": token
|
||||
},
|
||||
data={**db_data} # type: ignore
|
||||
)
|
||||
print_verbose("\033[91m" + f"DB write succeeded {response}" + "\033[0m")
|
||||
return {"token": token, "data": db_data}
|
||||
if token is not None:
|
||||
print_verbose(f"token: {token}")
|
||||
# check if plain text or hash
|
||||
if token.startswith("sk-"):
|
||||
token = self.hash_token(token=token)
|
||||
db_data["token"] = token
|
||||
response = await self.db.litellm_verificationtoken.update(
|
||||
where={
|
||||
"token": token # type: ignore
|
||||
},
|
||||
data={**db_data} # type: ignore
|
||||
)
|
||||
print_verbose("\033[91m" + f"DB write succeeded {response}" + "\033[0m")
|
||||
return {"token": token, "data": db_data}
|
||||
elif user_id is not None:
|
||||
"""
|
||||
If data['spend'] + data['user'], update the user table with spend info as well
|
||||
"""
|
||||
update_user_row = await self.db.litellm_usertable.update(
|
||||
where={
|
||||
'user_id': user_id # type: ignore
|
||||
},
|
||||
data={**db_data} # type: ignore
|
||||
)
|
||||
return {"user_id": user_id, "data": db_data}
|
||||
except Exception as e:
|
||||
asyncio.create_task(self.proxy_logging_obj.failure_handler(original_exception=e))
|
||||
print_verbose("\033[91m" + f"DB write failed: {e}" + "\033[0m")
|
||||
|
@ -342,4 +373,16 @@ def get_instance_fn(value: str, config_file_path: Optional[str] = None) -> Any:
|
|||
except Exception as e:
|
||||
raise e
|
||||
|
||||
|
||||
### HELPER FUNCTIONS ###
|
||||
async def _cache_user_row(user_id: str, cache: DualCache, db: PrismaClient):
|
||||
"""
|
||||
Check if a user_id exists in cache,
|
||||
if not retrieve it.
|
||||
"""
|
||||
cache_key = f"{user_id}_user_api_key_user_id"
|
||||
response = cache.get_cache(key=cache_key)
|
||||
if response is None: # Cache miss
|
||||
user_row = await db.get_data(user_id=user_id)
|
||||
cache_value = user_row.model_dump_json()
|
||||
cache.set_cache(key=cache_key, value=cache_value, ttl=600) # store for 10 minutes
|
||||
return
|
Loading…
Add table
Add a link
Reference in a new issue