(feat) log error class, function_name on prometheus service failure hook + only log DB related failures on DB service hook (#6650)

* log error on prometheus service failure hook

* use a more accurate function name for wrapper that handles logging db metrics

* fix log_db_metrics

* test_log_db_metrics_failure_error_types

* fix linting

* fix auth checks
This commit is contained in:
Ishaan Jaff 2024-11-07 17:01:18 -08:00 committed by GitHub
parent ae385cfcdc
commit eb47117800
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
8 changed files with 249 additions and 110 deletions

View file

@ -241,7 +241,8 @@ class ServiceLogging(CustomLogger):
if callback == "prometheus_system":
await self.init_prometheus_services_logger_if_none()
await self.prometheusServicesLogger.async_service_failure_hook(
payload=payload
payload=payload,
error=error,
)
elif callback == "datadog":
await self.init_datadog_logger_if_none()

View file

@ -9,6 +9,7 @@ import subprocess
import sys
import traceback
import uuid
from typing import List, Optional, Union
import dotenv
import requests # type: ignore
@ -51,7 +52,9 @@ class PrometheusServicesLogger:
for service in self.services:
histogram = self.create_histogram(service, type_of_request="latency")
counter_failed_request = self.create_counter(
service, type_of_request="failed_requests"
service,
type_of_request="failed_requests",
additional_labels=["error_class", "function_name"],
)
counter_total_requests = self.create_counter(
service, type_of_request="total_requests"
@ -99,7 +102,12 @@ class PrometheusServicesLogger:
buckets=LATENCY_BUCKETS,
)
def create_counter(self, service: str, type_of_request: str):
def create_counter(
self,
service: str,
type_of_request: str,
additional_labels: Optional[List[str]] = None,
):
metric_name = "litellm_{}_{}".format(service, type_of_request)
is_registered = self.is_metric_registered(metric_name)
if is_registered:
@ -107,7 +115,7 @@ class PrometheusServicesLogger:
return self.Counter(
metric_name,
"Total {} for {} service".format(type_of_request, service),
labelnames=[service],
labelnames=[service] + (additional_labels or []),
)
def observe_histogram(
@ -125,10 +133,14 @@ class PrometheusServicesLogger:
counter,
labels: str,
amount: float,
additional_labels: Optional[List[str]] = [],
):
assert isinstance(counter, self.Counter)
counter.labels(labels).inc(amount)
if additional_labels:
counter.labels(labels, *additional_labels).inc(amount)
else:
counter.labels(labels).inc(amount)
def service_success_hook(self, payload: ServiceLoggerPayload):
if self.mock_testing:
@ -187,16 +199,25 @@ class PrometheusServicesLogger:
amount=1, # LOG TOTAL REQUESTS TO PROMETHEUS
)
async def async_service_failure_hook(self, payload: ServiceLoggerPayload):
async def async_service_failure_hook(
self,
payload: ServiceLoggerPayload,
error: Union[str, Exception],
):
if self.mock_testing:
self.mock_testing_failure_calls += 1
error_class = error.__class__.__name__
function_name = payload.call_type
if payload.service.value in self.payload_to_prometheus_map:
prom_objects = self.payload_to_prometheus_map[payload.service.value]
for obj in prom_objects:
# increment both failed and total requests
if isinstance(obj, self.Counter):
self.increment_counter(
counter=obj,
labels=payload.service.value,
# log additional_labels=["error_class", "function_name"], used for debugging what's going wrong with the DB
additional_labels=[error_class, function_name],
amount=1, # LOG ERROR COUNT TO PROMETHEUS
)

View file

@ -32,7 +32,7 @@ from litellm.proxy._types import (
UserAPIKeyAuth,
)
from litellm.proxy.auth.route_checks import RouteChecks
from litellm.proxy.utils import PrismaClient, ProxyLogging, log_to_opentelemetry
from litellm.proxy.utils import PrismaClient, ProxyLogging, log_db_metrics
from litellm.types.services import ServiceLoggerPayload, ServiceTypes
from .auth_checks_organization import organization_role_based_access_check
@ -290,7 +290,7 @@ def get_actual_routes(allowed_routes: list) -> list:
return actual_routes
@log_to_opentelemetry
@log_db_metrics
async def get_end_user_object(
end_user_id: Optional[str],
prisma_client: Optional[PrismaClient],
@ -415,7 +415,7 @@ def _update_last_db_access_time(
last_db_access_time[key] = (value, time.time())
@log_to_opentelemetry
@log_db_metrics
async def get_user_object(
user_id: str,
prisma_client: Optional[PrismaClient],
@ -562,7 +562,7 @@ async def _delete_cache_key_object(
)
@log_to_opentelemetry
@log_db_metrics
async def _get_team_db_check(team_id: str, prisma_client: PrismaClient):
return await prisma_client.db.litellm_teamtable.find_unique(
where={"team_id": team_id}
@ -658,7 +658,7 @@ async def get_team_object(
)
@log_to_opentelemetry
@log_db_metrics
async def get_key_object(
hashed_token: str,
prisma_client: Optional[PrismaClient],
@ -766,7 +766,7 @@ async def _handle_failed_db_connection_for_get_key_object(
raise e
@log_to_opentelemetry
@log_db_metrics
async def get_org_object(
org_id: str,
prisma_client: Optional[PrismaClient],

View file

@ -58,7 +58,7 @@ from litellm.proxy.auth.auth_checks import (
get_org_object,
get_team_object,
get_user_object,
log_to_opentelemetry,
log_db_metrics,
)
from litellm.proxy.auth.auth_utils import (
_get_request_ip_address,

View file

@ -0,0 +1,138 @@
"""
Handles logging DB success/failure to ServiceLogger()
ServiceLogger() then sends DB logs to Prometheus, OTEL, Datadog etc
"""
from datetime import datetime
from functools import wraps
from typing import Callable, Dict, Tuple
from litellm._service_logger import ServiceTypes
from litellm.litellm_core_utils.core_helpers import (
_get_parent_otel_span_from_kwargs,
get_litellm_metadata_from_kwargs,
)
def log_db_metrics(func):
"""
Decorator to log the duration of a DB related function to ServiceLogger()
Handles logging DB success/failure to ServiceLogger(), which logs to Prometheus, OTEL, Datadog
When logging Failure it checks if the Exception is a PrismaError, httpx.ConnectError or httpx.TimeoutException and then logs that as a DB Service Failure
Args:
func: The function to be decorated
Returns:
Result from the decorated function
Raises:
Exception: If the decorated function raises an exception
"""
@wraps(func)
async def wrapper(*args, **kwargs):
from prisma.errors import PrismaError
start_time: datetime = datetime.now()
try:
result = await func(*args, **kwargs)
end_time: datetime = datetime.now()
from litellm.proxy.proxy_server import proxy_logging_obj
if "PROXY" not in func.__name__:
await proxy_logging_obj.service_logging_obj.async_service_success_hook(
service=ServiceTypes.DB,
call_type=func.__name__,
parent_otel_span=kwargs.get("parent_otel_span", None),
duration=(end_time - start_time).total_seconds(),
start_time=start_time,
end_time=end_time,
event_metadata={
"function_name": func.__name__,
"function_kwargs": kwargs,
"function_args": args,
},
)
elif (
# in litellm custom callbacks kwargs is passed as arg[0]
# https://docs.litellm.ai/docs/observability/custom_callback#callback-functions
args is not None
and len(args) > 0
and isinstance(args[0], dict)
):
passed_kwargs = args[0]
parent_otel_span = _get_parent_otel_span_from_kwargs(
kwargs=passed_kwargs
)
if parent_otel_span is not None:
metadata = get_litellm_metadata_from_kwargs(kwargs=passed_kwargs)
await proxy_logging_obj.service_logging_obj.async_service_success_hook(
service=ServiceTypes.BATCH_WRITE_TO_DB,
call_type=func.__name__,
parent_otel_span=parent_otel_span,
duration=0.0,
start_time=start_time,
end_time=end_time,
event_metadata=metadata,
)
# end of logging to otel
return result
except Exception as e:
end_time: datetime = datetime.now()
await _handle_logging_db_exception(
e=e,
func=func,
kwargs=kwargs,
args=args,
start_time=start_time,
end_time=end_time,
)
raise e
return wrapper
def _is_exception_related_to_db(e: Exception) -> bool:
"""
Returns True if the exception is related to the DB
"""
import httpx
from prisma.errors import PrismaError
return isinstance(e, (PrismaError, httpx.ConnectError, httpx.TimeoutException))
async def _handle_logging_db_exception(
e: Exception,
func: Callable,
kwargs: Dict,
args: Tuple,
start_time: datetime,
end_time: datetime,
) -> None:
from litellm.proxy.proxy_server import proxy_logging_obj
# don't log this as a DB Service Failure, if the DB did not raise an exception
if _is_exception_related_to_db(e) is not True:
return
await proxy_logging_obj.service_logging_obj.async_service_failure_hook(
error=e,
service=ServiceTypes.DB,
call_type=func.__name__,
parent_otel_span=kwargs.get("parent_otel_span"),
duration=(end_time - start_time).total_seconds(),
start_time=start_time,
end_time=end_time,
event_metadata={
"function_name": func.__name__,
"function_kwargs": kwargs,
"function_args": args,
},
)

View file

@ -125,7 +125,7 @@ from litellm.proxy._types import *
from litellm.proxy.analytics_endpoints.analytics_endpoints import (
router as analytics_router,
)
from litellm.proxy.auth.auth_checks import log_to_opentelemetry
from litellm.proxy.auth.auth_checks import log_db_metrics
from litellm.proxy.auth.auth_utils import check_response_size_is_safe
from litellm.proxy.auth.handle_jwt import JWTHandler
from litellm.proxy.auth.litellm_license import LicenseCheck
@ -747,7 +747,7 @@ async def _PROXY_failure_handler(
pass
@log_to_opentelemetry
@log_db_metrics
async def _PROXY_track_cost_callback(
kwargs, # kwargs to completion
completion_response: litellm.ModelResponse, # response from completion

View file

@ -55,10 +55,6 @@ from litellm.integrations.custom_guardrail import CustomGuardrail
from litellm.integrations.custom_logger import CustomLogger
from litellm.integrations.SlackAlerting.slack_alerting import SlackAlerting
from litellm.integrations.SlackAlerting.utils import _add_langfuse_trace_id_to_alert
from litellm.litellm_core_utils.core_helpers import (
_get_parent_otel_span_from_kwargs,
get_litellm_metadata_from_kwargs,
)
from litellm.litellm_core_utils.litellm_logging import Logging
from litellm.llms.custom_httpx.httpx_handler import HTTPHandler
from litellm.proxy._types import (
@ -77,6 +73,7 @@ from litellm.proxy.db.create_views import (
create_missing_views,
should_create_missing_views,
)
from litellm.proxy.db.log_db_metrics import log_db_metrics
from litellm.proxy.db.prisma_client import PrismaWrapper
from litellm.proxy.hooks.cache_control_check import _PROXY_CacheControlCheck
from litellm.proxy.hooks.max_budget_limiter import _PROXY_MaxBudgetLimiter
@ -137,83 +134,6 @@ def safe_deep_copy(data):
return new_data
def log_to_opentelemetry(func):
"""
Decorator to log the duration of a DB related function to ServiceLogger()
Handles logging DB success/failure to ServiceLogger(), which logs to Prometheus, OTEL, Datadog
"""
@wraps(func)
async def wrapper(*args, **kwargs):
start_time: datetime = datetime.now()
try:
result = await func(*args, **kwargs)
end_time: datetime = datetime.now()
from litellm.proxy.proxy_server import proxy_logging_obj
if "PROXY" not in func.__name__:
await proxy_logging_obj.service_logging_obj.async_service_success_hook(
service=ServiceTypes.DB,
call_type=func.__name__,
parent_otel_span=kwargs.get("parent_otel_span", None),
duration=(end_time - start_time).total_seconds(),
start_time=start_time,
end_time=end_time,
event_metadata={
"function_name": func.__name__,
"function_kwargs": kwargs,
"function_args": args,
},
)
elif (
# in litellm custom callbacks kwargs is passed as arg[0]
# https://docs.litellm.ai/docs/observability/custom_callback#callback-functions
args is not None
and len(args) > 0
and isinstance(args[0], dict)
):
passed_kwargs = args[0]
parent_otel_span = _get_parent_otel_span_from_kwargs(
kwargs=passed_kwargs
)
if parent_otel_span is not None:
metadata = get_litellm_metadata_from_kwargs(kwargs=passed_kwargs)
await proxy_logging_obj.service_logging_obj.async_service_success_hook(
service=ServiceTypes.BATCH_WRITE_TO_DB,
call_type=func.__name__,
parent_otel_span=parent_otel_span,
duration=0.0,
start_time=start_time,
end_time=end_time,
event_metadata=metadata,
)
# end of logging to otel
return result
except Exception as e:
from litellm.proxy.proxy_server import proxy_logging_obj
end_time: datetime = datetime.now()
await proxy_logging_obj.service_logging_obj.async_service_failure_hook(
error=e,
service=ServiceTypes.DB,
call_type=func.__name__,
parent_otel_span=kwargs.get("parent_otel_span"),
duration=(end_time - start_time).total_seconds(),
start_time=start_time,
end_time=end_time,
event_metadata={
"function_name": func.__name__,
"function_kwargs": kwargs,
"function_args": args,
},
)
raise e
return wrapper
class InternalUsageCache:
def __init__(self, dual_cache: DualCache):
self.dual_cache: DualCache = dual_cache
@ -1397,7 +1317,7 @@ class PrismaClient:
return
@log_to_opentelemetry
@log_db_metrics
@backoff.on_exception(
backoff.expo,
Exception, # base exception to catch for the backoff
@ -1463,7 +1383,7 @@ class PrismaClient:
max_time=10, # maximum total time to retry for
on_backoff=on_backoff, # specifying the function to call on backoff
)
@log_to_opentelemetry
@log_db_metrics
async def get_data( # noqa: PLR0915
self,
token: Optional[Union[str, list]] = None,

View file

@ -17,23 +17,25 @@ import pytest
import litellm
from litellm import completion
from litellm._logging import verbose_logger
from litellm.proxy.utils import log_to_opentelemetry, ServiceTypes
from litellm.proxy.utils import log_db_metrics, ServiceTypes
from datetime import datetime
import httpx
from prisma.errors import ClientNotConnectedError
# Test async function to decorate
@log_to_opentelemetry
@log_db_metrics
async def sample_db_function(*args, **kwargs):
return "success"
@log_to_opentelemetry
@log_db_metrics
async def sample_proxy_function(*args, **kwargs):
return "success"
@pytest.mark.asyncio
async def test_log_to_opentelemetry_success():
async def test_log_db_metrics_success():
# Mock the proxy_logging_obj
with patch("litellm.proxy.proxy_server.proxy_logging_obj") as mock_proxy_logging:
# Setup mock
@ -61,14 +63,14 @@ async def test_log_to_opentelemetry_success():
@pytest.mark.asyncio
async def test_log_to_opentelemetry_duration():
async def test_log_db_metrics_duration():
# Mock the proxy_logging_obj
with patch("litellm.proxy.proxy_server.proxy_logging_obj") as mock_proxy_logging:
# Setup mock
mock_proxy_logging.service_logging_obj.async_service_success_hook = AsyncMock()
# Add a delay to the function to test duration
@log_to_opentelemetry
@log_db_metrics
async def delayed_function(**kwargs):
await asyncio.sleep(1) # 1 second delay
return "success"
@ -95,23 +97,28 @@ async def test_log_to_opentelemetry_duration():
@pytest.mark.asyncio
async def test_log_to_opentelemetry_failure():
async def test_log_db_metrics_failure():
"""
should log a failure if a prisma error is raised
"""
# Mock the proxy_logging_obj
from prisma.errors import ClientNotConnectedError
with patch("litellm.proxy.proxy_server.proxy_logging_obj") as mock_proxy_logging:
# Setup mock
mock_proxy_logging.service_logging_obj.async_service_failure_hook = AsyncMock()
# Create a failing function
@log_to_opentelemetry
@log_db_metrics
async def failing_function(**kwargs):
raise ValueError("Test error")
raise ClientNotConnectedError()
# Call the decorated function and expect it to raise
with pytest.raises(ValueError) as exc_info:
with pytest.raises(ClientNotConnectedError) as exc_info:
await failing_function(parent_otel_span="test_span")
# Assertions
assert str(exc_info.value) == "Test error"
assert "Client is not connected to the query engine" in str(exc_info.value)
mock_proxy_logging.service_logging_obj.async_service_failure_hook.assert_called_once()
call_args = (
mock_proxy_logging.service_logging_obj.async_service_failure_hook.call_args[
@ -125,4 +132,56 @@ async def test_log_to_opentelemetry_failure():
assert isinstance(call_args["duration"], float)
assert isinstance(call_args["start_time"], datetime)
assert isinstance(call_args["end_time"], datetime)
assert isinstance(call_args["error"], ValueError)
assert isinstance(call_args["error"], ClientNotConnectedError)
@pytest.mark.asyncio
@pytest.mark.parametrize(
"exception,should_log",
[
(ValueError("Generic error"), False),
(KeyError("Missing key"), False),
(TypeError("Type error"), False),
(httpx.ConnectError("Failed to connect"), True),
(httpx.TimeoutException("Request timed out"), True),
(ClientNotConnectedError(), True), # Prisma error
],
)
async def test_log_db_metrics_failure_error_types(exception, should_log):
"""
Why Test?
Users were seeing that non-DB errors were being logged as DB Service Failures
Example a failure to read a value from cache was being logged as a DB Service Failure
Parameterized test to verify:
- DB-related errors (Prisma, httpx) are logged as service failures
- Non-DB errors (ValueError, KeyError, etc.) are not logged
"""
with patch("litellm.proxy.proxy_server.proxy_logging_obj") as mock_proxy_logging:
mock_proxy_logging.service_logging_obj.async_service_failure_hook = AsyncMock()
@log_db_metrics
async def failing_function(**kwargs):
raise exception
# Call the function and expect it to raise the exception
with pytest.raises(type(exception)):
await failing_function(parent_otel_span="test_span")
if should_log:
# Assert failure was logged for DB-related errors
mock_proxy_logging.service_logging_obj.async_service_failure_hook.assert_called_once()
call_args = mock_proxy_logging.service_logging_obj.async_service_failure_hook.call_args[
1
]
assert call_args["service"] == ServiceTypes.DB
assert call_args["call_type"] == "failing_function"
assert call_args["parent_otel_span"] == "test_span"
assert isinstance(call_args["duration"], float)
assert isinstance(call_args["start_time"], datetime)
assert isinstance(call_args["end_time"], datetime)
assert isinstance(call_args["error"], type(exception))
else:
# Assert failure was NOT logged for non-DB errors
mock_proxy_logging.service_logging_obj.async_service_failure_hook.assert_not_called()