Merge pull request #3144 from BerriAI/litellm_prometheus_latency_tracking

feat(prometheus_services.py): emit proxy latency for successful llm api requests
This commit is contained in:
Krish Dholakia 2024-04-18 19:10:58 -07:00 committed by GitHub
commit 77a353d484
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
8 changed files with 255 additions and 51 deletions

View file

@ -14,6 +14,7 @@ model_list:
model: gpt-3.5-turbo model: gpt-3.5-turbo
litellm_settings: litellm_settings:
success_callback: ["prometheus"] success_callback: ["prometheus"]
failure_callback: ["prometheus"]
``` ```
Start the proxy Start the proxy
@ -70,3 +71,4 @@ litellm_settings:
|----------------------|--------------------------------------| |----------------------|--------------------------------------|
| `litellm_redis_latency` | histogram latency for redis calls | | `litellm_redis_latency` | histogram latency for redis calls |
| `litellm_redis_fails` | Number of failed redis calls | | `litellm_redis_fails` | Number of failed redis calls |
| `litellm_self_latency` | Histogram latency for successful litellm api call |

View file

@ -1,9 +1,13 @@
import litellm import litellm, traceback
from litellm.proxy._types import UserAPIKeyAuth
from .types.services import ServiceTypes, ServiceLoggerPayload from .types.services import ServiceTypes, ServiceLoggerPayload
from .integrations.prometheus_services import PrometheusServicesLogger from .integrations.prometheus_services import PrometheusServicesLogger
from .integrations.custom_logger import CustomLogger
from datetime import timedelta
from typing import Union
class ServiceLogging: class ServiceLogging(CustomLogger):
""" """
Separate class used for monitoring health of litellm-adjacent services (redis/postgres). Separate class used for monitoring health of litellm-adjacent services (redis/postgres).
""" """
@ -14,7 +18,6 @@ class ServiceLogging:
self.mock_testing_async_success_hook = 0 self.mock_testing_async_success_hook = 0
self.mock_testing_sync_failure_hook = 0 self.mock_testing_sync_failure_hook = 0
self.mock_testing_async_failure_hook = 0 self.mock_testing_async_failure_hook = 0
if "prometheus_system" in litellm.service_callback: if "prometheus_system" in litellm.service_callback:
self.prometheusServicesLogger = PrometheusServicesLogger() self.prometheusServicesLogger = PrometheusServicesLogger()
@ -34,7 +37,9 @@ class ServiceLogging:
if self.mock_testing: if self.mock_testing:
self.mock_testing_sync_failure_hook += 1 self.mock_testing_sync_failure_hook += 1
async def async_service_success_hook(self, service: ServiceTypes, duration: float): async def async_service_success_hook(
self, service: ServiceTypes, duration: float, call_type: str
):
""" """
- For counting if the redis, postgres call is successful - For counting if the redis, postgres call is successful
""" """
@ -42,7 +47,11 @@ class ServiceLogging:
self.mock_testing_async_success_hook += 1 self.mock_testing_async_success_hook += 1
payload = ServiceLoggerPayload( payload = ServiceLoggerPayload(
is_error=False, error=None, service=service, duration=duration is_error=False,
error=None,
service=service,
duration=duration,
call_type=call_type,
) )
for callback in litellm.service_callback: for callback in litellm.service_callback:
if callback == "prometheus_system": if callback == "prometheus_system":
@ -51,7 +60,11 @@ class ServiceLogging:
) )
async def async_service_failure_hook( async def async_service_failure_hook(
self, service: ServiceTypes, duration: float, error: Exception self,
service: ServiceTypes,
duration: float,
error: Union[str, Exception],
call_type: str,
): ):
""" """
- For counting if the redis, postgres call is unsuccessful - For counting if the redis, postgres call is unsuccessful
@ -59,8 +72,18 @@ class ServiceLogging:
if self.mock_testing: if self.mock_testing:
self.mock_testing_async_failure_hook += 1 self.mock_testing_async_failure_hook += 1
error_message = ""
if isinstance(error, Exception):
error_message = str(error)
elif isinstance(error, str):
error_message = error
payload = ServiceLoggerPayload( payload = ServiceLoggerPayload(
is_error=True, error=str(error), service=service, duration=duration is_error=True,
error=error_message,
service=service,
duration=duration,
call_type=call_type,
) )
for callback in litellm.service_callback: for callback in litellm.service_callback:
if callback == "prometheus_system": if callback == "prometheus_system":
@ -69,3 +92,37 @@ class ServiceLogging:
await self.prometheusServicesLogger.async_service_failure_hook( await self.prometheusServicesLogger.async_service_failure_hook(
payload=payload payload=payload
) )
async def async_post_call_failure_hook(
self, original_exception: Exception, user_api_key_dict: UserAPIKeyAuth
):
"""
Hook to track failed litellm-service calls
"""
return await super().async_post_call_failure_hook(
original_exception, user_api_key_dict
)
async def async_log_success_event(self, kwargs, response_obj, start_time, end_time):
"""
Hook to track latency for litellm proxy llm api calls
"""
try:
_duration = end_time - start_time
if isinstance(_duration, timedelta):
_duration = _duration.total_seconds()
elif isinstance(_duration, float):
pass
else:
raise Exception(
"Duration={} is not a float or timedelta object. type={}".format(
_duration, type(_duration)
)
) # invalid _duration value
await self.async_service_success_hook(
service=ServiceTypes.LITELLM,
duration=_duration,
call_type=kwargs["call_type"],
)
except Exception as e:
raise e

View file

@ -13,7 +13,6 @@ import json, traceback, ast, hashlib
from typing import Optional, Literal, List, Union, Any, BinaryIO from typing import Optional, Literal, List, Union, Any, BinaryIO
from openai._models import BaseModel as OpenAIObject from openai._models import BaseModel as OpenAIObject
from litellm._logging import verbose_logger from litellm._logging import verbose_logger
from litellm._service_logger import ServiceLogging
from litellm.types.services import ServiceLoggerPayload, ServiceTypes from litellm.types.services import ServiceLoggerPayload, ServiceTypes
import traceback import traceback
@ -132,6 +131,7 @@ class RedisCache(BaseCache):
**kwargs, **kwargs,
): ):
from ._redis import get_redis_client, get_redis_connection_pool from ._redis import get_redis_client, get_redis_connection_pool
from litellm._service_logger import ServiceLogging
import redis import redis
redis_kwargs = {} redis_kwargs = {}
@ -216,7 +216,9 @@ class RedisCache(BaseCache):
_duration = end_time - start_time _duration = end_time - start_time
asyncio.create_task( asyncio.create_task(
self.service_logger_obj.async_service_success_hook( self.service_logger_obj.async_service_success_hook(
service=ServiceTypes.REDIS, duration=_duration service=ServiceTypes.REDIS,
duration=_duration,
call_type="async_scan_iter",
) )
) # DO NOT SLOW DOWN CALL B/C OF THIS ) # DO NOT SLOW DOWN CALL B/C OF THIS
return keys return keys
@ -227,7 +229,10 @@ class RedisCache(BaseCache):
_duration = end_time - start_time _duration = end_time - start_time
asyncio.create_task( asyncio.create_task(
self.service_logger_obj.async_service_failure_hook( self.service_logger_obj.async_service_failure_hook(
service=ServiceTypes.REDIS, duration=_duration, error=e service=ServiceTypes.REDIS,
duration=_duration,
error=e,
call_type="async_scan_iter",
) )
) )
raise e raise e
@ -267,7 +272,9 @@ class RedisCache(BaseCache):
_duration = end_time - start_time _duration = end_time - start_time
asyncio.create_task( asyncio.create_task(
self.service_logger_obj.async_service_success_hook( self.service_logger_obj.async_service_success_hook(
service=ServiceTypes.REDIS, duration=_duration service=ServiceTypes.REDIS,
duration=_duration,
call_type="async_set_cache",
) )
) )
except Exception as e: except Exception as e:
@ -275,7 +282,10 @@ class RedisCache(BaseCache):
_duration = end_time - start_time _duration = end_time - start_time
asyncio.create_task( asyncio.create_task(
self.service_logger_obj.async_service_failure_hook( self.service_logger_obj.async_service_failure_hook(
service=ServiceTypes.REDIS, duration=_duration, error=e service=ServiceTypes.REDIS,
duration=_duration,
error=e,
call_type="async_set_cache",
) )
) )
# NON blocking - notify users Redis is throwing an exception # NON blocking - notify users Redis is throwing an exception
@ -316,7 +326,9 @@ class RedisCache(BaseCache):
_duration = end_time - start_time _duration = end_time - start_time
asyncio.create_task( asyncio.create_task(
self.service_logger_obj.async_service_success_hook( self.service_logger_obj.async_service_success_hook(
service=ServiceTypes.REDIS, duration=_duration service=ServiceTypes.REDIS,
duration=_duration,
call_type="async_set_cache_pipeline",
) )
) )
return results return results
@ -326,7 +338,10 @@ class RedisCache(BaseCache):
_duration = end_time - start_time _duration = end_time - start_time
asyncio.create_task( asyncio.create_task(
self.service_logger_obj.async_service_failure_hook( self.service_logger_obj.async_service_failure_hook(
service=ServiceTypes.REDIS, duration=_duration, error=e service=ServiceTypes.REDIS,
duration=_duration,
error=e,
call_type="async_set_cache_pipeline",
) )
) )
@ -359,6 +374,7 @@ class RedisCache(BaseCache):
self.service_logger_obj.async_service_success_hook( self.service_logger_obj.async_service_success_hook(
service=ServiceTypes.REDIS, service=ServiceTypes.REDIS,
duration=_duration, duration=_duration,
call_type="async_increment",
) )
) )
return result return result
@ -368,7 +384,10 @@ class RedisCache(BaseCache):
_duration = end_time - start_time _duration = end_time - start_time
asyncio.create_task( asyncio.create_task(
self.service_logger_obj.async_service_failure_hook( self.service_logger_obj.async_service_failure_hook(
service=ServiceTypes.REDIS, duration=_duration, error=e service=ServiceTypes.REDIS,
duration=_duration,
error=e,
call_type="async_increment",
) )
) )
verbose_logger.error( verbose_logger.error(
@ -459,7 +478,9 @@ class RedisCache(BaseCache):
_duration = end_time - start_time _duration = end_time - start_time
asyncio.create_task( asyncio.create_task(
self.service_logger_obj.async_service_success_hook( self.service_logger_obj.async_service_success_hook(
service=ServiceTypes.REDIS, duration=_duration service=ServiceTypes.REDIS,
duration=_duration,
call_type="async_get_cache",
) )
) )
return response return response
@ -469,7 +490,10 @@ class RedisCache(BaseCache):
_duration = end_time - start_time _duration = end_time - start_time
asyncio.create_task( asyncio.create_task(
self.service_logger_obj.async_service_failure_hook( self.service_logger_obj.async_service_failure_hook(
service=ServiceTypes.REDIS, duration=_duration, error=e service=ServiceTypes.REDIS,
duration=_duration,
error=e,
call_type="async_get_cache",
) )
) )
# NON blocking - notify users Redis is throwing an exception # NON blocking - notify users Redis is throwing an exception
@ -497,7 +521,9 @@ class RedisCache(BaseCache):
_duration = end_time - start_time _duration = end_time - start_time
asyncio.create_task( asyncio.create_task(
self.service_logger_obj.async_service_success_hook( self.service_logger_obj.async_service_success_hook(
service=ServiceTypes.REDIS, duration=_duration service=ServiceTypes.REDIS,
duration=_duration,
call_type="async_batch_get_cache",
) )
) )
@ -519,7 +545,10 @@ class RedisCache(BaseCache):
_duration = end_time - start_time _duration = end_time - start_time
asyncio.create_task( asyncio.create_task(
self.service_logger_obj.async_service_failure_hook( self.service_logger_obj.async_service_failure_hook(
service=ServiceTypes.REDIS, duration=_duration, error=e service=ServiceTypes.REDIS,
duration=_duration,
error=e,
call_type="async_batch_get_cache",
) )
) )
print_verbose(f"Error occurred in pipeline read - {str(e)}") print_verbose(f"Error occurred in pipeline read - {str(e)}")

View file

@ -30,6 +30,7 @@ class PrometheusServicesLogger:
raise Exception( raise Exception(
"Missing prometheus_client. Run `pip install prometheus-client`" "Missing prometheus_client. Run `pip install prometheus-client`"
) )
print("INITIALIZES PROMETHEUS SERVICE LOGGER!")
self.Histogram = Histogram self.Histogram = Histogram
self.Counter = Counter self.Counter = Counter
@ -151,6 +152,7 @@ class PrometheusServicesLogger:
if self.mock_testing: if self.mock_testing:
self.mock_testing_success_calls += 1 self.mock_testing_success_calls += 1
print(f"LOGS SUCCESSFUL CALL TO PROMETHEUS - payload={payload}")
if payload.service.value in self.payload_to_prometheus_map: if payload.service.value in self.payload_to_prometheus_map:
prom_objects = self.payload_to_prometheus_map[payload.service.value] prom_objects = self.payload_to_prometheus_map[payload.service.value]
for obj in prom_objects: for obj in prom_objects:

View file

@ -31,12 +31,12 @@ litellm_settings:
upperbound_key_generate_params: upperbound_key_generate_params:
max_budget: os.environ/LITELLM_UPPERBOUND_KEYS_MAX_BUDGET max_budget: os.environ/LITELLM_UPPERBOUND_KEYS_MAX_BUDGET
router_settings: # router_settings:
routing_strategy: usage-based-routing-v2 # routing_strategy: usage-based-routing-v2
redis_host: os.environ/REDIS_HOST # redis_host: os.environ/REDIS_HOST
redis_password: os.environ/REDIS_PASSWORD # redis_password: os.environ/REDIS_PASSWORD
redis_port: os.environ/REDIS_PORT # redis_port: os.environ/REDIS_PORT
enable_pre_call_checks: True # enable_pre_call_checks: True
general_settings: general_settings:
master_key: sk-1234 master_key: sk-1234

View file

@ -1,5 +1,5 @@
from typing import Optional, List, Any, Literal, Union from typing import Optional, List, Any, Literal, Union
import os, subprocess, hashlib, importlib, asyncio, copy, json, aiohttp, httpx import os, subprocess, hashlib, importlib, asyncio, copy, json, aiohttp, httpx, time
import litellm, backoff import litellm, backoff
from litellm.proxy._types import ( from litellm.proxy._types import (
UserAPIKeyAuth, UserAPIKeyAuth,
@ -18,6 +18,7 @@ from litellm.llms.custom_httpx.httpx_handler import HTTPHandler
from litellm.proxy.hooks.parallel_request_limiter import ( from litellm.proxy.hooks.parallel_request_limiter import (
_PROXY_MaxParallelRequestsHandler, _PROXY_MaxParallelRequestsHandler,
) )
from litellm._service_logger import ServiceLogging, ServiceTypes
from litellm import ModelResponse, EmbeddingResponse, ImageResponse from litellm import ModelResponse, EmbeddingResponse, ImageResponse
from litellm.proxy.hooks.max_budget_limiter import _PROXY_MaxBudgetLimiter from litellm.proxy.hooks.max_budget_limiter import _PROXY_MaxBudgetLimiter
from litellm.proxy.hooks.tpm_rpm_limiter import _PROXY_MaxTPMRPMLimiter from litellm.proxy.hooks.tpm_rpm_limiter import _PROXY_MaxTPMRPMLimiter
@ -109,10 +110,12 @@ class ProxyLogging:
def _init_litellm_callbacks(self): def _init_litellm_callbacks(self):
print_verbose(f"INITIALIZING LITELLM CALLBACKS!") print_verbose(f"INITIALIZING LITELLM CALLBACKS!")
self.service_logging_obj = ServiceLogging()
litellm.callbacks.append(self.max_parallel_request_limiter) litellm.callbacks.append(self.max_parallel_request_limiter)
litellm.callbacks.append(self.max_tpm_rpm_limiter) litellm.callbacks.append(self.max_tpm_rpm_limiter)
litellm.callbacks.append(self.max_budget_limiter) litellm.callbacks.append(self.max_budget_limiter)
litellm.callbacks.append(self.cache_control_check) litellm.callbacks.append(self.cache_control_check)
litellm.callbacks.append(self.service_logging_obj)
litellm.success_callback.append(self.response_taking_too_long_callback) litellm.success_callback.append(self.response_taking_too_long_callback)
for callback in litellm.callbacks: for callback in litellm.callbacks:
if callback not in litellm.input_callback: if callback not in litellm.input_callback:
@ -493,7 +496,9 @@ class ProxyLogging:
else: else:
raise Exception("Missing SENTRY_DSN from environment") raise Exception("Missing SENTRY_DSN from environment")
async def failure_handler(self, original_exception, traceback_str=""): async def failure_handler(
self, original_exception, duration: float, call_type: str, traceback_str=""
):
""" """
Log failed db read/writes Log failed db read/writes
@ -520,6 +525,14 @@ class ProxyLogging:
) )
) )
if hasattr(self, "service_logging_obj"):
self.service_logging_obj.async_service_failure_hook(
service=ServiceTypes.DB,
duration=duration,
error=error_message,
call_type=call_type,
)
if litellm.utils.capture_exception: if litellm.utils.capture_exception:
litellm.utils.capture_exception(error=original_exception) litellm.utils.capture_exception(error=original_exception)
@ -842,6 +855,7 @@ class PrismaClient:
verbose_proxy_logger.debug( verbose_proxy_logger.debug(
f"PrismaClient: get_generic_data: {key}, table_name: {table_name}" f"PrismaClient: get_generic_data: {key}, table_name: {table_name}"
) )
start_time = time.time()
try: try:
if table_name == "users": if table_name == "users":
response = await self.db.litellm_usertable.find_first( response = await self.db.litellm_usertable.find_first(
@ -866,11 +880,17 @@ class PrismaClient:
error_msg = f"LiteLLM Prisma Client Exception get_generic_data: {str(e)}" error_msg = f"LiteLLM Prisma Client Exception get_generic_data: {str(e)}"
print_verbose(error_msg) print_verbose(error_msg)
error_traceback = error_msg + "\n" + traceback.format_exc() error_traceback = error_msg + "\n" + traceback.format_exc()
end_time = time.time()
_duration = end_time - start_time
asyncio.create_task( asyncio.create_task(
self.proxy_logging_obj.failure_handler( self.proxy_logging_obj.failure_handler(
original_exception=e, traceback_str=error_traceback original_exception=e,
duration=_duration,
traceback_str=error_traceback,
call_type="get_generic_data",
) )
) )
raise e raise e
@backoff.on_exception( @backoff.on_exception(
@ -908,6 +928,7 @@ class PrismaClient:
] = None, # pagination, number of rows to getch when find_all==True ] = None, # pagination, number of rows to getch when find_all==True
): ):
args_passed_in = locals() args_passed_in = locals()
start_time = time.time()
verbose_proxy_logger.debug( verbose_proxy_logger.debug(
f"PrismaClient: get_data - args_passed_in: {args_passed_in}" f"PrismaClient: get_data - args_passed_in: {args_passed_in}"
) )
@ -1173,9 +1194,15 @@ class PrismaClient:
print_verbose(error_msg) print_verbose(error_msg)
error_traceback = error_msg + "\n" + traceback.format_exc() error_traceback = error_msg + "\n" + traceback.format_exc()
verbose_proxy_logger.debug(error_traceback) verbose_proxy_logger.debug(error_traceback)
end_time = time.time()
_duration = end_time - start_time
asyncio.create_task( asyncio.create_task(
self.proxy_logging_obj.failure_handler( self.proxy_logging_obj.failure_handler(
original_exception=e, traceback_str=error_traceback original_exception=e,
duration=_duration,
call_type="get_data",
traceback_str=error_traceback,
) )
) )
raise e raise e
@ -1198,6 +1225,7 @@ class PrismaClient:
""" """
Add a key to the database. If it already exists, do nothing. Add a key to the database. If it already exists, do nothing.
""" """
start_time = time.time()
try: try:
verbose_proxy_logger.debug("PrismaClient: insert_data: %s", data) verbose_proxy_logger.debug("PrismaClient: insert_data: %s", data)
if table_name == "key": if table_name == "key":
@ -1315,9 +1343,14 @@ class PrismaClient:
error_msg = f"LiteLLM Prisma Client Exception in insert_data: {str(e)}" error_msg = f"LiteLLM Prisma Client Exception in insert_data: {str(e)}"
print_verbose(error_msg) print_verbose(error_msg)
error_traceback = error_msg + "\n" + traceback.format_exc() error_traceback = error_msg + "\n" + traceback.format_exc()
end_time = time.time()
_duration = end_time - start_time
asyncio.create_task( asyncio.create_task(
self.proxy_logging_obj.failure_handler( self.proxy_logging_obj.failure_handler(
original_exception=e, traceback_str=error_traceback original_exception=e,
duration=_duration,
call_type="insert_data",
traceback_str=error_traceback,
) )
) )
raise e raise e
@ -1348,6 +1381,7 @@ class PrismaClient:
verbose_proxy_logger.debug( verbose_proxy_logger.debug(
f"PrismaClient: update_data, table_name: {table_name}" f"PrismaClient: update_data, table_name: {table_name}"
) )
start_time = time.time()
try: try:
db_data = self.jsonify_object(data=data) db_data = self.jsonify_object(data=data)
if update_key_values is not None: if update_key_values is not None:
@ -1509,9 +1543,14 @@ class PrismaClient:
error_msg = f"LiteLLM Prisma Client Exception - update_data: {str(e)}" error_msg = f"LiteLLM Prisma Client Exception - update_data: {str(e)}"
print_verbose(error_msg) print_verbose(error_msg)
error_traceback = error_msg + "\n" + traceback.format_exc() error_traceback = error_msg + "\n" + traceback.format_exc()
end_time = time.time()
_duration = end_time - start_time
asyncio.create_task( asyncio.create_task(
self.proxy_logging_obj.failure_handler( self.proxy_logging_obj.failure_handler(
original_exception=e, traceback_str=error_traceback original_exception=e,
duration=_duration,
call_type="update_data",
traceback_str=error_traceback,
) )
) )
raise e raise e
@ -1536,6 +1575,7 @@ class PrismaClient:
Ensure user owns that key, unless admin. Ensure user owns that key, unless admin.
""" """
start_time = time.time()
try: try:
if tokens is not None and isinstance(tokens, List): if tokens is not None and isinstance(tokens, List):
hashed_tokens = [] hashed_tokens = []
@ -1583,9 +1623,14 @@ class PrismaClient:
error_msg = f"LiteLLM Prisma Client Exception - delete_data: {str(e)}" error_msg = f"LiteLLM Prisma Client Exception - delete_data: {str(e)}"
print_verbose(error_msg) print_verbose(error_msg)
error_traceback = error_msg + "\n" + traceback.format_exc() error_traceback = error_msg + "\n" + traceback.format_exc()
end_time = time.time()
_duration = end_time - start_time
asyncio.create_task( asyncio.create_task(
self.proxy_logging_obj.failure_handler( self.proxy_logging_obj.failure_handler(
original_exception=e, traceback_str=error_traceback original_exception=e,
duration=_duration,
call_type="delete_data",
traceback_str=error_traceback,
) )
) )
raise e raise e
@ -1599,6 +1644,7 @@ class PrismaClient:
on_backoff=on_backoff, # specifying the function to call on backoff on_backoff=on_backoff, # specifying the function to call on backoff
) )
async def connect(self): async def connect(self):
start_time = time.time()
try: try:
verbose_proxy_logger.debug( verbose_proxy_logger.debug(
"PrismaClient: connect() called Attempting to Connect to DB" "PrismaClient: connect() called Attempting to Connect to DB"
@ -1614,9 +1660,14 @@ class PrismaClient:
error_msg = f"LiteLLM Prisma Client Exception connect(): {str(e)}" error_msg = f"LiteLLM Prisma Client Exception connect(): {str(e)}"
print_verbose(error_msg) print_verbose(error_msg)
error_traceback = error_msg + "\n" + traceback.format_exc() error_traceback = error_msg + "\n" + traceback.format_exc()
end_time = time.time()
_duration = end_time - start_time
asyncio.create_task( asyncio.create_task(
self.proxy_logging_obj.failure_handler( self.proxy_logging_obj.failure_handler(
original_exception=e, traceback_str=error_traceback original_exception=e,
duration=_duration,
call_type="connect",
traceback_str=error_traceback,
) )
) )
raise e raise e
@ -1630,6 +1681,7 @@ class PrismaClient:
on_backoff=on_backoff, # specifying the function to call on backoff on_backoff=on_backoff, # specifying the function to call on backoff
) )
async def disconnect(self): async def disconnect(self):
start_time = time.time()
try: try:
await self.db.disconnect() await self.db.disconnect()
except Exception as e: except Exception as e:
@ -1638,9 +1690,14 @@ class PrismaClient:
error_msg = f"LiteLLM Prisma Client Exception disconnect(): {str(e)}" error_msg = f"LiteLLM Prisma Client Exception disconnect(): {str(e)}"
print_verbose(error_msg) print_verbose(error_msg)
error_traceback = error_msg + "\n" + traceback.format_exc() error_traceback = error_msg + "\n" + traceback.format_exc()
end_time = time.time()
_duration = end_time - start_time
asyncio.create_task( asyncio.create_task(
self.proxy_logging_obj.failure_handler( self.proxy_logging_obj.failure_handler(
original_exception=e, traceback_str=error_traceback original_exception=e,
duration=_duration,
call_type="disconnect",
traceback_str=error_traceback,
) )
) )
raise e raise e
@ -1649,16 +1706,35 @@ class PrismaClient:
""" """
Health check endpoint for the prisma client Health check endpoint for the prisma client
""" """
sql_query = """ start_time = time.time()
SELECT 1 try:
FROM "LiteLLM_VerificationToken" sql_query = """
LIMIT 1 SELECT 1
""" FROM "LiteLLM_VerificationToken"
LIMIT 1
"""
# Execute the raw query # Execute the raw query
# The asterisk before `user_id_list` unpacks the list into separate arguments # The asterisk before `user_id_list` unpacks the list into separate arguments
response = await self.db.query_raw(sql_query) response = await self.db.query_raw(sql_query)
return response return response
except Exception as e:
import traceback
error_msg = f"LiteLLM Prisma Client Exception disconnect(): {str(e)}"
print_verbose(error_msg)
error_traceback = error_msg + "\n" + traceback.format_exc()
end_time = time.time()
_duration = end_time - start_time
asyncio.create_task(
self.proxy_logging_obj.failure_handler(
original_exception=e,
duration=_duration,
call_type="health_check",
traceback_str=error_traceback,
)
)
raise e
class DBClient: class DBClient:
@ -2034,6 +2110,7 @@ async def update_spend(
### UPDATE USER TABLE ### ### UPDATE USER TABLE ###
if len(prisma_client.user_list_transactons.keys()) > 0: if len(prisma_client.user_list_transactons.keys()) > 0:
for i in range(n_retry_times + 1): for i in range(n_retry_times + 1):
start_time = time.time()
try: try:
async with prisma_client.db.tx( async with prisma_client.db.tx(
timeout=timedelta(seconds=60) timeout=timedelta(seconds=60)
@ -2064,9 +2141,14 @@ async def update_spend(
) )
print_verbose(error_msg) print_verbose(error_msg)
error_traceback = error_msg + "\n" + traceback.format_exc() error_traceback = error_msg + "\n" + traceback.format_exc()
end_time = time.time()
_duration = end_time - start_time
asyncio.create_task( asyncio.create_task(
proxy_logging_obj.failure_handler( proxy_logging_obj.failure_handler(
original_exception=e, traceback_str=error_traceback original_exception=e,
duration=_duration,
call_type="update_spend",
traceback_str=error_traceback,
) )
) )
raise e raise e
@ -2074,6 +2156,7 @@ async def update_spend(
### UPDATE END-USER TABLE ### ### UPDATE END-USER TABLE ###
if len(prisma_client.end_user_list_transactons.keys()) > 0: if len(prisma_client.end_user_list_transactons.keys()) > 0:
for i in range(n_retry_times + 1): for i in range(n_retry_times + 1):
start_time = time.time()
try: try:
async with prisma_client.db.tx( async with prisma_client.db.tx(
timeout=timedelta(seconds=60) timeout=timedelta(seconds=60)
@ -2110,9 +2193,14 @@ async def update_spend(
) )
print_verbose(error_msg) print_verbose(error_msg)
error_traceback = error_msg + "\n" + traceback.format_exc() error_traceback = error_msg + "\n" + traceback.format_exc()
end_time = time.time()
_duration = end_time - start_time
asyncio.create_task( asyncio.create_task(
proxy_logging_obj.failure_handler( proxy_logging_obj.failure_handler(
original_exception=e, traceback_str=error_traceback original_exception=e,
duration=_duration,
call_type="update_spend",
traceback_str=error_traceback,
) )
) )
raise e raise e
@ -2120,6 +2208,7 @@ async def update_spend(
### UPDATE KEY TABLE ### ### UPDATE KEY TABLE ###
if len(prisma_client.key_list_transactons.keys()) > 0: if len(prisma_client.key_list_transactons.keys()) > 0:
for i in range(n_retry_times + 1): for i in range(n_retry_times + 1):
start_time = time.time()
try: try:
async with prisma_client.db.tx( async with prisma_client.db.tx(
timeout=timedelta(seconds=60) timeout=timedelta(seconds=60)
@ -2150,9 +2239,14 @@ async def update_spend(
) )
print_verbose(error_msg) print_verbose(error_msg)
error_traceback = error_msg + "\n" + traceback.format_exc() error_traceback = error_msg + "\n" + traceback.format_exc()
end_time = time.time()
_duration = end_time - start_time
asyncio.create_task( asyncio.create_task(
proxy_logging_obj.failure_handler( proxy_logging_obj.failure_handler(
original_exception=e, traceback_str=error_traceback original_exception=e,
duration=_duration,
call_type="update_spend",
traceback_str=error_traceback,
) )
) )
raise e raise e
@ -2165,6 +2259,7 @@ async def update_spend(
) )
if len(prisma_client.team_list_transactons.keys()) > 0: if len(prisma_client.team_list_transactons.keys()) > 0:
for i in range(n_retry_times + 1): for i in range(n_retry_times + 1):
start_time = time.time()
try: try:
async with prisma_client.db.tx( async with prisma_client.db.tx(
timeout=timedelta(seconds=60) timeout=timedelta(seconds=60)
@ -2200,9 +2295,14 @@ async def update_spend(
) )
print_verbose(error_msg) print_verbose(error_msg)
error_traceback = error_msg + "\n" + traceback.format_exc() error_traceback = error_msg + "\n" + traceback.format_exc()
end_time = time.time()
_duration = end_time - start_time
asyncio.create_task( asyncio.create_task(
proxy_logging_obj.failure_handler( proxy_logging_obj.failure_handler(
original_exception=e, traceback_str=error_traceback original_exception=e,
duration=_duration,
call_type="update_spend",
traceback_str=error_traceback,
) )
) )
raise e raise e
@ -2210,6 +2310,7 @@ async def update_spend(
### UPDATE ORG TABLE ### ### UPDATE ORG TABLE ###
if len(prisma_client.org_list_transactons.keys()) > 0: if len(prisma_client.org_list_transactons.keys()) > 0:
for i in range(n_retry_times + 1): for i in range(n_retry_times + 1):
start_time = time.time()
try: try:
async with prisma_client.db.tx( async with prisma_client.db.tx(
timeout=timedelta(seconds=60) timeout=timedelta(seconds=60)
@ -2240,9 +2341,14 @@ async def update_spend(
) )
print_verbose(error_msg) print_verbose(error_msg)
error_traceback = error_msg + "\n" + traceback.format_exc() error_traceback = error_msg + "\n" + traceback.format_exc()
end_time = time.time()
_duration = end_time - start_time
asyncio.create_task( asyncio.create_task(
proxy_logging_obj.failure_handler( proxy_logging_obj.failure_handler(
original_exception=e, traceback_str=error_traceback original_exception=e,
duration=_duration,
call_type="update_spend",
traceback_str=error_traceback,
) )
) )
raise e raise e
@ -2257,6 +2363,7 @@ async def update_spend(
if len(prisma_client.spend_log_transactions) > 0: if len(prisma_client.spend_log_transactions) > 0:
for _ in range(n_retry_times + 1): for _ in range(n_retry_times + 1):
start_time = time.time()
try: try:
base_url = os.getenv("SPEND_LOGS_URL", None) base_url = os.getenv("SPEND_LOGS_URL", None)
## WRITE TO SEPARATE SERVER ## ## WRITE TO SEPARATE SERVER ##
@ -2322,9 +2429,14 @@ async def update_spend(
) )
print_verbose(error_msg) print_verbose(error_msg)
error_traceback = error_msg + "\n" + traceback.format_exc() error_traceback = error_msg + "\n" + traceback.format_exc()
end_time = time.time()
_duration = end_time - start_time
asyncio.create_task( asyncio.create_task(
proxy_logging_obj.failure_handler( proxy_logging_obj.failure_handler(
original_exception=e, traceback_str=error_traceback original_exception=e,
duration=_duration,
call_type="update_spend",
traceback_str=error_traceback,
) )
) )
raise e raise e

View file

@ -412,7 +412,7 @@ async def test_cost_tracking_with_caching():
""" """
from litellm import Cache from litellm import Cache
litellm.set_verbose = False litellm.set_verbose = True
litellm.cache = Cache( litellm.cache = Cache(
type="redis", type="redis",
host=os.environ["REDIS_HOST"], host=os.environ["REDIS_HOST"],

View file

@ -5,11 +5,12 @@ from typing import Optional
class ServiceTypes(enum.Enum): class ServiceTypes(enum.Enum):
""" """
Enum for litellm-adjacent services (redis/postgres/etc.) Enum for litellm + litellm-adjacent services (redis/postgres/etc.)
""" """
REDIS = "redis" REDIS = "redis"
DB = "postgres" DB = "postgres"
LITELLM = "self"
class ServiceLoggerPayload(BaseModel): class ServiceLoggerPayload(BaseModel):
@ -21,6 +22,7 @@ class ServiceLoggerPayload(BaseModel):
error: Optional[str] = Field(None, description="what was the error") error: Optional[str] = Field(None, description="what was the error")
service: ServiceTypes = Field(description="who is this for? - postgres/redis") service: ServiceTypes = Field(description="who is this for? - postgres/redis")
duration: float = Field(description="How long did the request take?") duration: float = Field(description="How long did the request take?")
call_type: str = Field(description="The call of the service, being made")
def to_json(self, **kwargs): def to_json(self, **kwargs):
try: try: