forked from phoenix/litellm-mirror
feat(proxy_server.py): add sentry logging for db read/writes
This commit is contained in:
parent
4e6a8d09d0
commit
7aec95ed7c
4 changed files with 208 additions and 118 deletions
|
@ -94,7 +94,8 @@ import litellm
|
||||||
from litellm.proxy.utils import (
|
from litellm.proxy.utils import (
|
||||||
PrismaClient,
|
PrismaClient,
|
||||||
get_instance_fn,
|
get_instance_fn,
|
||||||
CallHooks
|
CallHooks,
|
||||||
|
ProxyLogging
|
||||||
)
|
)
|
||||||
import pydantic
|
import pydantic
|
||||||
from litellm.proxy._types import *
|
from litellm.proxy._types import *
|
||||||
|
@ -198,6 +199,7 @@ use_background_health_checks = None
|
||||||
health_check_interval = None
|
health_check_interval = None
|
||||||
health_check_results = {}
|
health_check_results = {}
|
||||||
call_hooks = CallHooks()
|
call_hooks = CallHooks()
|
||||||
|
proxy_logging_obj: Optional[ProxyLogging] = None
|
||||||
### REDIS QUEUE ###
|
### REDIS QUEUE ###
|
||||||
async_result = None
|
async_result = None
|
||||||
celery_app_conn = None
|
celery_app_conn = None
|
||||||
|
@ -300,10 +302,10 @@ async def user_api_key_auth(request: Request, api_key: str = fastapi.Security(ap
|
||||||
)
|
)
|
||||||
|
|
||||||
def prisma_setup(database_url: Optional[str]):
|
def prisma_setup(database_url: Optional[str]):
|
||||||
global prisma_client
|
global prisma_client, proxy_logging_obj
|
||||||
if database_url:
|
if database_url is not None and proxy_logging_obj is not None:
|
||||||
try:
|
try:
|
||||||
prisma_client = PrismaClient(database_url=database_url)
|
prisma_client = PrismaClient(database_url=database_url, proxy_logging_obj=proxy_logging_obj)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print("Error when initializing prisma, Ensure you run pip install prisma", e)
|
print("Error when initializing prisma, Ensure you run pip install prisma", e)
|
||||||
|
|
||||||
|
@ -839,11 +841,13 @@ async def rate_limit_per_token(request: Request, call_next):
|
||||||
|
|
||||||
@router.on_event("startup")
|
@router.on_event("startup")
|
||||||
async def startup_event():
|
async def startup_event():
|
||||||
global prisma_client, master_key, use_background_health_checks
|
global prisma_client, master_key, use_background_health_checks, proxy_logging_obj
|
||||||
import json
|
import json
|
||||||
|
### INITIALIZE GLOBAL LOGGING OBJECT ###
|
||||||
|
proxy_logging_obj = ProxyLogging()
|
||||||
|
|
||||||
|
### LOAD CONFIG ###
|
||||||
worker_config = litellm.get_secret("WORKER_CONFIG")
|
worker_config = litellm.get_secret("WORKER_CONFIG")
|
||||||
print(f"worker_config: {worker_config}")
|
|
||||||
print_verbose(f"worker_config: {worker_config}")
|
print_verbose(f"worker_config: {worker_config}")
|
||||||
# check if it's a valid file path
|
# check if it's a valid file path
|
||||||
if os.path.isfile(worker_config):
|
if os.path.isfile(worker_config):
|
||||||
|
|
|
@ -1,9 +1,72 @@
|
||||||
from typing import Optional, List, Any, Literal
|
from typing import Optional, List, Any, Literal
|
||||||
import os, subprocess, hashlib, importlib
|
import os, subprocess, hashlib, importlib, asyncio
|
||||||
|
import litellm
|
||||||
|
|
||||||
|
### LOGGING ###
|
||||||
|
class ProxyLogging:
|
||||||
|
"""
|
||||||
|
Logging for proxy.
|
||||||
|
|
||||||
|
Implemented mainly to log successful/failed db read/writes.
|
||||||
|
|
||||||
|
Currently just logs this to a provided sentry integration.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self,):
|
||||||
|
## INITIALIZE LITELLM CALLBACKS ##
|
||||||
|
self._init_litellm_callbacks()
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
def _init_litellm_callbacks(self):
|
||||||
|
if len(litellm.callbacks) > 0:
|
||||||
|
for callback in litellm.callbacks:
|
||||||
|
if callback not in litellm.input_callback:
|
||||||
|
litellm.input_callback.append(callback)
|
||||||
|
if callback not in litellm.success_callback:
|
||||||
|
litellm.success_callback.append(callback)
|
||||||
|
if callback not in litellm.failure_callback:
|
||||||
|
litellm.failure_callback.append(callback)
|
||||||
|
if callback not in litellm._async_success_callback:
|
||||||
|
litellm._async_success_callback.append(callback)
|
||||||
|
if callback not in litellm._async_failure_callback:
|
||||||
|
litellm._async_failure_callback.append(callback)
|
||||||
|
if (
|
||||||
|
len(litellm.input_callback) > 0
|
||||||
|
or len(litellm.success_callback) > 0
|
||||||
|
or len(litellm.failure_callback) > 0
|
||||||
|
) and len(callback_list) == 0:
|
||||||
|
callback_list = list(
|
||||||
|
set(
|
||||||
|
litellm.input_callback
|
||||||
|
+ litellm.success_callback
|
||||||
|
+ litellm.failure_callback
|
||||||
|
)
|
||||||
|
)
|
||||||
|
litellm.utils.set_callbacks(
|
||||||
|
callback_list=callback_list
|
||||||
|
)
|
||||||
|
|
||||||
|
async def success_handler(self, *args, **kwargs):
|
||||||
|
"""
|
||||||
|
Log successful db read/writes
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
async def failure_handler(self, original_exception):
|
||||||
|
"""
|
||||||
|
Log failed db read/writes
|
||||||
|
|
||||||
|
Currently only logs exceptions to sentry
|
||||||
|
"""
|
||||||
|
print(f"reaches failure handler logging - {original_exception}; sentry: {litellm.utils.capture_exception}")
|
||||||
|
if litellm.utils.capture_exception:
|
||||||
|
litellm.utils.capture_exception(error=original_exception)
|
||||||
|
|
||||||
|
|
||||||
### DB CONNECTOR ###
|
### DB CONNECTOR ###
|
||||||
class PrismaClient:
|
class PrismaClient:
|
||||||
def __init__(self, database_url: str):
|
def __init__(self, database_url: str, proxy_logging_obj: ProxyLogging):
|
||||||
print("LiteLLM: DATABASE_URL Set in config, trying to 'pip install prisma'")
|
print("LiteLLM: DATABASE_URL Set in config, trying to 'pip install prisma'")
|
||||||
os.environ["DATABASE_URL"] = database_url
|
os.environ["DATABASE_URL"] = database_url
|
||||||
# Save the current working directory
|
# Save the current working directory
|
||||||
|
@ -22,6 +85,9 @@ class PrismaClient:
|
||||||
from prisma import Client # type: ignore
|
from prisma import Client # type: ignore
|
||||||
self.db = Client() #Client to connect to Prisma db
|
self.db = Client() #Client to connect to Prisma db
|
||||||
|
|
||||||
|
## init logging object
|
||||||
|
self.proxy_logging_obj = proxy_logging_obj
|
||||||
|
|
||||||
def hash_token(self, token: str):
|
def hash_token(self, token: str):
|
||||||
# Hash the string using SHA-256
|
# Hash the string using SHA-256
|
||||||
hashed_token = hashlib.sha256(token.encode()).hexdigest()
|
hashed_token = hashlib.sha256(token.encode()).hexdigest()
|
||||||
|
@ -29,42 +95,48 @@ class PrismaClient:
|
||||||
return hashed_token
|
return hashed_token
|
||||||
|
|
||||||
async def get_data(self, token: str, expires: Optional[Any]=None):
|
async def get_data(self, token: str, expires: Optional[Any]=None):
|
||||||
hashed_token = self.hash_token(token=token)
|
try:
|
||||||
if expires:
|
hashed_token = self.hash_token(token=token)
|
||||||
response = await self.db.litellm_verificationtoken.find_first(
|
if expires:
|
||||||
|
response = await self.db.litellm_verificationtoken.find_first(
|
||||||
|
where={
|
||||||
|
"token": hashed_token,
|
||||||
|
"expires": {"gte": expires} # Check if the token is not expired
|
||||||
|
}
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
response = await self.db.litellm_verificationtoken.find_unique(
|
||||||
where={
|
where={
|
||||||
"token": hashed_token,
|
"token": hashed_token
|
||||||
"expires": {"gte": expires} # Check if the token is not expired
|
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
else:
|
return response
|
||||||
response = await self.db.litellm_verificationtoken.find_unique(
|
except Exception as e:
|
||||||
where={
|
asyncio.create_task(self.proxy_logging_obj.failure_handler(original_exception=e))
|
||||||
"token": hashed_token
|
|
||||||
}
|
|
||||||
)
|
|
||||||
return response
|
|
||||||
|
|
||||||
async def insert_data(self, data: dict):
|
async def insert_data(self, data: dict):
|
||||||
"""
|
"""
|
||||||
Add a key to the database. If it already exists, do nothing.
|
Add a key to the database. If it already exists, do nothing.
|
||||||
"""
|
"""
|
||||||
token = data["token"]
|
try:
|
||||||
hashed_token = self.hash_token(token=token)
|
token = data["token"]
|
||||||
data["token"] = hashed_token
|
hashed_token = self.hash_token(token=token)
|
||||||
print(f"passed in data: {data}; hashed_token: {hashed_token}")
|
data["token"] = hashed_token
|
||||||
|
print(f"passed in data: {data}; hashed_token: {hashed_token}")
|
||||||
|
|
||||||
new_verification_token = await self.db.litellm_verificationtoken.upsert( # type: ignore
|
new_verification_token = await self.db.litellm_verificationtoken.upsert( # type: ignore
|
||||||
where={
|
where={
|
||||||
'token': hashed_token,
|
'token': hashed_token,
|
||||||
},
|
},
|
||||||
data={
|
data={
|
||||||
"create": {**data}, #type: ignore
|
"create": {**data}, #type: ignore
|
||||||
"update": {} # don't do anything if it already exists
|
"update": {} # don't do anything if it already exists
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
return new_verification_token
|
return new_verification_token
|
||||||
|
except Exception as e:
|
||||||
|
asyncio.create_task(self.proxy_logging_obj.failure_handler(original_exception=e))
|
||||||
|
|
||||||
async def update_data(self, token: str, data: dict):
|
async def update_data(self, token: str, data: dict):
|
||||||
"""
|
"""
|
||||||
|
@ -82,6 +154,7 @@ class PrismaClient:
|
||||||
print("\033[91m" + f"DB write succeeded" + "\033[0m")
|
print("\033[91m" + f"DB write succeeded" + "\033[0m")
|
||||||
return {"token": token, "data": data}
|
return {"token": token, "data": data}
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
asyncio.create_task(self.proxy_logging_obj.failure_handler(original_exception=e))
|
||||||
print()
|
print()
|
||||||
print()
|
print()
|
||||||
print()
|
print()
|
||||||
|
@ -90,21 +163,31 @@ class PrismaClient:
|
||||||
print()
|
print()
|
||||||
print()
|
print()
|
||||||
|
|
||||||
|
|
||||||
async def delete_data(self, tokens: List):
|
async def delete_data(self, tokens: List):
|
||||||
"""
|
"""
|
||||||
Allow user to delete a key(s)
|
Allow user to delete a key(s)
|
||||||
"""
|
"""
|
||||||
hashed_tokens = [self.hash_token(token=token) for token in tokens]
|
try:
|
||||||
await self.db.litellm_verificationtoken.delete_many(
|
hashed_tokens = [self.hash_token(token=token) for token in tokens]
|
||||||
where={"token": {"in": hashed_tokens}}
|
await self.db.litellm_verificationtoken.delete_many(
|
||||||
)
|
where={"token": {"in": hashed_tokens}}
|
||||||
return {"deleted_keys": tokens}
|
)
|
||||||
|
return {"deleted_keys": tokens}
|
||||||
|
except Exception as e:
|
||||||
|
asyncio.create_task(self.proxy_logging_obj.failure_handler(original_exception=e))
|
||||||
|
|
||||||
async def connect(self):
|
async def connect(self):
|
||||||
await self.db.connect()
|
try:
|
||||||
|
await self.db.connect()
|
||||||
|
except Exception as e:
|
||||||
|
asyncio.create_task(self.proxy_logging_obj.failure_handler(original_exception=e))
|
||||||
|
|
||||||
async def disconnect(self):
|
async def disconnect(self):
|
||||||
await self.db.disconnect()
|
try:
|
||||||
|
await self.db.disconnect()
|
||||||
|
except Exception as e:
|
||||||
|
asyncio.create_task(self.proxy_logging_obj.failure_handler(original_exception=e))
|
||||||
|
|
||||||
### CUSTOM FILE ###
|
### CUSTOM FILE ###
|
||||||
def get_instance_fn(value: str, config_file_path: Optional[str] = None) -> Any:
|
def get_instance_fn(value: str, config_file_path: Optional[str] = None) -> Any:
|
||||||
|
@ -142,6 +225,8 @@ def get_instance_fn(value: str, config_file_path: Optional[str] = None) -> Any:
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise e
|
raise e
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
### CALL HOOKS ###
|
### CALL HOOKS ###
|
||||||
class CallHooks:
|
class CallHooks:
|
||||||
"""
|
"""
|
||||||
|
|
|
@ -1,82 +1,82 @@
|
||||||
import openai, json, time, asyncio
|
# import openai, json, time, asyncio
|
||||||
client = openai.AsyncOpenAI(
|
# client = openai.AsyncOpenAI(
|
||||||
api_key="sk-1234",
|
# api_key="sk-1234",
|
||||||
base_url="http://0.0.0.0:8000"
|
# base_url="http://0.0.0.0:8000"
|
||||||
)
|
# )
|
||||||
|
|
||||||
super_fake_messages = [
|
# super_fake_messages = [
|
||||||
{
|
# {
|
||||||
"role": "user",
|
# "role": "user",
|
||||||
"content": f"What's the weather like in San Francisco, Tokyo, and Paris? {time.time()}"
|
# "content": f"What's the weather like in San Francisco, Tokyo, and Paris? {time.time()}"
|
||||||
},
|
# },
|
||||||
{
|
# {
|
||||||
"content": None,
|
# "content": None,
|
||||||
"role": "assistant",
|
# "role": "assistant",
|
||||||
"tool_calls": [
|
# "tool_calls": [
|
||||||
{
|
# {
|
||||||
"id": "1",
|
# "id": "1",
|
||||||
"function": {
|
# "function": {
|
||||||
"arguments": "{\"location\": \"San Francisco\", \"unit\": \"celsius\"}",
|
# "arguments": "{\"location\": \"San Francisco\", \"unit\": \"celsius\"}",
|
||||||
"name": "get_current_weather"
|
# "name": "get_current_weather"
|
||||||
},
|
# },
|
||||||
"type": "function"
|
# "type": "function"
|
||||||
},
|
# },
|
||||||
{
|
# {
|
||||||
"id": "2",
|
# "id": "2",
|
||||||
"function": {
|
# "function": {
|
||||||
"arguments": "{\"location\": \"Tokyo\", \"unit\": \"celsius\"}",
|
# "arguments": "{\"location\": \"Tokyo\", \"unit\": \"celsius\"}",
|
||||||
"name": "get_current_weather"
|
# "name": "get_current_weather"
|
||||||
},
|
# },
|
||||||
"type": "function"
|
# "type": "function"
|
||||||
},
|
# },
|
||||||
{
|
# {
|
||||||
"id": "3",
|
# "id": "3",
|
||||||
"function": {
|
# "function": {
|
||||||
"arguments": "{\"location\": \"Paris\", \"unit\": \"celsius\"}",
|
# "arguments": "{\"location\": \"Paris\", \"unit\": \"celsius\"}",
|
||||||
"name": "get_current_weather"
|
# "name": "get_current_weather"
|
||||||
},
|
# },
|
||||||
"type": "function"
|
# "type": "function"
|
||||||
}
|
# }
|
||||||
]
|
# ]
|
||||||
},
|
# },
|
||||||
{
|
# {
|
||||||
"tool_call_id": "1",
|
# "tool_call_id": "1",
|
||||||
"role": "tool",
|
# "role": "tool",
|
||||||
"name": "get_current_weather",
|
# "name": "get_current_weather",
|
||||||
"content": "{\"location\": \"San Francisco\", \"temperature\": \"90\", \"unit\": \"celsius\"}"
|
# "content": "{\"location\": \"San Francisco\", \"temperature\": \"90\", \"unit\": \"celsius\"}"
|
||||||
},
|
# },
|
||||||
{
|
# {
|
||||||
"tool_call_id": "2",
|
# "tool_call_id": "2",
|
||||||
"role": "tool",
|
# "role": "tool",
|
||||||
"name": "get_current_weather",
|
# "name": "get_current_weather",
|
||||||
"content": "{\"location\": \"Tokyo\", \"temperature\": \"30\", \"unit\": \"celsius\"}"
|
# "content": "{\"location\": \"Tokyo\", \"temperature\": \"30\", \"unit\": \"celsius\"}"
|
||||||
},
|
# },
|
||||||
{
|
# {
|
||||||
"tool_call_id": "3",
|
# "tool_call_id": "3",
|
||||||
"role": "tool",
|
# "role": "tool",
|
||||||
"name": "get_current_weather",
|
# "name": "get_current_weather",
|
||||||
"content": "{\"location\": \"Paris\", \"temperature\": \"50\", \"unit\": \"celsius\"}"
|
# "content": "{\"location\": \"Paris\", \"temperature\": \"50\", \"unit\": \"celsius\"}"
|
||||||
}
|
# }
|
||||||
]
|
# ]
|
||||||
|
|
||||||
async def chat_completions():
|
# async def chat_completions():
|
||||||
super_fake_response = await client.chat.completions.create(
|
# super_fake_response = await client.chat.completions.create(
|
||||||
model="gpt-3.5-turbo",
|
# model="gpt-3.5-turbo",
|
||||||
messages=super_fake_messages,
|
# messages=super_fake_messages,
|
||||||
seed=1337,
|
# seed=1337,
|
||||||
stream=False
|
# stream=False
|
||||||
) # get a new response from the model where it can see the function response
|
# ) # get a new response from the model where it can see the function response
|
||||||
await asyncio.sleep(1)
|
# await asyncio.sleep(1)
|
||||||
return super_fake_response
|
# return super_fake_response
|
||||||
|
|
||||||
async def loadtest_fn(n = 2000):
|
# async def loadtest_fn(n = 1):
|
||||||
global num_task_cancelled_errors, exception_counts, chat_completions
|
# global num_task_cancelled_errors, exception_counts, chat_completions
|
||||||
start = time.time()
|
# start = time.time()
|
||||||
tasks = [chat_completions() for _ in range(n)]
|
# tasks = [chat_completions() for _ in range(n)]
|
||||||
chat_completions = await asyncio.gather(*tasks)
|
# chat_completions = await asyncio.gather(*tasks)
|
||||||
successful_completions = [c for c in chat_completions if c is not None]
|
# successful_completions = [c for c in chat_completions if c is not None]
|
||||||
print(n, time.time() - start, len(successful_completions))
|
# print(n, time.time() - start, len(successful_completions))
|
||||||
|
|
||||||
# print(json.dumps(super_fake_response.model_dump(), indent=4))
|
# # print(json.dumps(super_fake_response.model_dump(), indent=4))
|
||||||
|
|
||||||
asyncio.run(loadtest_fn())
|
# asyncio.run(loadtest_fn())
|
|
@ -1127,6 +1127,7 @@ class Logging:
|
||||||
f"LiteLLM.LoggingError: [Non-Blocking] Exception occurred while failure logging {traceback.format_exc()}"
|
f"LiteLLM.LoggingError: [Non-Blocking] Exception occurred while failure logging {traceback.format_exc()}"
|
||||||
)
|
)
|
||||||
pass
|
pass
|
||||||
|
|
||||||
async def async_failure_handler(self, exception, traceback_exception, start_time=None, end_time=None):
|
async def async_failure_handler(self, exception, traceback_exception, start_time=None, end_time=None):
|
||||||
"""
|
"""
|
||||||
Implementing async callbacks, to handle asyncio event loop issues when custom integrations need to use async functions.
|
Implementing async callbacks, to handle asyncio event loop issues when custom integrations need to use async functions.
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue