Merge pull request #3144 from BerriAI/litellm_prometheus_latency_tracking

feat(prometheus_services.py): emit proxy latency for successful llm api requests
This commit is contained in:
Krish Dholakia 2024-04-18 19:10:58 -07:00 committed by GitHub
commit 77a353d484
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
8 changed files with 255 additions and 51 deletions

View file

@ -14,6 +14,7 @@ model_list:
model: gpt-3.5-turbo
litellm_settings:
success_callback: ["prometheus"]
failure_callback: ["prometheus"]
```
Start the proxy
@ -70,3 +71,4 @@ litellm_settings:
|----------------------|--------------------------------------|
| `litellm_redis_latency` | histogram latency for redis calls |
| `litellm_redis_fails` | Number of failed redis calls |
| `litellm_self_latency` | Histogram latency for successful litellm api call |

View file

@ -1,9 +1,13 @@
import litellm
import litellm, traceback
from litellm.proxy._types import UserAPIKeyAuth
from .types.services import ServiceTypes, ServiceLoggerPayload
from .integrations.prometheus_services import PrometheusServicesLogger
from .integrations.custom_logger import CustomLogger
from datetime import timedelta
from typing import Union
class ServiceLogging:
class ServiceLogging(CustomLogger):
"""
Separate class used for monitoring health of litellm-adjacent services (redis/postgres).
"""
@ -14,7 +18,6 @@ class ServiceLogging:
self.mock_testing_async_success_hook = 0
self.mock_testing_sync_failure_hook = 0
self.mock_testing_async_failure_hook = 0
if "prometheus_system" in litellm.service_callback:
self.prometheusServicesLogger = PrometheusServicesLogger()
@ -34,7 +37,9 @@ class ServiceLogging:
if self.mock_testing:
self.mock_testing_sync_failure_hook += 1
async def async_service_success_hook(self, service: ServiceTypes, duration: float):
async def async_service_success_hook(
self, service: ServiceTypes, duration: float, call_type: str
):
"""
- For counting if the redis, postgres call is successful
"""
@ -42,7 +47,11 @@ class ServiceLogging:
self.mock_testing_async_success_hook += 1
payload = ServiceLoggerPayload(
is_error=False, error=None, service=service, duration=duration
is_error=False,
error=None,
service=service,
duration=duration,
call_type=call_type,
)
for callback in litellm.service_callback:
if callback == "prometheus_system":
@ -51,7 +60,11 @@ class ServiceLogging:
)
async def async_service_failure_hook(
self, service: ServiceTypes, duration: float, error: Exception
self,
service: ServiceTypes,
duration: float,
error: Union[str, Exception],
call_type: str,
):
"""
- For counting if the redis, postgres call is unsuccessful
@ -59,8 +72,18 @@ class ServiceLogging:
if self.mock_testing:
self.mock_testing_async_failure_hook += 1
error_message = ""
if isinstance(error, Exception):
error_message = str(error)
elif isinstance(error, str):
error_message = error
payload = ServiceLoggerPayload(
is_error=True, error=str(error), service=service, duration=duration
is_error=True,
error=error_message,
service=service,
duration=duration,
call_type=call_type,
)
for callback in litellm.service_callback:
if callback == "prometheus_system":
@ -69,3 +92,37 @@ class ServiceLogging:
await self.prometheusServicesLogger.async_service_failure_hook(
payload=payload
)
async def async_post_call_failure_hook(
self, original_exception: Exception, user_api_key_dict: UserAPIKeyAuth
):
"""
Hook to track failed litellm-service calls
"""
return await super().async_post_call_failure_hook(
original_exception, user_api_key_dict
)
async def async_log_success_event(self, kwargs, response_obj, start_time, end_time):
"""
Hook to track latency for litellm proxy llm api calls
"""
try:
_duration = end_time - start_time
if isinstance(_duration, timedelta):
_duration = _duration.total_seconds()
elif isinstance(_duration, float):
pass
else:
raise Exception(
"Duration={} is not a float or timedelta object. type={}".format(
_duration, type(_duration)
)
) # invalid _duration value
await self.async_service_success_hook(
service=ServiceTypes.LITELLM,
duration=_duration,
call_type=kwargs["call_type"],
)
except Exception as e:
raise e

View file

@ -13,7 +13,6 @@ import json, traceback, ast, hashlib
from typing import Optional, Literal, List, Union, Any, BinaryIO
from openai._models import BaseModel as OpenAIObject
from litellm._logging import verbose_logger
from litellm._service_logger import ServiceLogging
from litellm.types.services import ServiceLoggerPayload, ServiceTypes
import traceback
@ -132,6 +131,7 @@ class RedisCache(BaseCache):
**kwargs,
):
from ._redis import get_redis_client, get_redis_connection_pool
from litellm._service_logger import ServiceLogging
import redis
redis_kwargs = {}
@ -216,7 +216,9 @@ class RedisCache(BaseCache):
_duration = end_time - start_time
asyncio.create_task(
self.service_logger_obj.async_service_success_hook(
service=ServiceTypes.REDIS, duration=_duration
service=ServiceTypes.REDIS,
duration=_duration,
call_type="async_scan_iter",
)
) # DO NOT SLOW DOWN CALL B/C OF THIS
return keys
@ -227,7 +229,10 @@ class RedisCache(BaseCache):
_duration = end_time - start_time
asyncio.create_task(
self.service_logger_obj.async_service_failure_hook(
service=ServiceTypes.REDIS, duration=_duration, error=e
service=ServiceTypes.REDIS,
duration=_duration,
error=e,
call_type="async_scan_iter",
)
)
raise e
@ -267,7 +272,9 @@ class RedisCache(BaseCache):
_duration = end_time - start_time
asyncio.create_task(
self.service_logger_obj.async_service_success_hook(
service=ServiceTypes.REDIS, duration=_duration
service=ServiceTypes.REDIS,
duration=_duration,
call_type="async_set_cache",
)
)
except Exception as e:
@ -275,7 +282,10 @@ class RedisCache(BaseCache):
_duration = end_time - start_time
asyncio.create_task(
self.service_logger_obj.async_service_failure_hook(
service=ServiceTypes.REDIS, duration=_duration, error=e
service=ServiceTypes.REDIS,
duration=_duration,
error=e,
call_type="async_set_cache",
)
)
# NON blocking - notify users Redis is throwing an exception
@ -316,7 +326,9 @@ class RedisCache(BaseCache):
_duration = end_time - start_time
asyncio.create_task(
self.service_logger_obj.async_service_success_hook(
service=ServiceTypes.REDIS, duration=_duration
service=ServiceTypes.REDIS,
duration=_duration,
call_type="async_set_cache_pipeline",
)
)
return results
@ -326,7 +338,10 @@ class RedisCache(BaseCache):
_duration = end_time - start_time
asyncio.create_task(
self.service_logger_obj.async_service_failure_hook(
service=ServiceTypes.REDIS, duration=_duration, error=e
service=ServiceTypes.REDIS,
duration=_duration,
error=e,
call_type="async_set_cache_pipeline",
)
)
@ -359,6 +374,7 @@ class RedisCache(BaseCache):
self.service_logger_obj.async_service_success_hook(
service=ServiceTypes.REDIS,
duration=_duration,
call_type="async_increment",
)
)
return result
@ -368,7 +384,10 @@ class RedisCache(BaseCache):
_duration = end_time - start_time
asyncio.create_task(
self.service_logger_obj.async_service_failure_hook(
service=ServiceTypes.REDIS, duration=_duration, error=e
service=ServiceTypes.REDIS,
duration=_duration,
error=e,
call_type="async_increment",
)
)
verbose_logger.error(
@ -459,7 +478,9 @@ class RedisCache(BaseCache):
_duration = end_time - start_time
asyncio.create_task(
self.service_logger_obj.async_service_success_hook(
service=ServiceTypes.REDIS, duration=_duration
service=ServiceTypes.REDIS,
duration=_duration,
call_type="async_get_cache",
)
)
return response
@ -469,7 +490,10 @@ class RedisCache(BaseCache):
_duration = end_time - start_time
asyncio.create_task(
self.service_logger_obj.async_service_failure_hook(
service=ServiceTypes.REDIS, duration=_duration, error=e
service=ServiceTypes.REDIS,
duration=_duration,
error=e,
call_type="async_get_cache",
)
)
# NON blocking - notify users Redis is throwing an exception
@ -497,7 +521,9 @@ class RedisCache(BaseCache):
_duration = end_time - start_time
asyncio.create_task(
self.service_logger_obj.async_service_success_hook(
service=ServiceTypes.REDIS, duration=_duration
service=ServiceTypes.REDIS,
duration=_duration,
call_type="async_batch_get_cache",
)
)
@ -519,7 +545,10 @@ class RedisCache(BaseCache):
_duration = end_time - start_time
asyncio.create_task(
self.service_logger_obj.async_service_failure_hook(
service=ServiceTypes.REDIS, duration=_duration, error=e
service=ServiceTypes.REDIS,
duration=_duration,
error=e,
call_type="async_batch_get_cache",
)
)
print_verbose(f"Error occurred in pipeline read - {str(e)}")

View file

@ -30,6 +30,7 @@ class PrometheusServicesLogger:
raise Exception(
"Missing prometheus_client. Run `pip install prometheus-client`"
)
print("INITIALIZES PROMETHEUS SERVICE LOGGER!")
self.Histogram = Histogram
self.Counter = Counter
@ -151,6 +152,7 @@ class PrometheusServicesLogger:
if self.mock_testing:
self.mock_testing_success_calls += 1
print(f"LOGS SUCCESSFUL CALL TO PROMETHEUS - payload={payload}")
if payload.service.value in self.payload_to_prometheus_map:
prom_objects = self.payload_to_prometheus_map[payload.service.value]
for obj in prom_objects:

View file

@ -31,12 +31,12 @@ litellm_settings:
upperbound_key_generate_params:
max_budget: os.environ/LITELLM_UPPERBOUND_KEYS_MAX_BUDGET
router_settings:
routing_strategy: usage-based-routing-v2
redis_host: os.environ/REDIS_HOST
redis_password: os.environ/REDIS_PASSWORD
redis_port: os.environ/REDIS_PORT
enable_pre_call_checks: True
# router_settings:
# routing_strategy: usage-based-routing-v2
# redis_host: os.environ/REDIS_HOST
# redis_password: os.environ/REDIS_PASSWORD
# redis_port: os.environ/REDIS_PORT
# enable_pre_call_checks: True
general_settings:
master_key: sk-1234

View file

@ -1,5 +1,5 @@
from typing import Optional, List, Any, Literal, Union
import os, subprocess, hashlib, importlib, asyncio, copy, json, aiohttp, httpx
import os, subprocess, hashlib, importlib, asyncio, copy, json, aiohttp, httpx, time
import litellm, backoff
from litellm.proxy._types import (
UserAPIKeyAuth,
@ -18,6 +18,7 @@ from litellm.llms.custom_httpx.httpx_handler import HTTPHandler
from litellm.proxy.hooks.parallel_request_limiter import (
_PROXY_MaxParallelRequestsHandler,
)
from litellm._service_logger import ServiceLogging, ServiceTypes
from litellm import ModelResponse, EmbeddingResponse, ImageResponse
from litellm.proxy.hooks.max_budget_limiter import _PROXY_MaxBudgetLimiter
from litellm.proxy.hooks.tpm_rpm_limiter import _PROXY_MaxTPMRPMLimiter
@ -109,10 +110,12 @@ class ProxyLogging:
def _init_litellm_callbacks(self):
print_verbose(f"INITIALIZING LITELLM CALLBACKS!")
self.service_logging_obj = ServiceLogging()
litellm.callbacks.append(self.max_parallel_request_limiter)
litellm.callbacks.append(self.max_tpm_rpm_limiter)
litellm.callbacks.append(self.max_budget_limiter)
litellm.callbacks.append(self.cache_control_check)
litellm.callbacks.append(self.service_logging_obj)
litellm.success_callback.append(self.response_taking_too_long_callback)
for callback in litellm.callbacks:
if callback not in litellm.input_callback:
@ -493,7 +496,9 @@ class ProxyLogging:
else:
raise Exception("Missing SENTRY_DSN from environment")
async def failure_handler(self, original_exception, traceback_str=""):
async def failure_handler(
self, original_exception, duration: float, call_type: str, traceback_str=""
):
"""
Log failed db read/writes
@ -520,6 +525,14 @@ class ProxyLogging:
)
)
if hasattr(self, "service_logging_obj"):
self.service_logging_obj.async_service_failure_hook(
service=ServiceTypes.DB,
duration=duration,
error=error_message,
call_type=call_type,
)
if litellm.utils.capture_exception:
litellm.utils.capture_exception(error=original_exception)
@ -842,6 +855,7 @@ class PrismaClient:
verbose_proxy_logger.debug(
f"PrismaClient: get_generic_data: {key}, table_name: {table_name}"
)
start_time = time.time()
try:
if table_name == "users":
response = await self.db.litellm_usertable.find_first(
@ -866,11 +880,17 @@ class PrismaClient:
error_msg = f"LiteLLM Prisma Client Exception get_generic_data: {str(e)}"
print_verbose(error_msg)
error_traceback = error_msg + "\n" + traceback.format_exc()
end_time = time.time()
_duration = end_time - start_time
asyncio.create_task(
self.proxy_logging_obj.failure_handler(
original_exception=e, traceback_str=error_traceback
original_exception=e,
duration=_duration,
traceback_str=error_traceback,
call_type="get_generic_data",
)
)
raise e
@backoff.on_exception(
@ -908,6 +928,7 @@ class PrismaClient:
] = None, # pagination, number of rows to getch when find_all==True
):
args_passed_in = locals()
start_time = time.time()
verbose_proxy_logger.debug(
f"PrismaClient: get_data - args_passed_in: {args_passed_in}"
)
@ -1173,9 +1194,15 @@ class PrismaClient:
print_verbose(error_msg)
error_traceback = error_msg + "\n" + traceback.format_exc()
verbose_proxy_logger.debug(error_traceback)
end_time = time.time()
_duration = end_time - start_time
asyncio.create_task(
self.proxy_logging_obj.failure_handler(
original_exception=e, traceback_str=error_traceback
original_exception=e,
duration=_duration,
call_type="get_data",
traceback_str=error_traceback,
)
)
raise e
@ -1198,6 +1225,7 @@ class PrismaClient:
"""
Add a key to the database. If it already exists, do nothing.
"""
start_time = time.time()
try:
verbose_proxy_logger.debug("PrismaClient: insert_data: %s", data)
if table_name == "key":
@ -1315,9 +1343,14 @@ class PrismaClient:
error_msg = f"LiteLLM Prisma Client Exception in insert_data: {str(e)}"
print_verbose(error_msg)
error_traceback = error_msg + "\n" + traceback.format_exc()
end_time = time.time()
_duration = end_time - start_time
asyncio.create_task(
self.proxy_logging_obj.failure_handler(
original_exception=e, traceback_str=error_traceback
original_exception=e,
duration=_duration,
call_type="insert_data",
traceback_str=error_traceback,
)
)
raise e
@ -1348,6 +1381,7 @@ class PrismaClient:
verbose_proxy_logger.debug(
f"PrismaClient: update_data, table_name: {table_name}"
)
start_time = time.time()
try:
db_data = self.jsonify_object(data=data)
if update_key_values is not None:
@ -1509,9 +1543,14 @@ class PrismaClient:
error_msg = f"LiteLLM Prisma Client Exception - update_data: {str(e)}"
print_verbose(error_msg)
error_traceback = error_msg + "\n" + traceback.format_exc()
end_time = time.time()
_duration = end_time - start_time
asyncio.create_task(
self.proxy_logging_obj.failure_handler(
original_exception=e, traceback_str=error_traceback
original_exception=e,
duration=_duration,
call_type="update_data",
traceback_str=error_traceback,
)
)
raise e
@ -1536,6 +1575,7 @@ class PrismaClient:
Ensure user owns that key, unless admin.
"""
start_time = time.time()
try:
if tokens is not None and isinstance(tokens, List):
hashed_tokens = []
@ -1583,9 +1623,14 @@ class PrismaClient:
error_msg = f"LiteLLM Prisma Client Exception - delete_data: {str(e)}"
print_verbose(error_msg)
error_traceback = error_msg + "\n" + traceback.format_exc()
end_time = time.time()
_duration = end_time - start_time
asyncio.create_task(
self.proxy_logging_obj.failure_handler(
original_exception=e, traceback_str=error_traceback
original_exception=e,
duration=_duration,
call_type="delete_data",
traceback_str=error_traceback,
)
)
raise e
@ -1599,6 +1644,7 @@ class PrismaClient:
on_backoff=on_backoff, # specifying the function to call on backoff
)
async def connect(self):
start_time = time.time()
try:
verbose_proxy_logger.debug(
"PrismaClient: connect() called Attempting to Connect to DB"
@ -1614,9 +1660,14 @@ class PrismaClient:
error_msg = f"LiteLLM Prisma Client Exception connect(): {str(e)}"
print_verbose(error_msg)
error_traceback = error_msg + "\n" + traceback.format_exc()
end_time = time.time()
_duration = end_time - start_time
asyncio.create_task(
self.proxy_logging_obj.failure_handler(
original_exception=e, traceback_str=error_traceback
original_exception=e,
duration=_duration,
call_type="connect",
traceback_str=error_traceback,
)
)
raise e
@ -1630,6 +1681,7 @@ class PrismaClient:
on_backoff=on_backoff, # specifying the function to call on backoff
)
async def disconnect(self):
start_time = time.time()
try:
await self.db.disconnect()
except Exception as e:
@ -1638,9 +1690,14 @@ class PrismaClient:
error_msg = f"LiteLLM Prisma Client Exception disconnect(): {str(e)}"
print_verbose(error_msg)
error_traceback = error_msg + "\n" + traceback.format_exc()
end_time = time.time()
_duration = end_time - start_time
asyncio.create_task(
self.proxy_logging_obj.failure_handler(
original_exception=e, traceback_str=error_traceback
original_exception=e,
duration=_duration,
call_type="disconnect",
traceback_str=error_traceback,
)
)
raise e
@ -1649,6 +1706,8 @@ class PrismaClient:
"""
Health check endpoint for the prisma client
"""
start_time = time.time()
try:
sql_query = """
SELECT 1
FROM "LiteLLM_VerificationToken"
@ -1659,6 +1718,23 @@ class PrismaClient:
# The asterisk before `user_id_list` unpacks the list into separate arguments
response = await self.db.query_raw(sql_query)
return response
except Exception as e:
import traceback
error_msg = f"LiteLLM Prisma Client Exception disconnect(): {str(e)}"
print_verbose(error_msg)
error_traceback = error_msg + "\n" + traceback.format_exc()
end_time = time.time()
_duration = end_time - start_time
asyncio.create_task(
self.proxy_logging_obj.failure_handler(
original_exception=e,
duration=_duration,
call_type="health_check",
traceback_str=error_traceback,
)
)
raise e
class DBClient:
@ -2034,6 +2110,7 @@ async def update_spend(
### UPDATE USER TABLE ###
if len(prisma_client.user_list_transactons.keys()) > 0:
for i in range(n_retry_times + 1):
start_time = time.time()
try:
async with prisma_client.db.tx(
timeout=timedelta(seconds=60)
@ -2064,9 +2141,14 @@ async def update_spend(
)
print_verbose(error_msg)
error_traceback = error_msg + "\n" + traceback.format_exc()
end_time = time.time()
_duration = end_time - start_time
asyncio.create_task(
proxy_logging_obj.failure_handler(
original_exception=e, traceback_str=error_traceback
original_exception=e,
duration=_duration,
call_type="update_spend",
traceback_str=error_traceback,
)
)
raise e
@ -2074,6 +2156,7 @@ async def update_spend(
### UPDATE END-USER TABLE ###
if len(prisma_client.end_user_list_transactons.keys()) > 0:
for i in range(n_retry_times + 1):
start_time = time.time()
try:
async with prisma_client.db.tx(
timeout=timedelta(seconds=60)
@ -2110,9 +2193,14 @@ async def update_spend(
)
print_verbose(error_msg)
error_traceback = error_msg + "\n" + traceback.format_exc()
end_time = time.time()
_duration = end_time - start_time
asyncio.create_task(
proxy_logging_obj.failure_handler(
original_exception=e, traceback_str=error_traceback
original_exception=e,
duration=_duration,
call_type="update_spend",
traceback_str=error_traceback,
)
)
raise e
@ -2120,6 +2208,7 @@ async def update_spend(
### UPDATE KEY TABLE ###
if len(prisma_client.key_list_transactons.keys()) > 0:
for i in range(n_retry_times + 1):
start_time = time.time()
try:
async with prisma_client.db.tx(
timeout=timedelta(seconds=60)
@ -2150,9 +2239,14 @@ async def update_spend(
)
print_verbose(error_msg)
error_traceback = error_msg + "\n" + traceback.format_exc()
end_time = time.time()
_duration = end_time - start_time
asyncio.create_task(
proxy_logging_obj.failure_handler(
original_exception=e, traceback_str=error_traceback
original_exception=e,
duration=_duration,
call_type="update_spend",
traceback_str=error_traceback,
)
)
raise e
@ -2165,6 +2259,7 @@ async def update_spend(
)
if len(prisma_client.team_list_transactons.keys()) > 0:
for i in range(n_retry_times + 1):
start_time = time.time()
try:
async with prisma_client.db.tx(
timeout=timedelta(seconds=60)
@ -2200,9 +2295,14 @@ async def update_spend(
)
print_verbose(error_msg)
error_traceback = error_msg + "\n" + traceback.format_exc()
end_time = time.time()
_duration = end_time - start_time
asyncio.create_task(
proxy_logging_obj.failure_handler(
original_exception=e, traceback_str=error_traceback
original_exception=e,
duration=_duration,
call_type="update_spend",
traceback_str=error_traceback,
)
)
raise e
@ -2210,6 +2310,7 @@ async def update_spend(
### UPDATE ORG TABLE ###
if len(prisma_client.org_list_transactons.keys()) > 0:
for i in range(n_retry_times + 1):
start_time = time.time()
try:
async with prisma_client.db.tx(
timeout=timedelta(seconds=60)
@ -2240,9 +2341,14 @@ async def update_spend(
)
print_verbose(error_msg)
error_traceback = error_msg + "\n" + traceback.format_exc()
end_time = time.time()
_duration = end_time - start_time
asyncio.create_task(
proxy_logging_obj.failure_handler(
original_exception=e, traceback_str=error_traceback
original_exception=e,
duration=_duration,
call_type="update_spend",
traceback_str=error_traceback,
)
)
raise e
@ -2257,6 +2363,7 @@ async def update_spend(
if len(prisma_client.spend_log_transactions) > 0:
for _ in range(n_retry_times + 1):
start_time = time.time()
try:
base_url = os.getenv("SPEND_LOGS_URL", None)
## WRITE TO SEPARATE SERVER ##
@ -2322,9 +2429,14 @@ async def update_spend(
)
print_verbose(error_msg)
error_traceback = error_msg + "\n" + traceback.format_exc()
end_time = time.time()
_duration = end_time - start_time
asyncio.create_task(
proxy_logging_obj.failure_handler(
original_exception=e, traceback_str=error_traceback
original_exception=e,
duration=_duration,
call_type="update_spend",
traceback_str=error_traceback,
)
)
raise e

View file

@ -412,7 +412,7 @@ async def test_cost_tracking_with_caching():
"""
from litellm import Cache
litellm.set_verbose = False
litellm.set_verbose = True
litellm.cache = Cache(
type="redis",
host=os.environ["REDIS_HOST"],

View file

@ -5,11 +5,12 @@ from typing import Optional
class ServiceTypes(enum.Enum):
"""
Enum for litellm-adjacent services (redis/postgres/etc.)
Enum for litellm + litellm-adjacent services (redis/postgres/etc.)
"""
REDIS = "redis"
DB = "postgres"
LITELLM = "self"
class ServiceLoggerPayload(BaseModel):
@ -21,6 +22,7 @@ class ServiceLoggerPayload(BaseModel):
error: Optional[str] = Field(None, description="what was the error")
service: ServiceTypes = Field(description="who is this for? - postgres/redis")
duration: float = Field(description="How long did the request take?")
call_type: str = Field(description="The call of the service, being made")
def to_json(self, **kwargs):
try: