fix(proxy/utils.py): fix db writes on retry

This commit is contained in:
Krrish Dholakia 2023-12-11 21:14:12 -08:00
parent 92cc39f00e
commit 66e0c06476
4 changed files with 18 additions and 47 deletions

View file

@ -270,7 +270,9 @@ async def user_api_key_auth(request: Request, api_key: str = fastapi.Security(ap
print(f"valid_token from cache: {valid_token}") print(f"valid_token from cache: {valid_token}")
if valid_token is None: if valid_token is None:
## check db ## check db
print(f"api key: {api_key}")
valid_token = await prisma_client.get_data(token=api_key, expires=datetime.utcnow()) valid_token = await prisma_client.get_data(token=api_key, expires=datetime.utcnow())
print(f"valid token from prisma: {valid_token}")
user_api_key_cache.set_cache(key=api_key, value=valid_token, ttl=60) user_api_key_cache.set_cache(key=api_key, value=valid_token, ttl=60)
elif valid_token is not None: elif valid_token is not None:
print(f"API Key Cache Hit!") print(f"API Key Cache Hit!")

View file

@ -1,5 +1,5 @@
from typing import Optional, List, Any, Literal from typing import Optional, List, Any, Literal
import os, subprocess, hashlib, importlib, asyncio import os, subprocess, hashlib, importlib, asyncio, copy
import litellm, backoff import litellm, backoff
from litellm.proxy._types import UserAPIKeyAuth from litellm.proxy._types import UserAPIKeyAuth
from litellm.caching import DualCache from litellm.caching import DualCache
@ -67,7 +67,6 @@ class ProxyLogging:
try: try:
self.call_details["data"] = data self.call_details["data"] = data
self.call_details["call_type"] = call_type self.call_details["call_type"] = call_type
## check if max parallel requests set ## check if max parallel requests set
if user_api_key_dict.max_parallel_requests is not None: if user_api_key_dict.max_parallel_requests is not None:
## if set, check if request allowed ## if set, check if request allowed
@ -165,19 +164,20 @@ class PrismaClient:
async def get_data(self, token: str, expires: Optional[Any]=None): async def get_data(self, token: str, expires: Optional[Any]=None):
try: try:
# check if plain text or hash # check if plain text or hash
hashed_token = token
if token.startswith("sk-"): if token.startswith("sk-"):
token = self.hash_token(token=token) hashed_token = self.hash_token(token=token)
if expires: if expires:
response = await self.db.litellm_verificationtoken.find_first( response = await self.db.litellm_verificationtoken.find_first(
where={ where={
"token": token, "token": hashed_token,
"expires": {"gte": expires} # Check if the token is not expired "expires": {"gte": expires} # Check if the token is not expired
} }
) )
else: else:
response = await self.db.litellm_verificationtoken.find_unique( response = await self.db.litellm_verificationtoken.find_unique(
where={ where={
"token": token "token": hashed_token
} }
) )
return response return response
@ -200,18 +200,18 @@ class PrismaClient:
try: try:
token = data["token"] token = data["token"]
hashed_token = self.hash_token(token=token) hashed_token = self.hash_token(token=token)
data["token"] = hashed_token db_data = copy.deepcopy(data)
db_data["token"] = hashed_token
new_verification_token = await self.db.litellm_verificationtoken.upsert( # type: ignore new_verification_token = await self.db.litellm_verificationtoken.upsert( # type: ignore
where={ where={
'token': hashed_token, 'token': hashed_token,
}, },
data={ data={
"create": {**data}, #type: ignore "create": {**db_data}, #type: ignore
"update": {} # don't do anything if it already exists "update": {} # don't do anything if it already exists
} }
) )
return new_verification_token return new_verification_token
except Exception as e: except Exception as e:
asyncio.create_task(self.proxy_logging_obj.failure_handler(original_exception=e)) asyncio.create_task(self.proxy_logging_obj.failure_handler(original_exception=e))
@ -235,15 +235,16 @@ class PrismaClient:
if token.startswith("sk-"): if token.startswith("sk-"):
token = self.hash_token(token=token) token = self.hash_token(token=token)
data["token"] = token db_data = copy.deepcopy(data)
db_data["token"] = token
response = await self.db.litellm_verificationtoken.update( response = await self.db.litellm_verificationtoken.update(
where={ where={
"token": token "token": token
}, },
data={**data} # type: ignore data={**db_data} # type: ignore
) )
print_verbose("\033[91m" + f"DB write succeeded {response}" + "\033[0m") print_verbose("\033[91m" + f"DB write succeeded {response}" + "\033[0m")
return {"token": token, "data": data} return {"token": token, "data": db_data}
except Exception as e: except Exception as e:
asyncio.create_task(self.proxy_logging_obj.failure_handler(original_exception=e)) asyncio.create_task(self.proxy_logging_obj.failure_handler(original_exception=e))
print_verbose("\033[91m" + f"DB write failed: {e}" + "\033[0m") print_verbose("\033[91m" + f"DB write failed: {e}" + "\033[0m")

View file

@ -1,33 +0,0 @@
Task exception was never retrieved
future: <Task finished name='Task-299' coro=<QueryEngine.aclose() done, defined at /opt/homebrew/lib/python3.11/site-packages/prisma/engine/query.py:110> exception=RuntimeError('Event loop is closed')>
Traceback (most recent call last):
File "/opt/homebrew/lib/python3.11/site-packages/prisma/engine/query.py", line 112, in aclose
await self._close_session()
File "/opt/homebrew/lib/python3.11/site-packages/prisma/engine/query.py", line 116, in _close_session
await self.session.close()
File "/opt/homebrew/lib/python3.11/site-packages/prisma/_async_http.py", line 35, in close
await self.session.aclose()
File "/opt/homebrew/lib/python3.11/site-packages/httpx/_client.py", line 1974, in aclose
await self._transport.aclose()
File "/opt/homebrew/lib/python3.11/site-packages/httpx/_transports/default.py", line 365, in aclose
await self._pool.aclose()
File "/opt/homebrew/lib/python3.11/site-packages/httpcore/_async/connection_pool.py", line 314, in aclose
await connection.aclose()
File "/opt/homebrew/lib/python3.11/site-packages/httpcore/_async/connection.py", line 166, in aclose
await self._connection.aclose()
File "/opt/homebrew/lib/python3.11/site-packages/httpcore/_async/http11.py", line 241, in aclose
await self._network_stream.aclose()
File "/opt/homebrew/lib/python3.11/site-packages/httpcore/_backends/anyio.py", line 54, in aclose
await self._stream.aclose()
File "/opt/homebrew/lib/python3.11/site-packages/anyio/_backends/_asyncio.py", line 1261, in aclose
self._transport.close()
File "/opt/homebrew/Cellar/python@3.11/3.11.6_1/Frameworks/Python.framework/Versions/3.11/lib/python3.11/asyncio/selector_events.py", line 860, in close
self._loop.call_soon(self._call_connection_lost, None)
File "/opt/homebrew/Cellar/python@3.11/3.11.6_1/Frameworks/Python.framework/Versions/3.11/lib/python3.11/asyncio/base_events.py", line 761, in call_soon
self._check_closed()
File "/opt/homebrew/Cellar/python@3.11/3.11.6_1/Frameworks/Python.framework/Versions/3.11/lib/python3.11/asyncio/base_events.py", line 519, in _check_closed
raise RuntimeError('Event loop is closed')
RuntimeError: Event loop is closed
Giving up get_data(...) after 3 tries (prisma.errors.ClientNotConnectedError: Client is not connected to the query engine, you must call `connect()` before attempting to query data.)
Giving up get_data(...) after 3 tries (prisma.errors.ClientNotConnectedError: Client is not connected to the query engine, you must call `connect()` before attempting to query data.)
Giving up get_data(...) after 3 tries (prisma.errors.ClientNotConnectedError: Client is not connected to the query engine, you must call `connect()` before attempting to query data.)

View file

@ -1,4 +1,4 @@
import sys, os, time import sys, os, time, asyncio
import traceback import traceback
from dotenv import load_dotenv from dotenv import load_dotenv
@ -71,8 +71,8 @@ def test_add_new_key(client):
# # Run the test - only runs via pytest # # Run the test - only runs via pytest
@pytest.mark.asyncio
def test_add_new_key_max_parallel_limit(client): async def test_add_new_key_max_parallel_limit(client):
try: try:
# Your test data # Your test data
test_data = {"duration": "20m", "max_parallel_requests": 1} test_data = {"duration": "20m", "max_parallel_requests": 1}
@ -88,6 +88,7 @@ def test_add_new_key_max_parallel_limit(client):
result = response.json() result = response.json()
def _post_data(): def _post_data():
json_data = {'model': 'azure-model', "messages": [{"role": "user", "content": f"this is a test request, write a short poem {time.time()}"}]} json_data = {'model': 'azure-model', "messages": [{"role": "user", "content": f"this is a test request, write a short poem {time.time()}"}]}
print(f"bearer token key: {result['key']}")
response = client.post("/chat/completions", json=json_data, headers={"Authorization": f"Bearer {result['key']}"}) response = client.post("/chat/completions", json=json_data, headers={"Authorization": f"Bearer {result['key']}"})
return response return response
def _run_in_parallel(): def _run_in_parallel():