feat(prometheus_services.py): monitor health of proxy adjacent services (redis / postgres / etc.)

This commit is contained in:
Krrish Dholakia 2024-04-13 18:15:02 -07:00
parent a06a0e7b81
commit 4e81acf2c6
9 changed files with 591 additions and 13 deletions

View file

@ -19,6 +19,7 @@ if set_verbose == True:
input_callback: List[Union[str, Callable]] = []
success_callback: List[Union[str, Callable]] = []
failure_callback: List[Union[str, Callable]] = []
service_callback: List[Union[str, Callable]] = []
callbacks: List[Callable] = []
_async_input_callback: List[Callable] = (
[]

View file

@ -0,0 +1,71 @@
import litellm
from .types.services import ServiceTypes, ServiceLoggerPayload
from .integrations.prometheus_services import PrometheusServicesLogger
class ServiceLogging:
"""
Separate class used for monitoring health of litellm-adjacent services (redis/postgres).
"""
def __init__(self, mock_testing: bool = False) -> None:
self.mock_testing = mock_testing
self.mock_testing_sync_success_hook = 0
self.mock_testing_async_success_hook = 0
self.mock_testing_sync_failure_hook = 0
self.mock_testing_async_failure_hook = 0
if "prometheus_system" in litellm.service_callback:
self.prometheusServicesLogger = PrometheusServicesLogger()
def service_success_hook(self, service: ServiceTypes, duration: float):
"""
[TODO] Not implemented for sync calls yet. V0 is focused on async monitoring (used by proxy).
"""
if self.mock_testing:
self.mock_testing_sync_success_hook += 1
def service_failure_hook(
self, service: ServiceTypes, duration: float, error: Exception
):
"""
[TODO] Not implemented for sync calls yet. V0 is focused on async monitoring (used by proxy).
"""
if self.mock_testing:
self.mock_testing_sync_failure_hook += 1
async def async_service_success_hook(self, service: ServiceTypes, duration: float):
"""
- For counting if the redis, postgres call is successful
"""
if self.mock_testing:
self.mock_testing_async_success_hook += 1
payload = ServiceLoggerPayload(
is_error=False, error=None, service=service, duration=duration
)
for callback in litellm.service_callback:
if callback == "prometheus_system":
await self.prometheusServicesLogger.async_service_success_hook(
payload=payload
)
async def async_service_failure_hook(
self, service: ServiceTypes, duration: float, error: Exception
):
"""
- For counting if the redis, postgres call is unsuccessful
"""
if self.mock_testing:
self.mock_testing_async_failure_hook += 1
payload = ServiceLoggerPayload(
is_error=True, error=str(error), service=service, duration=duration
)
for callback in litellm.service_callback:
if callback == "prometheus_system":
if self.prometheusServicesLogger is None:
self.prometheusServicesLogger = self.prometheusServicesLogger()
await self.prometheusServicesLogger.async_service_failure_hook(
payload=payload
)

View file

@ -13,6 +13,8 @@ import json, traceback, ast, hashlib
from typing import Optional, Literal, List, Union, Any, BinaryIO
from openai._models import BaseModel as OpenAIObject
from litellm._logging import verbose_logger
from litellm._service_logger import ServiceLogging
from litellm.types.services import ServiceLoggerPayload, ServiceTypes
import traceback
@ -163,6 +165,9 @@ class RedisCache(BaseCache):
except Exception as e:
pass
### HEALTH MONITORING OBJECT ###
self.service_logger_obj = ServiceLogging()
def init_async_client(self):
from ._redis import get_redis_async_client
@ -194,17 +199,59 @@ class RedisCache(BaseCache):
)
async def async_scan_iter(self, pattern: str, count: int = 100) -> list:
keys = []
_redis_client = self.init_async_client()
async with _redis_client as redis_client:
async for key in redis_client.scan_iter(match=pattern + "*", count=count):
keys.append(key)
if len(keys) >= count:
break
return keys
start_time = time.time()
try:
keys = []
_redis_client = self.init_async_client()
async with _redis_client as redis_client:
async for key in redis_client.scan_iter(
match=pattern + "*", count=count
):
keys.append(key)
if len(keys) >= count:
break
## LOGGING ##
end_time = time.time()
_duration = end_time - start_time
asyncio.create_task(
self.service_logger_obj.async_service_success_hook(
service=ServiceTypes.REDIS, duration=_duration
)
) # DO NOT SLOW DOWN CALL B/C OF THIS
return keys
except Exception as e:
# NON blocking - notify users Redis is throwing an exception
## LOGGING ##
end_time = time.time()
_duration = end_time - start_time
asyncio.create_task(
self.service_logger_obj.async_service_failure_hook(
service=ServiceTypes.REDIS, duration=_duration, error=e
)
)
raise e
async def async_set_cache(self, key, value, **kwargs):
_redis_client = self.init_async_client()
start_time = time.time()
try:
_redis_client = self.init_async_client()
except Exception as e:
end_time = time.time()
_duration = end_time - start_time
asyncio.create_task(
self.service_logger_obj.async_service_failure_hook(
service=ServiceTypes.REDIS, duration=_duration, error=e
)
)
# NON blocking - notify users Redis is throwing an exception
verbose_logger.error(
"LiteLLM Redis Caching: async set() - Got exception from REDIS %s, Writing value=%s",
str(e),
value,
)
traceback.print_exc()
key = self.check_and_fix_namespace(key=key)
async with _redis_client as redis_client:
ttl = kwargs.get("ttl", None)
@ -216,7 +263,21 @@ class RedisCache(BaseCache):
print_verbose(
f"Successfully Set ASYNC Redis Cache: key: {key}\nValue {value}\nttl={ttl}"
)
end_time = time.time()
_duration = end_time - start_time
asyncio.create_task(
self.service_logger_obj.async_service_success_hook(
service=ServiceTypes.REDIS, duration=_duration
)
)
except Exception as e:
end_time = time.time()
_duration = end_time - start_time
asyncio.create_task(
self.service_logger_obj.async_service_failure_hook(
service=ServiceTypes.REDIS, duration=_duration, error=e
)
)
# NON blocking - notify users Redis is throwing an exception
verbose_logger.error(
"LiteLLM Redis Caching: async set() - Got exception from REDIS %s, Writing value=%s",
@ -230,6 +291,7 @@ class RedisCache(BaseCache):
Use Redis Pipelines for bulk write operations
"""
_redis_client = self.init_async_client()
start_time = time.time()
try:
async with _redis_client as redis_client:
async with redis_client.pipeline(transaction=True) as pipe:
@ -249,8 +311,25 @@ class RedisCache(BaseCache):
print_verbose(f"pipeline results: {results}")
# Optionally, you could process 'results' to make sure that all set operations were successful.
## LOGGING ##
end_time = time.time()
_duration = end_time - start_time
asyncio.create_task(
self.service_logger_obj.async_service_success_hook(
service=ServiceTypes.REDIS, duration=_duration
)
)
return results
except Exception as e:
## LOGGING ##
end_time = time.time()
_duration = end_time - start_time
asyncio.create_task(
self.service_logger_obj.async_service_failure_hook(
service=ServiceTypes.REDIS, duration=_duration, error=e
)
)
verbose_logger.error(
"LiteLLM Redis Caching: async set_cache_pipeline() - Got exception from REDIS %s, Writing value=%s",
str(e),
@ -265,15 +344,33 @@ class RedisCache(BaseCache):
key = self.check_and_fix_namespace(key=key)
self.redis_batch_writing_buffer.append((key, value))
if len(self.redis_batch_writing_buffer) >= self.redis_flush_size:
await self.flush_cache_buffer()
await self.flush_cache_buffer() # logging done in here
async def async_increment(self, key, value: int, **kwargs) -> int:
_redis_client = self.init_async_client()
start_time = time.time()
try:
async with _redis_client as redis_client:
result = await redis_client.incr(name=key, amount=value)
## LOGGING ##
end_time = time.time()
_duration = end_time - start_time
asyncio.create_task(
self.service_logger_obj.async_service_success_hook(
service=ServiceTypes.REDIS,
duration=_duration,
)
)
return result
except Exception as e:
## LOGGING ##
end_time = time.time()
_duration = end_time - start_time
asyncio.create_task(
self.service_logger_obj.async_service_failure_hook(
service=ServiceTypes.REDIS, duration=_duration, error=e
)
)
verbose_logger.error(
"LiteLLM Redis Caching: async async_increment() - Got exception from REDIS %s, Writing value=%s",
str(e),
@ -348,6 +445,7 @@ class RedisCache(BaseCache):
async def async_get_cache(self, key, **kwargs):
_redis_client = self.init_async_client()
key = self.check_and_fix_namespace(key=key)
start_time = time.time()
async with _redis_client as redis_client:
try:
print_verbose(f"Get Async Redis Cache: key: {key}")
@ -356,8 +454,24 @@ class RedisCache(BaseCache):
f"Got Async Redis Cache: key: {key}, cached_response {cached_response}"
)
response = self._get_cache_logic(cached_response=cached_response)
## LOGGING ##
end_time = time.time()
_duration = end_time - start_time
asyncio.create_task(
self.service_logger_obj.async_service_success_hook(
service=ServiceTypes.REDIS, duration=_duration
)
)
return response
except Exception as e:
## LOGGING ##
end_time = time.time()
_duration = end_time - start_time
asyncio.create_task(
self.service_logger_obj.async_service_failure_hook(
service=ServiceTypes.REDIS, duration=_duration, error=e
)
)
# NON blocking - notify users Redis is throwing an exception
print_verbose(
f"LiteLLM Caching: async get() - Got exception from REDIS: {str(e)}"
@ -369,6 +483,7 @@ class RedisCache(BaseCache):
"""
_redis_client = await self.init_async_client()
key_value_dict = {}
start_time = time.time()
try:
async with _redis_client as redis_client:
_keys = []
@ -377,6 +492,15 @@ class RedisCache(BaseCache):
_keys.append(cache_key)
results = await redis_client.mget(keys=_keys)
## LOGGING ##
end_time = time.time()
_duration = end_time - start_time
asyncio.create_task(
self.service_logger_obj.async_service_success_hook(
service=ServiceTypes.REDIS, duration=_duration
)
)
# Associate the results back with their keys.
# 'results' is a list of values corresponding to the order of keys in 'key_list'.
key_value_dict = dict(zip(key_list, results))
@ -388,6 +512,14 @@ class RedisCache(BaseCache):
return decoded_results
except Exception as e:
## LOGGING ##
end_time = time.time()
_duration = end_time - start_time
asyncio.create_task(
self.service_logger_obj.async_service_failure_hook(
service=ServiceTypes.REDIS, duration=_duration, error=e
)
)
print_verbose(f"Error occurred in pipeline read - {str(e)}")
return key_value_dict

View file

@ -1,6 +1,6 @@
# used for /metrics endpoint on LiteLLM Proxy
#### What this does ####
# On success + failure, log events to Supabase
# On success, log events to Prometheus
import dotenv, os
import requests

View file

@ -0,0 +1,139 @@
# used for monitoring litellm services health on `/metrics` endpoint on LiteLLM Proxy
#### What this does ####
# On success + failure, log events to Prometheus for litellm / adjacent services (litellm, redis, postgres, llm api providers)
import dotenv, os
import requests
dotenv.load_dotenv() # Loading env variables using dotenv
import traceback
import datetime, subprocess, sys
import litellm, uuid
from litellm._logging import print_verbose, verbose_logger
from litellm.types.services import ServiceLoggerPayload, ServiceTypes
class PrometheusServicesLogger:
# Class variables or attributes
litellm_service_latency = None # Class-level attribute to store the Histogram
def __init__(
self,
mock_testing: bool = False,
**kwargs,
):
try:
from prometheus_client import Counter, Histogram, REGISTRY
self.Histogram = Histogram
self.Counter = Counter
self.REGISTRY = REGISTRY
verbose_logger.debug(f"in init prometheus services metrics")
self.services = [item.value for item in ServiceTypes]
self.payload_to_prometheus_map = (
{}
) # store the prometheus histogram/counter we need to call for each field in payload
for service in self.services:
histogram = self.create_histogram(service)
self.payload_to_prometheus_map[service] = histogram
self.prometheus_to_amount_map: dict = (
{}
) # the field / value in ServiceLoggerPayload the object needs to be incremented by
# self.payload_to_prometheus_map["service"] = [self.litellm_service_latency]
# self.prometheus_to_amount_map[self.litellm_service_latency._name] = (
# "duration"
# )
### MOCK TESTING ###
self.mock_testing = mock_testing
self.mock_testing_success_calls = 0
self.mock_testing_failure_calls = 0
except Exception as e:
print_verbose(f"Got exception on init prometheus client {str(e)}")
raise e
def is_metric_registered(self, metric_name) -> bool:
for metric in self.REGISTRY.collect():
print(f"metric name: {metric.name}")
if metric_name == metric.name:
return True
return False
def get_metric(self, metric_name):
for metric in self.REGISTRY.collect():
for sample in metric.samples:
if metric_name == sample.name:
return metric
return None
def create_histogram(self, label: str):
metric_name = "litellm_{}_latency".format(label)
is_registered = self.is_metric_registered(metric_name)
if is_registered:
return self.get_metric(metric_name)
return self.Histogram(
metric_name,
"Latency for {} service".format(label),
labelnames=[label],
)
def observe_histogram(
self,
histogram,
labels: str,
amount: float,
):
assert isinstance(histogram, self.Histogram)
histogram.labels(labels).observe(amount)
def increment_counter(
self,
counter,
labels: list,
amount: float,
):
assert isinstance(counter, self.Counter)
counter.labels(labels).inc(amount)
def service_success_hook(self, payload: ServiceLoggerPayload):
if self.mock_testing:
self.mock_testing_success_calls += 1
if payload.service.value in self.payload_to_prometheus_map:
self.observe_histogram(
histogram=self.payload_to_prometheus_map[payload.service.value],
labels=payload.service.value,
amount=payload.duration,
)
def service_failure_hook(self, payload: ServiceLoggerPayload):
if self.mock_testing:
self.mock_testing_failure_calls += 1
async def async_service_success_hook(self, payload: ServiceLoggerPayload):
"""
Log successful call to prometheus
"""
if self.mock_testing:
self.mock_testing_success_calls += 1
if payload.service.value in self.payload_to_prometheus_map:
self.observe_histogram(
histogram=self.payload_to_prometheus_map[payload.service.value],
labels=payload.service.value,
amount=payload.duration,
)
async def async_service_failure_hook(self, payload: ServiceLoggerPayload):
if self.mock_testing:
self.mock_testing_failure_calls += 1

View file

@ -25,6 +25,7 @@ model_list:
litellm_settings:
success_callback: ["prometheus"]
service_callback: ["prometheus_system"]
upperbound_key_generate_params:
max_budget: os.environ/LITELLM_UPPERBOUND_KEYS_MAX_BUDGET

View file

@ -0,0 +1,199 @@
# What is this?
## Unit Tests for prometheus service monitoring
import json
import sys
import os
import io, asyncio
sys.path.insert(0, os.path.abspath("../.."))
import pytest
from litellm import acompletion, Cache
from litellm._service_logger import ServiceLogging
from litellm.integrations.prometheus_services import PrometheusServicesLogger
import litellm
"""
- Check if it receives a call when redis is used
- Check if it fires messages accordingly
"""
@pytest.mark.asyncio
async def test_init_prometheus():
"""
- Run completion with caching
- Assert success callback gets called
"""
pl = PrometheusServicesLogger(mock_testing=True)
@pytest.mark.asyncio
async def test_completion_with_caching():
"""
- Run completion with caching
- Assert success callback gets called
"""
litellm.set_verbose = True
litellm.cache = Cache(type="redis")
litellm.service_callback = ["prometheus_system"]
sl = ServiceLogging(mock_testing=True)
sl.prometheusServicesLogger.mock_testing = True
litellm.cache.cache.service_logger_obj = sl
messages = [{"role": "user", "content": "Hey, how's it going?"}]
response1 = await acompletion(
model="gpt-3.5-turbo", messages=messages, caching=True
)
response1 = await acompletion(
model="gpt-3.5-turbo", messages=messages, caching=True
)
assert sl.mock_testing_async_success_hook > 0
assert sl.prometheusServicesLogger.mock_testing_success_calls > 0
@pytest.mark.asyncio
async def test_completion_with_caching_bad_call():
"""
- Run completion with caching (incorrect credentials)
- Assert failure callback gets called
"""
litellm.set_verbose = True
sl = ServiceLogging(mock_testing=True)
try:
litellm.cache = Cache(type="redis", host="hello-world")
litellm.service_callback = ["prometheus_system"]
litellm.cache.cache.service_logger_obj = sl
messages = [{"role": "user", "content": "Hey, how's it going?"}]
response1 = await acompletion(
model="gpt-3.5-turbo", messages=messages, caching=True
)
response1 = await acompletion(
model="gpt-3.5-turbo", messages=messages, caching=True
)
except Exception as e:
pass
assert sl.mock_testing_async_failure_hook > 0
@pytest.mark.asyncio
async def test_router_with_caching():
"""
- Run router with usage-based-routing-v2
- Assert success callback gets called
"""
try:
def get_azure_params(deployment_name: str):
params = {
"model": f"azure/{deployment_name}",
"api_key": os.environ["AZURE_API_KEY"],
"api_version": os.environ["AZURE_API_VERSION"],
"api_base": os.environ["AZURE_API_BASE"],
}
return params
model_list = [
{
"model_name": "azure/gpt-4",
"litellm_params": get_azure_params("chatgpt-v-2"),
"tpm": 100,
},
{
"model_name": "azure/gpt-4",
"litellm_params": get_azure_params("chatgpt-v-2"),
"tpm": 1000,
},
]
router = litellm.Router(
model_list=model_list,
set_verbose=True,
debug_level="DEBUG",
routing_strategy="usage-based-routing-v2",
redis_host=os.environ["REDIS_HOST"],
redis_port=os.environ["REDIS_PORT"],
redis_password=os.environ["REDIS_PASSWORD"],
)
litellm.service_callback = ["prometheus_system"]
sl = ServiceLogging(mock_testing=True)
sl.prometheusServicesLogger.mock_testing = True
router.cache.redis_cache.service_logger_obj = sl
messages = [{"role": "user", "content": "Hey, how's it going?"}]
response1 = await router.acompletion(model="azure/gpt-4", messages=messages)
response1 = await router.acompletion(model="azure/gpt-4", messages=messages)
assert sl.mock_testing_async_success_hook > 0
assert sl.prometheusServicesLogger.mock_testing_success_calls > 0
except Exception as e:
pytest.fail(f"An exception occured - {str(e)}")
@pytest.mark.asyncio
async def test_router_with_caching_bad_call():
"""
- Run completion with caching (incorrect credentials)
- Assert failure callback gets called
"""
try:
def get_azure_params(deployment_name: str):
params = {
"model": f"azure/{deployment_name}",
"api_key": os.environ["AZURE_API_KEY"],
"api_version": os.environ["AZURE_API_VERSION"],
"api_base": os.environ["AZURE_API_BASE"],
}
return params
model_list = [
{
"model_name": "azure/gpt-4",
"litellm_params": get_azure_params("chatgpt-v-2"),
"tpm": 100,
},
{
"model_name": "azure/gpt-4",
"litellm_params": get_azure_params("chatgpt-v-2"),
"tpm": 1000,
},
]
router = litellm.Router(
model_list=model_list,
set_verbose=True,
debug_level="DEBUG",
routing_strategy="usage-based-routing-v2",
redis_host="hello world",
redis_port=os.environ["REDIS_PORT"],
redis_password=os.environ["REDIS_PASSWORD"],
)
litellm.service_callback = ["prometheus_system"]
sl = ServiceLogging(mock_testing=True)
sl.prometheusServicesLogger.mock_testing = True
router.cache.redis_cache.service_logger_obj = sl
messages = [{"role": "user", "content": "Hey, how's it going?"}]
try:
response1 = await router.acompletion(model="azure/gpt-4", messages=messages)
response1 = await router.acompletion(model="azure/gpt-4", messages=messages)
except Exception as e:
pass
assert sl.mock_testing_async_failure_hook > 0
except Exception as e:
pytest.fail(f"An exception occured - {str(e)}")

30
litellm/types/services.py Normal file
View file

@ -0,0 +1,30 @@
import uuid, enum
from pydantic import BaseModel, Field
from typing import Optional
class ServiceTypes(enum.Enum):
"""
Enum for litellm-adjacent services (redis/postgres/etc.)
"""
REDIS = "redis"
DB = "postgres"
class ServiceLoggerPayload(BaseModel):
"""
The payload logged during service success/failure
"""
is_error: bool = Field(description="did an error occur")
error: Optional[str] = Field(None, description="what was the error")
service: ServiceTypes = Field(description="who is this for? - postgres/redis")
duration: float = Field(description="How long did the request take?")
def to_json(self, **kwargs):
try:
return self.model_dump(**kwargs) # noqa
except Exception as e:
# if using pydantic v1
return self.dict(**kwargs)

View file

@ -29,7 +29,9 @@ from tokenizers import Tokenizer
from dataclasses import (
dataclass,
field,
) # for storing API inputs, outputs, and metadata
)
import litellm._service_logger # for storing API inputs, outputs, and metadata
try:
# this works in python 3.8
@ -69,6 +71,7 @@ from .integrations.custom_logger import CustomLogger
from .integrations.langfuse import LangFuseLogger
from .integrations.datadog import DataDogLogger
from .integrations.prometheus import PrometheusLogger
from .integrations.prometheus_services import PrometheusServicesLogger
from .integrations.dynamodb import DyanmoDBLogger
from .integrations.s3 import S3Logger
from .integrations.clickhouse import ClickhouseLogger
@ -6564,7 +6567,9 @@ def handle_failure(exception, traceback_exception, start_time, end_time, args, k
for detail in additional_details:
slack_msg += f"{detail}: {additional_details[detail]}\n"
slack_msg += f"Traceback: {traceback_exception}"
truncated_slack_msg = textwrap.shorten(slack_msg, width=512, placeholder="...")
truncated_slack_msg = textwrap.shorten(
slack_msg, width=512, placeholder="..."
)
slack_app.client.chat_postMessage(
channel=alerts_channel, text=truncated_slack_msg
)