mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-25 18:54:30 +00:00
feat(proxy_server.py): add sentry logging for db read/writes
This commit is contained in:
parent
4e6a8d09d0
commit
7aec95ed7c
4 changed files with 208 additions and 118 deletions
|
@ -94,7 +94,8 @@ import litellm
|
|||
from litellm.proxy.utils import (
|
||||
PrismaClient,
|
||||
get_instance_fn,
|
||||
CallHooks
|
||||
CallHooks,
|
||||
ProxyLogging
|
||||
)
|
||||
import pydantic
|
||||
from litellm.proxy._types import *
|
||||
|
@ -198,6 +199,7 @@ use_background_health_checks = None
|
|||
health_check_interval = None
|
||||
health_check_results = {}
|
||||
call_hooks = CallHooks()
|
||||
proxy_logging_obj: Optional[ProxyLogging] = None
|
||||
### REDIS QUEUE ###
|
||||
async_result = None
|
||||
celery_app_conn = None
|
||||
|
@ -300,10 +302,10 @@ async def user_api_key_auth(request: Request, api_key: str = fastapi.Security(ap
|
|||
)
|
||||
|
||||
def prisma_setup(database_url: Optional[str]):
|
||||
global prisma_client
|
||||
if database_url:
|
||||
global prisma_client, proxy_logging_obj
|
||||
if database_url is not None and proxy_logging_obj is not None:
|
||||
try:
|
||||
prisma_client = PrismaClient(database_url=database_url)
|
||||
prisma_client = PrismaClient(database_url=database_url, proxy_logging_obj=proxy_logging_obj)
|
||||
except Exception as e:
|
||||
print("Error when initializing prisma, Ensure you run pip install prisma", e)
|
||||
|
||||
|
@ -839,11 +841,13 @@ async def rate_limit_per_token(request: Request, call_next):
|
|||
|
||||
@router.on_event("startup")
|
||||
async def startup_event():
|
||||
global prisma_client, master_key, use_background_health_checks
|
||||
global prisma_client, master_key, use_background_health_checks, proxy_logging_obj
|
||||
import json
|
||||
### INITIALIZE GLOBAL LOGGING OBJECT ###
|
||||
proxy_logging_obj = ProxyLogging()
|
||||
|
||||
### LOAD CONFIG ###
|
||||
worker_config = litellm.get_secret("WORKER_CONFIG")
|
||||
print(f"worker_config: {worker_config}")
|
||||
print_verbose(f"worker_config: {worker_config}")
|
||||
# check if it's a valid file path
|
||||
if os.path.isfile(worker_config):
|
||||
|
|
|
@ -1,9 +1,72 @@
|
|||
from typing import Optional, List, Any, Literal
|
||||
import os, subprocess, hashlib, importlib
|
||||
import os, subprocess, hashlib, importlib, asyncio
|
||||
import litellm
|
||||
|
||||
### LOGGING ###
|
||||
class ProxyLogging:
|
||||
"""
|
||||
Logging for proxy.
|
||||
|
||||
Implemented mainly to log successful/failed db read/writes.
|
||||
|
||||
Currently just logs this to a provided sentry integration.
|
||||
"""
|
||||
|
||||
def __init__(self,):
|
||||
## INITIALIZE LITELLM CALLBACKS ##
|
||||
self._init_litellm_callbacks()
|
||||
pass
|
||||
|
||||
|
||||
def _init_litellm_callbacks(self):
|
||||
if len(litellm.callbacks) > 0:
|
||||
for callback in litellm.callbacks:
|
||||
if callback not in litellm.input_callback:
|
||||
litellm.input_callback.append(callback)
|
||||
if callback not in litellm.success_callback:
|
||||
litellm.success_callback.append(callback)
|
||||
if callback not in litellm.failure_callback:
|
||||
litellm.failure_callback.append(callback)
|
||||
if callback not in litellm._async_success_callback:
|
||||
litellm._async_success_callback.append(callback)
|
||||
if callback not in litellm._async_failure_callback:
|
||||
litellm._async_failure_callback.append(callback)
|
||||
if (
|
||||
len(litellm.input_callback) > 0
|
||||
or len(litellm.success_callback) > 0
|
||||
or len(litellm.failure_callback) > 0
|
||||
) and len(callback_list) == 0:
|
||||
callback_list = list(
|
||||
set(
|
||||
litellm.input_callback
|
||||
+ litellm.success_callback
|
||||
+ litellm.failure_callback
|
||||
)
|
||||
)
|
||||
litellm.utils.set_callbacks(
|
||||
callback_list=callback_list
|
||||
)
|
||||
|
||||
async def success_handler(self, *args, **kwargs):
|
||||
"""
|
||||
Log successful db read/writes
|
||||
"""
|
||||
pass
|
||||
|
||||
async def failure_handler(self, original_exception):
|
||||
"""
|
||||
Log failed db read/writes
|
||||
|
||||
Currently only logs exceptions to sentry
|
||||
"""
|
||||
print(f"reaches failure handler logging - {original_exception}; sentry: {litellm.utils.capture_exception}")
|
||||
if litellm.utils.capture_exception:
|
||||
litellm.utils.capture_exception(error=original_exception)
|
||||
|
||||
|
||||
### DB CONNECTOR ###
|
||||
class PrismaClient:
|
||||
def __init__(self, database_url: str):
|
||||
def __init__(self, database_url: str, proxy_logging_obj: ProxyLogging):
|
||||
print("LiteLLM: DATABASE_URL Set in config, trying to 'pip install prisma'")
|
||||
os.environ["DATABASE_URL"] = database_url
|
||||
# Save the current working directory
|
||||
|
@ -22,6 +85,9 @@ class PrismaClient:
|
|||
from prisma import Client # type: ignore
|
||||
self.db = Client() #Client to connect to Prisma db
|
||||
|
||||
## init logging object
|
||||
self.proxy_logging_obj = proxy_logging_obj
|
||||
|
||||
def hash_token(self, token: str):
|
||||
# Hash the string using SHA-256
|
||||
hashed_token = hashlib.sha256(token.encode()).hexdigest()
|
||||
|
@ -29,6 +95,7 @@ class PrismaClient:
|
|||
return hashed_token
|
||||
|
||||
async def get_data(self, token: str, expires: Optional[Any]=None):
|
||||
try:
|
||||
hashed_token = self.hash_token(token=token)
|
||||
if expires:
|
||||
response = await self.db.litellm_verificationtoken.find_first(
|
||||
|
@ -44,11 +111,14 @@ class PrismaClient:
|
|||
}
|
||||
)
|
||||
return response
|
||||
except Exception as e:
|
||||
asyncio.create_task(self.proxy_logging_obj.failure_handler(original_exception=e))
|
||||
|
||||
async def insert_data(self, data: dict):
|
||||
"""
|
||||
Add a key to the database. If it already exists, do nothing.
|
||||
"""
|
||||
try:
|
||||
token = data["token"]
|
||||
hashed_token = self.hash_token(token=token)
|
||||
data["token"] = hashed_token
|
||||
|
@ -65,6 +135,8 @@ class PrismaClient:
|
|||
)
|
||||
|
||||
return new_verification_token
|
||||
except Exception as e:
|
||||
asyncio.create_task(self.proxy_logging_obj.failure_handler(original_exception=e))
|
||||
|
||||
async def update_data(self, token: str, data: dict):
|
||||
"""
|
||||
|
@ -82,6 +154,7 @@ class PrismaClient:
|
|||
print("\033[91m" + f"DB write succeeded" + "\033[0m")
|
||||
return {"token": token, "data": data}
|
||||
except Exception as e:
|
||||
asyncio.create_task(self.proxy_logging_obj.failure_handler(original_exception=e))
|
||||
print()
|
||||
print()
|
||||
print()
|
||||
|
@ -90,21 +163,31 @@ class PrismaClient:
|
|||
print()
|
||||
print()
|
||||
|
||||
|
||||
async def delete_data(self, tokens: List):
|
||||
"""
|
||||
Allow user to delete a key(s)
|
||||
"""
|
||||
try:
|
||||
hashed_tokens = [self.hash_token(token=token) for token in tokens]
|
||||
await self.db.litellm_verificationtoken.delete_many(
|
||||
where={"token": {"in": hashed_tokens}}
|
||||
)
|
||||
return {"deleted_keys": tokens}
|
||||
except Exception as e:
|
||||
asyncio.create_task(self.proxy_logging_obj.failure_handler(original_exception=e))
|
||||
|
||||
async def connect(self):
|
||||
try:
|
||||
await self.db.connect()
|
||||
except Exception as e:
|
||||
asyncio.create_task(self.proxy_logging_obj.failure_handler(original_exception=e))
|
||||
|
||||
async def disconnect(self):
|
||||
try:
|
||||
await self.db.disconnect()
|
||||
except Exception as e:
|
||||
asyncio.create_task(self.proxy_logging_obj.failure_handler(original_exception=e))
|
||||
|
||||
### CUSTOM FILE ###
|
||||
def get_instance_fn(value: str, config_file_path: Optional[str] = None) -> Any:
|
||||
|
@ -142,6 +225,8 @@ def get_instance_fn(value: str, config_file_path: Optional[str] = None) -> Any:
|
|||
except Exception as e:
|
||||
raise e
|
||||
|
||||
|
||||
|
||||
### CALL HOOKS ###
|
||||
class CallHooks:
|
||||
"""
|
||||
|
|
|
@ -1,82 +1,82 @@
|
|||
import openai, json, time, asyncio
|
||||
client = openai.AsyncOpenAI(
|
||||
api_key="sk-1234",
|
||||
base_url="http://0.0.0.0:8000"
|
||||
)
|
||||
# import openai, json, time, asyncio
|
||||
# client = openai.AsyncOpenAI(
|
||||
# api_key="sk-1234",
|
||||
# base_url="http://0.0.0.0:8000"
|
||||
# )
|
||||
|
||||
super_fake_messages = [
|
||||
{
|
||||
"role": "user",
|
||||
"content": f"What's the weather like in San Francisco, Tokyo, and Paris? {time.time()}"
|
||||
},
|
||||
{
|
||||
"content": None,
|
||||
"role": "assistant",
|
||||
"tool_calls": [
|
||||
{
|
||||
"id": "1",
|
||||
"function": {
|
||||
"arguments": "{\"location\": \"San Francisco\", \"unit\": \"celsius\"}",
|
||||
"name": "get_current_weather"
|
||||
},
|
||||
"type": "function"
|
||||
},
|
||||
{
|
||||
"id": "2",
|
||||
"function": {
|
||||
"arguments": "{\"location\": \"Tokyo\", \"unit\": \"celsius\"}",
|
||||
"name": "get_current_weather"
|
||||
},
|
||||
"type": "function"
|
||||
},
|
||||
{
|
||||
"id": "3",
|
||||
"function": {
|
||||
"arguments": "{\"location\": \"Paris\", \"unit\": \"celsius\"}",
|
||||
"name": "get_current_weather"
|
||||
},
|
||||
"type": "function"
|
||||
}
|
||||
]
|
||||
},
|
||||
{
|
||||
"tool_call_id": "1",
|
||||
"role": "tool",
|
||||
"name": "get_current_weather",
|
||||
"content": "{\"location\": \"San Francisco\", \"temperature\": \"90\", \"unit\": \"celsius\"}"
|
||||
},
|
||||
{
|
||||
"tool_call_id": "2",
|
||||
"role": "tool",
|
||||
"name": "get_current_weather",
|
||||
"content": "{\"location\": \"Tokyo\", \"temperature\": \"30\", \"unit\": \"celsius\"}"
|
||||
},
|
||||
{
|
||||
"tool_call_id": "3",
|
||||
"role": "tool",
|
||||
"name": "get_current_weather",
|
||||
"content": "{\"location\": \"Paris\", \"temperature\": \"50\", \"unit\": \"celsius\"}"
|
||||
}
|
||||
]
|
||||
# super_fake_messages = [
|
||||
# {
|
||||
# "role": "user",
|
||||
# "content": f"What's the weather like in San Francisco, Tokyo, and Paris? {time.time()}"
|
||||
# },
|
||||
# {
|
||||
# "content": None,
|
||||
# "role": "assistant",
|
||||
# "tool_calls": [
|
||||
# {
|
||||
# "id": "1",
|
||||
# "function": {
|
||||
# "arguments": "{\"location\": \"San Francisco\", \"unit\": \"celsius\"}",
|
||||
# "name": "get_current_weather"
|
||||
# },
|
||||
# "type": "function"
|
||||
# },
|
||||
# {
|
||||
# "id": "2",
|
||||
# "function": {
|
||||
# "arguments": "{\"location\": \"Tokyo\", \"unit\": \"celsius\"}",
|
||||
# "name": "get_current_weather"
|
||||
# },
|
||||
# "type": "function"
|
||||
# },
|
||||
# {
|
||||
# "id": "3",
|
||||
# "function": {
|
||||
# "arguments": "{\"location\": \"Paris\", \"unit\": \"celsius\"}",
|
||||
# "name": "get_current_weather"
|
||||
# },
|
||||
# "type": "function"
|
||||
# }
|
||||
# ]
|
||||
# },
|
||||
# {
|
||||
# "tool_call_id": "1",
|
||||
# "role": "tool",
|
||||
# "name": "get_current_weather",
|
||||
# "content": "{\"location\": \"San Francisco\", \"temperature\": \"90\", \"unit\": \"celsius\"}"
|
||||
# },
|
||||
# {
|
||||
# "tool_call_id": "2",
|
||||
# "role": "tool",
|
||||
# "name": "get_current_weather",
|
||||
# "content": "{\"location\": \"Tokyo\", \"temperature\": \"30\", \"unit\": \"celsius\"}"
|
||||
# },
|
||||
# {
|
||||
# "tool_call_id": "3",
|
||||
# "role": "tool",
|
||||
# "name": "get_current_weather",
|
||||
# "content": "{\"location\": \"Paris\", \"temperature\": \"50\", \"unit\": \"celsius\"}"
|
||||
# }
|
||||
# ]
|
||||
|
||||
async def chat_completions():
|
||||
super_fake_response = await client.chat.completions.create(
|
||||
model="gpt-3.5-turbo",
|
||||
messages=super_fake_messages,
|
||||
seed=1337,
|
||||
stream=False
|
||||
) # get a new response from the model where it can see the function response
|
||||
await asyncio.sleep(1)
|
||||
return super_fake_response
|
||||
# async def chat_completions():
|
||||
# super_fake_response = await client.chat.completions.create(
|
||||
# model="gpt-3.5-turbo",
|
||||
# messages=super_fake_messages,
|
||||
# seed=1337,
|
||||
# stream=False
|
||||
# ) # get a new response from the model where it can see the function response
|
||||
# await asyncio.sleep(1)
|
||||
# return super_fake_response
|
||||
|
||||
async def loadtest_fn(n = 2000):
|
||||
global num_task_cancelled_errors, exception_counts, chat_completions
|
||||
start = time.time()
|
||||
tasks = [chat_completions() for _ in range(n)]
|
||||
chat_completions = await asyncio.gather(*tasks)
|
||||
successful_completions = [c for c in chat_completions if c is not None]
|
||||
print(n, time.time() - start, len(successful_completions))
|
||||
# async def loadtest_fn(n = 1):
|
||||
# global num_task_cancelled_errors, exception_counts, chat_completions
|
||||
# start = time.time()
|
||||
# tasks = [chat_completions() for _ in range(n)]
|
||||
# chat_completions = await asyncio.gather(*tasks)
|
||||
# successful_completions = [c for c in chat_completions if c is not None]
|
||||
# print(n, time.time() - start, len(successful_completions))
|
||||
|
||||
# print(json.dumps(super_fake_response.model_dump(), indent=4))
|
||||
# # print(json.dumps(super_fake_response.model_dump(), indent=4))
|
||||
|
||||
asyncio.run(loadtest_fn())
|
||||
# asyncio.run(loadtest_fn())
|
|
@ -1127,6 +1127,7 @@ class Logging:
|
|||
f"LiteLLM.LoggingError: [Non-Blocking] Exception occurred while failure logging {traceback.format_exc()}"
|
||||
)
|
||||
pass
|
||||
|
||||
async def async_failure_handler(self, exception, traceback_exception, start_time=None, end_time=None):
|
||||
"""
|
||||
Implementing async callbacks, to handle asyncio event loop issues when custom integrations need to use async functions.
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue