mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-27 03:34:10 +00:00
fix(proxy/utils.py): fix db writes on retry
This commit is contained in:
parent
92cc39f00e
commit
66e0c06476
4 changed files with 18 additions and 47 deletions
|
@ -270,7 +270,9 @@ async def user_api_key_auth(request: Request, api_key: str = fastapi.Security(ap
|
|||
print(f"valid_token from cache: {valid_token}")
|
||||
if valid_token is None:
|
||||
## check db
|
||||
print(f"api key: {api_key}")
|
||||
valid_token = await prisma_client.get_data(token=api_key, expires=datetime.utcnow())
|
||||
print(f"valid token from prisma: {valid_token}")
|
||||
user_api_key_cache.set_cache(key=api_key, value=valid_token, ttl=60)
|
||||
elif valid_token is not None:
|
||||
print(f"API Key Cache Hit!")
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
from typing import Optional, List, Any, Literal
|
||||
import os, subprocess, hashlib, importlib, asyncio
|
||||
import os, subprocess, hashlib, importlib, asyncio, copy
|
||||
import litellm, backoff
|
||||
from litellm.proxy._types import UserAPIKeyAuth
|
||||
from litellm.caching import DualCache
|
||||
|
@ -67,7 +67,6 @@ class ProxyLogging:
|
|||
try:
|
||||
self.call_details["data"] = data
|
||||
self.call_details["call_type"] = call_type
|
||||
|
||||
## check if max parallel requests set
|
||||
if user_api_key_dict.max_parallel_requests is not None:
|
||||
## if set, check if request allowed
|
||||
|
@ -165,19 +164,20 @@ class PrismaClient:
|
|||
async def get_data(self, token: str, expires: Optional[Any]=None):
|
||||
try:
|
||||
# check if plain text or hash
|
||||
hashed_token = token
|
||||
if token.startswith("sk-"):
|
||||
token = self.hash_token(token=token)
|
||||
hashed_token = self.hash_token(token=token)
|
||||
if expires:
|
||||
response = await self.db.litellm_verificationtoken.find_first(
|
||||
where={
|
||||
"token": token,
|
||||
"token": hashed_token,
|
||||
"expires": {"gte": expires} # Check if the token is not expired
|
||||
}
|
||||
)
|
||||
else:
|
||||
response = await self.db.litellm_verificationtoken.find_unique(
|
||||
where={
|
||||
"token": token
|
||||
"token": hashed_token
|
||||
}
|
||||
)
|
||||
return response
|
||||
|
@ -200,18 +200,18 @@ class PrismaClient:
|
|||
try:
|
||||
token = data["token"]
|
||||
hashed_token = self.hash_token(token=token)
|
||||
data["token"] = hashed_token
|
||||
db_data = copy.deepcopy(data)
|
||||
db_data["token"] = hashed_token
|
||||
|
||||
new_verification_token = await self.db.litellm_verificationtoken.upsert( # type: ignore
|
||||
where={
|
||||
'token': hashed_token,
|
||||
},
|
||||
data={
|
||||
"create": {**data}, #type: ignore
|
||||
"create": {**db_data}, #type: ignore
|
||||
"update": {} # don't do anything if it already exists
|
||||
}
|
||||
)
|
||||
|
||||
return new_verification_token
|
||||
except Exception as e:
|
||||
asyncio.create_task(self.proxy_logging_obj.failure_handler(original_exception=e))
|
||||
|
@ -235,15 +235,16 @@ class PrismaClient:
|
|||
if token.startswith("sk-"):
|
||||
token = self.hash_token(token=token)
|
||||
|
||||
data["token"] = token
|
||||
db_data = copy.deepcopy(data)
|
||||
db_data["token"] = token
|
||||
response = await self.db.litellm_verificationtoken.update(
|
||||
where={
|
||||
"token": token
|
||||
},
|
||||
data={**data} # type: ignore
|
||||
data={**db_data} # type: ignore
|
||||
)
|
||||
print_verbose("\033[91m" + f"DB write succeeded {response}" + "\033[0m")
|
||||
return {"token": token, "data": data}
|
||||
return {"token": token, "data": db_data}
|
||||
except Exception as e:
|
||||
asyncio.create_task(self.proxy_logging_obj.failure_handler(original_exception=e))
|
||||
print_verbose("\033[91m" + f"DB write failed: {e}" + "\033[0m")
|
||||
|
|
|
@ -1,33 +0,0 @@
|
|||
Task exception was never retrieved
|
||||
future: <Task finished name='Task-299' coro=<QueryEngine.aclose() done, defined at /opt/homebrew/lib/python3.11/site-packages/prisma/engine/query.py:110> exception=RuntimeError('Event loop is closed')>
|
||||
Traceback (most recent call last):
|
||||
File "/opt/homebrew/lib/python3.11/site-packages/prisma/engine/query.py", line 112, in aclose
|
||||
await self._close_session()
|
||||
File "/opt/homebrew/lib/python3.11/site-packages/prisma/engine/query.py", line 116, in _close_session
|
||||
await self.session.close()
|
||||
File "/opt/homebrew/lib/python3.11/site-packages/prisma/_async_http.py", line 35, in close
|
||||
await self.session.aclose()
|
||||
File "/opt/homebrew/lib/python3.11/site-packages/httpx/_client.py", line 1974, in aclose
|
||||
await self._transport.aclose()
|
||||
File "/opt/homebrew/lib/python3.11/site-packages/httpx/_transports/default.py", line 365, in aclose
|
||||
await self._pool.aclose()
|
||||
File "/opt/homebrew/lib/python3.11/site-packages/httpcore/_async/connection_pool.py", line 314, in aclose
|
||||
await connection.aclose()
|
||||
File "/opt/homebrew/lib/python3.11/site-packages/httpcore/_async/connection.py", line 166, in aclose
|
||||
await self._connection.aclose()
|
||||
File "/opt/homebrew/lib/python3.11/site-packages/httpcore/_async/http11.py", line 241, in aclose
|
||||
await self._network_stream.aclose()
|
||||
File "/opt/homebrew/lib/python3.11/site-packages/httpcore/_backends/anyio.py", line 54, in aclose
|
||||
await self._stream.aclose()
|
||||
File "/opt/homebrew/lib/python3.11/site-packages/anyio/_backends/_asyncio.py", line 1261, in aclose
|
||||
self._transport.close()
|
||||
File "/opt/homebrew/Cellar/python@3.11/3.11.6_1/Frameworks/Python.framework/Versions/3.11/lib/python3.11/asyncio/selector_events.py", line 860, in close
|
||||
self._loop.call_soon(self._call_connection_lost, None)
|
||||
File "/opt/homebrew/Cellar/python@3.11/3.11.6_1/Frameworks/Python.framework/Versions/3.11/lib/python3.11/asyncio/base_events.py", line 761, in call_soon
|
||||
self._check_closed()
|
||||
File "/opt/homebrew/Cellar/python@3.11/3.11.6_1/Frameworks/Python.framework/Versions/3.11/lib/python3.11/asyncio/base_events.py", line 519, in _check_closed
|
||||
raise RuntimeError('Event loop is closed')
|
||||
RuntimeError: Event loop is closed
|
||||
Giving up get_data(...) after 3 tries (prisma.errors.ClientNotConnectedError: Client is not connected to the query engine, you must call `connect()` before attempting to query data.)
|
||||
Giving up get_data(...) after 3 tries (prisma.errors.ClientNotConnectedError: Client is not connected to the query engine, you must call `connect()` before attempting to query data.)
|
||||
Giving up get_data(...) after 3 tries (prisma.errors.ClientNotConnectedError: Client is not connected to the query engine, you must call `connect()` before attempting to query data.)
|
|
@ -1,4 +1,4 @@
|
|||
import sys, os, time
|
||||
import sys, os, time, asyncio
|
||||
import traceback
|
||||
from dotenv import load_dotenv
|
||||
|
||||
|
@ -71,8 +71,8 @@ def test_add_new_key(client):
|
|||
|
||||
# # Run the test - only runs via pytest
|
||||
|
||||
|
||||
def test_add_new_key_max_parallel_limit(client):
|
||||
@pytest.mark.asyncio
|
||||
async def test_add_new_key_max_parallel_limit(client):
|
||||
try:
|
||||
# Your test data
|
||||
test_data = {"duration": "20m", "max_parallel_requests": 1}
|
||||
|
@ -88,6 +88,7 @@ def test_add_new_key_max_parallel_limit(client):
|
|||
result = response.json()
|
||||
def _post_data():
|
||||
json_data = {'model': 'azure-model', "messages": [{"role": "user", "content": f"this is a test request, write a short poem {time.time()}"}]}
|
||||
print(f"bearer token key: {result['key']}")
|
||||
response = client.post("/chat/completions", json=json_data, headers={"Authorization": f"Bearer {result['key']}"})
|
||||
return response
|
||||
def _run_in_parallel():
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue