fix(proxy/utils.py): add call type and duration to proxy_logging failure calls

this is for tracking failed db requests on prometheus
This commit is contained in:
Krrish Dholakia 2024-04-18 16:24:36 -07:00
parent 48d3b563d8
commit d61250109e

View file

@ -1,5 +1,5 @@
from typing import Optional, List, Any, Literal, Union from typing import Optional, List, Any, Literal, Union
import os, subprocess, hashlib, importlib, asyncio, copy, json, aiohttp, httpx import os, subprocess, hashlib, importlib, asyncio, copy, json, aiohttp, httpx, time
import litellm, backoff import litellm, backoff
from litellm.proxy._types import ( from litellm.proxy._types import (
UserAPIKeyAuth, UserAPIKeyAuth,
@ -18,7 +18,7 @@ from litellm.llms.custom_httpx.httpx_handler import HTTPHandler
from litellm.proxy.hooks.parallel_request_limiter import ( from litellm.proxy.hooks.parallel_request_limiter import (
_PROXY_MaxParallelRequestsHandler, _PROXY_MaxParallelRequestsHandler,
) )
from litellm._service_logger import ServiceLogging from litellm._service_logger import ServiceLogging, ServiceTypes
from litellm import ModelResponse, EmbeddingResponse, ImageResponse from litellm import ModelResponse, EmbeddingResponse, ImageResponse
from litellm.proxy.hooks.max_budget_limiter import _PROXY_MaxBudgetLimiter from litellm.proxy.hooks.max_budget_limiter import _PROXY_MaxBudgetLimiter
from litellm.proxy.hooks.tpm_rpm_limiter import _PROXY_MaxTPMRPMLimiter from litellm.proxy.hooks.tpm_rpm_limiter import _PROXY_MaxTPMRPMLimiter
@ -458,7 +458,9 @@ class ProxyLogging:
else: else:
raise Exception("Missing SENTRY_DSN from environment") raise Exception("Missing SENTRY_DSN from environment")
async def failure_handler(self, original_exception, traceback_str=""): async def failure_handler(
self, original_exception, duration: float, call_type: str, traceback_str=""
):
""" """
Log failed db read/writes Log failed db read/writes
@ -483,6 +485,11 @@ class ProxyLogging:
) )
) )
if hasattr(self, "service_logging_obj"):
self.service_logging_obj.async_service_failure_hook(
service=ServiceTypes.DB,
)
if litellm.utils.capture_exception: if litellm.utils.capture_exception:
litellm.utils.capture_exception(error=original_exception) litellm.utils.capture_exception(error=original_exception)
@ -803,6 +810,7 @@ class PrismaClient:
verbose_proxy_logger.debug( verbose_proxy_logger.debug(
f"PrismaClient: get_generic_data: {key}, table_name: {table_name}" f"PrismaClient: get_generic_data: {key}, table_name: {table_name}"
) )
start_time = time.time()
try: try:
if table_name == "users": if table_name == "users":
response = await self.db.litellm_usertable.find_first( response = await self.db.litellm_usertable.find_first(
@ -827,11 +835,17 @@ class PrismaClient:
error_msg = f"LiteLLM Prisma Client Exception get_generic_data: {str(e)}" error_msg = f"LiteLLM Prisma Client Exception get_generic_data: {str(e)}"
print_verbose(error_msg) print_verbose(error_msg)
error_traceback = error_msg + "\n" + traceback.format_exc() error_traceback = error_msg + "\n" + traceback.format_exc()
end_time = time.time()
_duration = end_time - start_time
asyncio.create_task( asyncio.create_task(
self.proxy_logging_obj.failure_handler( self.proxy_logging_obj.failure_handler(
original_exception=e, traceback_str=error_traceback original_exception=e,
duration=_duration,
traceback_str=error_traceback,
call_type="get_generic_data",
) )
) )
raise e raise e
@backoff.on_exception( @backoff.on_exception(
@ -869,6 +883,7 @@ class PrismaClient:
] = None, # pagination, number of rows to getch when find_all==True ] = None, # pagination, number of rows to getch when find_all==True
): ):
args_passed_in = locals() args_passed_in = locals()
start_time = time.time()
verbose_proxy_logger.debug( verbose_proxy_logger.debug(
f"PrismaClient: get_data - args_passed_in: {args_passed_in}" f"PrismaClient: get_data - args_passed_in: {args_passed_in}"
) )
@ -1122,9 +1137,15 @@ class PrismaClient:
print_verbose(error_msg) print_verbose(error_msg)
error_traceback = error_msg + "\n" + traceback.format_exc() error_traceback = error_msg + "\n" + traceback.format_exc()
verbose_proxy_logger.debug(error_traceback) verbose_proxy_logger.debug(error_traceback)
end_time = time.time()
_duration = end_time - start_time
asyncio.create_task( asyncio.create_task(
self.proxy_logging_obj.failure_handler( self.proxy_logging_obj.failure_handler(
original_exception=e, traceback_str=error_traceback original_exception=e,
duration=_duration,
call_type="get_data",
traceback_str=error_traceback,
) )
) )
raise e raise e
@ -1147,6 +1168,7 @@ class PrismaClient:
""" """
Add a key to the database. If it already exists, do nothing. Add a key to the database. If it already exists, do nothing.
""" """
start_time = time.time()
try: try:
verbose_proxy_logger.debug("PrismaClient: insert_data: %s", data) verbose_proxy_logger.debug("PrismaClient: insert_data: %s", data)
if table_name == "key": if table_name == "key":
@ -1264,9 +1286,14 @@ class PrismaClient:
error_msg = f"LiteLLM Prisma Client Exception in insert_data: {str(e)}" error_msg = f"LiteLLM Prisma Client Exception in insert_data: {str(e)}"
print_verbose(error_msg) print_verbose(error_msg)
error_traceback = error_msg + "\n" + traceback.format_exc() error_traceback = error_msg + "\n" + traceback.format_exc()
end_time = time.time()
_duration = end_time - start_time
asyncio.create_task( asyncio.create_task(
self.proxy_logging_obj.failure_handler( self.proxy_logging_obj.failure_handler(
original_exception=e, traceback_str=error_traceback original_exception=e,
duration=_duration,
call_type="insert_data",
traceback_str=error_traceback,
) )
) )
raise e raise e
@ -1297,6 +1324,7 @@ class PrismaClient:
verbose_proxy_logger.debug( verbose_proxy_logger.debug(
f"PrismaClient: update_data, table_name: {table_name}" f"PrismaClient: update_data, table_name: {table_name}"
) )
start_time = time.time()
try: try:
db_data = self.jsonify_object(data=data) db_data = self.jsonify_object(data=data)
if update_key_values is not None: if update_key_values is not None:
@ -1458,9 +1486,14 @@ class PrismaClient:
error_msg = f"LiteLLM Prisma Client Exception - update_data: {str(e)}" error_msg = f"LiteLLM Prisma Client Exception - update_data: {str(e)}"
print_verbose(error_msg) print_verbose(error_msg)
error_traceback = error_msg + "\n" + traceback.format_exc() error_traceback = error_msg + "\n" + traceback.format_exc()
end_time = time.time()
_duration = end_time - start_time
asyncio.create_task( asyncio.create_task(
self.proxy_logging_obj.failure_handler( self.proxy_logging_obj.failure_handler(
original_exception=e, traceback_str=error_traceback original_exception=e,
duration=_duration,
call_type="update_data",
traceback_str=error_traceback,
) )
) )
raise e raise e
@ -1485,6 +1518,7 @@ class PrismaClient:
Ensure user owns that key, unless admin. Ensure user owns that key, unless admin.
""" """
start_time = time.time()
try: try:
if tokens is not None and isinstance(tokens, List): if tokens is not None and isinstance(tokens, List):
hashed_tokens = [] hashed_tokens = []
@ -1532,9 +1566,14 @@ class PrismaClient:
error_msg = f"LiteLLM Prisma Client Exception - delete_data: {str(e)}" error_msg = f"LiteLLM Prisma Client Exception - delete_data: {str(e)}"
print_verbose(error_msg) print_verbose(error_msg)
error_traceback = error_msg + "\n" + traceback.format_exc() error_traceback = error_msg + "\n" + traceback.format_exc()
end_time = time.time()
_duration = end_time - start_time
asyncio.create_task( asyncio.create_task(
self.proxy_logging_obj.failure_handler( self.proxy_logging_obj.failure_handler(
original_exception=e, traceback_str=error_traceback original_exception=e,
duration=_duration,
call_type="delete_data",
traceback_str=error_traceback,
) )
) )
raise e raise e
@ -1548,6 +1587,7 @@ class PrismaClient:
on_backoff=on_backoff, # specifying the function to call on backoff on_backoff=on_backoff, # specifying the function to call on backoff
) )
async def connect(self): async def connect(self):
start_time = time.time()
try: try:
verbose_proxy_logger.debug( verbose_proxy_logger.debug(
"PrismaClient: connect() called Attempting to Connect to DB" "PrismaClient: connect() called Attempting to Connect to DB"
@ -1563,9 +1603,14 @@ class PrismaClient:
error_msg = f"LiteLLM Prisma Client Exception connect(): {str(e)}" error_msg = f"LiteLLM Prisma Client Exception connect(): {str(e)}"
print_verbose(error_msg) print_verbose(error_msg)
error_traceback = error_msg + "\n" + traceback.format_exc() error_traceback = error_msg + "\n" + traceback.format_exc()
end_time = time.time()
_duration = end_time - start_time
asyncio.create_task( asyncio.create_task(
self.proxy_logging_obj.failure_handler( self.proxy_logging_obj.failure_handler(
original_exception=e, traceback_str=error_traceback original_exception=e,
duration=_duration,
call_type="connect",
traceback_str=error_traceback,
) )
) )
raise e raise e
@ -1579,6 +1624,7 @@ class PrismaClient:
on_backoff=on_backoff, # specifying the function to call on backoff on_backoff=on_backoff, # specifying the function to call on backoff
) )
async def disconnect(self): async def disconnect(self):
start_time = time.time()
try: try:
await self.db.disconnect() await self.db.disconnect()
except Exception as e: except Exception as e:
@ -1587,9 +1633,14 @@ class PrismaClient:
error_msg = f"LiteLLM Prisma Client Exception disconnect(): {str(e)}" error_msg = f"LiteLLM Prisma Client Exception disconnect(): {str(e)}"
print_verbose(error_msg) print_verbose(error_msg)
error_traceback = error_msg + "\n" + traceback.format_exc() error_traceback = error_msg + "\n" + traceback.format_exc()
end_time = time.time()
_duration = end_time - start_time
asyncio.create_task( asyncio.create_task(
self.proxy_logging_obj.failure_handler( self.proxy_logging_obj.failure_handler(
original_exception=e, traceback_str=error_traceback original_exception=e,
duration=_duration,
call_type="disconnect",
traceback_str=error_traceback,
) )
) )
raise e raise e
@ -1598,6 +1649,8 @@ class PrismaClient:
""" """
Health check endpoint for the prisma client Health check endpoint for the prisma client
""" """
start_time = time.time()
try:
sql_query = """ sql_query = """
SELECT 1 SELECT 1
FROM "LiteLLM_VerificationToken" FROM "LiteLLM_VerificationToken"
@ -1608,6 +1661,23 @@ class PrismaClient:
# The asterisk before `user_id_list` unpacks the list into separate arguments # The asterisk before `user_id_list` unpacks the list into separate arguments
response = await self.db.query_raw(sql_query) response = await self.db.query_raw(sql_query)
return response return response
except Exception as e:
import traceback
error_msg = f"LiteLLM Prisma Client Exception disconnect(): {str(e)}"
print_verbose(error_msg)
error_traceback = error_msg + "\n" + traceback.format_exc()
end_time = time.time()
_duration = end_time - start_time
asyncio.create_task(
self.proxy_logging_obj.failure_handler(
original_exception=e,
duration=_duration,
call_type="health_check",
traceback_str=error_traceback,
)
)
raise e
class DBClient: class DBClient:
@ -1983,6 +2053,7 @@ async def update_spend(
### UPDATE USER TABLE ### ### UPDATE USER TABLE ###
if len(prisma_client.user_list_transactons.keys()) > 0: if len(prisma_client.user_list_transactons.keys()) > 0:
for i in range(n_retry_times + 1): for i in range(n_retry_times + 1):
start_time = time.time()
try: try:
async with prisma_client.db.tx( async with prisma_client.db.tx(
timeout=timedelta(seconds=60) timeout=timedelta(seconds=60)
@ -2013,9 +2084,14 @@ async def update_spend(
) )
print_verbose(error_msg) print_verbose(error_msg)
error_traceback = error_msg + "\n" + traceback.format_exc() error_traceback = error_msg + "\n" + traceback.format_exc()
end_time = time.time()
_duration = end_time - start_time
asyncio.create_task( asyncio.create_task(
proxy_logging_obj.failure_handler( proxy_logging_obj.failure_handler(
original_exception=e, traceback_str=error_traceback original_exception=e,
duration=_duration,
call_type="update_spend",
traceback_str=error_traceback,
) )
) )
raise e raise e
@ -2023,6 +2099,7 @@ async def update_spend(
### UPDATE END-USER TABLE ### ### UPDATE END-USER TABLE ###
if len(prisma_client.end_user_list_transactons.keys()) > 0: if len(prisma_client.end_user_list_transactons.keys()) > 0:
for i in range(n_retry_times + 1): for i in range(n_retry_times + 1):
start_time = time.time()
try: try:
async with prisma_client.db.tx( async with prisma_client.db.tx(
timeout=timedelta(seconds=60) timeout=timedelta(seconds=60)
@ -2059,9 +2136,14 @@ async def update_spend(
) )
print_verbose(error_msg) print_verbose(error_msg)
error_traceback = error_msg + "\n" + traceback.format_exc() error_traceback = error_msg + "\n" + traceback.format_exc()
end_time = time.time()
_duration = end_time - start_time
asyncio.create_task( asyncio.create_task(
proxy_logging_obj.failure_handler( proxy_logging_obj.failure_handler(
original_exception=e, traceback_str=error_traceback original_exception=e,
duration=_duration,
call_type="update_spend",
traceback_str=error_traceback,
) )
) )
raise e raise e
@ -2069,6 +2151,7 @@ async def update_spend(
### UPDATE KEY TABLE ### ### UPDATE KEY TABLE ###
if len(prisma_client.key_list_transactons.keys()) > 0: if len(prisma_client.key_list_transactons.keys()) > 0:
for i in range(n_retry_times + 1): for i in range(n_retry_times + 1):
start_time = time.time()
try: try:
async with prisma_client.db.tx( async with prisma_client.db.tx(
timeout=timedelta(seconds=60) timeout=timedelta(seconds=60)
@ -2099,9 +2182,14 @@ async def update_spend(
) )
print_verbose(error_msg) print_verbose(error_msg)
error_traceback = error_msg + "\n" + traceback.format_exc() error_traceback = error_msg + "\n" + traceback.format_exc()
end_time = time.time()
_duration = end_time - start_time
asyncio.create_task( asyncio.create_task(
proxy_logging_obj.failure_handler( proxy_logging_obj.failure_handler(
original_exception=e, traceback_str=error_traceback original_exception=e,
duration=_duration,
call_type="update_spend",
traceback_str=error_traceback,
) )
) )
raise e raise e
@ -2114,6 +2202,7 @@ async def update_spend(
) )
if len(prisma_client.team_list_transactons.keys()) > 0: if len(prisma_client.team_list_transactons.keys()) > 0:
for i in range(n_retry_times + 1): for i in range(n_retry_times + 1):
start_time = time.time()
try: try:
async with prisma_client.db.tx( async with prisma_client.db.tx(
timeout=timedelta(seconds=60) timeout=timedelta(seconds=60)
@ -2149,9 +2238,14 @@ async def update_spend(
) )
print_verbose(error_msg) print_verbose(error_msg)
error_traceback = error_msg + "\n" + traceback.format_exc() error_traceback = error_msg + "\n" + traceback.format_exc()
end_time = time.time()
_duration = end_time - start_time
asyncio.create_task( asyncio.create_task(
proxy_logging_obj.failure_handler( proxy_logging_obj.failure_handler(
original_exception=e, traceback_str=error_traceback original_exception=e,
duration=_duration,
call_type="update_spend",
traceback_str=error_traceback,
) )
) )
raise e raise e
@ -2159,6 +2253,7 @@ async def update_spend(
### UPDATE ORG TABLE ### ### UPDATE ORG TABLE ###
if len(prisma_client.org_list_transactons.keys()) > 0: if len(prisma_client.org_list_transactons.keys()) > 0:
for i in range(n_retry_times + 1): for i in range(n_retry_times + 1):
start_time = time.time()
try: try:
async with prisma_client.db.tx( async with prisma_client.db.tx(
timeout=timedelta(seconds=60) timeout=timedelta(seconds=60)
@ -2189,9 +2284,14 @@ async def update_spend(
) )
print_verbose(error_msg) print_verbose(error_msg)
error_traceback = error_msg + "\n" + traceback.format_exc() error_traceback = error_msg + "\n" + traceback.format_exc()
end_time = time.time()
_duration = end_time - start_time
asyncio.create_task( asyncio.create_task(
proxy_logging_obj.failure_handler( proxy_logging_obj.failure_handler(
original_exception=e, traceback_str=error_traceback original_exception=e,
duration=_duration,
call_type="update_spend",
traceback_str=error_traceback,
) )
) )
raise e raise e
@ -2206,6 +2306,7 @@ async def update_spend(
if len(prisma_client.spend_log_transactions) > 0: if len(prisma_client.spend_log_transactions) > 0:
for _ in range(n_retry_times + 1): for _ in range(n_retry_times + 1):
start_time = time.time()
try: try:
base_url = os.getenv("SPEND_LOGS_URL", None) base_url = os.getenv("SPEND_LOGS_URL", None)
## WRITE TO SEPARATE SERVER ## ## WRITE TO SEPARATE SERVER ##
@ -2271,9 +2372,14 @@ async def update_spend(
) )
print_verbose(error_msg) print_verbose(error_msg)
error_traceback = error_msg + "\n" + traceback.format_exc() error_traceback = error_msg + "\n" + traceback.format_exc()
end_time = time.time()
_duration = end_time - start_time
asyncio.create_task( asyncio.create_task(
proxy_logging_obj.failure_handler( proxy_logging_obj.failure_handler(
original_exception=e, traceback_str=error_traceback original_exception=e,
duration=_duration,
call_type="update_spend",
traceback_str=error_traceback,
) )
) )
raise e raise e