feat(proxy_server.py): add sentry logging for db read/writes

This commit is contained in:
Krrish Dholakia 2023-12-08 11:40:19 -08:00
parent 4e6a8d09d0
commit 7aec95ed7c
4 changed files with 208 additions and 118 deletions

View file

@ -94,7 +94,8 @@ import litellm
from litellm.proxy.utils import ( from litellm.proxy.utils import (
PrismaClient, PrismaClient,
get_instance_fn, get_instance_fn,
CallHooks CallHooks,
ProxyLogging
) )
import pydantic import pydantic
from litellm.proxy._types import * from litellm.proxy._types import *
@ -198,6 +199,7 @@ use_background_health_checks = None
health_check_interval = None health_check_interval = None
health_check_results = {} health_check_results = {}
call_hooks = CallHooks() call_hooks = CallHooks()
proxy_logging_obj: Optional[ProxyLogging] = None
### REDIS QUEUE ### ### REDIS QUEUE ###
async_result = None async_result = None
celery_app_conn = None celery_app_conn = None
@ -300,10 +302,10 @@ async def user_api_key_auth(request: Request, api_key: str = fastapi.Security(ap
) )
def prisma_setup(database_url: Optional[str]): def prisma_setup(database_url: Optional[str]):
global prisma_client global prisma_client, proxy_logging_obj
if database_url: if database_url is not None and proxy_logging_obj is not None:
try: try:
prisma_client = PrismaClient(database_url=database_url) prisma_client = PrismaClient(database_url=database_url, proxy_logging_obj=proxy_logging_obj)
except Exception as e: except Exception as e:
print("Error when initializing prisma, Ensure you run pip install prisma", e) print("Error when initializing prisma, Ensure you run pip install prisma", e)
@ -839,11 +841,13 @@ async def rate_limit_per_token(request: Request, call_next):
@router.on_event("startup") @router.on_event("startup")
async def startup_event(): async def startup_event():
global prisma_client, master_key, use_background_health_checks global prisma_client, master_key, use_background_health_checks, proxy_logging_obj
import json import json
### INITIALIZE GLOBAL LOGGING OBJECT ###
proxy_logging_obj = ProxyLogging()
### LOAD CONFIG ###
worker_config = litellm.get_secret("WORKER_CONFIG") worker_config = litellm.get_secret("WORKER_CONFIG")
print(f"worker_config: {worker_config}")
print_verbose(f"worker_config: {worker_config}") print_verbose(f"worker_config: {worker_config}")
# check if it's a valid file path # check if it's a valid file path
if os.path.isfile(worker_config): if os.path.isfile(worker_config):

View file

@ -1,9 +1,72 @@
from typing import Optional, List, Any, Literal from typing import Optional, List, Any, Literal
import os, subprocess, hashlib, importlib import os, subprocess, hashlib, importlib, asyncio
import litellm
### LOGGING ###
class ProxyLogging:
"""
Logging for proxy.
Implemented mainly to log successful/failed db read/writes.
Currently just logs this to a provided sentry integration.
"""
def __init__(self,):
## INITIALIZE LITELLM CALLBACKS ##
self._init_litellm_callbacks()
pass
def _init_litellm_callbacks(self):
if len(litellm.callbacks) > 0:
for callback in litellm.callbacks:
if callback not in litellm.input_callback:
litellm.input_callback.append(callback)
if callback not in litellm.success_callback:
litellm.success_callback.append(callback)
if callback not in litellm.failure_callback:
litellm.failure_callback.append(callback)
if callback not in litellm._async_success_callback:
litellm._async_success_callback.append(callback)
if callback not in litellm._async_failure_callback:
litellm._async_failure_callback.append(callback)
if (
len(litellm.input_callback) > 0
or len(litellm.success_callback) > 0
or len(litellm.failure_callback) > 0
) and len(callback_list) == 0:
callback_list = list(
set(
litellm.input_callback
+ litellm.success_callback
+ litellm.failure_callback
)
)
litellm.utils.set_callbacks(
callback_list=callback_list
)
async def success_handler(self, *args, **kwargs):
"""
Log successful db read/writes
"""
pass
async def failure_handler(self, original_exception):
"""
Log failed db read/writes
Currently only logs exceptions to sentry
"""
print(f"reaches failure handler logging - {original_exception}; sentry: {litellm.utils.capture_exception}")
if litellm.utils.capture_exception:
litellm.utils.capture_exception(error=original_exception)
### DB CONNECTOR ### ### DB CONNECTOR ###
class PrismaClient: class PrismaClient:
def __init__(self, database_url: str): def __init__(self, database_url: str, proxy_logging_obj: ProxyLogging):
print("LiteLLM: DATABASE_URL Set in config, trying to 'pip install prisma'") print("LiteLLM: DATABASE_URL Set in config, trying to 'pip install prisma'")
os.environ["DATABASE_URL"] = database_url os.environ["DATABASE_URL"] = database_url
# Save the current working directory # Save the current working directory
@ -22,6 +85,9 @@ class PrismaClient:
from prisma import Client # type: ignore from prisma import Client # type: ignore
self.db = Client() #Client to connect to Prisma db self.db = Client() #Client to connect to Prisma db
## init logging object
self.proxy_logging_obj = proxy_logging_obj
def hash_token(self, token: str): def hash_token(self, token: str):
# Hash the string using SHA-256 # Hash the string using SHA-256
hashed_token = hashlib.sha256(token.encode()).hexdigest() hashed_token = hashlib.sha256(token.encode()).hexdigest()
@ -29,42 +95,48 @@ class PrismaClient:
return hashed_token return hashed_token
async def get_data(self, token: str, expires: Optional[Any]=None): async def get_data(self, token: str, expires: Optional[Any]=None):
hashed_token = self.hash_token(token=token) try:
if expires: hashed_token = self.hash_token(token=token)
response = await self.db.litellm_verificationtoken.find_first( if expires:
response = await self.db.litellm_verificationtoken.find_first(
where={
"token": hashed_token,
"expires": {"gte": expires} # Check if the token is not expired
}
)
else:
response = await self.db.litellm_verificationtoken.find_unique(
where={ where={
"token": hashed_token, "token": hashed_token
"expires": {"gte": expires} # Check if the token is not expired
} }
) )
else: return response
response = await self.db.litellm_verificationtoken.find_unique( except Exception as e:
where={ asyncio.create_task(self.proxy_logging_obj.failure_handler(original_exception=e))
"token": hashed_token
}
)
return response
async def insert_data(self, data: dict): async def insert_data(self, data: dict):
""" """
Add a key to the database. If it already exists, do nothing. Add a key to the database. If it already exists, do nothing.
""" """
token = data["token"] try:
hashed_token = self.hash_token(token=token) token = data["token"]
data["token"] = hashed_token hashed_token = self.hash_token(token=token)
print(f"passed in data: {data}; hashed_token: {hashed_token}") data["token"] = hashed_token
print(f"passed in data: {data}; hashed_token: {hashed_token}")
new_verification_token = await self.db.litellm_verificationtoken.upsert( # type: ignore new_verification_token = await self.db.litellm_verificationtoken.upsert( # type: ignore
where={ where={
'token': hashed_token, 'token': hashed_token,
}, },
data={ data={
"create": {**data}, #type: ignore "create": {**data}, #type: ignore
"update": {} # don't do anything if it already exists "update": {} # don't do anything if it already exists
} }
) )
return new_verification_token return new_verification_token
except Exception as e:
asyncio.create_task(self.proxy_logging_obj.failure_handler(original_exception=e))
async def update_data(self, token: str, data: dict): async def update_data(self, token: str, data: dict):
""" """
@ -82,6 +154,7 @@ class PrismaClient:
print("\033[91m" + f"DB write succeeded" + "\033[0m") print("\033[91m" + f"DB write succeeded" + "\033[0m")
return {"token": token, "data": data} return {"token": token, "data": data}
except Exception as e: except Exception as e:
asyncio.create_task(self.proxy_logging_obj.failure_handler(original_exception=e))
print() print()
print() print()
print() print()
@ -90,21 +163,31 @@ class PrismaClient:
print() print()
print() print()
async def delete_data(self, tokens: List): async def delete_data(self, tokens: List):
""" """
Allow user to delete a key(s) Allow user to delete a key(s)
""" """
hashed_tokens = [self.hash_token(token=token) for token in tokens] try:
await self.db.litellm_verificationtoken.delete_many( hashed_tokens = [self.hash_token(token=token) for token in tokens]
where={"token": {"in": hashed_tokens}} await self.db.litellm_verificationtoken.delete_many(
) where={"token": {"in": hashed_tokens}}
return {"deleted_keys": tokens} )
return {"deleted_keys": tokens}
except Exception as e:
asyncio.create_task(self.proxy_logging_obj.failure_handler(original_exception=e))
async def connect(self): async def connect(self):
await self.db.connect() try:
await self.db.connect()
except Exception as e:
asyncio.create_task(self.proxy_logging_obj.failure_handler(original_exception=e))
async def disconnect(self): async def disconnect(self):
await self.db.disconnect() try:
await self.db.disconnect()
except Exception as e:
asyncio.create_task(self.proxy_logging_obj.failure_handler(original_exception=e))
### CUSTOM FILE ### ### CUSTOM FILE ###
def get_instance_fn(value: str, config_file_path: Optional[str] = None) -> Any: def get_instance_fn(value: str, config_file_path: Optional[str] = None) -> Any:
@ -142,6 +225,8 @@ def get_instance_fn(value: str, config_file_path: Optional[str] = None) -> Any:
except Exception as e: except Exception as e:
raise e raise e
### CALL HOOKS ### ### CALL HOOKS ###
class CallHooks: class CallHooks:
""" """

View file

@ -1,82 +1,82 @@
import openai, json, time, asyncio # import openai, json, time, asyncio
client = openai.AsyncOpenAI( # client = openai.AsyncOpenAI(
api_key="sk-1234", # api_key="sk-1234",
base_url="http://0.0.0.0:8000" # base_url="http://0.0.0.0:8000"
) # )
super_fake_messages = [ # super_fake_messages = [
{ # {
"role": "user", # "role": "user",
"content": f"What's the weather like in San Francisco, Tokyo, and Paris? {time.time()}" # "content": f"What's the weather like in San Francisco, Tokyo, and Paris? {time.time()}"
}, # },
{ # {
"content": None, # "content": None,
"role": "assistant", # "role": "assistant",
"tool_calls": [ # "tool_calls": [
{ # {
"id": "1", # "id": "1",
"function": { # "function": {
"arguments": "{\"location\": \"San Francisco\", \"unit\": \"celsius\"}", # "arguments": "{\"location\": \"San Francisco\", \"unit\": \"celsius\"}",
"name": "get_current_weather" # "name": "get_current_weather"
}, # },
"type": "function" # "type": "function"
}, # },
{ # {
"id": "2", # "id": "2",
"function": { # "function": {
"arguments": "{\"location\": \"Tokyo\", \"unit\": \"celsius\"}", # "arguments": "{\"location\": \"Tokyo\", \"unit\": \"celsius\"}",
"name": "get_current_weather" # "name": "get_current_weather"
}, # },
"type": "function" # "type": "function"
}, # },
{ # {
"id": "3", # "id": "3",
"function": { # "function": {
"arguments": "{\"location\": \"Paris\", \"unit\": \"celsius\"}", # "arguments": "{\"location\": \"Paris\", \"unit\": \"celsius\"}",
"name": "get_current_weather" # "name": "get_current_weather"
}, # },
"type": "function" # "type": "function"
} # }
] # ]
}, # },
{ # {
"tool_call_id": "1", # "tool_call_id": "1",
"role": "tool", # "role": "tool",
"name": "get_current_weather", # "name": "get_current_weather",
"content": "{\"location\": \"San Francisco\", \"temperature\": \"90\", \"unit\": \"celsius\"}" # "content": "{\"location\": \"San Francisco\", \"temperature\": \"90\", \"unit\": \"celsius\"}"
}, # },
{ # {
"tool_call_id": "2", # "tool_call_id": "2",
"role": "tool", # "role": "tool",
"name": "get_current_weather", # "name": "get_current_weather",
"content": "{\"location\": \"Tokyo\", \"temperature\": \"30\", \"unit\": \"celsius\"}" # "content": "{\"location\": \"Tokyo\", \"temperature\": \"30\", \"unit\": \"celsius\"}"
}, # },
{ # {
"tool_call_id": "3", # "tool_call_id": "3",
"role": "tool", # "role": "tool",
"name": "get_current_weather", # "name": "get_current_weather",
"content": "{\"location\": \"Paris\", \"temperature\": \"50\", \"unit\": \"celsius\"}" # "content": "{\"location\": \"Paris\", \"temperature\": \"50\", \"unit\": \"celsius\"}"
} # }
] # ]
async def chat_completions(): # async def chat_completions():
super_fake_response = await client.chat.completions.create( # super_fake_response = await client.chat.completions.create(
model="gpt-3.5-turbo", # model="gpt-3.5-turbo",
messages=super_fake_messages, # messages=super_fake_messages,
seed=1337, # seed=1337,
stream=False # stream=False
) # get a new response from the model where it can see the function response # ) # get a new response from the model where it can see the function response
await asyncio.sleep(1) # await asyncio.sleep(1)
return super_fake_response # return super_fake_response
async def loadtest_fn(n = 2000): # async def loadtest_fn(n = 1):
global num_task_cancelled_errors, exception_counts, chat_completions # global num_task_cancelled_errors, exception_counts, chat_completions
start = time.time() # start = time.time()
tasks = [chat_completions() for _ in range(n)] # tasks = [chat_completions() for _ in range(n)]
chat_completions = await asyncio.gather(*tasks) # chat_completions = await asyncio.gather(*tasks)
successful_completions = [c for c in chat_completions if c is not None] # successful_completions = [c for c in chat_completions if c is not None]
print(n, time.time() - start, len(successful_completions)) # print(n, time.time() - start, len(successful_completions))
# print(json.dumps(super_fake_response.model_dump(), indent=4)) # # print(json.dumps(super_fake_response.model_dump(), indent=4))
asyncio.run(loadtest_fn()) # asyncio.run(loadtest_fn())

View file

@ -1127,6 +1127,7 @@ class Logging:
f"LiteLLM.LoggingError: [Non-Blocking] Exception occurred while failure logging {traceback.format_exc()}" f"LiteLLM.LoggingError: [Non-Blocking] Exception occurred while failure logging {traceback.format_exc()}"
) )
pass pass
async def async_failure_handler(self, exception, traceback_exception, start_time=None, end_time=None): async def async_failure_handler(self, exception, traceback_exception, start_time=None, end_time=None):
""" """
Implementing async callbacks, to handle asyncio event loop issues when custom integrations need to use async functions. Implementing async callbacks, to handle asyncio event loop issues when custom integrations need to use async functions.