feat(proxy_server.py): add sentry logging for db read/writes

This commit is contained in:
Krrish Dholakia 2023-12-08 11:40:19 -08:00
parent 4e6a8d09d0
commit 7aec95ed7c
4 changed files with 208 additions and 118 deletions

View file

@ -1,9 +1,72 @@
from typing import Optional, List, Any, Literal
import os, subprocess, hashlib, importlib
import os, subprocess, hashlib, importlib, asyncio
import litellm
### LOGGING ###
class ProxyLogging:
"""
Logging for proxy.
Implemented mainly to log successful/failed db read/writes.
Currently just logs this to a provided sentry integration.
"""
def __init__(self,):
## INITIALIZE LITELLM CALLBACKS ##
self._init_litellm_callbacks()
pass
def _init_litellm_callbacks(self):
if len(litellm.callbacks) > 0:
for callback in litellm.callbacks:
if callback not in litellm.input_callback:
litellm.input_callback.append(callback)
if callback not in litellm.success_callback:
litellm.success_callback.append(callback)
if callback not in litellm.failure_callback:
litellm.failure_callback.append(callback)
if callback not in litellm._async_success_callback:
litellm._async_success_callback.append(callback)
if callback not in litellm._async_failure_callback:
litellm._async_failure_callback.append(callback)
if (
len(litellm.input_callback) > 0
or len(litellm.success_callback) > 0
or len(litellm.failure_callback) > 0
) and len(callback_list) == 0:
callback_list = list(
set(
litellm.input_callback
+ litellm.success_callback
+ litellm.failure_callback
)
)
litellm.utils.set_callbacks(
callback_list=callback_list
)
async def success_handler(self, *args, **kwargs):
"""
Log successful db read/writes
"""
pass
async def failure_handler(self, original_exception):
"""
Log failed db read/writes
Currently only logs exceptions to sentry
"""
print(f"reaches failure handler logging - {original_exception}; sentry: {litellm.utils.capture_exception}")
if litellm.utils.capture_exception:
litellm.utils.capture_exception(error=original_exception)
### DB CONNECTOR ###
class PrismaClient:
def __init__(self, database_url: str):
def __init__(self, database_url: str, proxy_logging_obj: ProxyLogging):
print("LiteLLM: DATABASE_URL Set in config, trying to 'pip install prisma'")
os.environ["DATABASE_URL"] = database_url
# Save the current working directory
@ -22,6 +85,9 @@ class PrismaClient:
from prisma import Client # type: ignore
self.db = Client() #Client to connect to Prisma db
## init logging object
self.proxy_logging_obj = proxy_logging_obj
def hash_token(self, token: str):
# Hash the string using SHA-256
hashed_token = hashlib.sha256(token.encode()).hexdigest()
@ -29,42 +95,48 @@ class PrismaClient:
return hashed_token
async def get_data(self, token: str, expires: Optional[Any]=None):
hashed_token = self.hash_token(token=token)
if expires:
response = await self.db.litellm_verificationtoken.find_first(
try:
hashed_token = self.hash_token(token=token)
if expires:
response = await self.db.litellm_verificationtoken.find_first(
where={
"token": hashed_token,
"expires": {"gte": expires} # Check if the token is not expired
}
)
else:
response = await self.db.litellm_verificationtoken.find_unique(
where={
"token": hashed_token,
"expires": {"gte": expires} # Check if the token is not expired
"token": hashed_token
}
)
else:
response = await self.db.litellm_verificationtoken.find_unique(
where={
"token": hashed_token
}
)
return response
return response
except Exception as e:
asyncio.create_task(self.proxy_logging_obj.failure_handler(original_exception=e))
async def insert_data(self, data: dict):
"""
Add a key to the database. If it already exists, do nothing.
"""
token = data["token"]
hashed_token = self.hash_token(token=token)
data["token"] = hashed_token
print(f"passed in data: {data}; hashed_token: {hashed_token}")
try:
token = data["token"]
hashed_token = self.hash_token(token=token)
data["token"] = hashed_token
print(f"passed in data: {data}; hashed_token: {hashed_token}")
new_verification_token = await self.db.litellm_verificationtoken.upsert( # type: ignore
where={
'token': hashed_token,
},
data={
"create": {**data}, #type: ignore
"update": {} # don't do anything if it already exists
}
)
new_verification_token = await self.db.litellm_verificationtoken.upsert( # type: ignore
where={
'token': hashed_token,
},
data={
"create": {**data}, #type: ignore
"update": {} # don't do anything if it already exists
}
)
return new_verification_token
return new_verification_token
except Exception as e:
asyncio.create_task(self.proxy_logging_obj.failure_handler(original_exception=e))
async def update_data(self, token: str, data: dict):
"""
@ -82,6 +154,7 @@ class PrismaClient:
print("\033[91m" + f"DB write succeeded" + "\033[0m")
return {"token": token, "data": data}
except Exception as e:
asyncio.create_task(self.proxy_logging_obj.failure_handler(original_exception=e))
print()
print()
print()
@ -90,21 +163,31 @@ class PrismaClient:
print()
print()
async def delete_data(self, tokens: List):
"""
Allow user to delete a key(s)
"""
hashed_tokens = [self.hash_token(token=token) for token in tokens]
await self.db.litellm_verificationtoken.delete_many(
where={"token": {"in": hashed_tokens}}
)
return {"deleted_keys": tokens}
try:
hashed_tokens = [self.hash_token(token=token) for token in tokens]
await self.db.litellm_verificationtoken.delete_many(
where={"token": {"in": hashed_tokens}}
)
return {"deleted_keys": tokens}
except Exception as e:
asyncio.create_task(self.proxy_logging_obj.failure_handler(original_exception=e))
async def connect(self):
await self.db.connect()
try:
await self.db.connect()
except Exception as e:
asyncio.create_task(self.proxy_logging_obj.failure_handler(original_exception=e))
async def disconnect(self):
await self.db.disconnect()
try:
await self.db.disconnect()
except Exception as e:
asyncio.create_task(self.proxy_logging_obj.failure_handler(original_exception=e))
### CUSTOM FILE ###
def get_instance_fn(value: str, config_file_path: Optional[str] = None) -> Any:
@ -142,6 +225,8 @@ def get_instance_fn(value: str, config_file_path: Optional[str] = None) -> Any:
except Exception as e:
raise e
### CALL HOOKS ###
class CallHooks:
"""