diff --git a/litellm/__init__.py b/litellm/__init__.py index 091605148..055dd7424 100644 --- a/litellm/__init__.py +++ b/litellm/__init__.py @@ -19,6 +19,7 @@ if set_verbose == True: input_callback: List[Union[str, Callable]] = [] success_callback: List[Union[str, Callable]] = [] failure_callback: List[Union[str, Callable]] = [] +service_callback: List[Union[str, Callable]] = [] callbacks: List[Callable] = [] _async_input_callback: List[Callable] = ( [] diff --git a/litellm/_service_logger.py b/litellm/_service_logger.py new file mode 100644 index 000000000..814ec011a --- /dev/null +++ b/litellm/_service_logger.py @@ -0,0 +1,71 @@ +import litellm +from .types.services import ServiceTypes, ServiceLoggerPayload +from .integrations.prometheus_services import PrometheusServicesLogger + + +class ServiceLogging: + """ + Separate class used for monitoring health of litellm-adjacent services (redis/postgres). + """ + + def __init__(self, mock_testing: bool = False) -> None: + self.mock_testing = mock_testing + self.mock_testing_sync_success_hook = 0 + self.mock_testing_async_success_hook = 0 + self.mock_testing_sync_failure_hook = 0 + self.mock_testing_async_failure_hook = 0 + + if "prometheus_system" in litellm.service_callback: + self.prometheusServicesLogger = PrometheusServicesLogger() + + def service_success_hook(self, service: ServiceTypes, duration: float): + """ + [TODO] Not implemented for sync calls yet. V0 is focused on async monitoring (used by proxy). + """ + if self.mock_testing: + self.mock_testing_sync_success_hook += 1 + + def service_failure_hook( + self, service: ServiceTypes, duration: float, error: Exception + ): + """ + [TODO] Not implemented for sync calls yet. V0 is focused on async monitoring (used by proxy). + """ + if self.mock_testing: + self.mock_testing_sync_failure_hook += 1 + + async def async_service_success_hook(self, service: ServiceTypes, duration: float): + """ + - For counting if the redis, postgres call is successful + """ + if self.mock_testing: + self.mock_testing_async_success_hook += 1 + + payload = ServiceLoggerPayload( + is_error=False, error=None, service=service, duration=duration + ) + for callback in litellm.service_callback: + if callback == "prometheus_system": + await self.prometheusServicesLogger.async_service_success_hook( + payload=payload + ) + + async def async_service_failure_hook( + self, service: ServiceTypes, duration: float, error: Exception + ): + """ + - For counting if the redis, postgres call is unsuccessful + """ + if self.mock_testing: + self.mock_testing_async_failure_hook += 1 + + payload = ServiceLoggerPayload( + is_error=True, error=str(error), service=service, duration=duration + ) + for callback in litellm.service_callback: + if callback == "prometheus_system": + if self.prometheusServicesLogger is None: + self.prometheusServicesLogger = self.prometheusServicesLogger() + await self.prometheusServicesLogger.async_service_failure_hook( + payload=payload + ) diff --git a/litellm/caching.py b/litellm/caching.py index cdb98d790..dfc99367b 100644 --- a/litellm/caching.py +++ b/litellm/caching.py @@ -13,6 +13,8 @@ import json, traceback, ast, hashlib from typing import Optional, Literal, List, Union, Any, BinaryIO from openai._models import BaseModel as OpenAIObject from litellm._logging import verbose_logger +from litellm._service_logger import ServiceLogging +from litellm.types.services import ServiceLoggerPayload, ServiceTypes import traceback @@ -163,6 +165,9 @@ class RedisCache(BaseCache): except Exception as e: pass + ### HEALTH MONITORING OBJECT ### + self.service_logger_obj = ServiceLogging() + def init_async_client(self): from ._redis import get_redis_async_client @@ -194,17 +199,59 @@ class RedisCache(BaseCache): ) async def async_scan_iter(self, pattern: str, count: int = 100) -> list: - keys = [] - _redis_client = self.init_async_client() - async with _redis_client as redis_client: - async for key in redis_client.scan_iter(match=pattern + "*", count=count): - keys.append(key) - if len(keys) >= count: - break - return keys + start_time = time.time() + try: + keys = [] + _redis_client = self.init_async_client() + async with _redis_client as redis_client: + async for key in redis_client.scan_iter( + match=pattern + "*", count=count + ): + keys.append(key) + if len(keys) >= count: + break + + ## LOGGING ## + end_time = time.time() + _duration = end_time - start_time + asyncio.create_task( + self.service_logger_obj.async_service_success_hook( + service=ServiceTypes.REDIS, duration=_duration + ) + ) # DO NOT SLOW DOWN CALL B/C OF THIS + return keys + except Exception as e: + # NON blocking - notify users Redis is throwing an exception + ## LOGGING ## + end_time = time.time() + _duration = end_time - start_time + asyncio.create_task( + self.service_logger_obj.async_service_failure_hook( + service=ServiceTypes.REDIS, duration=_duration, error=e + ) + ) + raise e async def async_set_cache(self, key, value, **kwargs): - _redis_client = self.init_async_client() + start_time = time.time() + try: + _redis_client = self.init_async_client() + except Exception as e: + end_time = time.time() + _duration = end_time - start_time + asyncio.create_task( + self.service_logger_obj.async_service_failure_hook( + service=ServiceTypes.REDIS, duration=_duration, error=e + ) + ) + # NON blocking - notify users Redis is throwing an exception + verbose_logger.error( + "LiteLLM Redis Caching: async set() - Got exception from REDIS %s, Writing value=%s", + str(e), + value, + ) + traceback.print_exc() + key = self.check_and_fix_namespace(key=key) async with _redis_client as redis_client: ttl = kwargs.get("ttl", None) @@ -216,7 +263,21 @@ class RedisCache(BaseCache): print_verbose( f"Successfully Set ASYNC Redis Cache: key: {key}\nValue {value}\nttl={ttl}" ) + end_time = time.time() + _duration = end_time - start_time + asyncio.create_task( + self.service_logger_obj.async_service_success_hook( + service=ServiceTypes.REDIS, duration=_duration + ) + ) except Exception as e: + end_time = time.time() + _duration = end_time - start_time + asyncio.create_task( + self.service_logger_obj.async_service_failure_hook( + service=ServiceTypes.REDIS, duration=_duration, error=e + ) + ) # NON blocking - notify users Redis is throwing an exception verbose_logger.error( "LiteLLM Redis Caching: async set() - Got exception from REDIS %s, Writing value=%s", @@ -230,6 +291,7 @@ class RedisCache(BaseCache): Use Redis Pipelines for bulk write operations """ _redis_client = self.init_async_client() + start_time = time.time() try: async with _redis_client as redis_client: async with redis_client.pipeline(transaction=True) as pipe: @@ -249,8 +311,25 @@ class RedisCache(BaseCache): print_verbose(f"pipeline results: {results}") # Optionally, you could process 'results' to make sure that all set operations were successful. + ## LOGGING ## + end_time = time.time() + _duration = end_time - start_time + asyncio.create_task( + self.service_logger_obj.async_service_success_hook( + service=ServiceTypes.REDIS, duration=_duration + ) + ) return results except Exception as e: + ## LOGGING ## + end_time = time.time() + _duration = end_time - start_time + asyncio.create_task( + self.service_logger_obj.async_service_failure_hook( + service=ServiceTypes.REDIS, duration=_duration, error=e + ) + ) + verbose_logger.error( "LiteLLM Redis Caching: async set_cache_pipeline() - Got exception from REDIS %s, Writing value=%s", str(e), @@ -265,15 +344,33 @@ class RedisCache(BaseCache): key = self.check_and_fix_namespace(key=key) self.redis_batch_writing_buffer.append((key, value)) if len(self.redis_batch_writing_buffer) >= self.redis_flush_size: - await self.flush_cache_buffer() + await self.flush_cache_buffer() # logging done in here async def async_increment(self, key, value: int, **kwargs) -> int: _redis_client = self.init_async_client() + start_time = time.time() try: async with _redis_client as redis_client: result = await redis_client.incr(name=key, amount=value) + ## LOGGING ## + end_time = time.time() + _duration = end_time - start_time + asyncio.create_task( + self.service_logger_obj.async_service_success_hook( + service=ServiceTypes.REDIS, + duration=_duration, + ) + ) return result except Exception as e: + ## LOGGING ## + end_time = time.time() + _duration = end_time - start_time + asyncio.create_task( + self.service_logger_obj.async_service_failure_hook( + service=ServiceTypes.REDIS, duration=_duration, error=e + ) + ) verbose_logger.error( "LiteLLM Redis Caching: async async_increment() - Got exception from REDIS %s, Writing value=%s", str(e), @@ -348,6 +445,7 @@ class RedisCache(BaseCache): async def async_get_cache(self, key, **kwargs): _redis_client = self.init_async_client() key = self.check_and_fix_namespace(key=key) + start_time = time.time() async with _redis_client as redis_client: try: print_verbose(f"Get Async Redis Cache: key: {key}") @@ -356,8 +454,24 @@ class RedisCache(BaseCache): f"Got Async Redis Cache: key: {key}, cached_response {cached_response}" ) response = self._get_cache_logic(cached_response=cached_response) + ## LOGGING ## + end_time = time.time() + _duration = end_time - start_time + asyncio.create_task( + self.service_logger_obj.async_service_success_hook( + service=ServiceTypes.REDIS, duration=_duration + ) + ) return response except Exception as e: + ## LOGGING ## + end_time = time.time() + _duration = end_time - start_time + asyncio.create_task( + self.service_logger_obj.async_service_failure_hook( + service=ServiceTypes.REDIS, duration=_duration, error=e + ) + ) # NON blocking - notify users Redis is throwing an exception print_verbose( f"LiteLLM Caching: async get() - Got exception from REDIS: {str(e)}" @@ -369,6 +483,7 @@ class RedisCache(BaseCache): """ _redis_client = await self.init_async_client() key_value_dict = {} + start_time = time.time() try: async with _redis_client as redis_client: _keys = [] @@ -377,6 +492,15 @@ class RedisCache(BaseCache): _keys.append(cache_key) results = await redis_client.mget(keys=_keys) + ## LOGGING ## + end_time = time.time() + _duration = end_time - start_time + asyncio.create_task( + self.service_logger_obj.async_service_success_hook( + service=ServiceTypes.REDIS, duration=_duration + ) + ) + # Associate the results back with their keys. # 'results' is a list of values corresponding to the order of keys in 'key_list'. key_value_dict = dict(zip(key_list, results)) @@ -388,6 +512,14 @@ class RedisCache(BaseCache): return decoded_results except Exception as e: + ## LOGGING ## + end_time = time.time() + _duration = end_time - start_time + asyncio.create_task( + self.service_logger_obj.async_service_failure_hook( + service=ServiceTypes.REDIS, duration=_duration, error=e + ) + ) print_verbose(f"Error occurred in pipeline read - {str(e)}") return key_value_dict diff --git a/litellm/integrations/prometheus.py b/litellm/integrations/prometheus.py index 28e84e6f8..d5f10262f 100644 --- a/litellm/integrations/prometheus.py +++ b/litellm/integrations/prometheus.py @@ -1,6 +1,6 @@ # used for /metrics endpoint on LiteLLM Proxy #### What this does #### -# On success + failure, log events to Supabase +# On success, log events to Prometheus import dotenv, os import requests diff --git a/litellm/integrations/prometheus_services.py b/litellm/integrations/prometheus_services.py new file mode 100644 index 000000000..a26271987 --- /dev/null +++ b/litellm/integrations/prometheus_services.py @@ -0,0 +1,139 @@ +# used for monitoring litellm services health on `/metrics` endpoint on LiteLLM Proxy +#### What this does #### +# On success + failure, log events to Prometheus for litellm / adjacent services (litellm, redis, postgres, llm api providers) + + +import dotenv, os +import requests + +dotenv.load_dotenv() # Loading env variables using dotenv +import traceback +import datetime, subprocess, sys +import litellm, uuid +from litellm._logging import print_verbose, verbose_logger +from litellm.types.services import ServiceLoggerPayload, ServiceTypes + + +class PrometheusServicesLogger: + # Class variables or attributes + litellm_service_latency = None # Class-level attribute to store the Histogram + + def __init__( + self, + mock_testing: bool = False, + **kwargs, + ): + try: + from prometheus_client import Counter, Histogram, REGISTRY + + self.Histogram = Histogram + self.Counter = Counter + self.REGISTRY = REGISTRY + + verbose_logger.debug(f"in init prometheus services metrics") + + self.services = [item.value for item in ServiceTypes] + + self.payload_to_prometheus_map = ( + {} + ) # store the prometheus histogram/counter we need to call for each field in payload + + for service in self.services: + histogram = self.create_histogram(service) + self.payload_to_prometheus_map[service] = histogram + + self.prometheus_to_amount_map: dict = ( + {} + ) # the field / value in ServiceLoggerPayload the object needs to be incremented by + + # self.payload_to_prometheus_map["service"] = [self.litellm_service_latency] + # self.prometheus_to_amount_map[self.litellm_service_latency._name] = ( + # "duration" + # ) + + ### MOCK TESTING ### + self.mock_testing = mock_testing + self.mock_testing_success_calls = 0 + self.mock_testing_failure_calls = 0 + + except Exception as e: + print_verbose(f"Got exception on init prometheus client {str(e)}") + raise e + + def is_metric_registered(self, metric_name) -> bool: + for metric in self.REGISTRY.collect(): + print(f"metric name: {metric.name}") + if metric_name == metric.name: + return True + return False + + def get_metric(self, metric_name): + for metric in self.REGISTRY.collect(): + for sample in metric.samples: + if metric_name == sample.name: + return metric + return None + + def create_histogram(self, label: str): + metric_name = "litellm_{}_latency".format(label) + is_registered = self.is_metric_registered(metric_name) + if is_registered: + return self.get_metric(metric_name) + return self.Histogram( + metric_name, + "Latency for {} service".format(label), + labelnames=[label], + ) + + def observe_histogram( + self, + histogram, + labels: str, + amount: float, + ): + assert isinstance(histogram, self.Histogram) + + histogram.labels(labels).observe(amount) + + def increment_counter( + self, + counter, + labels: list, + amount: float, + ): + assert isinstance(counter, self.Counter) + + counter.labels(labels).inc(amount) + + def service_success_hook(self, payload: ServiceLoggerPayload): + if self.mock_testing: + self.mock_testing_success_calls += 1 + + if payload.service.value in self.payload_to_prometheus_map: + self.observe_histogram( + histogram=self.payload_to_prometheus_map[payload.service.value], + labels=payload.service.value, + amount=payload.duration, + ) + + def service_failure_hook(self, payload: ServiceLoggerPayload): + if self.mock_testing: + self.mock_testing_failure_calls += 1 + + async def async_service_success_hook(self, payload: ServiceLoggerPayload): + """ + Log successful call to prometheus + """ + if self.mock_testing: + self.mock_testing_success_calls += 1 + + if payload.service.value in self.payload_to_prometheus_map: + self.observe_histogram( + histogram=self.payload_to_prometheus_map[payload.service.value], + labels=payload.service.value, + amount=payload.duration, + ) + + async def async_service_failure_hook(self, payload: ServiceLoggerPayload): + if self.mock_testing: + self.mock_testing_failure_calls += 1 diff --git a/litellm/proxy/_new_secret_config.yaml b/litellm/proxy/_new_secret_config.yaml index 59499fc99..9b920224e 100644 --- a/litellm/proxy/_new_secret_config.yaml +++ b/litellm/proxy/_new_secret_config.yaml @@ -25,6 +25,7 @@ model_list: litellm_settings: success_callback: ["prometheus"] + service_callback: ["prometheus_system"] upperbound_key_generate_params: max_budget: os.environ/LITELLM_UPPERBOUND_KEYS_MAX_BUDGET diff --git a/litellm/tests/test_prometheus_service.py b/litellm/tests/test_prometheus_service.py new file mode 100644 index 000000000..2cb505fd9 --- /dev/null +++ b/litellm/tests/test_prometheus_service.py @@ -0,0 +1,199 @@ +# What is this? +## Unit Tests for prometheus service monitoring + +import json +import sys +import os +import io, asyncio + +sys.path.insert(0, os.path.abspath("../..")) +import pytest +from litellm import acompletion, Cache +from litellm._service_logger import ServiceLogging +from litellm.integrations.prometheus_services import PrometheusServicesLogger +import litellm + +""" +- Check if it receives a call when redis is used +- Check if it fires messages accordingly +""" + + +@pytest.mark.asyncio +async def test_init_prometheus(): + """ + - Run completion with caching + - Assert success callback gets called + """ + + pl = PrometheusServicesLogger(mock_testing=True) + + +@pytest.mark.asyncio +async def test_completion_with_caching(): + """ + - Run completion with caching + - Assert success callback gets called + """ + + litellm.set_verbose = True + litellm.cache = Cache(type="redis") + litellm.service_callback = ["prometheus_system"] + + sl = ServiceLogging(mock_testing=True) + sl.prometheusServicesLogger.mock_testing = True + litellm.cache.cache.service_logger_obj = sl + + messages = [{"role": "user", "content": "Hey, how's it going?"}] + response1 = await acompletion( + model="gpt-3.5-turbo", messages=messages, caching=True + ) + response1 = await acompletion( + model="gpt-3.5-turbo", messages=messages, caching=True + ) + + assert sl.mock_testing_async_success_hook > 0 + assert sl.prometheusServicesLogger.mock_testing_success_calls > 0 + + +@pytest.mark.asyncio +async def test_completion_with_caching_bad_call(): + """ + - Run completion with caching (incorrect credentials) + - Assert failure callback gets called + """ + litellm.set_verbose = True + sl = ServiceLogging(mock_testing=True) + try: + litellm.cache = Cache(type="redis", host="hello-world") + litellm.service_callback = ["prometheus_system"] + + litellm.cache.cache.service_logger_obj = sl + + messages = [{"role": "user", "content": "Hey, how's it going?"}] + response1 = await acompletion( + model="gpt-3.5-turbo", messages=messages, caching=True + ) + response1 = await acompletion( + model="gpt-3.5-turbo", messages=messages, caching=True + ) + except Exception as e: + pass + + assert sl.mock_testing_async_failure_hook > 0 + + +@pytest.mark.asyncio +async def test_router_with_caching(): + """ + - Run router with usage-based-routing-v2 + - Assert success callback gets called + """ + try: + + def get_azure_params(deployment_name: str): + params = { + "model": f"azure/{deployment_name}", + "api_key": os.environ["AZURE_API_KEY"], + "api_version": os.environ["AZURE_API_VERSION"], + "api_base": os.environ["AZURE_API_BASE"], + } + return params + + model_list = [ + { + "model_name": "azure/gpt-4", + "litellm_params": get_azure_params("chatgpt-v-2"), + "tpm": 100, + }, + { + "model_name": "azure/gpt-4", + "litellm_params": get_azure_params("chatgpt-v-2"), + "tpm": 1000, + }, + ] + + router = litellm.Router( + model_list=model_list, + set_verbose=True, + debug_level="DEBUG", + routing_strategy="usage-based-routing-v2", + redis_host=os.environ["REDIS_HOST"], + redis_port=os.environ["REDIS_PORT"], + redis_password=os.environ["REDIS_PASSWORD"], + ) + + litellm.service_callback = ["prometheus_system"] + + sl = ServiceLogging(mock_testing=True) + sl.prometheusServicesLogger.mock_testing = True + router.cache.redis_cache.service_logger_obj = sl + + messages = [{"role": "user", "content": "Hey, how's it going?"}] + response1 = await router.acompletion(model="azure/gpt-4", messages=messages) + response1 = await router.acompletion(model="azure/gpt-4", messages=messages) + + assert sl.mock_testing_async_success_hook > 0 + assert sl.prometheusServicesLogger.mock_testing_success_calls > 0 + + except Exception as e: + pytest.fail(f"An exception occured - {str(e)}") + + +@pytest.mark.asyncio +async def test_router_with_caching_bad_call(): + """ + - Run completion with caching (incorrect credentials) + - Assert failure callback gets called + """ + try: + + def get_azure_params(deployment_name: str): + params = { + "model": f"azure/{deployment_name}", + "api_key": os.environ["AZURE_API_KEY"], + "api_version": os.environ["AZURE_API_VERSION"], + "api_base": os.environ["AZURE_API_BASE"], + } + return params + + model_list = [ + { + "model_name": "azure/gpt-4", + "litellm_params": get_azure_params("chatgpt-v-2"), + "tpm": 100, + }, + { + "model_name": "azure/gpt-4", + "litellm_params": get_azure_params("chatgpt-v-2"), + "tpm": 1000, + }, + ] + + router = litellm.Router( + model_list=model_list, + set_verbose=True, + debug_level="DEBUG", + routing_strategy="usage-based-routing-v2", + redis_host="hello world", + redis_port=os.environ["REDIS_PORT"], + redis_password=os.environ["REDIS_PASSWORD"], + ) + + litellm.service_callback = ["prometheus_system"] + + sl = ServiceLogging(mock_testing=True) + sl.prometheusServicesLogger.mock_testing = True + router.cache.redis_cache.service_logger_obj = sl + + messages = [{"role": "user", "content": "Hey, how's it going?"}] + try: + response1 = await router.acompletion(model="azure/gpt-4", messages=messages) + response1 = await router.acompletion(model="azure/gpt-4", messages=messages) + except Exception as e: + pass + + assert sl.mock_testing_async_failure_hook > 0 + + except Exception as e: + pytest.fail(f"An exception occured - {str(e)}") diff --git a/litellm/types/services.py b/litellm/types/services.py new file mode 100644 index 000000000..ea5172ebc --- /dev/null +++ b/litellm/types/services.py @@ -0,0 +1,30 @@ +import uuid, enum +from pydantic import BaseModel, Field +from typing import Optional + + +class ServiceTypes(enum.Enum): + """ + Enum for litellm-adjacent services (redis/postgres/etc.) + """ + + REDIS = "redis" + DB = "postgres" + + +class ServiceLoggerPayload(BaseModel): + """ + The payload logged during service success/failure + """ + + is_error: bool = Field(description="did an error occur") + error: Optional[str] = Field(None, description="what was the error") + service: ServiceTypes = Field(description="who is this for? - postgres/redis") + duration: float = Field(description="How long did the request take?") + + def to_json(self, **kwargs): + try: + return self.model_dump(**kwargs) # noqa + except Exception as e: + # if using pydantic v1 + return self.dict(**kwargs) diff --git a/litellm/utils.py b/litellm/utils.py index d1133affb..b94c22bc9 100644 --- a/litellm/utils.py +++ b/litellm/utils.py @@ -29,7 +29,9 @@ from tokenizers import Tokenizer from dataclasses import ( dataclass, field, -) # for storing API inputs, outputs, and metadata +) + +import litellm._service_logger # for storing API inputs, outputs, and metadata try: # this works in python 3.8 @@ -69,6 +71,7 @@ from .integrations.custom_logger import CustomLogger from .integrations.langfuse import LangFuseLogger from .integrations.datadog import DataDogLogger from .integrations.prometheus import PrometheusLogger +from .integrations.prometheus_services import PrometheusServicesLogger from .integrations.dynamodb import DyanmoDBLogger from .integrations.s3 import S3Logger from .integrations.clickhouse import ClickhouseLogger @@ -6564,7 +6567,9 @@ def handle_failure(exception, traceback_exception, start_time, end_time, args, k for detail in additional_details: slack_msg += f"{detail}: {additional_details[detail]}\n" slack_msg += f"Traceback: {traceback_exception}" - truncated_slack_msg = textwrap.shorten(slack_msg, width=512, placeholder="...") + truncated_slack_msg = textwrap.shorten( + slack_msg, width=512, placeholder="..." + ) slack_app.client.chat_postMessage( channel=alerts_channel, text=truncated_slack_msg )