mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-26 19:24:27 +00:00
fix(proxy/utils.py): fix db writes on retry
This commit is contained in:
parent
92cc39f00e
commit
66e0c06476
4 changed files with 18 additions and 47 deletions
|
@ -270,7 +270,9 @@ async def user_api_key_auth(request: Request, api_key: str = fastapi.Security(ap
|
||||||
print(f"valid_token from cache: {valid_token}")
|
print(f"valid_token from cache: {valid_token}")
|
||||||
if valid_token is None:
|
if valid_token is None:
|
||||||
## check db
|
## check db
|
||||||
|
print(f"api key: {api_key}")
|
||||||
valid_token = await prisma_client.get_data(token=api_key, expires=datetime.utcnow())
|
valid_token = await prisma_client.get_data(token=api_key, expires=datetime.utcnow())
|
||||||
|
print(f"valid token from prisma: {valid_token}")
|
||||||
user_api_key_cache.set_cache(key=api_key, value=valid_token, ttl=60)
|
user_api_key_cache.set_cache(key=api_key, value=valid_token, ttl=60)
|
||||||
elif valid_token is not None:
|
elif valid_token is not None:
|
||||||
print(f"API Key Cache Hit!")
|
print(f"API Key Cache Hit!")
|
||||||
|
|
|
@ -1,5 +1,5 @@
|
||||||
from typing import Optional, List, Any, Literal
|
from typing import Optional, List, Any, Literal
|
||||||
import os, subprocess, hashlib, importlib, asyncio
|
import os, subprocess, hashlib, importlib, asyncio, copy
|
||||||
import litellm, backoff
|
import litellm, backoff
|
||||||
from litellm.proxy._types import UserAPIKeyAuth
|
from litellm.proxy._types import UserAPIKeyAuth
|
||||||
from litellm.caching import DualCache
|
from litellm.caching import DualCache
|
||||||
|
@ -67,7 +67,6 @@ class ProxyLogging:
|
||||||
try:
|
try:
|
||||||
self.call_details["data"] = data
|
self.call_details["data"] = data
|
||||||
self.call_details["call_type"] = call_type
|
self.call_details["call_type"] = call_type
|
||||||
|
|
||||||
## check if max parallel requests set
|
## check if max parallel requests set
|
||||||
if user_api_key_dict.max_parallel_requests is not None:
|
if user_api_key_dict.max_parallel_requests is not None:
|
||||||
## if set, check if request allowed
|
## if set, check if request allowed
|
||||||
|
@ -165,19 +164,20 @@ class PrismaClient:
|
||||||
async def get_data(self, token: str, expires: Optional[Any]=None):
|
async def get_data(self, token: str, expires: Optional[Any]=None):
|
||||||
try:
|
try:
|
||||||
# check if plain text or hash
|
# check if plain text or hash
|
||||||
|
hashed_token = token
|
||||||
if token.startswith("sk-"):
|
if token.startswith("sk-"):
|
||||||
token = self.hash_token(token=token)
|
hashed_token = self.hash_token(token=token)
|
||||||
if expires:
|
if expires:
|
||||||
response = await self.db.litellm_verificationtoken.find_first(
|
response = await self.db.litellm_verificationtoken.find_first(
|
||||||
where={
|
where={
|
||||||
"token": token,
|
"token": hashed_token,
|
||||||
"expires": {"gte": expires} # Check if the token is not expired
|
"expires": {"gte": expires} # Check if the token is not expired
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
response = await self.db.litellm_verificationtoken.find_unique(
|
response = await self.db.litellm_verificationtoken.find_unique(
|
||||||
where={
|
where={
|
||||||
"token": token
|
"token": hashed_token
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
return response
|
return response
|
||||||
|
@ -200,18 +200,18 @@ class PrismaClient:
|
||||||
try:
|
try:
|
||||||
token = data["token"]
|
token = data["token"]
|
||||||
hashed_token = self.hash_token(token=token)
|
hashed_token = self.hash_token(token=token)
|
||||||
data["token"] = hashed_token
|
db_data = copy.deepcopy(data)
|
||||||
|
db_data["token"] = hashed_token
|
||||||
|
|
||||||
new_verification_token = await self.db.litellm_verificationtoken.upsert( # type: ignore
|
new_verification_token = await self.db.litellm_verificationtoken.upsert( # type: ignore
|
||||||
where={
|
where={
|
||||||
'token': hashed_token,
|
'token': hashed_token,
|
||||||
},
|
},
|
||||||
data={
|
data={
|
||||||
"create": {**data}, #type: ignore
|
"create": {**db_data}, #type: ignore
|
||||||
"update": {} # don't do anything if it already exists
|
"update": {} # don't do anything if it already exists
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
return new_verification_token
|
return new_verification_token
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
asyncio.create_task(self.proxy_logging_obj.failure_handler(original_exception=e))
|
asyncio.create_task(self.proxy_logging_obj.failure_handler(original_exception=e))
|
||||||
|
@ -235,15 +235,16 @@ class PrismaClient:
|
||||||
if token.startswith("sk-"):
|
if token.startswith("sk-"):
|
||||||
token = self.hash_token(token=token)
|
token = self.hash_token(token=token)
|
||||||
|
|
||||||
data["token"] = token
|
db_data = copy.deepcopy(data)
|
||||||
|
db_data["token"] = token
|
||||||
response = await self.db.litellm_verificationtoken.update(
|
response = await self.db.litellm_verificationtoken.update(
|
||||||
where={
|
where={
|
||||||
"token": token
|
"token": token
|
||||||
},
|
},
|
||||||
data={**data} # type: ignore
|
data={**db_data} # type: ignore
|
||||||
)
|
)
|
||||||
print_verbose("\033[91m" + f"DB write succeeded {response}" + "\033[0m")
|
print_verbose("\033[91m" + f"DB write succeeded {response}" + "\033[0m")
|
||||||
return {"token": token, "data": data}
|
return {"token": token, "data": db_data}
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
asyncio.create_task(self.proxy_logging_obj.failure_handler(original_exception=e))
|
asyncio.create_task(self.proxy_logging_obj.failure_handler(original_exception=e))
|
||||||
print_verbose("\033[91m" + f"DB write failed: {e}" + "\033[0m")
|
print_verbose("\033[91m" + f"DB write failed: {e}" + "\033[0m")
|
||||||
|
|
|
@ -1,33 +0,0 @@
|
||||||
Task exception was never retrieved
|
|
||||||
future: <Task finished name='Task-299' coro=<QueryEngine.aclose() done, defined at /opt/homebrew/lib/python3.11/site-packages/prisma/engine/query.py:110> exception=RuntimeError('Event loop is closed')>
|
|
||||||
Traceback (most recent call last):
|
|
||||||
File "/opt/homebrew/lib/python3.11/site-packages/prisma/engine/query.py", line 112, in aclose
|
|
||||||
await self._close_session()
|
|
||||||
File "/opt/homebrew/lib/python3.11/site-packages/prisma/engine/query.py", line 116, in _close_session
|
|
||||||
await self.session.close()
|
|
||||||
File "/opt/homebrew/lib/python3.11/site-packages/prisma/_async_http.py", line 35, in close
|
|
||||||
await self.session.aclose()
|
|
||||||
File "/opt/homebrew/lib/python3.11/site-packages/httpx/_client.py", line 1974, in aclose
|
|
||||||
await self._transport.aclose()
|
|
||||||
File "/opt/homebrew/lib/python3.11/site-packages/httpx/_transports/default.py", line 365, in aclose
|
|
||||||
await self._pool.aclose()
|
|
||||||
File "/opt/homebrew/lib/python3.11/site-packages/httpcore/_async/connection_pool.py", line 314, in aclose
|
|
||||||
await connection.aclose()
|
|
||||||
File "/opt/homebrew/lib/python3.11/site-packages/httpcore/_async/connection.py", line 166, in aclose
|
|
||||||
await self._connection.aclose()
|
|
||||||
File "/opt/homebrew/lib/python3.11/site-packages/httpcore/_async/http11.py", line 241, in aclose
|
|
||||||
await self._network_stream.aclose()
|
|
||||||
File "/opt/homebrew/lib/python3.11/site-packages/httpcore/_backends/anyio.py", line 54, in aclose
|
|
||||||
await self._stream.aclose()
|
|
||||||
File "/opt/homebrew/lib/python3.11/site-packages/anyio/_backends/_asyncio.py", line 1261, in aclose
|
|
||||||
self._transport.close()
|
|
||||||
File "/opt/homebrew/Cellar/python@3.11/3.11.6_1/Frameworks/Python.framework/Versions/3.11/lib/python3.11/asyncio/selector_events.py", line 860, in close
|
|
||||||
self._loop.call_soon(self._call_connection_lost, None)
|
|
||||||
File "/opt/homebrew/Cellar/python@3.11/3.11.6_1/Frameworks/Python.framework/Versions/3.11/lib/python3.11/asyncio/base_events.py", line 761, in call_soon
|
|
||||||
self._check_closed()
|
|
||||||
File "/opt/homebrew/Cellar/python@3.11/3.11.6_1/Frameworks/Python.framework/Versions/3.11/lib/python3.11/asyncio/base_events.py", line 519, in _check_closed
|
|
||||||
raise RuntimeError('Event loop is closed')
|
|
||||||
RuntimeError: Event loop is closed
|
|
||||||
Giving up get_data(...) after 3 tries (prisma.errors.ClientNotConnectedError: Client is not connected to the query engine, you must call `connect()` before attempting to query data.)
|
|
||||||
Giving up get_data(...) after 3 tries (prisma.errors.ClientNotConnectedError: Client is not connected to the query engine, you must call `connect()` before attempting to query data.)
|
|
||||||
Giving up get_data(...) after 3 tries (prisma.errors.ClientNotConnectedError: Client is not connected to the query engine, you must call `connect()` before attempting to query data.)
|
|
|
@ -1,4 +1,4 @@
|
||||||
import sys, os, time
|
import sys, os, time, asyncio
|
||||||
import traceback
|
import traceback
|
||||||
from dotenv import load_dotenv
|
from dotenv import load_dotenv
|
||||||
|
|
||||||
|
@ -71,8 +71,8 @@ def test_add_new_key(client):
|
||||||
|
|
||||||
# # Run the test - only runs via pytest
|
# # Run the test - only runs via pytest
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
def test_add_new_key_max_parallel_limit(client):
|
async def test_add_new_key_max_parallel_limit(client):
|
||||||
try:
|
try:
|
||||||
# Your test data
|
# Your test data
|
||||||
test_data = {"duration": "20m", "max_parallel_requests": 1}
|
test_data = {"duration": "20m", "max_parallel_requests": 1}
|
||||||
|
@ -88,6 +88,7 @@ def test_add_new_key_max_parallel_limit(client):
|
||||||
result = response.json()
|
result = response.json()
|
||||||
def _post_data():
|
def _post_data():
|
||||||
json_data = {'model': 'azure-model', "messages": [{"role": "user", "content": f"this is a test request, write a short poem {time.time()}"}]}
|
json_data = {'model': 'azure-model', "messages": [{"role": "user", "content": f"this is a test request, write a short poem {time.time()}"}]}
|
||||||
|
print(f"bearer token key: {result['key']}")
|
||||||
response = client.post("/chat/completions", json=json_data, headers={"Authorization": f"Bearer {result['key']}"})
|
response = client.post("/chat/completions", json=json_data, headers={"Authorization": f"Bearer {result['key']}"})
|
||||||
return response
|
return response
|
||||||
def _run_in_parallel():
|
def _run_in_parallel():
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue