forked from phoenix/litellm-mirror
(feat) log error class, function_name on prometheus service failure hook + only log DB related failures on DB service hook (#6650)
* log error on prometheus service failure hook * use a more accurate function name for wrapper that handles logging db metrics * fix log_db_metrics * test_log_db_metrics_failure_error_types * fix linting * fix auth checks
This commit is contained in:
parent
ae385cfcdc
commit
eb47117800
8 changed files with 249 additions and 110 deletions
|
@ -241,7 +241,8 @@ class ServiceLogging(CustomLogger):
|
||||||
if callback == "prometheus_system":
|
if callback == "prometheus_system":
|
||||||
await self.init_prometheus_services_logger_if_none()
|
await self.init_prometheus_services_logger_if_none()
|
||||||
await self.prometheusServicesLogger.async_service_failure_hook(
|
await self.prometheusServicesLogger.async_service_failure_hook(
|
||||||
payload=payload
|
payload=payload,
|
||||||
|
error=error,
|
||||||
)
|
)
|
||||||
elif callback == "datadog":
|
elif callback == "datadog":
|
||||||
await self.init_datadog_logger_if_none()
|
await self.init_datadog_logger_if_none()
|
||||||
|
|
|
@ -9,6 +9,7 @@ import subprocess
|
||||||
import sys
|
import sys
|
||||||
import traceback
|
import traceback
|
||||||
import uuid
|
import uuid
|
||||||
|
from typing import List, Optional, Union
|
||||||
|
|
||||||
import dotenv
|
import dotenv
|
||||||
import requests # type: ignore
|
import requests # type: ignore
|
||||||
|
@ -51,7 +52,9 @@ class PrometheusServicesLogger:
|
||||||
for service in self.services:
|
for service in self.services:
|
||||||
histogram = self.create_histogram(service, type_of_request="latency")
|
histogram = self.create_histogram(service, type_of_request="latency")
|
||||||
counter_failed_request = self.create_counter(
|
counter_failed_request = self.create_counter(
|
||||||
service, type_of_request="failed_requests"
|
service,
|
||||||
|
type_of_request="failed_requests",
|
||||||
|
additional_labels=["error_class", "function_name"],
|
||||||
)
|
)
|
||||||
counter_total_requests = self.create_counter(
|
counter_total_requests = self.create_counter(
|
||||||
service, type_of_request="total_requests"
|
service, type_of_request="total_requests"
|
||||||
|
@ -99,7 +102,12 @@ class PrometheusServicesLogger:
|
||||||
buckets=LATENCY_BUCKETS,
|
buckets=LATENCY_BUCKETS,
|
||||||
)
|
)
|
||||||
|
|
||||||
def create_counter(self, service: str, type_of_request: str):
|
def create_counter(
|
||||||
|
self,
|
||||||
|
service: str,
|
||||||
|
type_of_request: str,
|
||||||
|
additional_labels: Optional[List[str]] = None,
|
||||||
|
):
|
||||||
metric_name = "litellm_{}_{}".format(service, type_of_request)
|
metric_name = "litellm_{}_{}".format(service, type_of_request)
|
||||||
is_registered = self.is_metric_registered(metric_name)
|
is_registered = self.is_metric_registered(metric_name)
|
||||||
if is_registered:
|
if is_registered:
|
||||||
|
@ -107,7 +115,7 @@ class PrometheusServicesLogger:
|
||||||
return self.Counter(
|
return self.Counter(
|
||||||
metric_name,
|
metric_name,
|
||||||
"Total {} for {} service".format(type_of_request, service),
|
"Total {} for {} service".format(type_of_request, service),
|
||||||
labelnames=[service],
|
labelnames=[service] + (additional_labels or []),
|
||||||
)
|
)
|
||||||
|
|
||||||
def observe_histogram(
|
def observe_histogram(
|
||||||
|
@ -125,9 +133,13 @@ class PrometheusServicesLogger:
|
||||||
counter,
|
counter,
|
||||||
labels: str,
|
labels: str,
|
||||||
amount: float,
|
amount: float,
|
||||||
|
additional_labels: Optional[List[str]] = [],
|
||||||
):
|
):
|
||||||
assert isinstance(counter, self.Counter)
|
assert isinstance(counter, self.Counter)
|
||||||
|
|
||||||
|
if additional_labels:
|
||||||
|
counter.labels(labels, *additional_labels).inc(amount)
|
||||||
|
else:
|
||||||
counter.labels(labels).inc(amount)
|
counter.labels(labels).inc(amount)
|
||||||
|
|
||||||
def service_success_hook(self, payload: ServiceLoggerPayload):
|
def service_success_hook(self, payload: ServiceLoggerPayload):
|
||||||
|
@ -187,16 +199,25 @@ class PrometheusServicesLogger:
|
||||||
amount=1, # LOG TOTAL REQUESTS TO PROMETHEUS
|
amount=1, # LOG TOTAL REQUESTS TO PROMETHEUS
|
||||||
)
|
)
|
||||||
|
|
||||||
async def async_service_failure_hook(self, payload: ServiceLoggerPayload):
|
async def async_service_failure_hook(
|
||||||
|
self,
|
||||||
|
payload: ServiceLoggerPayload,
|
||||||
|
error: Union[str, Exception],
|
||||||
|
):
|
||||||
if self.mock_testing:
|
if self.mock_testing:
|
||||||
self.mock_testing_failure_calls += 1
|
self.mock_testing_failure_calls += 1
|
||||||
|
error_class = error.__class__.__name__
|
||||||
|
function_name = payload.call_type
|
||||||
|
|
||||||
if payload.service.value in self.payload_to_prometheus_map:
|
if payload.service.value in self.payload_to_prometheus_map:
|
||||||
prom_objects = self.payload_to_prometheus_map[payload.service.value]
|
prom_objects = self.payload_to_prometheus_map[payload.service.value]
|
||||||
for obj in prom_objects:
|
for obj in prom_objects:
|
||||||
|
# increment both failed and total requests
|
||||||
if isinstance(obj, self.Counter):
|
if isinstance(obj, self.Counter):
|
||||||
self.increment_counter(
|
self.increment_counter(
|
||||||
counter=obj,
|
counter=obj,
|
||||||
labels=payload.service.value,
|
labels=payload.service.value,
|
||||||
|
# log additional_labels=["error_class", "function_name"], used for debugging what's going wrong with the DB
|
||||||
|
additional_labels=[error_class, function_name],
|
||||||
amount=1, # LOG ERROR COUNT TO PROMETHEUS
|
amount=1, # LOG ERROR COUNT TO PROMETHEUS
|
||||||
)
|
)
|
||||||
|
|
|
@ -32,7 +32,7 @@ from litellm.proxy._types import (
|
||||||
UserAPIKeyAuth,
|
UserAPIKeyAuth,
|
||||||
)
|
)
|
||||||
from litellm.proxy.auth.route_checks import RouteChecks
|
from litellm.proxy.auth.route_checks import RouteChecks
|
||||||
from litellm.proxy.utils import PrismaClient, ProxyLogging, log_to_opentelemetry
|
from litellm.proxy.utils import PrismaClient, ProxyLogging, log_db_metrics
|
||||||
from litellm.types.services import ServiceLoggerPayload, ServiceTypes
|
from litellm.types.services import ServiceLoggerPayload, ServiceTypes
|
||||||
|
|
||||||
from .auth_checks_organization import organization_role_based_access_check
|
from .auth_checks_organization import organization_role_based_access_check
|
||||||
|
@ -290,7 +290,7 @@ def get_actual_routes(allowed_routes: list) -> list:
|
||||||
return actual_routes
|
return actual_routes
|
||||||
|
|
||||||
|
|
||||||
@log_to_opentelemetry
|
@log_db_metrics
|
||||||
async def get_end_user_object(
|
async def get_end_user_object(
|
||||||
end_user_id: Optional[str],
|
end_user_id: Optional[str],
|
||||||
prisma_client: Optional[PrismaClient],
|
prisma_client: Optional[PrismaClient],
|
||||||
|
@ -415,7 +415,7 @@ def _update_last_db_access_time(
|
||||||
last_db_access_time[key] = (value, time.time())
|
last_db_access_time[key] = (value, time.time())
|
||||||
|
|
||||||
|
|
||||||
@log_to_opentelemetry
|
@log_db_metrics
|
||||||
async def get_user_object(
|
async def get_user_object(
|
||||||
user_id: str,
|
user_id: str,
|
||||||
prisma_client: Optional[PrismaClient],
|
prisma_client: Optional[PrismaClient],
|
||||||
|
@ -562,7 +562,7 @@ async def _delete_cache_key_object(
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@log_to_opentelemetry
|
@log_db_metrics
|
||||||
async def _get_team_db_check(team_id: str, prisma_client: PrismaClient):
|
async def _get_team_db_check(team_id: str, prisma_client: PrismaClient):
|
||||||
return await prisma_client.db.litellm_teamtable.find_unique(
|
return await prisma_client.db.litellm_teamtable.find_unique(
|
||||||
where={"team_id": team_id}
|
where={"team_id": team_id}
|
||||||
|
@ -658,7 +658,7 @@ async def get_team_object(
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@log_to_opentelemetry
|
@log_db_metrics
|
||||||
async def get_key_object(
|
async def get_key_object(
|
||||||
hashed_token: str,
|
hashed_token: str,
|
||||||
prisma_client: Optional[PrismaClient],
|
prisma_client: Optional[PrismaClient],
|
||||||
|
@ -766,7 +766,7 @@ async def _handle_failed_db_connection_for_get_key_object(
|
||||||
raise e
|
raise e
|
||||||
|
|
||||||
|
|
||||||
@log_to_opentelemetry
|
@log_db_metrics
|
||||||
async def get_org_object(
|
async def get_org_object(
|
||||||
org_id: str,
|
org_id: str,
|
||||||
prisma_client: Optional[PrismaClient],
|
prisma_client: Optional[PrismaClient],
|
||||||
|
|
|
@ -58,7 +58,7 @@ from litellm.proxy.auth.auth_checks import (
|
||||||
get_org_object,
|
get_org_object,
|
||||||
get_team_object,
|
get_team_object,
|
||||||
get_user_object,
|
get_user_object,
|
||||||
log_to_opentelemetry,
|
log_db_metrics,
|
||||||
)
|
)
|
||||||
from litellm.proxy.auth.auth_utils import (
|
from litellm.proxy.auth.auth_utils import (
|
||||||
_get_request_ip_address,
|
_get_request_ip_address,
|
||||||
|
|
138
litellm/proxy/db/log_db_metrics.py
Normal file
138
litellm/proxy/db/log_db_metrics.py
Normal file
|
@ -0,0 +1,138 @@
|
||||||
|
"""
|
||||||
|
Handles logging DB success/failure to ServiceLogger()
|
||||||
|
|
||||||
|
ServiceLogger() then sends DB logs to Prometheus, OTEL, Datadog etc
|
||||||
|
"""
|
||||||
|
|
||||||
|
from datetime import datetime
|
||||||
|
from functools import wraps
|
||||||
|
from typing import Callable, Dict, Tuple
|
||||||
|
|
||||||
|
from litellm._service_logger import ServiceTypes
|
||||||
|
from litellm.litellm_core_utils.core_helpers import (
|
||||||
|
_get_parent_otel_span_from_kwargs,
|
||||||
|
get_litellm_metadata_from_kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def log_db_metrics(func):
|
||||||
|
"""
|
||||||
|
Decorator to log the duration of a DB related function to ServiceLogger()
|
||||||
|
|
||||||
|
Handles logging DB success/failure to ServiceLogger(), which logs to Prometheus, OTEL, Datadog
|
||||||
|
|
||||||
|
When logging Failure it checks if the Exception is a PrismaError, httpx.ConnectError or httpx.TimeoutException and then logs that as a DB Service Failure
|
||||||
|
|
||||||
|
Args:
|
||||||
|
func: The function to be decorated
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Result from the decorated function
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
Exception: If the decorated function raises an exception
|
||||||
|
"""
|
||||||
|
|
||||||
|
@wraps(func)
|
||||||
|
async def wrapper(*args, **kwargs):
|
||||||
|
from prisma.errors import PrismaError
|
||||||
|
|
||||||
|
start_time: datetime = datetime.now()
|
||||||
|
|
||||||
|
try:
|
||||||
|
result = await func(*args, **kwargs)
|
||||||
|
end_time: datetime = datetime.now()
|
||||||
|
from litellm.proxy.proxy_server import proxy_logging_obj
|
||||||
|
|
||||||
|
if "PROXY" not in func.__name__:
|
||||||
|
await proxy_logging_obj.service_logging_obj.async_service_success_hook(
|
||||||
|
service=ServiceTypes.DB,
|
||||||
|
call_type=func.__name__,
|
||||||
|
parent_otel_span=kwargs.get("parent_otel_span", None),
|
||||||
|
duration=(end_time - start_time).total_seconds(),
|
||||||
|
start_time=start_time,
|
||||||
|
end_time=end_time,
|
||||||
|
event_metadata={
|
||||||
|
"function_name": func.__name__,
|
||||||
|
"function_kwargs": kwargs,
|
||||||
|
"function_args": args,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
elif (
|
||||||
|
# in litellm custom callbacks kwargs is passed as arg[0]
|
||||||
|
# https://docs.litellm.ai/docs/observability/custom_callback#callback-functions
|
||||||
|
args is not None
|
||||||
|
and len(args) > 0
|
||||||
|
and isinstance(args[0], dict)
|
||||||
|
):
|
||||||
|
passed_kwargs = args[0]
|
||||||
|
parent_otel_span = _get_parent_otel_span_from_kwargs(
|
||||||
|
kwargs=passed_kwargs
|
||||||
|
)
|
||||||
|
if parent_otel_span is not None:
|
||||||
|
metadata = get_litellm_metadata_from_kwargs(kwargs=passed_kwargs)
|
||||||
|
await proxy_logging_obj.service_logging_obj.async_service_success_hook(
|
||||||
|
service=ServiceTypes.BATCH_WRITE_TO_DB,
|
||||||
|
call_type=func.__name__,
|
||||||
|
parent_otel_span=parent_otel_span,
|
||||||
|
duration=0.0,
|
||||||
|
start_time=start_time,
|
||||||
|
end_time=end_time,
|
||||||
|
event_metadata=metadata,
|
||||||
|
)
|
||||||
|
# end of logging to otel
|
||||||
|
return result
|
||||||
|
except Exception as e:
|
||||||
|
end_time: datetime = datetime.now()
|
||||||
|
await _handle_logging_db_exception(
|
||||||
|
e=e,
|
||||||
|
func=func,
|
||||||
|
kwargs=kwargs,
|
||||||
|
args=args,
|
||||||
|
start_time=start_time,
|
||||||
|
end_time=end_time,
|
||||||
|
)
|
||||||
|
raise e
|
||||||
|
|
||||||
|
return wrapper
|
||||||
|
|
||||||
|
|
||||||
|
def _is_exception_related_to_db(e: Exception) -> bool:
|
||||||
|
"""
|
||||||
|
Returns True if the exception is related to the DB
|
||||||
|
"""
|
||||||
|
|
||||||
|
import httpx
|
||||||
|
from prisma.errors import PrismaError
|
||||||
|
|
||||||
|
return isinstance(e, (PrismaError, httpx.ConnectError, httpx.TimeoutException))
|
||||||
|
|
||||||
|
|
||||||
|
async def _handle_logging_db_exception(
|
||||||
|
e: Exception,
|
||||||
|
func: Callable,
|
||||||
|
kwargs: Dict,
|
||||||
|
args: Tuple,
|
||||||
|
start_time: datetime,
|
||||||
|
end_time: datetime,
|
||||||
|
) -> None:
|
||||||
|
from litellm.proxy.proxy_server import proxy_logging_obj
|
||||||
|
|
||||||
|
# don't log this as a DB Service Failure, if the DB did not raise an exception
|
||||||
|
if _is_exception_related_to_db(e) is not True:
|
||||||
|
return
|
||||||
|
|
||||||
|
await proxy_logging_obj.service_logging_obj.async_service_failure_hook(
|
||||||
|
error=e,
|
||||||
|
service=ServiceTypes.DB,
|
||||||
|
call_type=func.__name__,
|
||||||
|
parent_otel_span=kwargs.get("parent_otel_span"),
|
||||||
|
duration=(end_time - start_time).total_seconds(),
|
||||||
|
start_time=start_time,
|
||||||
|
end_time=end_time,
|
||||||
|
event_metadata={
|
||||||
|
"function_name": func.__name__,
|
||||||
|
"function_kwargs": kwargs,
|
||||||
|
"function_args": args,
|
||||||
|
},
|
||||||
|
)
|
|
@ -125,7 +125,7 @@ from litellm.proxy._types import *
|
||||||
from litellm.proxy.analytics_endpoints.analytics_endpoints import (
|
from litellm.proxy.analytics_endpoints.analytics_endpoints import (
|
||||||
router as analytics_router,
|
router as analytics_router,
|
||||||
)
|
)
|
||||||
from litellm.proxy.auth.auth_checks import log_to_opentelemetry
|
from litellm.proxy.auth.auth_checks import log_db_metrics
|
||||||
from litellm.proxy.auth.auth_utils import check_response_size_is_safe
|
from litellm.proxy.auth.auth_utils import check_response_size_is_safe
|
||||||
from litellm.proxy.auth.handle_jwt import JWTHandler
|
from litellm.proxy.auth.handle_jwt import JWTHandler
|
||||||
from litellm.proxy.auth.litellm_license import LicenseCheck
|
from litellm.proxy.auth.litellm_license import LicenseCheck
|
||||||
|
@ -747,7 +747,7 @@ async def _PROXY_failure_handler(
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
@log_to_opentelemetry
|
@log_db_metrics
|
||||||
async def _PROXY_track_cost_callback(
|
async def _PROXY_track_cost_callback(
|
||||||
kwargs, # kwargs to completion
|
kwargs, # kwargs to completion
|
||||||
completion_response: litellm.ModelResponse, # response from completion
|
completion_response: litellm.ModelResponse, # response from completion
|
||||||
|
|
|
@ -55,10 +55,6 @@ from litellm.integrations.custom_guardrail import CustomGuardrail
|
||||||
from litellm.integrations.custom_logger import CustomLogger
|
from litellm.integrations.custom_logger import CustomLogger
|
||||||
from litellm.integrations.SlackAlerting.slack_alerting import SlackAlerting
|
from litellm.integrations.SlackAlerting.slack_alerting import SlackAlerting
|
||||||
from litellm.integrations.SlackAlerting.utils import _add_langfuse_trace_id_to_alert
|
from litellm.integrations.SlackAlerting.utils import _add_langfuse_trace_id_to_alert
|
||||||
from litellm.litellm_core_utils.core_helpers import (
|
|
||||||
_get_parent_otel_span_from_kwargs,
|
|
||||||
get_litellm_metadata_from_kwargs,
|
|
||||||
)
|
|
||||||
from litellm.litellm_core_utils.litellm_logging import Logging
|
from litellm.litellm_core_utils.litellm_logging import Logging
|
||||||
from litellm.llms.custom_httpx.httpx_handler import HTTPHandler
|
from litellm.llms.custom_httpx.httpx_handler import HTTPHandler
|
||||||
from litellm.proxy._types import (
|
from litellm.proxy._types import (
|
||||||
|
@ -77,6 +73,7 @@ from litellm.proxy.db.create_views import (
|
||||||
create_missing_views,
|
create_missing_views,
|
||||||
should_create_missing_views,
|
should_create_missing_views,
|
||||||
)
|
)
|
||||||
|
from litellm.proxy.db.log_db_metrics import log_db_metrics
|
||||||
from litellm.proxy.db.prisma_client import PrismaWrapper
|
from litellm.proxy.db.prisma_client import PrismaWrapper
|
||||||
from litellm.proxy.hooks.cache_control_check import _PROXY_CacheControlCheck
|
from litellm.proxy.hooks.cache_control_check import _PROXY_CacheControlCheck
|
||||||
from litellm.proxy.hooks.max_budget_limiter import _PROXY_MaxBudgetLimiter
|
from litellm.proxy.hooks.max_budget_limiter import _PROXY_MaxBudgetLimiter
|
||||||
|
@ -137,83 +134,6 @@ def safe_deep_copy(data):
|
||||||
return new_data
|
return new_data
|
||||||
|
|
||||||
|
|
||||||
def log_to_opentelemetry(func):
|
|
||||||
"""
|
|
||||||
Decorator to log the duration of a DB related function to ServiceLogger()
|
|
||||||
|
|
||||||
Handles logging DB success/failure to ServiceLogger(), which logs to Prometheus, OTEL, Datadog
|
|
||||||
"""
|
|
||||||
|
|
||||||
@wraps(func)
|
|
||||||
async def wrapper(*args, **kwargs):
|
|
||||||
start_time: datetime = datetime.now()
|
|
||||||
|
|
||||||
try:
|
|
||||||
result = await func(*args, **kwargs)
|
|
||||||
end_time: datetime = datetime.now()
|
|
||||||
from litellm.proxy.proxy_server import proxy_logging_obj
|
|
||||||
|
|
||||||
if "PROXY" not in func.__name__:
|
|
||||||
await proxy_logging_obj.service_logging_obj.async_service_success_hook(
|
|
||||||
service=ServiceTypes.DB,
|
|
||||||
call_type=func.__name__,
|
|
||||||
parent_otel_span=kwargs.get("parent_otel_span", None),
|
|
||||||
duration=(end_time - start_time).total_seconds(),
|
|
||||||
start_time=start_time,
|
|
||||||
end_time=end_time,
|
|
||||||
event_metadata={
|
|
||||||
"function_name": func.__name__,
|
|
||||||
"function_kwargs": kwargs,
|
|
||||||
"function_args": args,
|
|
||||||
},
|
|
||||||
)
|
|
||||||
elif (
|
|
||||||
# in litellm custom callbacks kwargs is passed as arg[0]
|
|
||||||
# https://docs.litellm.ai/docs/observability/custom_callback#callback-functions
|
|
||||||
args is not None
|
|
||||||
and len(args) > 0
|
|
||||||
and isinstance(args[0], dict)
|
|
||||||
):
|
|
||||||
passed_kwargs = args[0]
|
|
||||||
parent_otel_span = _get_parent_otel_span_from_kwargs(
|
|
||||||
kwargs=passed_kwargs
|
|
||||||
)
|
|
||||||
if parent_otel_span is not None:
|
|
||||||
metadata = get_litellm_metadata_from_kwargs(kwargs=passed_kwargs)
|
|
||||||
await proxy_logging_obj.service_logging_obj.async_service_success_hook(
|
|
||||||
service=ServiceTypes.BATCH_WRITE_TO_DB,
|
|
||||||
call_type=func.__name__,
|
|
||||||
parent_otel_span=parent_otel_span,
|
|
||||||
duration=0.0,
|
|
||||||
start_time=start_time,
|
|
||||||
end_time=end_time,
|
|
||||||
event_metadata=metadata,
|
|
||||||
)
|
|
||||||
# end of logging to otel
|
|
||||||
return result
|
|
||||||
except Exception as e:
|
|
||||||
from litellm.proxy.proxy_server import proxy_logging_obj
|
|
||||||
|
|
||||||
end_time: datetime = datetime.now()
|
|
||||||
await proxy_logging_obj.service_logging_obj.async_service_failure_hook(
|
|
||||||
error=e,
|
|
||||||
service=ServiceTypes.DB,
|
|
||||||
call_type=func.__name__,
|
|
||||||
parent_otel_span=kwargs.get("parent_otel_span"),
|
|
||||||
duration=(end_time - start_time).total_seconds(),
|
|
||||||
start_time=start_time,
|
|
||||||
end_time=end_time,
|
|
||||||
event_metadata={
|
|
||||||
"function_name": func.__name__,
|
|
||||||
"function_kwargs": kwargs,
|
|
||||||
"function_args": args,
|
|
||||||
},
|
|
||||||
)
|
|
||||||
raise e
|
|
||||||
|
|
||||||
return wrapper
|
|
||||||
|
|
||||||
|
|
||||||
class InternalUsageCache:
|
class InternalUsageCache:
|
||||||
def __init__(self, dual_cache: DualCache):
|
def __init__(self, dual_cache: DualCache):
|
||||||
self.dual_cache: DualCache = dual_cache
|
self.dual_cache: DualCache = dual_cache
|
||||||
|
@ -1397,7 +1317,7 @@ class PrismaClient:
|
||||||
|
|
||||||
return
|
return
|
||||||
|
|
||||||
@log_to_opentelemetry
|
@log_db_metrics
|
||||||
@backoff.on_exception(
|
@backoff.on_exception(
|
||||||
backoff.expo,
|
backoff.expo,
|
||||||
Exception, # base exception to catch for the backoff
|
Exception, # base exception to catch for the backoff
|
||||||
|
@ -1463,7 +1383,7 @@ class PrismaClient:
|
||||||
max_time=10, # maximum total time to retry for
|
max_time=10, # maximum total time to retry for
|
||||||
on_backoff=on_backoff, # specifying the function to call on backoff
|
on_backoff=on_backoff, # specifying the function to call on backoff
|
||||||
)
|
)
|
||||||
@log_to_opentelemetry
|
@log_db_metrics
|
||||||
async def get_data( # noqa: PLR0915
|
async def get_data( # noqa: PLR0915
|
||||||
self,
|
self,
|
||||||
token: Optional[Union[str, list]] = None,
|
token: Optional[Union[str, list]] = None,
|
||||||
|
|
|
@ -17,23 +17,25 @@ import pytest
|
||||||
import litellm
|
import litellm
|
||||||
from litellm import completion
|
from litellm import completion
|
||||||
from litellm._logging import verbose_logger
|
from litellm._logging import verbose_logger
|
||||||
from litellm.proxy.utils import log_to_opentelemetry, ServiceTypes
|
from litellm.proxy.utils import log_db_metrics, ServiceTypes
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
|
import httpx
|
||||||
|
from prisma.errors import ClientNotConnectedError
|
||||||
|
|
||||||
|
|
||||||
# Test async function to decorate
|
# Test async function to decorate
|
||||||
@log_to_opentelemetry
|
@log_db_metrics
|
||||||
async def sample_db_function(*args, **kwargs):
|
async def sample_db_function(*args, **kwargs):
|
||||||
return "success"
|
return "success"
|
||||||
|
|
||||||
|
|
||||||
@log_to_opentelemetry
|
@log_db_metrics
|
||||||
async def sample_proxy_function(*args, **kwargs):
|
async def sample_proxy_function(*args, **kwargs):
|
||||||
return "success"
|
return "success"
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_log_to_opentelemetry_success():
|
async def test_log_db_metrics_success():
|
||||||
# Mock the proxy_logging_obj
|
# Mock the proxy_logging_obj
|
||||||
with patch("litellm.proxy.proxy_server.proxy_logging_obj") as mock_proxy_logging:
|
with patch("litellm.proxy.proxy_server.proxy_logging_obj") as mock_proxy_logging:
|
||||||
# Setup mock
|
# Setup mock
|
||||||
|
@ -61,14 +63,14 @@ async def test_log_to_opentelemetry_success():
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_log_to_opentelemetry_duration():
|
async def test_log_db_metrics_duration():
|
||||||
# Mock the proxy_logging_obj
|
# Mock the proxy_logging_obj
|
||||||
with patch("litellm.proxy.proxy_server.proxy_logging_obj") as mock_proxy_logging:
|
with patch("litellm.proxy.proxy_server.proxy_logging_obj") as mock_proxy_logging:
|
||||||
# Setup mock
|
# Setup mock
|
||||||
mock_proxy_logging.service_logging_obj.async_service_success_hook = AsyncMock()
|
mock_proxy_logging.service_logging_obj.async_service_success_hook = AsyncMock()
|
||||||
|
|
||||||
# Add a delay to the function to test duration
|
# Add a delay to the function to test duration
|
||||||
@log_to_opentelemetry
|
@log_db_metrics
|
||||||
async def delayed_function(**kwargs):
|
async def delayed_function(**kwargs):
|
||||||
await asyncio.sleep(1) # 1 second delay
|
await asyncio.sleep(1) # 1 second delay
|
||||||
return "success"
|
return "success"
|
||||||
|
@ -95,23 +97,28 @@ async def test_log_to_opentelemetry_duration():
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_log_to_opentelemetry_failure():
|
async def test_log_db_metrics_failure():
|
||||||
|
"""
|
||||||
|
should log a failure if a prisma error is raised
|
||||||
|
"""
|
||||||
# Mock the proxy_logging_obj
|
# Mock the proxy_logging_obj
|
||||||
|
from prisma.errors import ClientNotConnectedError
|
||||||
|
|
||||||
with patch("litellm.proxy.proxy_server.proxy_logging_obj") as mock_proxy_logging:
|
with patch("litellm.proxy.proxy_server.proxy_logging_obj") as mock_proxy_logging:
|
||||||
# Setup mock
|
# Setup mock
|
||||||
mock_proxy_logging.service_logging_obj.async_service_failure_hook = AsyncMock()
|
mock_proxy_logging.service_logging_obj.async_service_failure_hook = AsyncMock()
|
||||||
|
|
||||||
# Create a failing function
|
# Create a failing function
|
||||||
@log_to_opentelemetry
|
@log_db_metrics
|
||||||
async def failing_function(**kwargs):
|
async def failing_function(**kwargs):
|
||||||
raise ValueError("Test error")
|
raise ClientNotConnectedError()
|
||||||
|
|
||||||
# Call the decorated function and expect it to raise
|
# Call the decorated function and expect it to raise
|
||||||
with pytest.raises(ValueError) as exc_info:
|
with pytest.raises(ClientNotConnectedError) as exc_info:
|
||||||
await failing_function(parent_otel_span="test_span")
|
await failing_function(parent_otel_span="test_span")
|
||||||
|
|
||||||
# Assertions
|
# Assertions
|
||||||
assert str(exc_info.value) == "Test error"
|
assert "Client is not connected to the query engine" in str(exc_info.value)
|
||||||
mock_proxy_logging.service_logging_obj.async_service_failure_hook.assert_called_once()
|
mock_proxy_logging.service_logging_obj.async_service_failure_hook.assert_called_once()
|
||||||
call_args = (
|
call_args = (
|
||||||
mock_proxy_logging.service_logging_obj.async_service_failure_hook.call_args[
|
mock_proxy_logging.service_logging_obj.async_service_failure_hook.call_args[
|
||||||
|
@ -125,4 +132,56 @@ async def test_log_to_opentelemetry_failure():
|
||||||
assert isinstance(call_args["duration"], float)
|
assert isinstance(call_args["duration"], float)
|
||||||
assert isinstance(call_args["start_time"], datetime)
|
assert isinstance(call_args["start_time"], datetime)
|
||||||
assert isinstance(call_args["end_time"], datetime)
|
assert isinstance(call_args["end_time"], datetime)
|
||||||
assert isinstance(call_args["error"], ValueError)
|
assert isinstance(call_args["error"], ClientNotConnectedError)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"exception,should_log",
|
||||||
|
[
|
||||||
|
(ValueError("Generic error"), False),
|
||||||
|
(KeyError("Missing key"), False),
|
||||||
|
(TypeError("Type error"), False),
|
||||||
|
(httpx.ConnectError("Failed to connect"), True),
|
||||||
|
(httpx.TimeoutException("Request timed out"), True),
|
||||||
|
(ClientNotConnectedError(), True), # Prisma error
|
||||||
|
],
|
||||||
|
)
|
||||||
|
async def test_log_db_metrics_failure_error_types(exception, should_log):
|
||||||
|
"""
|
||||||
|
Why Test?
|
||||||
|
Users were seeing that non-DB errors were being logged as DB Service Failures
|
||||||
|
Example a failure to read a value from cache was being logged as a DB Service Failure
|
||||||
|
|
||||||
|
|
||||||
|
Parameterized test to verify:
|
||||||
|
- DB-related errors (Prisma, httpx) are logged as service failures
|
||||||
|
- Non-DB errors (ValueError, KeyError, etc.) are not logged
|
||||||
|
"""
|
||||||
|
with patch("litellm.proxy.proxy_server.proxy_logging_obj") as mock_proxy_logging:
|
||||||
|
mock_proxy_logging.service_logging_obj.async_service_failure_hook = AsyncMock()
|
||||||
|
|
||||||
|
@log_db_metrics
|
||||||
|
async def failing_function(**kwargs):
|
||||||
|
raise exception
|
||||||
|
|
||||||
|
# Call the function and expect it to raise the exception
|
||||||
|
with pytest.raises(type(exception)):
|
||||||
|
await failing_function(parent_otel_span="test_span")
|
||||||
|
|
||||||
|
if should_log:
|
||||||
|
# Assert failure was logged for DB-related errors
|
||||||
|
mock_proxy_logging.service_logging_obj.async_service_failure_hook.assert_called_once()
|
||||||
|
call_args = mock_proxy_logging.service_logging_obj.async_service_failure_hook.call_args[
|
||||||
|
1
|
||||||
|
]
|
||||||
|
assert call_args["service"] == ServiceTypes.DB
|
||||||
|
assert call_args["call_type"] == "failing_function"
|
||||||
|
assert call_args["parent_otel_span"] == "test_span"
|
||||||
|
assert isinstance(call_args["duration"], float)
|
||||||
|
assert isinstance(call_args["start_time"], datetime)
|
||||||
|
assert isinstance(call_args["end_time"], datetime)
|
||||||
|
assert isinstance(call_args["error"], type(exception))
|
||||||
|
else:
|
||||||
|
# Assert failure was NOT logged for non-DB errors
|
||||||
|
mock_proxy_logging.service_logging_obj.async_service_failure_hook.assert_not_called()
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue