Merge pull request #3144 from BerriAI/litellm_prometheus_latency_tracking

feat(prometheus_services.py): emit proxy latency for successful llm api requests
This commit is contained in:
Krish Dholakia 2024-04-18 19:10:58 -07:00 committed by GitHub
commit 77a353d484
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
8 changed files with 255 additions and 51 deletions

View file

@ -1,5 +1,5 @@
from typing import Optional, List, Any, Literal, Union
import os, subprocess, hashlib, importlib, asyncio, copy, json, aiohttp, httpx
import os, subprocess, hashlib, importlib, asyncio, copy, json, aiohttp, httpx, time
import litellm, backoff
from litellm.proxy._types import (
UserAPIKeyAuth,
@ -18,6 +18,7 @@ from litellm.llms.custom_httpx.httpx_handler import HTTPHandler
from litellm.proxy.hooks.parallel_request_limiter import (
_PROXY_MaxParallelRequestsHandler,
)
from litellm._service_logger import ServiceLogging, ServiceTypes
from litellm import ModelResponse, EmbeddingResponse, ImageResponse
from litellm.proxy.hooks.max_budget_limiter import _PROXY_MaxBudgetLimiter
from litellm.proxy.hooks.tpm_rpm_limiter import _PROXY_MaxTPMRPMLimiter
@ -109,10 +110,12 @@ class ProxyLogging:
def _init_litellm_callbacks(self):
print_verbose(f"INITIALIZING LITELLM CALLBACKS!")
self.service_logging_obj = ServiceLogging()
litellm.callbacks.append(self.max_parallel_request_limiter)
litellm.callbacks.append(self.max_tpm_rpm_limiter)
litellm.callbacks.append(self.max_budget_limiter)
litellm.callbacks.append(self.cache_control_check)
litellm.callbacks.append(self.service_logging_obj)
litellm.success_callback.append(self.response_taking_too_long_callback)
for callback in litellm.callbacks:
if callback not in litellm.input_callback:
@ -493,7 +496,9 @@ class ProxyLogging:
else:
raise Exception("Missing SENTRY_DSN from environment")
async def failure_handler(self, original_exception, traceback_str=""):
async def failure_handler(
self, original_exception, duration: float, call_type: str, traceback_str=""
):
"""
Log failed db read/writes
@ -520,6 +525,14 @@ class ProxyLogging:
)
)
if hasattr(self, "service_logging_obj"):
self.service_logging_obj.async_service_failure_hook(
service=ServiceTypes.DB,
duration=duration,
error=error_message,
call_type=call_type,
)
if litellm.utils.capture_exception:
litellm.utils.capture_exception(error=original_exception)
@ -842,6 +855,7 @@ class PrismaClient:
verbose_proxy_logger.debug(
f"PrismaClient: get_generic_data: {key}, table_name: {table_name}"
)
start_time = time.time()
try:
if table_name == "users":
response = await self.db.litellm_usertable.find_first(
@ -866,11 +880,17 @@ class PrismaClient:
error_msg = f"LiteLLM Prisma Client Exception get_generic_data: {str(e)}"
print_verbose(error_msg)
error_traceback = error_msg + "\n" + traceback.format_exc()
end_time = time.time()
_duration = end_time - start_time
asyncio.create_task(
self.proxy_logging_obj.failure_handler(
original_exception=e, traceback_str=error_traceback
original_exception=e,
duration=_duration,
traceback_str=error_traceback,
call_type="get_generic_data",
)
)
raise e
@backoff.on_exception(
@ -908,6 +928,7 @@ class PrismaClient:
] = None, # pagination, number of rows to getch when find_all==True
):
args_passed_in = locals()
start_time = time.time()
verbose_proxy_logger.debug(
f"PrismaClient: get_data - args_passed_in: {args_passed_in}"
)
@ -1173,9 +1194,15 @@ class PrismaClient:
print_verbose(error_msg)
error_traceback = error_msg + "\n" + traceback.format_exc()
verbose_proxy_logger.debug(error_traceback)
end_time = time.time()
_duration = end_time - start_time
asyncio.create_task(
self.proxy_logging_obj.failure_handler(
original_exception=e, traceback_str=error_traceback
original_exception=e,
duration=_duration,
call_type="get_data",
traceback_str=error_traceback,
)
)
raise e
@ -1198,6 +1225,7 @@ class PrismaClient:
"""
Add a key to the database. If it already exists, do nothing.
"""
start_time = time.time()
try:
verbose_proxy_logger.debug("PrismaClient: insert_data: %s", data)
if table_name == "key":
@ -1315,9 +1343,14 @@ class PrismaClient:
error_msg = f"LiteLLM Prisma Client Exception in insert_data: {str(e)}"
print_verbose(error_msg)
error_traceback = error_msg + "\n" + traceback.format_exc()
end_time = time.time()
_duration = end_time - start_time
asyncio.create_task(
self.proxy_logging_obj.failure_handler(
original_exception=e, traceback_str=error_traceback
original_exception=e,
duration=_duration,
call_type="insert_data",
traceback_str=error_traceback,
)
)
raise e
@ -1348,6 +1381,7 @@ class PrismaClient:
verbose_proxy_logger.debug(
f"PrismaClient: update_data, table_name: {table_name}"
)
start_time = time.time()
try:
db_data = self.jsonify_object(data=data)
if update_key_values is not None:
@ -1509,9 +1543,14 @@ class PrismaClient:
error_msg = f"LiteLLM Prisma Client Exception - update_data: {str(e)}"
print_verbose(error_msg)
error_traceback = error_msg + "\n" + traceback.format_exc()
end_time = time.time()
_duration = end_time - start_time
asyncio.create_task(
self.proxy_logging_obj.failure_handler(
original_exception=e, traceback_str=error_traceback
original_exception=e,
duration=_duration,
call_type="update_data",
traceback_str=error_traceback,
)
)
raise e
@ -1536,6 +1575,7 @@ class PrismaClient:
Ensure user owns that key, unless admin.
"""
start_time = time.time()
try:
if tokens is not None and isinstance(tokens, List):
hashed_tokens = []
@ -1583,9 +1623,14 @@ class PrismaClient:
error_msg = f"LiteLLM Prisma Client Exception - delete_data: {str(e)}"
print_verbose(error_msg)
error_traceback = error_msg + "\n" + traceback.format_exc()
end_time = time.time()
_duration = end_time - start_time
asyncio.create_task(
self.proxy_logging_obj.failure_handler(
original_exception=e, traceback_str=error_traceback
original_exception=e,
duration=_duration,
call_type="delete_data",
traceback_str=error_traceback,
)
)
raise e
@ -1599,6 +1644,7 @@ class PrismaClient:
on_backoff=on_backoff, # specifying the function to call on backoff
)
async def connect(self):
start_time = time.time()
try:
verbose_proxy_logger.debug(
"PrismaClient: connect() called Attempting to Connect to DB"
@ -1614,9 +1660,14 @@ class PrismaClient:
error_msg = f"LiteLLM Prisma Client Exception connect(): {str(e)}"
print_verbose(error_msg)
error_traceback = error_msg + "\n" + traceback.format_exc()
end_time = time.time()
_duration = end_time - start_time
asyncio.create_task(
self.proxy_logging_obj.failure_handler(
original_exception=e, traceback_str=error_traceback
original_exception=e,
duration=_duration,
call_type="connect",
traceback_str=error_traceback,
)
)
raise e
@ -1630,6 +1681,7 @@ class PrismaClient:
on_backoff=on_backoff, # specifying the function to call on backoff
)
async def disconnect(self):
start_time = time.time()
try:
await self.db.disconnect()
except Exception as e:
@ -1638,9 +1690,14 @@ class PrismaClient:
error_msg = f"LiteLLM Prisma Client Exception disconnect(): {str(e)}"
print_verbose(error_msg)
error_traceback = error_msg + "\n" + traceback.format_exc()
end_time = time.time()
_duration = end_time - start_time
asyncio.create_task(
self.proxy_logging_obj.failure_handler(
original_exception=e, traceback_str=error_traceback
original_exception=e,
duration=_duration,
call_type="disconnect",
traceback_str=error_traceback,
)
)
raise e
@ -1649,16 +1706,35 @@ class PrismaClient:
"""
Health check endpoint for the prisma client
"""
sql_query = """
SELECT 1
FROM "LiteLLM_VerificationToken"
LIMIT 1
"""
start_time = time.time()
try:
sql_query = """
SELECT 1
FROM "LiteLLM_VerificationToken"
LIMIT 1
"""
# Execute the raw query
# The asterisk before `user_id_list` unpacks the list into separate arguments
response = await self.db.query_raw(sql_query)
return response
# Execute the raw query
# The asterisk before `user_id_list` unpacks the list into separate arguments
response = await self.db.query_raw(sql_query)
return response
except Exception as e:
import traceback
error_msg = f"LiteLLM Prisma Client Exception disconnect(): {str(e)}"
print_verbose(error_msg)
error_traceback = error_msg + "\n" + traceback.format_exc()
end_time = time.time()
_duration = end_time - start_time
asyncio.create_task(
self.proxy_logging_obj.failure_handler(
original_exception=e,
duration=_duration,
call_type="health_check",
traceback_str=error_traceback,
)
)
raise e
class DBClient:
@ -2034,6 +2110,7 @@ async def update_spend(
### UPDATE USER TABLE ###
if len(prisma_client.user_list_transactons.keys()) > 0:
for i in range(n_retry_times + 1):
start_time = time.time()
try:
async with prisma_client.db.tx(
timeout=timedelta(seconds=60)
@ -2064,9 +2141,14 @@ async def update_spend(
)
print_verbose(error_msg)
error_traceback = error_msg + "\n" + traceback.format_exc()
end_time = time.time()
_duration = end_time - start_time
asyncio.create_task(
proxy_logging_obj.failure_handler(
original_exception=e, traceback_str=error_traceback
original_exception=e,
duration=_duration,
call_type="update_spend",
traceback_str=error_traceback,
)
)
raise e
@ -2074,6 +2156,7 @@ async def update_spend(
### UPDATE END-USER TABLE ###
if len(prisma_client.end_user_list_transactons.keys()) > 0:
for i in range(n_retry_times + 1):
start_time = time.time()
try:
async with prisma_client.db.tx(
timeout=timedelta(seconds=60)
@ -2110,9 +2193,14 @@ async def update_spend(
)
print_verbose(error_msg)
error_traceback = error_msg + "\n" + traceback.format_exc()
end_time = time.time()
_duration = end_time - start_time
asyncio.create_task(
proxy_logging_obj.failure_handler(
original_exception=e, traceback_str=error_traceback
original_exception=e,
duration=_duration,
call_type="update_spend",
traceback_str=error_traceback,
)
)
raise e
@ -2120,6 +2208,7 @@ async def update_spend(
### UPDATE KEY TABLE ###
if len(prisma_client.key_list_transactons.keys()) > 0:
for i in range(n_retry_times + 1):
start_time = time.time()
try:
async with prisma_client.db.tx(
timeout=timedelta(seconds=60)
@ -2150,9 +2239,14 @@ async def update_spend(
)
print_verbose(error_msg)
error_traceback = error_msg + "\n" + traceback.format_exc()
end_time = time.time()
_duration = end_time - start_time
asyncio.create_task(
proxy_logging_obj.failure_handler(
original_exception=e, traceback_str=error_traceback
original_exception=e,
duration=_duration,
call_type="update_spend",
traceback_str=error_traceback,
)
)
raise e
@ -2165,6 +2259,7 @@ async def update_spend(
)
if len(prisma_client.team_list_transactons.keys()) > 0:
for i in range(n_retry_times + 1):
start_time = time.time()
try:
async with prisma_client.db.tx(
timeout=timedelta(seconds=60)
@ -2200,9 +2295,14 @@ async def update_spend(
)
print_verbose(error_msg)
error_traceback = error_msg + "\n" + traceback.format_exc()
end_time = time.time()
_duration = end_time - start_time
asyncio.create_task(
proxy_logging_obj.failure_handler(
original_exception=e, traceback_str=error_traceback
original_exception=e,
duration=_duration,
call_type="update_spend",
traceback_str=error_traceback,
)
)
raise e
@ -2210,6 +2310,7 @@ async def update_spend(
### UPDATE ORG TABLE ###
if len(prisma_client.org_list_transactons.keys()) > 0:
for i in range(n_retry_times + 1):
start_time = time.time()
try:
async with prisma_client.db.tx(
timeout=timedelta(seconds=60)
@ -2240,9 +2341,14 @@ async def update_spend(
)
print_verbose(error_msg)
error_traceback = error_msg + "\n" + traceback.format_exc()
end_time = time.time()
_duration = end_time - start_time
asyncio.create_task(
proxy_logging_obj.failure_handler(
original_exception=e, traceback_str=error_traceback
original_exception=e,
duration=_duration,
call_type="update_spend",
traceback_str=error_traceback,
)
)
raise e
@ -2257,6 +2363,7 @@ async def update_spend(
if len(prisma_client.spend_log_transactions) > 0:
for _ in range(n_retry_times + 1):
start_time = time.time()
try:
base_url = os.getenv("SPEND_LOGS_URL", None)
## WRITE TO SEPARATE SERVER ##
@ -2322,9 +2429,14 @@ async def update_spend(
)
print_verbose(error_msg)
error_traceback = error_msg + "\n" + traceback.format_exc()
end_time = time.time()
_duration = end_time - start_time
asyncio.create_task(
proxy_logging_obj.failure_handler(
original_exception=e, traceback_str=error_traceback
original_exception=e,
duration=_duration,
call_type="update_spend",
traceback_str=error_traceback,
)
)
raise e