From 7aec95ed7c36d03cfd2334e9516df27857bf0dc6 Mon Sep 17 00:00:00 2001 From: Krrish Dholakia Date: Fri, 8 Dec 2023 11:40:19 -0800 Subject: [PATCH] feat(proxy_server.py): add sentry logging for db read/writes --- litellm/proxy/proxy_server.py | 16 ++- litellm/proxy/utils.py | 155 ++++++++++++++++++----- litellm/tests/test_proxy_server_spend.py | 154 +++++++++++----------- litellm/utils.py | 1 + 4 files changed, 208 insertions(+), 118 deletions(-) diff --git a/litellm/proxy/proxy_server.py b/litellm/proxy/proxy_server.py index c6046d0595..66d23331ae 100644 --- a/litellm/proxy/proxy_server.py +++ b/litellm/proxy/proxy_server.py @@ -94,7 +94,8 @@ import litellm from litellm.proxy.utils import ( PrismaClient, get_instance_fn, - CallHooks + CallHooks, + ProxyLogging ) import pydantic from litellm.proxy._types import * @@ -198,6 +199,7 @@ use_background_health_checks = None health_check_interval = None health_check_results = {} call_hooks = CallHooks() +proxy_logging_obj: Optional[ProxyLogging] = None ### REDIS QUEUE ### async_result = None celery_app_conn = None @@ -300,10 +302,10 @@ async def user_api_key_auth(request: Request, api_key: str = fastapi.Security(ap ) def prisma_setup(database_url: Optional[str]): - global prisma_client - if database_url: + global prisma_client, proxy_logging_obj + if database_url is not None and proxy_logging_obj is not None: try: - prisma_client = PrismaClient(database_url=database_url) + prisma_client = PrismaClient(database_url=database_url, proxy_logging_obj=proxy_logging_obj) except Exception as e: print("Error when initializing prisma, Ensure you run pip install prisma", e) @@ -839,11 +841,13 @@ async def rate_limit_per_token(request: Request, call_next): @router.on_event("startup") async def startup_event(): - global prisma_client, master_key, use_background_health_checks + global prisma_client, master_key, use_background_health_checks, proxy_logging_obj import json + ### INITIALIZE GLOBAL LOGGING OBJECT ### + proxy_logging_obj = ProxyLogging() + ### LOAD CONFIG ### worker_config = litellm.get_secret("WORKER_CONFIG") - print(f"worker_config: {worker_config}") print_verbose(f"worker_config: {worker_config}") # check if it's a valid file path if os.path.isfile(worker_config): diff --git a/litellm/proxy/utils.py b/litellm/proxy/utils.py index 5cd2292148..0785215f68 100644 --- a/litellm/proxy/utils.py +++ b/litellm/proxy/utils.py @@ -1,9 +1,72 @@ from typing import Optional, List, Any, Literal -import os, subprocess, hashlib, importlib +import os, subprocess, hashlib, importlib, asyncio +import litellm + +### LOGGING ### +class ProxyLogging: + """ + Logging for proxy. + + Implemented mainly to log successful/failed db read/writes. + + Currently just logs this to a provided sentry integration. + """ + + def __init__(self,): + ## INITIALIZE LITELLM CALLBACKS ## + self._init_litellm_callbacks() + pass + + + def _init_litellm_callbacks(self): + if len(litellm.callbacks) > 0: + for callback in litellm.callbacks: + if callback not in litellm.input_callback: + litellm.input_callback.append(callback) + if callback not in litellm.success_callback: + litellm.success_callback.append(callback) + if callback not in litellm.failure_callback: + litellm.failure_callback.append(callback) + if callback not in litellm._async_success_callback: + litellm._async_success_callback.append(callback) + if callback not in litellm._async_failure_callback: + litellm._async_failure_callback.append(callback) + if ( + len(litellm.input_callback) > 0 + or len(litellm.success_callback) > 0 + or len(litellm.failure_callback) > 0 + ) and len(callback_list) == 0: + callback_list = list( + set( + litellm.input_callback + + litellm.success_callback + + litellm.failure_callback + ) + ) + litellm.utils.set_callbacks( + callback_list=callback_list + ) + + async def success_handler(self, *args, **kwargs): + """ + Log successful db read/writes + """ + pass + + async def failure_handler(self, original_exception): + """ + Log failed db read/writes + + Currently only logs exceptions to sentry + """ + print(f"reaches failure handler logging - {original_exception}; sentry: {litellm.utils.capture_exception}") + if litellm.utils.capture_exception: + litellm.utils.capture_exception(error=original_exception) + ### DB CONNECTOR ### class PrismaClient: - def __init__(self, database_url: str): + def __init__(self, database_url: str, proxy_logging_obj: ProxyLogging): print("LiteLLM: DATABASE_URL Set in config, trying to 'pip install prisma'") os.environ["DATABASE_URL"] = database_url # Save the current working directory @@ -22,6 +85,9 @@ class PrismaClient: from prisma import Client # type: ignore self.db = Client() #Client to connect to Prisma db + ## init logging object + self.proxy_logging_obj = proxy_logging_obj + def hash_token(self, token: str): # Hash the string using SHA-256 hashed_token = hashlib.sha256(token.encode()).hexdigest() @@ -29,42 +95,48 @@ class PrismaClient: return hashed_token async def get_data(self, token: str, expires: Optional[Any]=None): - hashed_token = self.hash_token(token=token) - if expires: - response = await self.db.litellm_verificationtoken.find_first( + try: + hashed_token = self.hash_token(token=token) + if expires: + response = await self.db.litellm_verificationtoken.find_first( + where={ + "token": hashed_token, + "expires": {"gte": expires} # Check if the token is not expired + } + ) + else: + response = await self.db.litellm_verificationtoken.find_unique( where={ - "token": hashed_token, - "expires": {"gte": expires} # Check if the token is not expired + "token": hashed_token } ) - else: - response = await self.db.litellm_verificationtoken.find_unique( - where={ - "token": hashed_token - } - ) - return response + return response + except Exception as e: + asyncio.create_task(self.proxy_logging_obj.failure_handler(original_exception=e)) async def insert_data(self, data: dict): """ Add a key to the database. If it already exists, do nothing. """ - token = data["token"] - hashed_token = self.hash_token(token=token) - data["token"] = hashed_token - print(f"passed in data: {data}; hashed_token: {hashed_token}") + try: + token = data["token"] + hashed_token = self.hash_token(token=token) + data["token"] = hashed_token + print(f"passed in data: {data}; hashed_token: {hashed_token}") - new_verification_token = await self.db.litellm_verificationtoken.upsert( # type: ignore - where={ - 'token': hashed_token, - }, - data={ - "create": {**data}, #type: ignore - "update": {} # don't do anything if it already exists - } - ) + new_verification_token = await self.db.litellm_verificationtoken.upsert( # type: ignore + where={ + 'token': hashed_token, + }, + data={ + "create": {**data}, #type: ignore + "update": {} # don't do anything if it already exists + } + ) - return new_verification_token + return new_verification_token + except Exception as e: + asyncio.create_task(self.proxy_logging_obj.failure_handler(original_exception=e)) async def update_data(self, token: str, data: dict): """ @@ -82,6 +154,7 @@ class PrismaClient: print("\033[91m" + f"DB write succeeded" + "\033[0m") return {"token": token, "data": data} except Exception as e: + asyncio.create_task(self.proxy_logging_obj.failure_handler(original_exception=e)) print() print() print() @@ -90,21 +163,31 @@ class PrismaClient: print() print() + async def delete_data(self, tokens: List): """ Allow user to delete a key(s) """ - hashed_tokens = [self.hash_token(token=token) for token in tokens] - await self.db.litellm_verificationtoken.delete_many( - where={"token": {"in": hashed_tokens}} - ) - return {"deleted_keys": tokens} + try: + hashed_tokens = [self.hash_token(token=token) for token in tokens] + await self.db.litellm_verificationtoken.delete_many( + where={"token": {"in": hashed_tokens}} + ) + return {"deleted_keys": tokens} + except Exception as e: + asyncio.create_task(self.proxy_logging_obj.failure_handler(original_exception=e)) async def connect(self): - await self.db.connect() + try: + await self.db.connect() + except Exception as e: + asyncio.create_task(self.proxy_logging_obj.failure_handler(original_exception=e)) async def disconnect(self): - await self.db.disconnect() + try: + await self.db.disconnect() + except Exception as e: + asyncio.create_task(self.proxy_logging_obj.failure_handler(original_exception=e)) ### CUSTOM FILE ### def get_instance_fn(value: str, config_file_path: Optional[str] = None) -> Any: @@ -142,6 +225,8 @@ def get_instance_fn(value: str, config_file_path: Optional[str] = None) -> Any: except Exception as e: raise e + + ### CALL HOOKS ### class CallHooks: """ diff --git a/litellm/tests/test_proxy_server_spend.py b/litellm/tests/test_proxy_server_spend.py index 0ad897c95e..f64ad89877 100644 --- a/litellm/tests/test_proxy_server_spend.py +++ b/litellm/tests/test_proxy_server_spend.py @@ -1,82 +1,82 @@ -import openai, json, time, asyncio -client = openai.AsyncOpenAI( - api_key="sk-1234", - base_url="http://0.0.0.0:8000" -) +# import openai, json, time, asyncio +# client = openai.AsyncOpenAI( +# api_key="sk-1234", +# base_url="http://0.0.0.0:8000" +# ) -super_fake_messages = [ - { - "role": "user", - "content": f"What's the weather like in San Francisco, Tokyo, and Paris? {time.time()}" - }, - { - "content": None, - "role": "assistant", - "tool_calls": [ - { - "id": "1", - "function": { - "arguments": "{\"location\": \"San Francisco\", \"unit\": \"celsius\"}", - "name": "get_current_weather" - }, - "type": "function" - }, - { - "id": "2", - "function": { - "arguments": "{\"location\": \"Tokyo\", \"unit\": \"celsius\"}", - "name": "get_current_weather" - }, - "type": "function" - }, - { - "id": "3", - "function": { - "arguments": "{\"location\": \"Paris\", \"unit\": \"celsius\"}", - "name": "get_current_weather" - }, - "type": "function" - } - ] - }, - { - "tool_call_id": "1", - "role": "tool", - "name": "get_current_weather", - "content": "{\"location\": \"San Francisco\", \"temperature\": \"90\", \"unit\": \"celsius\"}" - }, - { - "tool_call_id": "2", - "role": "tool", - "name": "get_current_weather", - "content": "{\"location\": \"Tokyo\", \"temperature\": \"30\", \"unit\": \"celsius\"}" - }, - { - "tool_call_id": "3", - "role": "tool", - "name": "get_current_weather", - "content": "{\"location\": \"Paris\", \"temperature\": \"50\", \"unit\": \"celsius\"}" - } -] +# super_fake_messages = [ +# { +# "role": "user", +# "content": f"What's the weather like in San Francisco, Tokyo, and Paris? {time.time()}" +# }, +# { +# "content": None, +# "role": "assistant", +# "tool_calls": [ +# { +# "id": "1", +# "function": { +# "arguments": "{\"location\": \"San Francisco\", \"unit\": \"celsius\"}", +# "name": "get_current_weather" +# }, +# "type": "function" +# }, +# { +# "id": "2", +# "function": { +# "arguments": "{\"location\": \"Tokyo\", \"unit\": \"celsius\"}", +# "name": "get_current_weather" +# }, +# "type": "function" +# }, +# { +# "id": "3", +# "function": { +# "arguments": "{\"location\": \"Paris\", \"unit\": \"celsius\"}", +# "name": "get_current_weather" +# }, +# "type": "function" +# } +# ] +# }, +# { +# "tool_call_id": "1", +# "role": "tool", +# "name": "get_current_weather", +# "content": "{\"location\": \"San Francisco\", \"temperature\": \"90\", \"unit\": \"celsius\"}" +# }, +# { +# "tool_call_id": "2", +# "role": "tool", +# "name": "get_current_weather", +# "content": "{\"location\": \"Tokyo\", \"temperature\": \"30\", \"unit\": \"celsius\"}" +# }, +# { +# "tool_call_id": "3", +# "role": "tool", +# "name": "get_current_weather", +# "content": "{\"location\": \"Paris\", \"temperature\": \"50\", \"unit\": \"celsius\"}" +# } +# ] -async def chat_completions(): - super_fake_response = await client.chat.completions.create( - model="gpt-3.5-turbo", - messages=super_fake_messages, - seed=1337, - stream=False - ) # get a new response from the model where it can see the function response - await asyncio.sleep(1) - return super_fake_response +# async def chat_completions(): +# super_fake_response = await client.chat.completions.create( +# model="gpt-3.5-turbo", +# messages=super_fake_messages, +# seed=1337, +# stream=False +# ) # get a new response from the model where it can see the function response +# await asyncio.sleep(1) +# return super_fake_response -async def loadtest_fn(n = 2000): - global num_task_cancelled_errors, exception_counts, chat_completions - start = time.time() - tasks = [chat_completions() for _ in range(n)] - chat_completions = await asyncio.gather(*tasks) - successful_completions = [c for c in chat_completions if c is not None] - print(n, time.time() - start, len(successful_completions)) +# async def loadtest_fn(n = 1): +# global num_task_cancelled_errors, exception_counts, chat_completions +# start = time.time() +# tasks = [chat_completions() for _ in range(n)] +# chat_completions = await asyncio.gather(*tasks) +# successful_completions = [c for c in chat_completions if c is not None] +# print(n, time.time() - start, len(successful_completions)) -# print(json.dumps(super_fake_response.model_dump(), indent=4)) +# # print(json.dumps(super_fake_response.model_dump(), indent=4)) -asyncio.run(loadtest_fn()) \ No newline at end of file +# asyncio.run(loadtest_fn()) \ No newline at end of file diff --git a/litellm/utils.py b/litellm/utils.py index cc11e0fbc3..8377051952 100644 --- a/litellm/utils.py +++ b/litellm/utils.py @@ -1127,6 +1127,7 @@ class Logging: f"LiteLLM.LoggingError: [Non-Blocking] Exception occurred while failure logging {traceback.format_exc()}" ) pass + async def async_failure_handler(self, exception, traceback_exception, start_time=None, end_time=None): """ Implementing async callbacks, to handle asyncio event loop issues when custom integrations need to use async functions.