mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-25 10:44:24 +00:00
feat(proxy_server.py): add sentry logging for db read/writes
This commit is contained in:
parent
4e6a8d09d0
commit
7aec95ed7c
4 changed files with 208 additions and 118 deletions
|
@ -1,9 +1,72 @@
|
|||
from typing import Optional, List, Any, Literal
|
||||
import os, subprocess, hashlib, importlib
|
||||
import os, subprocess, hashlib, importlib, asyncio
|
||||
import litellm
|
||||
|
||||
### LOGGING ###
|
||||
class ProxyLogging:
|
||||
"""
|
||||
Logging for proxy.
|
||||
|
||||
Implemented mainly to log successful/failed db read/writes.
|
||||
|
||||
Currently just logs this to a provided sentry integration.
|
||||
"""
|
||||
|
||||
def __init__(self,):
|
||||
## INITIALIZE LITELLM CALLBACKS ##
|
||||
self._init_litellm_callbacks()
|
||||
pass
|
||||
|
||||
|
||||
def _init_litellm_callbacks(self):
|
||||
if len(litellm.callbacks) > 0:
|
||||
for callback in litellm.callbacks:
|
||||
if callback not in litellm.input_callback:
|
||||
litellm.input_callback.append(callback)
|
||||
if callback not in litellm.success_callback:
|
||||
litellm.success_callback.append(callback)
|
||||
if callback not in litellm.failure_callback:
|
||||
litellm.failure_callback.append(callback)
|
||||
if callback not in litellm._async_success_callback:
|
||||
litellm._async_success_callback.append(callback)
|
||||
if callback not in litellm._async_failure_callback:
|
||||
litellm._async_failure_callback.append(callback)
|
||||
if (
|
||||
len(litellm.input_callback) > 0
|
||||
or len(litellm.success_callback) > 0
|
||||
or len(litellm.failure_callback) > 0
|
||||
) and len(callback_list) == 0:
|
||||
callback_list = list(
|
||||
set(
|
||||
litellm.input_callback
|
||||
+ litellm.success_callback
|
||||
+ litellm.failure_callback
|
||||
)
|
||||
)
|
||||
litellm.utils.set_callbacks(
|
||||
callback_list=callback_list
|
||||
)
|
||||
|
||||
async def success_handler(self, *args, **kwargs):
|
||||
"""
|
||||
Log successful db read/writes
|
||||
"""
|
||||
pass
|
||||
|
||||
async def failure_handler(self, original_exception):
|
||||
"""
|
||||
Log failed db read/writes
|
||||
|
||||
Currently only logs exceptions to sentry
|
||||
"""
|
||||
print(f"reaches failure handler logging - {original_exception}; sentry: {litellm.utils.capture_exception}")
|
||||
if litellm.utils.capture_exception:
|
||||
litellm.utils.capture_exception(error=original_exception)
|
||||
|
||||
|
||||
### DB CONNECTOR ###
|
||||
class PrismaClient:
|
||||
def __init__(self, database_url: str):
|
||||
def __init__(self, database_url: str, proxy_logging_obj: ProxyLogging):
|
||||
print("LiteLLM: DATABASE_URL Set in config, trying to 'pip install prisma'")
|
||||
os.environ["DATABASE_URL"] = database_url
|
||||
# Save the current working directory
|
||||
|
@ -22,6 +85,9 @@ class PrismaClient:
|
|||
from prisma import Client # type: ignore
|
||||
self.db = Client() #Client to connect to Prisma db
|
||||
|
||||
## init logging object
|
||||
self.proxy_logging_obj = proxy_logging_obj
|
||||
|
||||
def hash_token(self, token: str):
|
||||
# Hash the string using SHA-256
|
||||
hashed_token = hashlib.sha256(token.encode()).hexdigest()
|
||||
|
@ -29,42 +95,48 @@ class PrismaClient:
|
|||
return hashed_token
|
||||
|
||||
async def get_data(self, token: str, expires: Optional[Any]=None):
|
||||
hashed_token = self.hash_token(token=token)
|
||||
if expires:
|
||||
response = await self.db.litellm_verificationtoken.find_first(
|
||||
try:
|
||||
hashed_token = self.hash_token(token=token)
|
||||
if expires:
|
||||
response = await self.db.litellm_verificationtoken.find_first(
|
||||
where={
|
||||
"token": hashed_token,
|
||||
"expires": {"gte": expires} # Check if the token is not expired
|
||||
}
|
||||
)
|
||||
else:
|
||||
response = await self.db.litellm_verificationtoken.find_unique(
|
||||
where={
|
||||
"token": hashed_token,
|
||||
"expires": {"gte": expires} # Check if the token is not expired
|
||||
"token": hashed_token
|
||||
}
|
||||
)
|
||||
else:
|
||||
response = await self.db.litellm_verificationtoken.find_unique(
|
||||
where={
|
||||
"token": hashed_token
|
||||
}
|
||||
)
|
||||
return response
|
||||
return response
|
||||
except Exception as e:
|
||||
asyncio.create_task(self.proxy_logging_obj.failure_handler(original_exception=e))
|
||||
|
||||
async def insert_data(self, data: dict):
|
||||
"""
|
||||
Add a key to the database. If it already exists, do nothing.
|
||||
"""
|
||||
token = data["token"]
|
||||
hashed_token = self.hash_token(token=token)
|
||||
data["token"] = hashed_token
|
||||
print(f"passed in data: {data}; hashed_token: {hashed_token}")
|
||||
try:
|
||||
token = data["token"]
|
||||
hashed_token = self.hash_token(token=token)
|
||||
data["token"] = hashed_token
|
||||
print(f"passed in data: {data}; hashed_token: {hashed_token}")
|
||||
|
||||
new_verification_token = await self.db.litellm_verificationtoken.upsert( # type: ignore
|
||||
where={
|
||||
'token': hashed_token,
|
||||
},
|
||||
data={
|
||||
"create": {**data}, #type: ignore
|
||||
"update": {} # don't do anything if it already exists
|
||||
}
|
||||
)
|
||||
new_verification_token = await self.db.litellm_verificationtoken.upsert( # type: ignore
|
||||
where={
|
||||
'token': hashed_token,
|
||||
},
|
||||
data={
|
||||
"create": {**data}, #type: ignore
|
||||
"update": {} # don't do anything if it already exists
|
||||
}
|
||||
)
|
||||
|
||||
return new_verification_token
|
||||
return new_verification_token
|
||||
except Exception as e:
|
||||
asyncio.create_task(self.proxy_logging_obj.failure_handler(original_exception=e))
|
||||
|
||||
async def update_data(self, token: str, data: dict):
|
||||
"""
|
||||
|
@ -82,6 +154,7 @@ class PrismaClient:
|
|||
print("\033[91m" + f"DB write succeeded" + "\033[0m")
|
||||
return {"token": token, "data": data}
|
||||
except Exception as e:
|
||||
asyncio.create_task(self.proxy_logging_obj.failure_handler(original_exception=e))
|
||||
print()
|
||||
print()
|
||||
print()
|
||||
|
@ -90,21 +163,31 @@ class PrismaClient:
|
|||
print()
|
||||
print()
|
||||
|
||||
|
||||
async def delete_data(self, tokens: List):
|
||||
"""
|
||||
Allow user to delete a key(s)
|
||||
"""
|
||||
hashed_tokens = [self.hash_token(token=token) for token in tokens]
|
||||
await self.db.litellm_verificationtoken.delete_many(
|
||||
where={"token": {"in": hashed_tokens}}
|
||||
)
|
||||
return {"deleted_keys": tokens}
|
||||
try:
|
||||
hashed_tokens = [self.hash_token(token=token) for token in tokens]
|
||||
await self.db.litellm_verificationtoken.delete_many(
|
||||
where={"token": {"in": hashed_tokens}}
|
||||
)
|
||||
return {"deleted_keys": tokens}
|
||||
except Exception as e:
|
||||
asyncio.create_task(self.proxy_logging_obj.failure_handler(original_exception=e))
|
||||
|
||||
async def connect(self):
|
||||
await self.db.connect()
|
||||
try:
|
||||
await self.db.connect()
|
||||
except Exception as e:
|
||||
asyncio.create_task(self.proxy_logging_obj.failure_handler(original_exception=e))
|
||||
|
||||
async def disconnect(self):
|
||||
await self.db.disconnect()
|
||||
try:
|
||||
await self.db.disconnect()
|
||||
except Exception as e:
|
||||
asyncio.create_task(self.proxy_logging_obj.failure_handler(original_exception=e))
|
||||
|
||||
### CUSTOM FILE ###
|
||||
def get_instance_fn(value: str, config_file_path: Optional[str] = None) -> Any:
|
||||
|
@ -142,6 +225,8 @@ def get_instance_fn(value: str, config_file_path: Optional[str] = None) -> Any:
|
|||
except Exception as e:
|
||||
raise e
|
||||
|
||||
|
||||
|
||||
### CALL HOOKS ###
|
||||
class CallHooks:
|
||||
"""
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue