fix(proxy/utils.py): fix db writes on retry

This commit is contained in:
Krrish Dholakia 2023-12-11 21:14:12 -08:00
parent 92cc39f00e
commit 66e0c06476
4 changed files with 18 additions and 47 deletions

View file

@ -1,5 +1,5 @@
from typing import Optional, List, Any, Literal
import os, subprocess, hashlib, importlib, asyncio
import os, subprocess, hashlib, importlib, asyncio, copy
import litellm, backoff
from litellm.proxy._types import UserAPIKeyAuth
from litellm.caching import DualCache
@ -67,7 +67,6 @@ class ProxyLogging:
try:
self.call_details["data"] = data
self.call_details["call_type"] = call_type
## check if max parallel requests set
if user_api_key_dict.max_parallel_requests is not None:
## if set, check if request allowed
@ -165,19 +164,20 @@ class PrismaClient:
async def get_data(self, token: str, expires: Optional[Any]=None):
try:
# check if plain text or hash
hashed_token = token
if token.startswith("sk-"):
token = self.hash_token(token=token)
hashed_token = self.hash_token(token=token)
if expires:
response = await self.db.litellm_verificationtoken.find_first(
where={
"token": token,
"token": hashed_token,
"expires": {"gte": expires} # Check if the token is not expired
}
)
else:
response = await self.db.litellm_verificationtoken.find_unique(
where={
"token": token
"token": hashed_token
}
)
return response
@ -200,18 +200,18 @@ class PrismaClient:
try:
token = data["token"]
hashed_token = self.hash_token(token=token)
data["token"] = hashed_token
db_data = copy.deepcopy(data)
db_data["token"] = hashed_token
new_verification_token = await self.db.litellm_verificationtoken.upsert( # type: ignore
where={
'token': hashed_token,
},
data={
"create": {**data}, #type: ignore
"create": {**db_data}, #type: ignore
"update": {} # don't do anything if it already exists
}
)
return new_verification_token
except Exception as e:
asyncio.create_task(self.proxy_logging_obj.failure_handler(original_exception=e))
@ -235,15 +235,16 @@ class PrismaClient:
if token.startswith("sk-"):
token = self.hash_token(token=token)
data["token"] = token
db_data = copy.deepcopy(data)
db_data["token"] = token
response = await self.db.litellm_verificationtoken.update(
where={
"token": token
},
data={**data} # type: ignore
data={**db_data} # type: ignore
)
print_verbose("\033[91m" + f"DB write succeeded {response}" + "\033[0m")
return {"token": token, "data": data}
return {"token": token, "data": db_data}
except Exception as e:
asyncio.create_task(self.proxy_logging_obj.failure_handler(original_exception=e))
print_verbose("\033[91m" + f"DB write failed: {e}" + "\033[0m")