feat(proxy_server.py): add sentry logging for db read/writes

This commit is contained in:
Krrish Dholakia 2023-12-08 11:40:19 -08:00
parent 4e6a8d09d0
commit 7aec95ed7c
4 changed files with 208 additions and 118 deletions

View file

@ -94,7 +94,8 @@ import litellm
from litellm.proxy.utils import (
PrismaClient,
get_instance_fn,
CallHooks
CallHooks,
ProxyLogging
)
import pydantic
from litellm.proxy._types import *
@ -198,6 +199,7 @@ use_background_health_checks = None
health_check_interval = None
health_check_results = {}
call_hooks = CallHooks()
proxy_logging_obj: Optional[ProxyLogging] = None
### REDIS QUEUE ###
async_result = None
celery_app_conn = None
@ -300,10 +302,10 @@ async def user_api_key_auth(request: Request, api_key: str = fastapi.Security(ap
)
def prisma_setup(database_url: Optional[str]):
global prisma_client
if database_url:
global prisma_client, proxy_logging_obj
if database_url is not None and proxy_logging_obj is not None:
try:
prisma_client = PrismaClient(database_url=database_url)
prisma_client = PrismaClient(database_url=database_url, proxy_logging_obj=proxy_logging_obj)
except Exception as e:
print("Error when initializing prisma, Ensure you run pip install prisma", e)
@ -839,11 +841,13 @@ async def rate_limit_per_token(request: Request, call_next):
@router.on_event("startup")
async def startup_event():
global prisma_client, master_key, use_background_health_checks
global prisma_client, master_key, use_background_health_checks, proxy_logging_obj
import json
### INITIALIZE GLOBAL LOGGING OBJECT ###
proxy_logging_obj = ProxyLogging()
### LOAD CONFIG ###
worker_config = litellm.get_secret("WORKER_CONFIG")
print(f"worker_config: {worker_config}")
print_verbose(f"worker_config: {worker_config}")
# check if it's a valid file path
if os.path.isfile(worker_config):

View file

@ -1,9 +1,72 @@
from typing import Optional, List, Any, Literal
import os, subprocess, hashlib, importlib
import os, subprocess, hashlib, importlib, asyncio
import litellm
### LOGGING ###
class ProxyLogging:
"""
Logging for proxy.
Implemented mainly to log successful/failed db read/writes.
Currently just logs this to a provided sentry integration.
"""
def __init__(self,):
## INITIALIZE LITELLM CALLBACKS ##
self._init_litellm_callbacks()
pass
def _init_litellm_callbacks(self):
if len(litellm.callbacks) > 0:
for callback in litellm.callbacks:
if callback not in litellm.input_callback:
litellm.input_callback.append(callback)
if callback not in litellm.success_callback:
litellm.success_callback.append(callback)
if callback not in litellm.failure_callback:
litellm.failure_callback.append(callback)
if callback not in litellm._async_success_callback:
litellm._async_success_callback.append(callback)
if callback not in litellm._async_failure_callback:
litellm._async_failure_callback.append(callback)
if (
len(litellm.input_callback) > 0
or len(litellm.success_callback) > 0
or len(litellm.failure_callback) > 0
) and len(callback_list) == 0:
callback_list = list(
set(
litellm.input_callback
+ litellm.success_callback
+ litellm.failure_callback
)
)
litellm.utils.set_callbacks(
callback_list=callback_list
)
async def success_handler(self, *args, **kwargs):
"""
Log successful db read/writes
"""
pass
async def failure_handler(self, original_exception):
"""
Log failed db read/writes
Currently only logs exceptions to sentry
"""
print(f"reaches failure handler logging - {original_exception}; sentry: {litellm.utils.capture_exception}")
if litellm.utils.capture_exception:
litellm.utils.capture_exception(error=original_exception)
### DB CONNECTOR ###
class PrismaClient:
def __init__(self, database_url: str):
def __init__(self, database_url: str, proxy_logging_obj: ProxyLogging):
print("LiteLLM: DATABASE_URL Set in config, trying to 'pip install prisma'")
os.environ["DATABASE_URL"] = database_url
# Save the current working directory
@ -22,6 +85,9 @@ class PrismaClient:
from prisma import Client # type: ignore
self.db = Client() #Client to connect to Prisma db
## init logging object
self.proxy_logging_obj = proxy_logging_obj
def hash_token(self, token: str):
# Hash the string using SHA-256
hashed_token = hashlib.sha256(token.encode()).hexdigest()
@ -29,42 +95,48 @@ class PrismaClient:
return hashed_token
async def get_data(self, token: str, expires: Optional[Any]=None):
hashed_token = self.hash_token(token=token)
if expires:
response = await self.db.litellm_verificationtoken.find_first(
try:
hashed_token = self.hash_token(token=token)
if expires:
response = await self.db.litellm_verificationtoken.find_first(
where={
"token": hashed_token,
"expires": {"gte": expires} # Check if the token is not expired
}
)
else:
response = await self.db.litellm_verificationtoken.find_unique(
where={
"token": hashed_token,
"expires": {"gte": expires} # Check if the token is not expired
"token": hashed_token
}
)
else:
response = await self.db.litellm_verificationtoken.find_unique(
where={
"token": hashed_token
}
)
return response
return response
except Exception as e:
asyncio.create_task(self.proxy_logging_obj.failure_handler(original_exception=e))
async def insert_data(self, data: dict):
"""
Add a key to the database. If it already exists, do nothing.
"""
token = data["token"]
hashed_token = self.hash_token(token=token)
data["token"] = hashed_token
print(f"passed in data: {data}; hashed_token: {hashed_token}")
try:
token = data["token"]
hashed_token = self.hash_token(token=token)
data["token"] = hashed_token
print(f"passed in data: {data}; hashed_token: {hashed_token}")
new_verification_token = await self.db.litellm_verificationtoken.upsert( # type: ignore
where={
'token': hashed_token,
},
data={
"create": {**data}, #type: ignore
"update": {} # don't do anything if it already exists
}
)
new_verification_token = await self.db.litellm_verificationtoken.upsert( # type: ignore
where={
'token': hashed_token,
},
data={
"create": {**data}, #type: ignore
"update": {} # don't do anything if it already exists
}
)
return new_verification_token
return new_verification_token
except Exception as e:
asyncio.create_task(self.proxy_logging_obj.failure_handler(original_exception=e))
async def update_data(self, token: str, data: dict):
"""
@ -82,6 +154,7 @@ class PrismaClient:
print("\033[91m" + f"DB write succeeded" + "\033[0m")
return {"token": token, "data": data}
except Exception as e:
asyncio.create_task(self.proxy_logging_obj.failure_handler(original_exception=e))
print()
print()
print()
@ -90,21 +163,31 @@ class PrismaClient:
print()
print()
async def delete_data(self, tokens: List):
"""
Allow user to delete a key(s)
"""
hashed_tokens = [self.hash_token(token=token) for token in tokens]
await self.db.litellm_verificationtoken.delete_many(
where={"token": {"in": hashed_tokens}}
)
return {"deleted_keys": tokens}
try:
hashed_tokens = [self.hash_token(token=token) for token in tokens]
await self.db.litellm_verificationtoken.delete_many(
where={"token": {"in": hashed_tokens}}
)
return {"deleted_keys": tokens}
except Exception as e:
asyncio.create_task(self.proxy_logging_obj.failure_handler(original_exception=e))
async def connect(self):
await self.db.connect()
try:
await self.db.connect()
except Exception as e:
asyncio.create_task(self.proxy_logging_obj.failure_handler(original_exception=e))
async def disconnect(self):
await self.db.disconnect()
try:
await self.db.disconnect()
except Exception as e:
asyncio.create_task(self.proxy_logging_obj.failure_handler(original_exception=e))
### CUSTOM FILE ###
def get_instance_fn(value: str, config_file_path: Optional[str] = None) -> Any:
@ -142,6 +225,8 @@ def get_instance_fn(value: str, config_file_path: Optional[str] = None) -> Any:
except Exception as e:
raise e
### CALL HOOKS ###
class CallHooks:
"""

View file

@ -1,82 +1,82 @@
import openai, json, time, asyncio
client = openai.AsyncOpenAI(
api_key="sk-1234",
base_url="http://0.0.0.0:8000"
)
# import openai, json, time, asyncio
# client = openai.AsyncOpenAI(
# api_key="sk-1234",
# base_url="http://0.0.0.0:8000"
# )
super_fake_messages = [
{
"role": "user",
"content": f"What's the weather like in San Francisco, Tokyo, and Paris? {time.time()}"
},
{
"content": None,
"role": "assistant",
"tool_calls": [
{
"id": "1",
"function": {
"arguments": "{\"location\": \"San Francisco\", \"unit\": \"celsius\"}",
"name": "get_current_weather"
},
"type": "function"
},
{
"id": "2",
"function": {
"arguments": "{\"location\": \"Tokyo\", \"unit\": \"celsius\"}",
"name": "get_current_weather"
},
"type": "function"
},
{
"id": "3",
"function": {
"arguments": "{\"location\": \"Paris\", \"unit\": \"celsius\"}",
"name": "get_current_weather"
},
"type": "function"
}
]
},
{
"tool_call_id": "1",
"role": "tool",
"name": "get_current_weather",
"content": "{\"location\": \"San Francisco\", \"temperature\": \"90\", \"unit\": \"celsius\"}"
},
{
"tool_call_id": "2",
"role": "tool",
"name": "get_current_weather",
"content": "{\"location\": \"Tokyo\", \"temperature\": \"30\", \"unit\": \"celsius\"}"
},
{
"tool_call_id": "3",
"role": "tool",
"name": "get_current_weather",
"content": "{\"location\": \"Paris\", \"temperature\": \"50\", \"unit\": \"celsius\"}"
}
]
# super_fake_messages = [
# {
# "role": "user",
# "content": f"What's the weather like in San Francisco, Tokyo, and Paris? {time.time()}"
# },
# {
# "content": None,
# "role": "assistant",
# "tool_calls": [
# {
# "id": "1",
# "function": {
# "arguments": "{\"location\": \"San Francisco\", \"unit\": \"celsius\"}",
# "name": "get_current_weather"
# },
# "type": "function"
# },
# {
# "id": "2",
# "function": {
# "arguments": "{\"location\": \"Tokyo\", \"unit\": \"celsius\"}",
# "name": "get_current_weather"
# },
# "type": "function"
# },
# {
# "id": "3",
# "function": {
# "arguments": "{\"location\": \"Paris\", \"unit\": \"celsius\"}",
# "name": "get_current_weather"
# },
# "type": "function"
# }
# ]
# },
# {
# "tool_call_id": "1",
# "role": "tool",
# "name": "get_current_weather",
# "content": "{\"location\": \"San Francisco\", \"temperature\": \"90\", \"unit\": \"celsius\"}"
# },
# {
# "tool_call_id": "2",
# "role": "tool",
# "name": "get_current_weather",
# "content": "{\"location\": \"Tokyo\", \"temperature\": \"30\", \"unit\": \"celsius\"}"
# },
# {
# "tool_call_id": "3",
# "role": "tool",
# "name": "get_current_weather",
# "content": "{\"location\": \"Paris\", \"temperature\": \"50\", \"unit\": \"celsius\"}"
# }
# ]
async def chat_completions():
super_fake_response = await client.chat.completions.create(
model="gpt-3.5-turbo",
messages=super_fake_messages,
seed=1337,
stream=False
) # get a new response from the model where it can see the function response
await asyncio.sleep(1)
return super_fake_response
# async def chat_completions():
# super_fake_response = await client.chat.completions.create(
# model="gpt-3.5-turbo",
# messages=super_fake_messages,
# seed=1337,
# stream=False
# ) # get a new response from the model where it can see the function response
# await asyncio.sleep(1)
# return super_fake_response
async def loadtest_fn(n = 2000):
global num_task_cancelled_errors, exception_counts, chat_completions
start = time.time()
tasks = [chat_completions() for _ in range(n)]
chat_completions = await asyncio.gather(*tasks)
successful_completions = [c for c in chat_completions if c is not None]
print(n, time.time() - start, len(successful_completions))
# async def loadtest_fn(n = 1):
# global num_task_cancelled_errors, exception_counts, chat_completions
# start = time.time()
# tasks = [chat_completions() for _ in range(n)]
# chat_completions = await asyncio.gather(*tasks)
# successful_completions = [c for c in chat_completions if c is not None]
# print(n, time.time() - start, len(successful_completions))
# print(json.dumps(super_fake_response.model_dump(), indent=4))
# # print(json.dumps(super_fake_response.model_dump(), indent=4))
asyncio.run(loadtest_fn())
# asyncio.run(loadtest_fn())

View file

@ -1127,6 +1127,7 @@ class Logging:
f"LiteLLM.LoggingError: [Non-Blocking] Exception occurred while failure logging {traceback.format_exc()}"
)
pass
async def async_failure_handler(self, exception, traceback_exception, start_time=None, end_time=None):
"""
Implementing async callbacks, to handle asyncio event loop issues when custom integrations need to use async functions.