diff --git a/litellm/_service_logger.py b/litellm/_service_logger.py index d4aad68bb..f777c93d4 100644 --- a/litellm/_service_logger.py +++ b/litellm/_service_logger.py @@ -241,7 +241,8 @@ class ServiceLogging(CustomLogger): if callback == "prometheus_system": await self.init_prometheus_services_logger_if_none() await self.prometheusServicesLogger.async_service_failure_hook( - payload=payload + payload=payload, + error=error, ) elif callback == "datadog": await self.init_datadog_logger_if_none() diff --git a/litellm/integrations/prometheus_services.py b/litellm/integrations/prometheus_services.py index a36ac9b9c..df94ffcd8 100644 --- a/litellm/integrations/prometheus_services.py +++ b/litellm/integrations/prometheus_services.py @@ -9,6 +9,7 @@ import subprocess import sys import traceback import uuid +from typing import List, Optional, Union import dotenv import requests # type: ignore @@ -51,7 +52,9 @@ class PrometheusServicesLogger: for service in self.services: histogram = self.create_histogram(service, type_of_request="latency") counter_failed_request = self.create_counter( - service, type_of_request="failed_requests" + service, + type_of_request="failed_requests", + additional_labels=["error_class", "function_name"], ) counter_total_requests = self.create_counter( service, type_of_request="total_requests" @@ -99,7 +102,12 @@ class PrometheusServicesLogger: buckets=LATENCY_BUCKETS, ) - def create_counter(self, service: str, type_of_request: str): + def create_counter( + self, + service: str, + type_of_request: str, + additional_labels: Optional[List[str]] = None, + ): metric_name = "litellm_{}_{}".format(service, type_of_request) is_registered = self.is_metric_registered(metric_name) if is_registered: @@ -107,7 +115,7 @@ class PrometheusServicesLogger: return self.Counter( metric_name, "Total {} for {} service".format(type_of_request, service), - labelnames=[service], + labelnames=[service] + (additional_labels or []), ) def observe_histogram( @@ -125,10 +133,14 @@ class PrometheusServicesLogger: counter, labels: str, amount: float, + additional_labels: Optional[List[str]] = [], ): assert isinstance(counter, self.Counter) - counter.labels(labels).inc(amount) + if additional_labels: + counter.labels(labels, *additional_labels).inc(amount) + else: + counter.labels(labels).inc(amount) def service_success_hook(self, payload: ServiceLoggerPayload): if self.mock_testing: @@ -187,16 +199,25 @@ class PrometheusServicesLogger: amount=1, # LOG TOTAL REQUESTS TO PROMETHEUS ) - async def async_service_failure_hook(self, payload: ServiceLoggerPayload): + async def async_service_failure_hook( + self, + payload: ServiceLoggerPayload, + error: Union[str, Exception], + ): if self.mock_testing: self.mock_testing_failure_calls += 1 + error_class = error.__class__.__name__ + function_name = payload.call_type if payload.service.value in self.payload_to_prometheus_map: prom_objects = self.payload_to_prometheus_map[payload.service.value] for obj in prom_objects: + # increment both failed and total requests if isinstance(obj, self.Counter): self.increment_counter( counter=obj, labels=payload.service.value, + # log additional_labels=["error_class", "function_name"], used for debugging what's going wrong with the DB + additional_labels=[error_class, function_name], amount=1, # LOG ERROR COUNT TO PROMETHEUS ) diff --git a/litellm/proxy/auth/auth_checks.py b/litellm/proxy/auth/auth_checks.py index 8d504c739..12b6ec372 100644 --- a/litellm/proxy/auth/auth_checks.py +++ b/litellm/proxy/auth/auth_checks.py @@ -32,7 +32,7 @@ from litellm.proxy._types import ( UserAPIKeyAuth, ) from litellm.proxy.auth.route_checks import RouteChecks -from litellm.proxy.utils import PrismaClient, ProxyLogging, log_to_opentelemetry +from litellm.proxy.utils import PrismaClient, ProxyLogging, log_db_metrics from litellm.types.services import ServiceLoggerPayload, ServiceTypes from .auth_checks_organization import organization_role_based_access_check @@ -290,7 +290,7 @@ def get_actual_routes(allowed_routes: list) -> list: return actual_routes -@log_to_opentelemetry +@log_db_metrics async def get_end_user_object( end_user_id: Optional[str], prisma_client: Optional[PrismaClient], @@ -415,7 +415,7 @@ def _update_last_db_access_time( last_db_access_time[key] = (value, time.time()) -@log_to_opentelemetry +@log_db_metrics async def get_user_object( user_id: str, prisma_client: Optional[PrismaClient], @@ -562,7 +562,7 @@ async def _delete_cache_key_object( ) -@log_to_opentelemetry +@log_db_metrics async def _get_team_db_check(team_id: str, prisma_client: PrismaClient): return await prisma_client.db.litellm_teamtable.find_unique( where={"team_id": team_id} @@ -658,7 +658,7 @@ async def get_team_object( ) -@log_to_opentelemetry +@log_db_metrics async def get_key_object( hashed_token: str, prisma_client: Optional[PrismaClient], @@ -766,7 +766,7 @@ async def _handle_failed_db_connection_for_get_key_object( raise e -@log_to_opentelemetry +@log_db_metrics async def get_org_object( org_id: str, prisma_client: Optional[PrismaClient], diff --git a/litellm/proxy/auth/user_api_key_auth.py b/litellm/proxy/auth/user_api_key_auth.py index d25b6f620..f11bfcbc9 100644 --- a/litellm/proxy/auth/user_api_key_auth.py +++ b/litellm/proxy/auth/user_api_key_auth.py @@ -58,7 +58,7 @@ from litellm.proxy.auth.auth_checks import ( get_org_object, get_team_object, get_user_object, - log_to_opentelemetry, + log_db_metrics, ) from litellm.proxy.auth.auth_utils import ( _get_request_ip_address, diff --git a/litellm/proxy/db/log_db_metrics.py b/litellm/proxy/db/log_db_metrics.py new file mode 100644 index 000000000..e8040ae60 --- /dev/null +++ b/litellm/proxy/db/log_db_metrics.py @@ -0,0 +1,138 @@ +""" +Handles logging DB success/failure to ServiceLogger() + +ServiceLogger() then sends DB logs to Prometheus, OTEL, Datadog etc +""" + +from datetime import datetime +from functools import wraps +from typing import Callable, Dict, Tuple + +from litellm._service_logger import ServiceTypes +from litellm.litellm_core_utils.core_helpers import ( + _get_parent_otel_span_from_kwargs, + get_litellm_metadata_from_kwargs, +) + + +def log_db_metrics(func): + """ + Decorator to log the duration of a DB related function to ServiceLogger() + + Handles logging DB success/failure to ServiceLogger(), which logs to Prometheus, OTEL, Datadog + + When logging Failure it checks if the Exception is a PrismaError, httpx.ConnectError or httpx.TimeoutException and then logs that as a DB Service Failure + + Args: + func: The function to be decorated + + Returns: + Result from the decorated function + + Raises: + Exception: If the decorated function raises an exception + """ + + @wraps(func) + async def wrapper(*args, **kwargs): + from prisma.errors import PrismaError + + start_time: datetime = datetime.now() + + try: + result = await func(*args, **kwargs) + end_time: datetime = datetime.now() + from litellm.proxy.proxy_server import proxy_logging_obj + + if "PROXY" not in func.__name__: + await proxy_logging_obj.service_logging_obj.async_service_success_hook( + service=ServiceTypes.DB, + call_type=func.__name__, + parent_otel_span=kwargs.get("parent_otel_span", None), + duration=(end_time - start_time).total_seconds(), + start_time=start_time, + end_time=end_time, + event_metadata={ + "function_name": func.__name__, + "function_kwargs": kwargs, + "function_args": args, + }, + ) + elif ( + # in litellm custom callbacks kwargs is passed as arg[0] + # https://docs.litellm.ai/docs/observability/custom_callback#callback-functions + args is not None + and len(args) > 0 + and isinstance(args[0], dict) + ): + passed_kwargs = args[0] + parent_otel_span = _get_parent_otel_span_from_kwargs( + kwargs=passed_kwargs + ) + if parent_otel_span is not None: + metadata = get_litellm_metadata_from_kwargs(kwargs=passed_kwargs) + await proxy_logging_obj.service_logging_obj.async_service_success_hook( + service=ServiceTypes.BATCH_WRITE_TO_DB, + call_type=func.__name__, + parent_otel_span=parent_otel_span, + duration=0.0, + start_time=start_time, + end_time=end_time, + event_metadata=metadata, + ) + # end of logging to otel + return result + except Exception as e: + end_time: datetime = datetime.now() + await _handle_logging_db_exception( + e=e, + func=func, + kwargs=kwargs, + args=args, + start_time=start_time, + end_time=end_time, + ) + raise e + + return wrapper + + +def _is_exception_related_to_db(e: Exception) -> bool: + """ + Returns True if the exception is related to the DB + """ + + import httpx + from prisma.errors import PrismaError + + return isinstance(e, (PrismaError, httpx.ConnectError, httpx.TimeoutException)) + + +async def _handle_logging_db_exception( + e: Exception, + func: Callable, + kwargs: Dict, + args: Tuple, + start_time: datetime, + end_time: datetime, +) -> None: + from litellm.proxy.proxy_server import proxy_logging_obj + + # don't log this as a DB Service Failure, if the DB did not raise an exception + if _is_exception_related_to_db(e) is not True: + return + + await proxy_logging_obj.service_logging_obj.async_service_failure_hook( + error=e, + service=ServiceTypes.DB, + call_type=func.__name__, + parent_otel_span=kwargs.get("parent_otel_span"), + duration=(end_time - start_time).total_seconds(), + start_time=start_time, + end_time=end_time, + event_metadata={ + "function_name": func.__name__, + "function_kwargs": kwargs, + "function_args": args, + }, + ) diff --git a/litellm/proxy/proxy_server.py b/litellm/proxy/proxy_server.py index ce58c4d75..12e80876c 100644 --- a/litellm/proxy/proxy_server.py +++ b/litellm/proxy/proxy_server.py @@ -125,7 +125,7 @@ from litellm.proxy._types import * from litellm.proxy.analytics_endpoints.analytics_endpoints import ( router as analytics_router, ) -from litellm.proxy.auth.auth_checks import log_to_opentelemetry +from litellm.proxy.auth.auth_checks import log_db_metrics from litellm.proxy.auth.auth_utils import check_response_size_is_safe from litellm.proxy.auth.handle_jwt import JWTHandler from litellm.proxy.auth.litellm_license import LicenseCheck @@ -747,7 +747,7 @@ async def _PROXY_failure_handler( pass -@log_to_opentelemetry +@log_db_metrics async def _PROXY_track_cost_callback( kwargs, # kwargs to completion completion_response: litellm.ModelResponse, # response from completion diff --git a/litellm/proxy/utils.py b/litellm/proxy/utils.py index 44e9d151d..9d33244a0 100644 --- a/litellm/proxy/utils.py +++ b/litellm/proxy/utils.py @@ -55,10 +55,6 @@ from litellm.integrations.custom_guardrail import CustomGuardrail from litellm.integrations.custom_logger import CustomLogger from litellm.integrations.SlackAlerting.slack_alerting import SlackAlerting from litellm.integrations.SlackAlerting.utils import _add_langfuse_trace_id_to_alert -from litellm.litellm_core_utils.core_helpers import ( - _get_parent_otel_span_from_kwargs, - get_litellm_metadata_from_kwargs, -) from litellm.litellm_core_utils.litellm_logging import Logging from litellm.llms.custom_httpx.httpx_handler import HTTPHandler from litellm.proxy._types import ( @@ -77,6 +73,7 @@ from litellm.proxy.db.create_views import ( create_missing_views, should_create_missing_views, ) +from litellm.proxy.db.log_db_metrics import log_db_metrics from litellm.proxy.db.prisma_client import PrismaWrapper from litellm.proxy.hooks.cache_control_check import _PROXY_CacheControlCheck from litellm.proxy.hooks.max_budget_limiter import _PROXY_MaxBudgetLimiter @@ -137,83 +134,6 @@ def safe_deep_copy(data): return new_data -def log_to_opentelemetry(func): - """ - Decorator to log the duration of a DB related function to ServiceLogger() - - Handles logging DB success/failure to ServiceLogger(), which logs to Prometheus, OTEL, Datadog - """ - - @wraps(func) - async def wrapper(*args, **kwargs): - start_time: datetime = datetime.now() - - try: - result = await func(*args, **kwargs) - end_time: datetime = datetime.now() - from litellm.proxy.proxy_server import proxy_logging_obj - - if "PROXY" not in func.__name__: - await proxy_logging_obj.service_logging_obj.async_service_success_hook( - service=ServiceTypes.DB, - call_type=func.__name__, - parent_otel_span=kwargs.get("parent_otel_span", None), - duration=(end_time - start_time).total_seconds(), - start_time=start_time, - end_time=end_time, - event_metadata={ - "function_name": func.__name__, - "function_kwargs": kwargs, - "function_args": args, - }, - ) - elif ( - # in litellm custom callbacks kwargs is passed as arg[0] - # https://docs.litellm.ai/docs/observability/custom_callback#callback-functions - args is not None - and len(args) > 0 - and isinstance(args[0], dict) - ): - passed_kwargs = args[0] - parent_otel_span = _get_parent_otel_span_from_kwargs( - kwargs=passed_kwargs - ) - if parent_otel_span is not None: - metadata = get_litellm_metadata_from_kwargs(kwargs=passed_kwargs) - await proxy_logging_obj.service_logging_obj.async_service_success_hook( - service=ServiceTypes.BATCH_WRITE_TO_DB, - call_type=func.__name__, - parent_otel_span=parent_otel_span, - duration=0.0, - start_time=start_time, - end_time=end_time, - event_metadata=metadata, - ) - # end of logging to otel - return result - except Exception as e: - from litellm.proxy.proxy_server import proxy_logging_obj - - end_time: datetime = datetime.now() - await proxy_logging_obj.service_logging_obj.async_service_failure_hook( - error=e, - service=ServiceTypes.DB, - call_type=func.__name__, - parent_otel_span=kwargs.get("parent_otel_span"), - duration=(end_time - start_time).total_seconds(), - start_time=start_time, - end_time=end_time, - event_metadata={ - "function_name": func.__name__, - "function_kwargs": kwargs, - "function_args": args, - }, - ) - raise e - - return wrapper - - class InternalUsageCache: def __init__(self, dual_cache: DualCache): self.dual_cache: DualCache = dual_cache @@ -1397,7 +1317,7 @@ class PrismaClient: return - @log_to_opentelemetry + @log_db_metrics @backoff.on_exception( backoff.expo, Exception, # base exception to catch for the backoff @@ -1463,7 +1383,7 @@ class PrismaClient: max_time=10, # maximum total time to retry for on_backoff=on_backoff, # specifying the function to call on backoff ) - @log_to_opentelemetry + @log_db_metrics async def get_data( # noqa: PLR0915 self, token: Optional[Union[str, list]] = None, diff --git a/tests/logging_callback_tests/test_log_db_redis_services.py b/tests/logging_callback_tests/test_log_db_redis_services.py index 9f5db8009..9824e1a5b 100644 --- a/tests/logging_callback_tests/test_log_db_redis_services.py +++ b/tests/logging_callback_tests/test_log_db_redis_services.py @@ -17,23 +17,25 @@ import pytest import litellm from litellm import completion from litellm._logging import verbose_logger -from litellm.proxy.utils import log_to_opentelemetry, ServiceTypes +from litellm.proxy.utils import log_db_metrics, ServiceTypes from datetime import datetime +import httpx +from prisma.errors import ClientNotConnectedError # Test async function to decorate -@log_to_opentelemetry +@log_db_metrics async def sample_db_function(*args, **kwargs): return "success" -@log_to_opentelemetry +@log_db_metrics async def sample_proxy_function(*args, **kwargs): return "success" @pytest.mark.asyncio -async def test_log_to_opentelemetry_success(): +async def test_log_db_metrics_success(): # Mock the proxy_logging_obj with patch("litellm.proxy.proxy_server.proxy_logging_obj") as mock_proxy_logging: # Setup mock @@ -61,14 +63,14 @@ async def test_log_to_opentelemetry_success(): @pytest.mark.asyncio -async def test_log_to_opentelemetry_duration(): +async def test_log_db_metrics_duration(): # Mock the proxy_logging_obj with patch("litellm.proxy.proxy_server.proxy_logging_obj") as mock_proxy_logging: # Setup mock mock_proxy_logging.service_logging_obj.async_service_success_hook = AsyncMock() # Add a delay to the function to test duration - @log_to_opentelemetry + @log_db_metrics async def delayed_function(**kwargs): await asyncio.sleep(1) # 1 second delay return "success" @@ -95,23 +97,28 @@ async def test_log_to_opentelemetry_duration(): @pytest.mark.asyncio -async def test_log_to_opentelemetry_failure(): +async def test_log_db_metrics_failure(): + """ + should log a failure if a prisma error is raised + """ # Mock the proxy_logging_obj + from prisma.errors import ClientNotConnectedError + with patch("litellm.proxy.proxy_server.proxy_logging_obj") as mock_proxy_logging: # Setup mock mock_proxy_logging.service_logging_obj.async_service_failure_hook = AsyncMock() # Create a failing function - @log_to_opentelemetry + @log_db_metrics async def failing_function(**kwargs): - raise ValueError("Test error") + raise ClientNotConnectedError() # Call the decorated function and expect it to raise - with pytest.raises(ValueError) as exc_info: + with pytest.raises(ClientNotConnectedError) as exc_info: await failing_function(parent_otel_span="test_span") # Assertions - assert str(exc_info.value) == "Test error" + assert "Client is not connected to the query engine" in str(exc_info.value) mock_proxy_logging.service_logging_obj.async_service_failure_hook.assert_called_once() call_args = ( mock_proxy_logging.service_logging_obj.async_service_failure_hook.call_args[ @@ -125,4 +132,56 @@ async def test_log_to_opentelemetry_failure(): assert isinstance(call_args["duration"], float) assert isinstance(call_args["start_time"], datetime) assert isinstance(call_args["end_time"], datetime) - assert isinstance(call_args["error"], ValueError) + assert isinstance(call_args["error"], ClientNotConnectedError) + + +@pytest.mark.asyncio +@pytest.mark.parametrize( + "exception,should_log", + [ + (ValueError("Generic error"), False), + (KeyError("Missing key"), False), + (TypeError("Type error"), False), + (httpx.ConnectError("Failed to connect"), True), + (httpx.TimeoutException("Request timed out"), True), + (ClientNotConnectedError(), True), # Prisma error + ], +) +async def test_log_db_metrics_failure_error_types(exception, should_log): + """ + Why Test? + Users were seeing that non-DB errors were being logged as DB Service Failures + Example a failure to read a value from cache was being logged as a DB Service Failure + + + Parameterized test to verify: + - DB-related errors (Prisma, httpx) are logged as service failures + - Non-DB errors (ValueError, KeyError, etc.) are not logged + """ + with patch("litellm.proxy.proxy_server.proxy_logging_obj") as mock_proxy_logging: + mock_proxy_logging.service_logging_obj.async_service_failure_hook = AsyncMock() + + @log_db_metrics + async def failing_function(**kwargs): + raise exception + + # Call the function and expect it to raise the exception + with pytest.raises(type(exception)): + await failing_function(parent_otel_span="test_span") + + if should_log: + # Assert failure was logged for DB-related errors + mock_proxy_logging.service_logging_obj.async_service_failure_hook.assert_called_once() + call_args = mock_proxy_logging.service_logging_obj.async_service_failure_hook.call_args[ + 1 + ] + assert call_args["service"] == ServiceTypes.DB + assert call_args["call_type"] == "failing_function" + assert call_args["parent_otel_span"] == "test_span" + assert isinstance(call_args["duration"], float) + assert isinstance(call_args["start_time"], datetime) + assert isinstance(call_args["end_time"], datetime) + assert isinstance(call_args["error"], type(exception)) + else: + # Assert failure was NOT logged for non-DB errors + mock_proxy_logging.service_logging_obj.async_service_failure_hook.assert_not_called()