fix(proxy/utils.py): add call type and duration to proxy_logging failure calls

this is for tracking failed db requests on prometheus
This commit is contained in:
Krrish Dholakia 2024-04-18 16:24:36 -07:00
parent 48d3b563d8
commit d61250109e

View file

@ -1,5 +1,5 @@
from typing import Optional, List, Any, Literal, Union
import os, subprocess, hashlib, importlib, asyncio, copy, json, aiohttp, httpx
import os, subprocess, hashlib, importlib, asyncio, copy, json, aiohttp, httpx, time
import litellm, backoff
from litellm.proxy._types import (
UserAPIKeyAuth,
@ -18,7 +18,7 @@ from litellm.llms.custom_httpx.httpx_handler import HTTPHandler
from litellm.proxy.hooks.parallel_request_limiter import (
_PROXY_MaxParallelRequestsHandler,
)
from litellm._service_logger import ServiceLogging
from litellm._service_logger import ServiceLogging, ServiceTypes
from litellm import ModelResponse, EmbeddingResponse, ImageResponse
from litellm.proxy.hooks.max_budget_limiter import _PROXY_MaxBudgetLimiter
from litellm.proxy.hooks.tpm_rpm_limiter import _PROXY_MaxTPMRPMLimiter
@ -458,7 +458,9 @@ class ProxyLogging:
else:
raise Exception("Missing SENTRY_DSN from environment")
async def failure_handler(self, original_exception, traceback_str=""):
async def failure_handler(
self, original_exception, duration: float, call_type: str, traceback_str=""
):
"""
Log failed db read/writes
@ -483,6 +485,11 @@ class ProxyLogging:
)
)
if hasattr(self, "service_logging_obj"):
self.service_logging_obj.async_service_failure_hook(
service=ServiceTypes.DB,
)
if litellm.utils.capture_exception:
litellm.utils.capture_exception(error=original_exception)
@ -803,6 +810,7 @@ class PrismaClient:
verbose_proxy_logger.debug(
f"PrismaClient: get_generic_data: {key}, table_name: {table_name}"
)
start_time = time.time()
try:
if table_name == "users":
response = await self.db.litellm_usertable.find_first(
@ -827,11 +835,17 @@ class PrismaClient:
error_msg = f"LiteLLM Prisma Client Exception get_generic_data: {str(e)}"
print_verbose(error_msg)
error_traceback = error_msg + "\n" + traceback.format_exc()
end_time = time.time()
_duration = end_time - start_time
asyncio.create_task(
self.proxy_logging_obj.failure_handler(
original_exception=e, traceback_str=error_traceback
original_exception=e,
duration=_duration,
traceback_str=error_traceback,
call_type="get_generic_data",
)
)
raise e
@backoff.on_exception(
@ -869,6 +883,7 @@ class PrismaClient:
] = None, # pagination, number of rows to getch when find_all==True
):
args_passed_in = locals()
start_time = time.time()
verbose_proxy_logger.debug(
f"PrismaClient: get_data - args_passed_in: {args_passed_in}"
)
@ -1122,9 +1137,15 @@ class PrismaClient:
print_verbose(error_msg)
error_traceback = error_msg + "\n" + traceback.format_exc()
verbose_proxy_logger.debug(error_traceback)
end_time = time.time()
_duration = end_time - start_time
asyncio.create_task(
self.proxy_logging_obj.failure_handler(
original_exception=e, traceback_str=error_traceback
original_exception=e,
duration=_duration,
call_type="get_data",
traceback_str=error_traceback,
)
)
raise e
@ -1147,6 +1168,7 @@ class PrismaClient:
"""
Add a key to the database. If it already exists, do nothing.
"""
start_time = time.time()
try:
verbose_proxy_logger.debug("PrismaClient: insert_data: %s", data)
if table_name == "key":
@ -1264,9 +1286,14 @@ class PrismaClient:
error_msg = f"LiteLLM Prisma Client Exception in insert_data: {str(e)}"
print_verbose(error_msg)
error_traceback = error_msg + "\n" + traceback.format_exc()
end_time = time.time()
_duration = end_time - start_time
asyncio.create_task(
self.proxy_logging_obj.failure_handler(
original_exception=e, traceback_str=error_traceback
original_exception=e,
duration=_duration,
call_type="insert_data",
traceback_str=error_traceback,
)
)
raise e
@ -1297,6 +1324,7 @@ class PrismaClient:
verbose_proxy_logger.debug(
f"PrismaClient: update_data, table_name: {table_name}"
)
start_time = time.time()
try:
db_data = self.jsonify_object(data=data)
if update_key_values is not None:
@ -1458,9 +1486,14 @@ class PrismaClient:
error_msg = f"LiteLLM Prisma Client Exception - update_data: {str(e)}"
print_verbose(error_msg)
error_traceback = error_msg + "\n" + traceback.format_exc()
end_time = time.time()
_duration = end_time - start_time
asyncio.create_task(
self.proxy_logging_obj.failure_handler(
original_exception=e, traceback_str=error_traceback
original_exception=e,
duration=_duration,
call_type="update_data",
traceback_str=error_traceback,
)
)
raise e
@ -1485,6 +1518,7 @@ class PrismaClient:
Ensure user owns that key, unless admin.
"""
start_time = time.time()
try:
if tokens is not None and isinstance(tokens, List):
hashed_tokens = []
@ -1532,9 +1566,14 @@ class PrismaClient:
error_msg = f"LiteLLM Prisma Client Exception - delete_data: {str(e)}"
print_verbose(error_msg)
error_traceback = error_msg + "\n" + traceback.format_exc()
end_time = time.time()
_duration = end_time - start_time
asyncio.create_task(
self.proxy_logging_obj.failure_handler(
original_exception=e, traceback_str=error_traceback
original_exception=e,
duration=_duration,
call_type="delete_data",
traceback_str=error_traceback,
)
)
raise e
@ -1548,6 +1587,7 @@ class PrismaClient:
on_backoff=on_backoff, # specifying the function to call on backoff
)
async def connect(self):
start_time = time.time()
try:
verbose_proxy_logger.debug(
"PrismaClient: connect() called Attempting to Connect to DB"
@ -1563,9 +1603,14 @@ class PrismaClient:
error_msg = f"LiteLLM Prisma Client Exception connect(): {str(e)}"
print_verbose(error_msg)
error_traceback = error_msg + "\n" + traceback.format_exc()
end_time = time.time()
_duration = end_time - start_time
asyncio.create_task(
self.proxy_logging_obj.failure_handler(
original_exception=e, traceback_str=error_traceback
original_exception=e,
duration=_duration,
call_type="connect",
traceback_str=error_traceback,
)
)
raise e
@ -1579,6 +1624,7 @@ class PrismaClient:
on_backoff=on_backoff, # specifying the function to call on backoff
)
async def disconnect(self):
start_time = time.time()
try:
await self.db.disconnect()
except Exception as e:
@ -1587,9 +1633,14 @@ class PrismaClient:
error_msg = f"LiteLLM Prisma Client Exception disconnect(): {str(e)}"
print_verbose(error_msg)
error_traceback = error_msg + "\n" + traceback.format_exc()
end_time = time.time()
_duration = end_time - start_time
asyncio.create_task(
self.proxy_logging_obj.failure_handler(
original_exception=e, traceback_str=error_traceback
original_exception=e,
duration=_duration,
call_type="disconnect",
traceback_str=error_traceback,
)
)
raise e
@ -1598,16 +1649,35 @@ class PrismaClient:
"""
Health check endpoint for the prisma client
"""
sql_query = """
SELECT 1
FROM "LiteLLM_VerificationToken"
LIMIT 1
"""
start_time = time.time()
try:
sql_query = """
SELECT 1
FROM "LiteLLM_VerificationToken"
LIMIT 1
"""
# Execute the raw query
# The asterisk before `user_id_list` unpacks the list into separate arguments
response = await self.db.query_raw(sql_query)
return response
# Execute the raw query
# The asterisk before `user_id_list` unpacks the list into separate arguments
response = await self.db.query_raw(sql_query)
return response
except Exception as e:
import traceback
error_msg = f"LiteLLM Prisma Client Exception disconnect(): {str(e)}"
print_verbose(error_msg)
error_traceback = error_msg + "\n" + traceback.format_exc()
end_time = time.time()
_duration = end_time - start_time
asyncio.create_task(
self.proxy_logging_obj.failure_handler(
original_exception=e,
duration=_duration,
call_type="health_check",
traceback_str=error_traceback,
)
)
raise e
class DBClient:
@ -1983,6 +2053,7 @@ async def update_spend(
### UPDATE USER TABLE ###
if len(prisma_client.user_list_transactons.keys()) > 0:
for i in range(n_retry_times + 1):
start_time = time.time()
try:
async with prisma_client.db.tx(
timeout=timedelta(seconds=60)
@ -2013,9 +2084,14 @@ async def update_spend(
)
print_verbose(error_msg)
error_traceback = error_msg + "\n" + traceback.format_exc()
end_time = time.time()
_duration = end_time - start_time
asyncio.create_task(
proxy_logging_obj.failure_handler(
original_exception=e, traceback_str=error_traceback
original_exception=e,
duration=_duration,
call_type="update_spend",
traceback_str=error_traceback,
)
)
raise e
@ -2023,6 +2099,7 @@ async def update_spend(
### UPDATE END-USER TABLE ###
if len(prisma_client.end_user_list_transactons.keys()) > 0:
for i in range(n_retry_times + 1):
start_time = time.time()
try:
async with prisma_client.db.tx(
timeout=timedelta(seconds=60)
@ -2059,9 +2136,14 @@ async def update_spend(
)
print_verbose(error_msg)
error_traceback = error_msg + "\n" + traceback.format_exc()
end_time = time.time()
_duration = end_time - start_time
asyncio.create_task(
proxy_logging_obj.failure_handler(
original_exception=e, traceback_str=error_traceback
original_exception=e,
duration=_duration,
call_type="update_spend",
traceback_str=error_traceback,
)
)
raise e
@ -2069,6 +2151,7 @@ async def update_spend(
### UPDATE KEY TABLE ###
if len(prisma_client.key_list_transactons.keys()) > 0:
for i in range(n_retry_times + 1):
start_time = time.time()
try:
async with prisma_client.db.tx(
timeout=timedelta(seconds=60)
@ -2099,9 +2182,14 @@ async def update_spend(
)
print_verbose(error_msg)
error_traceback = error_msg + "\n" + traceback.format_exc()
end_time = time.time()
_duration = end_time - start_time
asyncio.create_task(
proxy_logging_obj.failure_handler(
original_exception=e, traceback_str=error_traceback
original_exception=e,
duration=_duration,
call_type="update_spend",
traceback_str=error_traceback,
)
)
raise e
@ -2114,6 +2202,7 @@ async def update_spend(
)
if len(prisma_client.team_list_transactons.keys()) > 0:
for i in range(n_retry_times + 1):
start_time = time.time()
try:
async with prisma_client.db.tx(
timeout=timedelta(seconds=60)
@ -2149,9 +2238,14 @@ async def update_spend(
)
print_verbose(error_msg)
error_traceback = error_msg + "\n" + traceback.format_exc()
end_time = time.time()
_duration = end_time - start_time
asyncio.create_task(
proxy_logging_obj.failure_handler(
original_exception=e, traceback_str=error_traceback
original_exception=e,
duration=_duration,
call_type="update_spend",
traceback_str=error_traceback,
)
)
raise e
@ -2159,6 +2253,7 @@ async def update_spend(
### UPDATE ORG TABLE ###
if len(prisma_client.org_list_transactons.keys()) > 0:
for i in range(n_retry_times + 1):
start_time = time.time()
try:
async with prisma_client.db.tx(
timeout=timedelta(seconds=60)
@ -2189,9 +2284,14 @@ async def update_spend(
)
print_verbose(error_msg)
error_traceback = error_msg + "\n" + traceback.format_exc()
end_time = time.time()
_duration = end_time - start_time
asyncio.create_task(
proxy_logging_obj.failure_handler(
original_exception=e, traceback_str=error_traceback
original_exception=e,
duration=_duration,
call_type="update_spend",
traceback_str=error_traceback,
)
)
raise e
@ -2206,6 +2306,7 @@ async def update_spend(
if len(prisma_client.spend_log_transactions) > 0:
for _ in range(n_retry_times + 1):
start_time = time.time()
try:
base_url = os.getenv("SPEND_LOGS_URL", None)
## WRITE TO SEPARATE SERVER ##
@ -2271,9 +2372,14 @@ async def update_spend(
)
print_verbose(error_msg)
error_traceback = error_msg + "\n" + traceback.format_exc()
end_time = time.time()
_duration = end_time - start_time
asyncio.create_task(
proxy_logging_obj.failure_handler(
original_exception=e, traceback_str=error_traceback
original_exception=e,
duration=_duration,
call_type="update_spend",
traceback_str=error_traceback,
)
)
raise e