mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-26 11:14:04 +00:00
fix(proxy/utils.py): fix db writes on retry
This commit is contained in:
parent
92cc39f00e
commit
66e0c06476
4 changed files with 18 additions and 47 deletions
|
@ -1,5 +1,5 @@
|
|||
from typing import Optional, List, Any, Literal
|
||||
import os, subprocess, hashlib, importlib, asyncio
|
||||
import os, subprocess, hashlib, importlib, asyncio, copy
|
||||
import litellm, backoff
|
||||
from litellm.proxy._types import UserAPIKeyAuth
|
||||
from litellm.caching import DualCache
|
||||
|
@ -67,7 +67,6 @@ class ProxyLogging:
|
|||
try:
|
||||
self.call_details["data"] = data
|
||||
self.call_details["call_type"] = call_type
|
||||
|
||||
## check if max parallel requests set
|
||||
if user_api_key_dict.max_parallel_requests is not None:
|
||||
## if set, check if request allowed
|
||||
|
@ -165,19 +164,20 @@ class PrismaClient:
|
|||
async def get_data(self, token: str, expires: Optional[Any]=None):
|
||||
try:
|
||||
# check if plain text or hash
|
||||
hashed_token = token
|
||||
if token.startswith("sk-"):
|
||||
token = self.hash_token(token=token)
|
||||
hashed_token = self.hash_token(token=token)
|
||||
if expires:
|
||||
response = await self.db.litellm_verificationtoken.find_first(
|
||||
where={
|
||||
"token": token,
|
||||
"token": hashed_token,
|
||||
"expires": {"gte": expires} # Check if the token is not expired
|
||||
}
|
||||
)
|
||||
else:
|
||||
response = await self.db.litellm_verificationtoken.find_unique(
|
||||
where={
|
||||
"token": token
|
||||
"token": hashed_token
|
||||
}
|
||||
)
|
||||
return response
|
||||
|
@ -200,18 +200,18 @@ class PrismaClient:
|
|||
try:
|
||||
token = data["token"]
|
||||
hashed_token = self.hash_token(token=token)
|
||||
data["token"] = hashed_token
|
||||
db_data = copy.deepcopy(data)
|
||||
db_data["token"] = hashed_token
|
||||
|
||||
new_verification_token = await self.db.litellm_verificationtoken.upsert( # type: ignore
|
||||
where={
|
||||
'token': hashed_token,
|
||||
},
|
||||
data={
|
||||
"create": {**data}, #type: ignore
|
||||
"create": {**db_data}, #type: ignore
|
||||
"update": {} # don't do anything if it already exists
|
||||
}
|
||||
)
|
||||
|
||||
return new_verification_token
|
||||
except Exception as e:
|
||||
asyncio.create_task(self.proxy_logging_obj.failure_handler(original_exception=e))
|
||||
|
@ -235,15 +235,16 @@ class PrismaClient:
|
|||
if token.startswith("sk-"):
|
||||
token = self.hash_token(token=token)
|
||||
|
||||
data["token"] = token
|
||||
db_data = copy.deepcopy(data)
|
||||
db_data["token"] = token
|
||||
response = await self.db.litellm_verificationtoken.update(
|
||||
where={
|
||||
"token": token
|
||||
},
|
||||
data={**data} # type: ignore
|
||||
data={**db_data} # type: ignore
|
||||
)
|
||||
print_verbose("\033[91m" + f"DB write succeeded {response}" + "\033[0m")
|
||||
return {"token": token, "data": data}
|
||||
return {"token": token, "data": db_data}
|
||||
except Exception as e:
|
||||
asyncio.create_task(self.proxy_logging_obj.failure_handler(original_exception=e))
|
||||
print_verbose("\033[91m" + f"DB write failed: {e}" + "\033[0m")
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue