fix(proxy/utils.py): fix db writes on retry

This commit is contained in:
Krrish Dholakia 2023-12-11 21:14:12 -08:00
parent 92cc39f00e
commit 66e0c06476
4 changed files with 18 additions and 47 deletions

View file

@ -270,7 +270,9 @@ async def user_api_key_auth(request: Request, api_key: str = fastapi.Security(ap
print(f"valid_token from cache: {valid_token}")
if valid_token is None:
## check db
print(f"api key: {api_key}")
valid_token = await prisma_client.get_data(token=api_key, expires=datetime.utcnow())
print(f"valid token from prisma: {valid_token}")
user_api_key_cache.set_cache(key=api_key, value=valid_token, ttl=60)
elif valid_token is not None:
print(f"API Key Cache Hit!")

View file

@ -1,5 +1,5 @@
from typing import Optional, List, Any, Literal
import os, subprocess, hashlib, importlib, asyncio
import os, subprocess, hashlib, importlib, asyncio, copy
import litellm, backoff
from litellm.proxy._types import UserAPIKeyAuth
from litellm.caching import DualCache
@ -67,7 +67,6 @@ class ProxyLogging:
try:
self.call_details["data"] = data
self.call_details["call_type"] = call_type
## check if max parallel requests set
if user_api_key_dict.max_parallel_requests is not None:
## if set, check if request allowed
@ -165,19 +164,20 @@ class PrismaClient:
async def get_data(self, token: str, expires: Optional[Any]=None):
try:
# check if plain text or hash
hashed_token = token
if token.startswith("sk-"):
token = self.hash_token(token=token)
hashed_token = self.hash_token(token=token)
if expires:
response = await self.db.litellm_verificationtoken.find_first(
where={
"token": token,
"token": hashed_token,
"expires": {"gte": expires} # Check if the token is not expired
}
)
else:
response = await self.db.litellm_verificationtoken.find_unique(
where={
"token": token
"token": hashed_token
}
)
return response
@ -200,18 +200,18 @@ class PrismaClient:
try:
token = data["token"]
hashed_token = self.hash_token(token=token)
data["token"] = hashed_token
db_data = copy.deepcopy(data)
db_data["token"] = hashed_token
new_verification_token = await self.db.litellm_verificationtoken.upsert( # type: ignore
where={
'token': hashed_token,
},
data={
"create": {**data}, #type: ignore
"create": {**db_data}, #type: ignore
"update": {} # don't do anything if it already exists
}
)
return new_verification_token
except Exception as e:
asyncio.create_task(self.proxy_logging_obj.failure_handler(original_exception=e))
@ -235,15 +235,16 @@ class PrismaClient:
if token.startswith("sk-"):
token = self.hash_token(token=token)
data["token"] = token
db_data = copy.deepcopy(data)
db_data["token"] = token
response = await self.db.litellm_verificationtoken.update(
where={
"token": token
},
data={**data} # type: ignore
data={**db_data} # type: ignore
)
print_verbose("\033[91m" + f"DB write succeeded {response}" + "\033[0m")
return {"token": token, "data": data}
return {"token": token, "data": db_data}
except Exception as e:
asyncio.create_task(self.proxy_logging_obj.failure_handler(original_exception=e))
print_verbose("\033[91m" + f"DB write failed: {e}" + "\033[0m")

View file

@ -1,33 +0,0 @@
Task exception was never retrieved
future: <Task finished name='Task-299' coro=<QueryEngine.aclose() done, defined at /opt/homebrew/lib/python3.11/site-packages/prisma/engine/query.py:110> exception=RuntimeError('Event loop is closed')>
Traceback (most recent call last):
File "/opt/homebrew/lib/python3.11/site-packages/prisma/engine/query.py", line 112, in aclose
await self._close_session()
File "/opt/homebrew/lib/python3.11/site-packages/prisma/engine/query.py", line 116, in _close_session
await self.session.close()
File "/opt/homebrew/lib/python3.11/site-packages/prisma/_async_http.py", line 35, in close
await self.session.aclose()
File "/opt/homebrew/lib/python3.11/site-packages/httpx/_client.py", line 1974, in aclose
await self._transport.aclose()
File "/opt/homebrew/lib/python3.11/site-packages/httpx/_transports/default.py", line 365, in aclose
await self._pool.aclose()
File "/opt/homebrew/lib/python3.11/site-packages/httpcore/_async/connection_pool.py", line 314, in aclose
await connection.aclose()
File "/opt/homebrew/lib/python3.11/site-packages/httpcore/_async/connection.py", line 166, in aclose
await self._connection.aclose()
File "/opt/homebrew/lib/python3.11/site-packages/httpcore/_async/http11.py", line 241, in aclose
await self._network_stream.aclose()
File "/opt/homebrew/lib/python3.11/site-packages/httpcore/_backends/anyio.py", line 54, in aclose
await self._stream.aclose()
File "/opt/homebrew/lib/python3.11/site-packages/anyio/_backends/_asyncio.py", line 1261, in aclose
self._transport.close()
File "/opt/homebrew/Cellar/python@3.11/3.11.6_1/Frameworks/Python.framework/Versions/3.11/lib/python3.11/asyncio/selector_events.py", line 860, in close
self._loop.call_soon(self._call_connection_lost, None)
File "/opt/homebrew/Cellar/python@3.11/3.11.6_1/Frameworks/Python.framework/Versions/3.11/lib/python3.11/asyncio/base_events.py", line 761, in call_soon
self._check_closed()
File "/opt/homebrew/Cellar/python@3.11/3.11.6_1/Frameworks/Python.framework/Versions/3.11/lib/python3.11/asyncio/base_events.py", line 519, in _check_closed
raise RuntimeError('Event loop is closed')
RuntimeError: Event loop is closed
Giving up get_data(...) after 3 tries (prisma.errors.ClientNotConnectedError: Client is not connected to the query engine, you must call `connect()` before attempting to query data.)
Giving up get_data(...) after 3 tries (prisma.errors.ClientNotConnectedError: Client is not connected to the query engine, you must call `connect()` before attempting to query data.)
Giving up get_data(...) after 3 tries (prisma.errors.ClientNotConnectedError: Client is not connected to the query engine, you must call `connect()` before attempting to query data.)

View file

@ -1,4 +1,4 @@
import sys, os, time
import sys, os, time, asyncio
import traceback
from dotenv import load_dotenv
@ -71,8 +71,8 @@ def test_add_new_key(client):
# # Run the test - only runs via pytest
def test_add_new_key_max_parallel_limit(client):
@pytest.mark.asyncio
async def test_add_new_key_max_parallel_limit(client):
try:
# Your test data
test_data = {"duration": "20m", "max_parallel_requests": 1}
@ -88,6 +88,7 @@ def test_add_new_key_max_parallel_limit(client):
result = response.json()
def _post_data():
json_data = {'model': 'azure-model', "messages": [{"role": "user", "content": f"this is a test request, write a short poem {time.time()}"}]}
print(f"bearer token key: {result['key']}")
response = client.post("/chat/completions", json=json_data, headers={"Authorization": f"Bearer {result['key']}"})
return response
def _run_in_parallel():