feat(prometheus_services.py): emit proxy latency for successful llm api requests

uses prometheus histogram for this
This commit is contained in:
Krrish Dholakia 2024-04-18 16:04:35 -07:00
parent df70e75ee1
commit 0f95a824c4
6 changed files with 87 additions and 20 deletions

View file

@ -1,9 +1,12 @@
import litellm import litellm, traceback
from litellm.proxy._types import UserAPIKeyAuth
from .types.services import ServiceTypes, ServiceLoggerPayload from .types.services import ServiceTypes, ServiceLoggerPayload
from .integrations.prometheus_services import PrometheusServicesLogger from .integrations.prometheus_services import PrometheusServicesLogger
from .integrations.custom_logger import CustomLogger
from datetime import timedelta
class ServiceLogging: class ServiceLogging(CustomLogger):
""" """
Separate class used for monitoring health of litellm-adjacent services (redis/postgres). Separate class used for monitoring health of litellm-adjacent services (redis/postgres).
""" """
@ -14,7 +17,6 @@ class ServiceLogging:
self.mock_testing_async_success_hook = 0 self.mock_testing_async_success_hook = 0
self.mock_testing_sync_failure_hook = 0 self.mock_testing_sync_failure_hook = 0
self.mock_testing_async_failure_hook = 0 self.mock_testing_async_failure_hook = 0
if "prometheus_system" in litellm.service_callback: if "prometheus_system" in litellm.service_callback:
self.prometheusServicesLogger = PrometheusServicesLogger() self.prometheusServicesLogger = PrometheusServicesLogger()
@ -34,7 +36,9 @@ class ServiceLogging:
if self.mock_testing: if self.mock_testing:
self.mock_testing_sync_failure_hook += 1 self.mock_testing_sync_failure_hook += 1
async def async_service_success_hook(self, service: ServiceTypes, duration: float): async def async_service_success_hook(
self, service: ServiceTypes, duration: float, call_type: str
):
""" """
- For counting if the redis, postgres call is successful - For counting if the redis, postgres call is successful
""" """
@ -42,7 +46,11 @@ class ServiceLogging:
self.mock_testing_async_success_hook += 1 self.mock_testing_async_success_hook += 1
payload = ServiceLoggerPayload( payload = ServiceLoggerPayload(
is_error=False, error=None, service=service, duration=duration is_error=False,
error=None,
service=service,
duration=duration,
call_type=call_type,
) )
for callback in litellm.service_callback: for callback in litellm.service_callback:
if callback == "prometheus_system": if callback == "prometheus_system":
@ -51,7 +59,7 @@ class ServiceLogging:
) )
async def async_service_failure_hook( async def async_service_failure_hook(
self, service: ServiceTypes, duration: float, error: Exception self, service: ServiceTypes, duration: float, error: Exception, call_type: str
): ):
""" """
- For counting if the redis, postgres call is unsuccessful - For counting if the redis, postgres call is unsuccessful
@ -60,7 +68,11 @@ class ServiceLogging:
self.mock_testing_async_failure_hook += 1 self.mock_testing_async_failure_hook += 1
payload = ServiceLoggerPayload( payload = ServiceLoggerPayload(
is_error=True, error=str(error), service=service, duration=duration is_error=True,
error=str(error),
service=service,
duration=duration,
call_type=call_type,
) )
for callback in litellm.service_callback: for callback in litellm.service_callback:
if callback == "prometheus_system": if callback == "prometheus_system":
@ -69,3 +81,37 @@ class ServiceLogging:
await self.prometheusServicesLogger.async_service_failure_hook( await self.prometheusServicesLogger.async_service_failure_hook(
payload=payload payload=payload
) )
async def async_post_call_failure_hook(
self, original_exception: Exception, user_api_key_dict: UserAPIKeyAuth
):
"""
Hook to track failed litellm-service calls
"""
return await super().async_post_call_failure_hook(
original_exception, user_api_key_dict
)
async def async_log_success_event(self, kwargs, response_obj, start_time, end_time):
"""
Hook to track latency for litellm proxy llm api calls
"""
try:
_duration = end_time - start_time
if isinstance(_duration, timedelta):
_duration = _duration.total_seconds()
elif isinstance(_duration, float):
pass
else:
raise Exception(
"Duration={} is not a float or timedelta object. type={}".format(
_duration, type(_duration)
)
) # invalid _duration value
await self.async_service_success_hook(
service=ServiceTypes.LITELLM,
duration=_duration,
call_type=kwargs["call_type"],
)
except Exception as e:
raise e

View file

@ -13,7 +13,6 @@ import json, traceback, ast, hashlib
from typing import Optional, Literal, List, Union, Any, BinaryIO from typing import Optional, Literal, List, Union, Any, BinaryIO
from openai._models import BaseModel as OpenAIObject from openai._models import BaseModel as OpenAIObject
from litellm._logging import verbose_logger from litellm._logging import verbose_logger
from litellm._service_logger import ServiceLogging
from litellm.types.services import ServiceLoggerPayload, ServiceTypes from litellm.types.services import ServiceLoggerPayload, ServiceTypes
import traceback import traceback
@ -132,6 +131,7 @@ class RedisCache(BaseCache):
**kwargs, **kwargs,
): ):
from ._redis import get_redis_client, get_redis_connection_pool from ._redis import get_redis_client, get_redis_connection_pool
from litellm._service_logger import ServiceLogging
import redis import redis
redis_kwargs = {} redis_kwargs = {}
@ -216,7 +216,9 @@ class RedisCache(BaseCache):
_duration = end_time - start_time _duration = end_time - start_time
asyncio.create_task( asyncio.create_task(
self.service_logger_obj.async_service_success_hook( self.service_logger_obj.async_service_success_hook(
service=ServiceTypes.REDIS, duration=_duration service=ServiceTypes.REDIS,
duration=_duration,
call_type="async_scan_iter",
) )
) # DO NOT SLOW DOWN CALL B/C OF THIS ) # DO NOT SLOW DOWN CALL B/C OF THIS
return keys return keys
@ -227,7 +229,10 @@ class RedisCache(BaseCache):
_duration = end_time - start_time _duration = end_time - start_time
asyncio.create_task( asyncio.create_task(
self.service_logger_obj.async_service_failure_hook( self.service_logger_obj.async_service_failure_hook(
service=ServiceTypes.REDIS, duration=_duration, error=e service=ServiceTypes.REDIS,
duration=_duration,
error=e,
call_type="async_scan_iter",
) )
) )
raise e raise e
@ -359,6 +364,7 @@ class RedisCache(BaseCache):
self.service_logger_obj.async_service_success_hook( self.service_logger_obj.async_service_success_hook(
service=ServiceTypes.REDIS, service=ServiceTypes.REDIS,
duration=_duration, duration=_duration,
call_type="async_increment",
) )
) )
return result return result
@ -368,7 +374,10 @@ class RedisCache(BaseCache):
_duration = end_time - start_time _duration = end_time - start_time
asyncio.create_task( asyncio.create_task(
self.service_logger_obj.async_service_failure_hook( self.service_logger_obj.async_service_failure_hook(
service=ServiceTypes.REDIS, duration=_duration, error=e service=ServiceTypes.REDIS,
duration=_duration,
error=e,
call_type="async_increment",
) )
) )
verbose_logger.error( verbose_logger.error(
@ -497,7 +506,9 @@ class RedisCache(BaseCache):
_duration = end_time - start_time _duration = end_time - start_time
asyncio.create_task( asyncio.create_task(
self.service_logger_obj.async_service_success_hook( self.service_logger_obj.async_service_success_hook(
service=ServiceTypes.REDIS, duration=_duration service=ServiceTypes.REDIS,
duration=_duration,
call_type="async_batch_get_cache",
) )
) )
@ -519,7 +530,10 @@ class RedisCache(BaseCache):
_duration = end_time - start_time _duration = end_time - start_time
asyncio.create_task( asyncio.create_task(
self.service_logger_obj.async_service_failure_hook( self.service_logger_obj.async_service_failure_hook(
service=ServiceTypes.REDIS, duration=_duration, error=e service=ServiceTypes.REDIS,
duration=_duration,
error=e,
call_type="async_batch_get_cache",
) )
) )
print_verbose(f"Error occurred in pipeline read - {str(e)}") print_verbose(f"Error occurred in pipeline read - {str(e)}")

View file

@ -30,6 +30,7 @@ class PrometheusServicesLogger:
raise Exception( raise Exception(
"Missing prometheus_client. Run `pip install prometheus-client`" "Missing prometheus_client. Run `pip install prometheus-client`"
) )
print("INITIALIZES PROMETHEUS SERVICE LOGGER!")
self.Histogram = Histogram self.Histogram = Histogram
self.Counter = Counter self.Counter = Counter
@ -151,6 +152,7 @@ class PrometheusServicesLogger:
if self.mock_testing: if self.mock_testing:
self.mock_testing_success_calls += 1 self.mock_testing_success_calls += 1
print(f"LOGS SUCCESSFUL CALL TO PROMETHEUS - payload={payload}")
if payload.service.value in self.payload_to_prometheus_map: if payload.service.value in self.payload_to_prometheus_map:
prom_objects = self.payload_to_prometheus_map[payload.service.value] prom_objects = self.payload_to_prometheus_map[payload.service.value]
for obj in prom_objects: for obj in prom_objects:

View file

@ -31,12 +31,12 @@ litellm_settings:
upperbound_key_generate_params: upperbound_key_generate_params:
max_budget: os.environ/LITELLM_UPPERBOUND_KEYS_MAX_BUDGET max_budget: os.environ/LITELLM_UPPERBOUND_KEYS_MAX_BUDGET
router_settings: # router_settings:
routing_strategy: usage-based-routing-v2 # routing_strategy: usage-based-routing-v2
redis_host: os.environ/REDIS_HOST # redis_host: os.environ/REDIS_HOST
redis_password: os.environ/REDIS_PASSWORD # redis_password: os.environ/REDIS_PASSWORD
redis_port: os.environ/REDIS_PORT # redis_port: os.environ/REDIS_PORT
enable_pre_call_checks: True # enable_pre_call_checks: True
general_settings: general_settings:
master_key: sk-1234 master_key: sk-1234

View file

@ -18,6 +18,7 @@ from litellm.llms.custom_httpx.httpx_handler import HTTPHandler
from litellm.proxy.hooks.parallel_request_limiter import ( from litellm.proxy.hooks.parallel_request_limiter import (
_PROXY_MaxParallelRequestsHandler, _PROXY_MaxParallelRequestsHandler,
) )
from litellm._service_logger import ServiceLogging
from litellm import ModelResponse, EmbeddingResponse, ImageResponse from litellm import ModelResponse, EmbeddingResponse, ImageResponse
from litellm.proxy.hooks.max_budget_limiter import _PROXY_MaxBudgetLimiter from litellm.proxy.hooks.max_budget_limiter import _PROXY_MaxBudgetLimiter
from litellm.proxy.hooks.tpm_rpm_limiter import _PROXY_MaxTPMRPMLimiter from litellm.proxy.hooks.tpm_rpm_limiter import _PROXY_MaxTPMRPMLimiter
@ -80,10 +81,12 @@ class ProxyLogging:
def _init_litellm_callbacks(self): def _init_litellm_callbacks(self):
print_verbose(f"INITIALIZING LITELLM CALLBACKS!") print_verbose(f"INITIALIZING LITELLM CALLBACKS!")
self.service_logging_obj = ServiceLogging()
litellm.callbacks.append(self.max_parallel_request_limiter) litellm.callbacks.append(self.max_parallel_request_limiter)
litellm.callbacks.append(self.max_tpm_rpm_limiter) litellm.callbacks.append(self.max_tpm_rpm_limiter)
litellm.callbacks.append(self.max_budget_limiter) litellm.callbacks.append(self.max_budget_limiter)
litellm.callbacks.append(self.cache_control_check) litellm.callbacks.append(self.cache_control_check)
litellm.callbacks.append(self.service_logging_obj)
litellm.success_callback.append(self.response_taking_too_long_callback) litellm.success_callback.append(self.response_taking_too_long_callback)
for callback in litellm.callbacks: for callback in litellm.callbacks:
if callback not in litellm.input_callback: if callback not in litellm.input_callback:

View file

@ -5,11 +5,12 @@ from typing import Optional
class ServiceTypes(enum.Enum): class ServiceTypes(enum.Enum):
""" """
Enum for litellm-adjacent services (redis/postgres/etc.) Enum for litellm + litellm-adjacent services (redis/postgres/etc.)
""" """
REDIS = "redis" REDIS = "redis"
DB = "postgres" DB = "postgres"
LITELLM = "self"
class ServiceLoggerPayload(BaseModel): class ServiceLoggerPayload(BaseModel):
@ -21,6 +22,7 @@ class ServiceLoggerPayload(BaseModel):
error: Optional[str] = Field(None, description="what was the error") error: Optional[str] = Field(None, description="what was the error")
service: ServiceTypes = Field(description="who is this for? - postgres/redis") service: ServiceTypes = Field(description="who is this for? - postgres/redis")
duration: float = Field(description="How long did the request take?") duration: float = Field(description="How long did the request take?")
call_type: str = Field(description="The call of the service, being made")
def to_json(self, **kwargs): def to_json(self, **kwargs):
try: try: