forked from phoenix/litellm-mirror
feat(prometheus_services.py): monitor health of proxy adjacent services (redis / postgres / etc.)
This commit is contained in:
parent
a06a0e7b81
commit
4e81acf2c6
9 changed files with 591 additions and 13 deletions
|
@ -19,6 +19,7 @@ if set_verbose == True:
|
|||
input_callback: List[Union[str, Callable]] = []
|
||||
success_callback: List[Union[str, Callable]] = []
|
||||
failure_callback: List[Union[str, Callable]] = []
|
||||
service_callback: List[Union[str, Callable]] = []
|
||||
callbacks: List[Callable] = []
|
||||
_async_input_callback: List[Callable] = (
|
||||
[]
|
||||
|
|
71
litellm/_service_logger.py
Normal file
71
litellm/_service_logger.py
Normal file
|
@ -0,0 +1,71 @@
|
|||
import litellm
|
||||
from .types.services import ServiceTypes, ServiceLoggerPayload
|
||||
from .integrations.prometheus_services import PrometheusServicesLogger
|
||||
|
||||
|
||||
class ServiceLogging:
|
||||
"""
|
||||
Separate class used for monitoring health of litellm-adjacent services (redis/postgres).
|
||||
"""
|
||||
|
||||
def __init__(self, mock_testing: bool = False) -> None:
|
||||
self.mock_testing = mock_testing
|
||||
self.mock_testing_sync_success_hook = 0
|
||||
self.mock_testing_async_success_hook = 0
|
||||
self.mock_testing_sync_failure_hook = 0
|
||||
self.mock_testing_async_failure_hook = 0
|
||||
|
||||
if "prometheus_system" in litellm.service_callback:
|
||||
self.prometheusServicesLogger = PrometheusServicesLogger()
|
||||
|
||||
def service_success_hook(self, service: ServiceTypes, duration: float):
|
||||
"""
|
||||
[TODO] Not implemented for sync calls yet. V0 is focused on async monitoring (used by proxy).
|
||||
"""
|
||||
if self.mock_testing:
|
||||
self.mock_testing_sync_success_hook += 1
|
||||
|
||||
def service_failure_hook(
|
||||
self, service: ServiceTypes, duration: float, error: Exception
|
||||
):
|
||||
"""
|
||||
[TODO] Not implemented for sync calls yet. V0 is focused on async monitoring (used by proxy).
|
||||
"""
|
||||
if self.mock_testing:
|
||||
self.mock_testing_sync_failure_hook += 1
|
||||
|
||||
async def async_service_success_hook(self, service: ServiceTypes, duration: float):
|
||||
"""
|
||||
- For counting if the redis, postgres call is successful
|
||||
"""
|
||||
if self.mock_testing:
|
||||
self.mock_testing_async_success_hook += 1
|
||||
|
||||
payload = ServiceLoggerPayload(
|
||||
is_error=False, error=None, service=service, duration=duration
|
||||
)
|
||||
for callback in litellm.service_callback:
|
||||
if callback == "prometheus_system":
|
||||
await self.prometheusServicesLogger.async_service_success_hook(
|
||||
payload=payload
|
||||
)
|
||||
|
||||
async def async_service_failure_hook(
|
||||
self, service: ServiceTypes, duration: float, error: Exception
|
||||
):
|
||||
"""
|
||||
- For counting if the redis, postgres call is unsuccessful
|
||||
"""
|
||||
if self.mock_testing:
|
||||
self.mock_testing_async_failure_hook += 1
|
||||
|
||||
payload = ServiceLoggerPayload(
|
||||
is_error=True, error=str(error), service=service, duration=duration
|
||||
)
|
||||
for callback in litellm.service_callback:
|
||||
if callback == "prometheus_system":
|
||||
if self.prometheusServicesLogger is None:
|
||||
self.prometheusServicesLogger = self.prometheusServicesLogger()
|
||||
await self.prometheusServicesLogger.async_service_failure_hook(
|
||||
payload=payload
|
||||
)
|
|
@ -13,6 +13,8 @@ import json, traceback, ast, hashlib
|
|||
from typing import Optional, Literal, List, Union, Any, BinaryIO
|
||||
from openai._models import BaseModel as OpenAIObject
|
||||
from litellm._logging import verbose_logger
|
||||
from litellm._service_logger import ServiceLogging
|
||||
from litellm.types.services import ServiceLoggerPayload, ServiceTypes
|
||||
import traceback
|
||||
|
||||
|
||||
|
@ -163,6 +165,9 @@ class RedisCache(BaseCache):
|
|||
except Exception as e:
|
||||
pass
|
||||
|
||||
### HEALTH MONITORING OBJECT ###
|
||||
self.service_logger_obj = ServiceLogging()
|
||||
|
||||
def init_async_client(self):
|
||||
from ._redis import get_redis_async_client
|
||||
|
||||
|
@ -194,17 +199,59 @@ class RedisCache(BaseCache):
|
|||
)
|
||||
|
||||
async def async_scan_iter(self, pattern: str, count: int = 100) -> list:
|
||||
keys = []
|
||||
_redis_client = self.init_async_client()
|
||||
async with _redis_client as redis_client:
|
||||
async for key in redis_client.scan_iter(match=pattern + "*", count=count):
|
||||
keys.append(key)
|
||||
if len(keys) >= count:
|
||||
break
|
||||
return keys
|
||||
start_time = time.time()
|
||||
try:
|
||||
keys = []
|
||||
_redis_client = self.init_async_client()
|
||||
async with _redis_client as redis_client:
|
||||
async for key in redis_client.scan_iter(
|
||||
match=pattern + "*", count=count
|
||||
):
|
||||
keys.append(key)
|
||||
if len(keys) >= count:
|
||||
break
|
||||
|
||||
## LOGGING ##
|
||||
end_time = time.time()
|
||||
_duration = end_time - start_time
|
||||
asyncio.create_task(
|
||||
self.service_logger_obj.async_service_success_hook(
|
||||
service=ServiceTypes.REDIS, duration=_duration
|
||||
)
|
||||
) # DO NOT SLOW DOWN CALL B/C OF THIS
|
||||
return keys
|
||||
except Exception as e:
|
||||
# NON blocking - notify users Redis is throwing an exception
|
||||
## LOGGING ##
|
||||
end_time = time.time()
|
||||
_duration = end_time - start_time
|
||||
asyncio.create_task(
|
||||
self.service_logger_obj.async_service_failure_hook(
|
||||
service=ServiceTypes.REDIS, duration=_duration, error=e
|
||||
)
|
||||
)
|
||||
raise e
|
||||
|
||||
async def async_set_cache(self, key, value, **kwargs):
|
||||
_redis_client = self.init_async_client()
|
||||
start_time = time.time()
|
||||
try:
|
||||
_redis_client = self.init_async_client()
|
||||
except Exception as e:
|
||||
end_time = time.time()
|
||||
_duration = end_time - start_time
|
||||
asyncio.create_task(
|
||||
self.service_logger_obj.async_service_failure_hook(
|
||||
service=ServiceTypes.REDIS, duration=_duration, error=e
|
||||
)
|
||||
)
|
||||
# NON blocking - notify users Redis is throwing an exception
|
||||
verbose_logger.error(
|
||||
"LiteLLM Redis Caching: async set() - Got exception from REDIS %s, Writing value=%s",
|
||||
str(e),
|
||||
value,
|
||||
)
|
||||
traceback.print_exc()
|
||||
|
||||
key = self.check_and_fix_namespace(key=key)
|
||||
async with _redis_client as redis_client:
|
||||
ttl = kwargs.get("ttl", None)
|
||||
|
@ -216,7 +263,21 @@ class RedisCache(BaseCache):
|
|||
print_verbose(
|
||||
f"Successfully Set ASYNC Redis Cache: key: {key}\nValue {value}\nttl={ttl}"
|
||||
)
|
||||
end_time = time.time()
|
||||
_duration = end_time - start_time
|
||||
asyncio.create_task(
|
||||
self.service_logger_obj.async_service_success_hook(
|
||||
service=ServiceTypes.REDIS, duration=_duration
|
||||
)
|
||||
)
|
||||
except Exception as e:
|
||||
end_time = time.time()
|
||||
_duration = end_time - start_time
|
||||
asyncio.create_task(
|
||||
self.service_logger_obj.async_service_failure_hook(
|
||||
service=ServiceTypes.REDIS, duration=_duration, error=e
|
||||
)
|
||||
)
|
||||
# NON blocking - notify users Redis is throwing an exception
|
||||
verbose_logger.error(
|
||||
"LiteLLM Redis Caching: async set() - Got exception from REDIS %s, Writing value=%s",
|
||||
|
@ -230,6 +291,7 @@ class RedisCache(BaseCache):
|
|||
Use Redis Pipelines for bulk write operations
|
||||
"""
|
||||
_redis_client = self.init_async_client()
|
||||
start_time = time.time()
|
||||
try:
|
||||
async with _redis_client as redis_client:
|
||||
async with redis_client.pipeline(transaction=True) as pipe:
|
||||
|
@ -249,8 +311,25 @@ class RedisCache(BaseCache):
|
|||
|
||||
print_verbose(f"pipeline results: {results}")
|
||||
# Optionally, you could process 'results' to make sure that all set operations were successful.
|
||||
## LOGGING ##
|
||||
end_time = time.time()
|
||||
_duration = end_time - start_time
|
||||
asyncio.create_task(
|
||||
self.service_logger_obj.async_service_success_hook(
|
||||
service=ServiceTypes.REDIS, duration=_duration
|
||||
)
|
||||
)
|
||||
return results
|
||||
except Exception as e:
|
||||
## LOGGING ##
|
||||
end_time = time.time()
|
||||
_duration = end_time - start_time
|
||||
asyncio.create_task(
|
||||
self.service_logger_obj.async_service_failure_hook(
|
||||
service=ServiceTypes.REDIS, duration=_duration, error=e
|
||||
)
|
||||
)
|
||||
|
||||
verbose_logger.error(
|
||||
"LiteLLM Redis Caching: async set_cache_pipeline() - Got exception from REDIS %s, Writing value=%s",
|
||||
str(e),
|
||||
|
@ -265,15 +344,33 @@ class RedisCache(BaseCache):
|
|||
key = self.check_and_fix_namespace(key=key)
|
||||
self.redis_batch_writing_buffer.append((key, value))
|
||||
if len(self.redis_batch_writing_buffer) >= self.redis_flush_size:
|
||||
await self.flush_cache_buffer()
|
||||
await self.flush_cache_buffer() # logging done in here
|
||||
|
||||
async def async_increment(self, key, value: int, **kwargs) -> int:
|
||||
_redis_client = self.init_async_client()
|
||||
start_time = time.time()
|
||||
try:
|
||||
async with _redis_client as redis_client:
|
||||
result = await redis_client.incr(name=key, amount=value)
|
||||
## LOGGING ##
|
||||
end_time = time.time()
|
||||
_duration = end_time - start_time
|
||||
asyncio.create_task(
|
||||
self.service_logger_obj.async_service_success_hook(
|
||||
service=ServiceTypes.REDIS,
|
||||
duration=_duration,
|
||||
)
|
||||
)
|
||||
return result
|
||||
except Exception as e:
|
||||
## LOGGING ##
|
||||
end_time = time.time()
|
||||
_duration = end_time - start_time
|
||||
asyncio.create_task(
|
||||
self.service_logger_obj.async_service_failure_hook(
|
||||
service=ServiceTypes.REDIS, duration=_duration, error=e
|
||||
)
|
||||
)
|
||||
verbose_logger.error(
|
||||
"LiteLLM Redis Caching: async async_increment() - Got exception from REDIS %s, Writing value=%s",
|
||||
str(e),
|
||||
|
@ -348,6 +445,7 @@ class RedisCache(BaseCache):
|
|||
async def async_get_cache(self, key, **kwargs):
|
||||
_redis_client = self.init_async_client()
|
||||
key = self.check_and_fix_namespace(key=key)
|
||||
start_time = time.time()
|
||||
async with _redis_client as redis_client:
|
||||
try:
|
||||
print_verbose(f"Get Async Redis Cache: key: {key}")
|
||||
|
@ -356,8 +454,24 @@ class RedisCache(BaseCache):
|
|||
f"Got Async Redis Cache: key: {key}, cached_response {cached_response}"
|
||||
)
|
||||
response = self._get_cache_logic(cached_response=cached_response)
|
||||
## LOGGING ##
|
||||
end_time = time.time()
|
||||
_duration = end_time - start_time
|
||||
asyncio.create_task(
|
||||
self.service_logger_obj.async_service_success_hook(
|
||||
service=ServiceTypes.REDIS, duration=_duration
|
||||
)
|
||||
)
|
||||
return response
|
||||
except Exception as e:
|
||||
## LOGGING ##
|
||||
end_time = time.time()
|
||||
_duration = end_time - start_time
|
||||
asyncio.create_task(
|
||||
self.service_logger_obj.async_service_failure_hook(
|
||||
service=ServiceTypes.REDIS, duration=_duration, error=e
|
||||
)
|
||||
)
|
||||
# NON blocking - notify users Redis is throwing an exception
|
||||
print_verbose(
|
||||
f"LiteLLM Caching: async get() - Got exception from REDIS: {str(e)}"
|
||||
|
@ -369,6 +483,7 @@ class RedisCache(BaseCache):
|
|||
"""
|
||||
_redis_client = await self.init_async_client()
|
||||
key_value_dict = {}
|
||||
start_time = time.time()
|
||||
try:
|
||||
async with _redis_client as redis_client:
|
||||
_keys = []
|
||||
|
@ -377,6 +492,15 @@ class RedisCache(BaseCache):
|
|||
_keys.append(cache_key)
|
||||
results = await redis_client.mget(keys=_keys)
|
||||
|
||||
## LOGGING ##
|
||||
end_time = time.time()
|
||||
_duration = end_time - start_time
|
||||
asyncio.create_task(
|
||||
self.service_logger_obj.async_service_success_hook(
|
||||
service=ServiceTypes.REDIS, duration=_duration
|
||||
)
|
||||
)
|
||||
|
||||
# Associate the results back with their keys.
|
||||
# 'results' is a list of values corresponding to the order of keys in 'key_list'.
|
||||
key_value_dict = dict(zip(key_list, results))
|
||||
|
@ -388,6 +512,14 @@ class RedisCache(BaseCache):
|
|||
|
||||
return decoded_results
|
||||
except Exception as e:
|
||||
## LOGGING ##
|
||||
end_time = time.time()
|
||||
_duration = end_time - start_time
|
||||
asyncio.create_task(
|
||||
self.service_logger_obj.async_service_failure_hook(
|
||||
service=ServiceTypes.REDIS, duration=_duration, error=e
|
||||
)
|
||||
)
|
||||
print_verbose(f"Error occurred in pipeline read - {str(e)}")
|
||||
return key_value_dict
|
||||
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
# used for /metrics endpoint on LiteLLM Proxy
|
||||
#### What this does ####
|
||||
# On success + failure, log events to Supabase
|
||||
# On success, log events to Prometheus
|
||||
|
||||
import dotenv, os
|
||||
import requests
|
||||
|
|
139
litellm/integrations/prometheus_services.py
Normal file
139
litellm/integrations/prometheus_services.py
Normal file
|
@ -0,0 +1,139 @@
|
|||
# used for monitoring litellm services health on `/metrics` endpoint on LiteLLM Proxy
|
||||
#### What this does ####
|
||||
# On success + failure, log events to Prometheus for litellm / adjacent services (litellm, redis, postgres, llm api providers)
|
||||
|
||||
|
||||
import dotenv, os
|
||||
import requests
|
||||
|
||||
dotenv.load_dotenv() # Loading env variables using dotenv
|
||||
import traceback
|
||||
import datetime, subprocess, sys
|
||||
import litellm, uuid
|
||||
from litellm._logging import print_verbose, verbose_logger
|
||||
from litellm.types.services import ServiceLoggerPayload, ServiceTypes
|
||||
|
||||
|
||||
class PrometheusServicesLogger:
|
||||
# Class variables or attributes
|
||||
litellm_service_latency = None # Class-level attribute to store the Histogram
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
mock_testing: bool = False,
|
||||
**kwargs,
|
||||
):
|
||||
try:
|
||||
from prometheus_client import Counter, Histogram, REGISTRY
|
||||
|
||||
self.Histogram = Histogram
|
||||
self.Counter = Counter
|
||||
self.REGISTRY = REGISTRY
|
||||
|
||||
verbose_logger.debug(f"in init prometheus services metrics")
|
||||
|
||||
self.services = [item.value for item in ServiceTypes]
|
||||
|
||||
self.payload_to_prometheus_map = (
|
||||
{}
|
||||
) # store the prometheus histogram/counter we need to call for each field in payload
|
||||
|
||||
for service in self.services:
|
||||
histogram = self.create_histogram(service)
|
||||
self.payload_to_prometheus_map[service] = histogram
|
||||
|
||||
self.prometheus_to_amount_map: dict = (
|
||||
{}
|
||||
) # the field / value in ServiceLoggerPayload the object needs to be incremented by
|
||||
|
||||
# self.payload_to_prometheus_map["service"] = [self.litellm_service_latency]
|
||||
# self.prometheus_to_amount_map[self.litellm_service_latency._name] = (
|
||||
# "duration"
|
||||
# )
|
||||
|
||||
### MOCK TESTING ###
|
||||
self.mock_testing = mock_testing
|
||||
self.mock_testing_success_calls = 0
|
||||
self.mock_testing_failure_calls = 0
|
||||
|
||||
except Exception as e:
|
||||
print_verbose(f"Got exception on init prometheus client {str(e)}")
|
||||
raise e
|
||||
|
||||
def is_metric_registered(self, metric_name) -> bool:
|
||||
for metric in self.REGISTRY.collect():
|
||||
print(f"metric name: {metric.name}")
|
||||
if metric_name == metric.name:
|
||||
return True
|
||||
return False
|
||||
|
||||
def get_metric(self, metric_name):
|
||||
for metric in self.REGISTRY.collect():
|
||||
for sample in metric.samples:
|
||||
if metric_name == sample.name:
|
||||
return metric
|
||||
return None
|
||||
|
||||
def create_histogram(self, label: str):
|
||||
metric_name = "litellm_{}_latency".format(label)
|
||||
is_registered = self.is_metric_registered(metric_name)
|
||||
if is_registered:
|
||||
return self.get_metric(metric_name)
|
||||
return self.Histogram(
|
||||
metric_name,
|
||||
"Latency for {} service".format(label),
|
||||
labelnames=[label],
|
||||
)
|
||||
|
||||
def observe_histogram(
|
||||
self,
|
||||
histogram,
|
||||
labels: str,
|
||||
amount: float,
|
||||
):
|
||||
assert isinstance(histogram, self.Histogram)
|
||||
|
||||
histogram.labels(labels).observe(amount)
|
||||
|
||||
def increment_counter(
|
||||
self,
|
||||
counter,
|
||||
labels: list,
|
||||
amount: float,
|
||||
):
|
||||
assert isinstance(counter, self.Counter)
|
||||
|
||||
counter.labels(labels).inc(amount)
|
||||
|
||||
def service_success_hook(self, payload: ServiceLoggerPayload):
|
||||
if self.mock_testing:
|
||||
self.mock_testing_success_calls += 1
|
||||
|
||||
if payload.service.value in self.payload_to_prometheus_map:
|
||||
self.observe_histogram(
|
||||
histogram=self.payload_to_prometheus_map[payload.service.value],
|
||||
labels=payload.service.value,
|
||||
amount=payload.duration,
|
||||
)
|
||||
|
||||
def service_failure_hook(self, payload: ServiceLoggerPayload):
|
||||
if self.mock_testing:
|
||||
self.mock_testing_failure_calls += 1
|
||||
|
||||
async def async_service_success_hook(self, payload: ServiceLoggerPayload):
|
||||
"""
|
||||
Log successful call to prometheus
|
||||
"""
|
||||
if self.mock_testing:
|
||||
self.mock_testing_success_calls += 1
|
||||
|
||||
if payload.service.value in self.payload_to_prometheus_map:
|
||||
self.observe_histogram(
|
||||
histogram=self.payload_to_prometheus_map[payload.service.value],
|
||||
labels=payload.service.value,
|
||||
amount=payload.duration,
|
||||
)
|
||||
|
||||
async def async_service_failure_hook(self, payload: ServiceLoggerPayload):
|
||||
if self.mock_testing:
|
||||
self.mock_testing_failure_calls += 1
|
|
@ -25,6 +25,7 @@ model_list:
|
|||
|
||||
litellm_settings:
|
||||
success_callback: ["prometheus"]
|
||||
service_callback: ["prometheus_system"]
|
||||
upperbound_key_generate_params:
|
||||
max_budget: os.environ/LITELLM_UPPERBOUND_KEYS_MAX_BUDGET
|
||||
|
||||
|
|
199
litellm/tests/test_prometheus_service.py
Normal file
199
litellm/tests/test_prometheus_service.py
Normal file
|
@ -0,0 +1,199 @@
|
|||
# What is this?
|
||||
## Unit Tests for prometheus service monitoring
|
||||
|
||||
import json
|
||||
import sys
|
||||
import os
|
||||
import io, asyncio
|
||||
|
||||
sys.path.insert(0, os.path.abspath("../.."))
|
||||
import pytest
|
||||
from litellm import acompletion, Cache
|
||||
from litellm._service_logger import ServiceLogging
|
||||
from litellm.integrations.prometheus_services import PrometheusServicesLogger
|
||||
import litellm
|
||||
|
||||
"""
|
||||
- Check if it receives a call when redis is used
|
||||
- Check if it fires messages accordingly
|
||||
"""
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_init_prometheus():
|
||||
"""
|
||||
- Run completion with caching
|
||||
- Assert success callback gets called
|
||||
"""
|
||||
|
||||
pl = PrometheusServicesLogger(mock_testing=True)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_completion_with_caching():
|
||||
"""
|
||||
- Run completion with caching
|
||||
- Assert success callback gets called
|
||||
"""
|
||||
|
||||
litellm.set_verbose = True
|
||||
litellm.cache = Cache(type="redis")
|
||||
litellm.service_callback = ["prometheus_system"]
|
||||
|
||||
sl = ServiceLogging(mock_testing=True)
|
||||
sl.prometheusServicesLogger.mock_testing = True
|
||||
litellm.cache.cache.service_logger_obj = sl
|
||||
|
||||
messages = [{"role": "user", "content": "Hey, how's it going?"}]
|
||||
response1 = await acompletion(
|
||||
model="gpt-3.5-turbo", messages=messages, caching=True
|
||||
)
|
||||
response1 = await acompletion(
|
||||
model="gpt-3.5-turbo", messages=messages, caching=True
|
||||
)
|
||||
|
||||
assert sl.mock_testing_async_success_hook > 0
|
||||
assert sl.prometheusServicesLogger.mock_testing_success_calls > 0
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_completion_with_caching_bad_call():
|
||||
"""
|
||||
- Run completion with caching (incorrect credentials)
|
||||
- Assert failure callback gets called
|
||||
"""
|
||||
litellm.set_verbose = True
|
||||
sl = ServiceLogging(mock_testing=True)
|
||||
try:
|
||||
litellm.cache = Cache(type="redis", host="hello-world")
|
||||
litellm.service_callback = ["prometheus_system"]
|
||||
|
||||
litellm.cache.cache.service_logger_obj = sl
|
||||
|
||||
messages = [{"role": "user", "content": "Hey, how's it going?"}]
|
||||
response1 = await acompletion(
|
||||
model="gpt-3.5-turbo", messages=messages, caching=True
|
||||
)
|
||||
response1 = await acompletion(
|
||||
model="gpt-3.5-turbo", messages=messages, caching=True
|
||||
)
|
||||
except Exception as e:
|
||||
pass
|
||||
|
||||
assert sl.mock_testing_async_failure_hook > 0
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_router_with_caching():
|
||||
"""
|
||||
- Run router with usage-based-routing-v2
|
||||
- Assert success callback gets called
|
||||
"""
|
||||
try:
|
||||
|
||||
def get_azure_params(deployment_name: str):
|
||||
params = {
|
||||
"model": f"azure/{deployment_name}",
|
||||
"api_key": os.environ["AZURE_API_KEY"],
|
||||
"api_version": os.environ["AZURE_API_VERSION"],
|
||||
"api_base": os.environ["AZURE_API_BASE"],
|
||||
}
|
||||
return params
|
||||
|
||||
model_list = [
|
||||
{
|
||||
"model_name": "azure/gpt-4",
|
||||
"litellm_params": get_azure_params("chatgpt-v-2"),
|
||||
"tpm": 100,
|
||||
},
|
||||
{
|
||||
"model_name": "azure/gpt-4",
|
||||
"litellm_params": get_azure_params("chatgpt-v-2"),
|
||||
"tpm": 1000,
|
||||
},
|
||||
]
|
||||
|
||||
router = litellm.Router(
|
||||
model_list=model_list,
|
||||
set_verbose=True,
|
||||
debug_level="DEBUG",
|
||||
routing_strategy="usage-based-routing-v2",
|
||||
redis_host=os.environ["REDIS_HOST"],
|
||||
redis_port=os.environ["REDIS_PORT"],
|
||||
redis_password=os.environ["REDIS_PASSWORD"],
|
||||
)
|
||||
|
||||
litellm.service_callback = ["prometheus_system"]
|
||||
|
||||
sl = ServiceLogging(mock_testing=True)
|
||||
sl.prometheusServicesLogger.mock_testing = True
|
||||
router.cache.redis_cache.service_logger_obj = sl
|
||||
|
||||
messages = [{"role": "user", "content": "Hey, how's it going?"}]
|
||||
response1 = await router.acompletion(model="azure/gpt-4", messages=messages)
|
||||
response1 = await router.acompletion(model="azure/gpt-4", messages=messages)
|
||||
|
||||
assert sl.mock_testing_async_success_hook > 0
|
||||
assert sl.prometheusServicesLogger.mock_testing_success_calls > 0
|
||||
|
||||
except Exception as e:
|
||||
pytest.fail(f"An exception occured - {str(e)}")
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_router_with_caching_bad_call():
|
||||
"""
|
||||
- Run completion with caching (incorrect credentials)
|
||||
- Assert failure callback gets called
|
||||
"""
|
||||
try:
|
||||
|
||||
def get_azure_params(deployment_name: str):
|
||||
params = {
|
||||
"model": f"azure/{deployment_name}",
|
||||
"api_key": os.environ["AZURE_API_KEY"],
|
||||
"api_version": os.environ["AZURE_API_VERSION"],
|
||||
"api_base": os.environ["AZURE_API_BASE"],
|
||||
}
|
||||
return params
|
||||
|
||||
model_list = [
|
||||
{
|
||||
"model_name": "azure/gpt-4",
|
||||
"litellm_params": get_azure_params("chatgpt-v-2"),
|
||||
"tpm": 100,
|
||||
},
|
||||
{
|
||||
"model_name": "azure/gpt-4",
|
||||
"litellm_params": get_azure_params("chatgpt-v-2"),
|
||||
"tpm": 1000,
|
||||
},
|
||||
]
|
||||
|
||||
router = litellm.Router(
|
||||
model_list=model_list,
|
||||
set_verbose=True,
|
||||
debug_level="DEBUG",
|
||||
routing_strategy="usage-based-routing-v2",
|
||||
redis_host="hello world",
|
||||
redis_port=os.environ["REDIS_PORT"],
|
||||
redis_password=os.environ["REDIS_PASSWORD"],
|
||||
)
|
||||
|
||||
litellm.service_callback = ["prometheus_system"]
|
||||
|
||||
sl = ServiceLogging(mock_testing=True)
|
||||
sl.prometheusServicesLogger.mock_testing = True
|
||||
router.cache.redis_cache.service_logger_obj = sl
|
||||
|
||||
messages = [{"role": "user", "content": "Hey, how's it going?"}]
|
||||
try:
|
||||
response1 = await router.acompletion(model="azure/gpt-4", messages=messages)
|
||||
response1 = await router.acompletion(model="azure/gpt-4", messages=messages)
|
||||
except Exception as e:
|
||||
pass
|
||||
|
||||
assert sl.mock_testing_async_failure_hook > 0
|
||||
|
||||
except Exception as e:
|
||||
pytest.fail(f"An exception occured - {str(e)}")
|
30
litellm/types/services.py
Normal file
30
litellm/types/services.py
Normal file
|
@ -0,0 +1,30 @@
|
|||
import uuid, enum
|
||||
from pydantic import BaseModel, Field
|
||||
from typing import Optional
|
||||
|
||||
|
||||
class ServiceTypes(enum.Enum):
|
||||
"""
|
||||
Enum for litellm-adjacent services (redis/postgres/etc.)
|
||||
"""
|
||||
|
||||
REDIS = "redis"
|
||||
DB = "postgres"
|
||||
|
||||
|
||||
class ServiceLoggerPayload(BaseModel):
|
||||
"""
|
||||
The payload logged during service success/failure
|
||||
"""
|
||||
|
||||
is_error: bool = Field(description="did an error occur")
|
||||
error: Optional[str] = Field(None, description="what was the error")
|
||||
service: ServiceTypes = Field(description="who is this for? - postgres/redis")
|
||||
duration: float = Field(description="How long did the request take?")
|
||||
|
||||
def to_json(self, **kwargs):
|
||||
try:
|
||||
return self.model_dump(**kwargs) # noqa
|
||||
except Exception as e:
|
||||
# if using pydantic v1
|
||||
return self.dict(**kwargs)
|
|
@ -29,7 +29,9 @@ from tokenizers import Tokenizer
|
|||
from dataclasses import (
|
||||
dataclass,
|
||||
field,
|
||||
) # for storing API inputs, outputs, and metadata
|
||||
)
|
||||
|
||||
import litellm._service_logger # for storing API inputs, outputs, and metadata
|
||||
|
||||
try:
|
||||
# this works in python 3.8
|
||||
|
@ -69,6 +71,7 @@ from .integrations.custom_logger import CustomLogger
|
|||
from .integrations.langfuse import LangFuseLogger
|
||||
from .integrations.datadog import DataDogLogger
|
||||
from .integrations.prometheus import PrometheusLogger
|
||||
from .integrations.prometheus_services import PrometheusServicesLogger
|
||||
from .integrations.dynamodb import DyanmoDBLogger
|
||||
from .integrations.s3 import S3Logger
|
||||
from .integrations.clickhouse import ClickhouseLogger
|
||||
|
@ -6564,7 +6567,9 @@ def handle_failure(exception, traceback_exception, start_time, end_time, args, k
|
|||
for detail in additional_details:
|
||||
slack_msg += f"{detail}: {additional_details[detail]}\n"
|
||||
slack_msg += f"Traceback: {traceback_exception}"
|
||||
truncated_slack_msg = textwrap.shorten(slack_msg, width=512, placeholder="...")
|
||||
truncated_slack_msg = textwrap.shorten(
|
||||
slack_msg, width=512, placeholder="..."
|
||||
)
|
||||
slack_app.client.chat_postMessage(
|
||||
channel=alerts_channel, text=truncated_slack_msg
|
||||
)
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue