[Feat] Improve OTEL Tracking - Require all Redis Cache reads to be logged on OTEL (#5881)

* fix use previous internal usage caching logic

* fix test_dual_cache_uses_redis

* redis track event_metadata in service logging

* show otel error on _get_parent_otel_span_from_kwargs

* track parent otel span on internal usage cache

* update_request_status

* fix internal usage cache

* fix linting

* fix test internal usage cache

* fix linting error

* show event metadata in redis set

* fix test_get_team_redis

* fix test_get_team_redis

* test_proxy_logging_setup
This commit is contained in:
Ishaan Jaff 2024-09-25 10:57:08 -07:00 committed by GitHub
parent 4ec4d02474
commit 7cbcf538c6
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
9 changed files with 243 additions and 79 deletions

View file

@ -432,6 +432,7 @@ class RedisCache(BaseCache):
start_time=start_time, start_time=start_time,
end_time=end_time, end_time=end_time,
parent_otel_span=_get_parent_otel_span_from_kwargs(kwargs), parent_otel_span=_get_parent_otel_span_from_kwargs(kwargs),
event_metadata={"key": key},
) )
) )
except Exception as e: except Exception as e:
@ -446,6 +447,7 @@ class RedisCache(BaseCache):
start_time=start_time, start_time=start_time,
end_time=end_time, end_time=end_time,
parent_otel_span=_get_parent_otel_span_from_kwargs(kwargs), parent_otel_span=_get_parent_otel_span_from_kwargs(kwargs),
event_metadata={"key": key},
) )
) )
# NON blocking - notify users Redis is throwing an exception # NON blocking - notify users Redis is throwing an exception
@ -753,6 +755,7 @@ class RedisCache(BaseCache):
start_time=start_time, start_time=start_time,
end_time=end_time, end_time=end_time,
parent_otel_span=_get_parent_otel_span_from_kwargs(kwargs), parent_otel_span=_get_parent_otel_span_from_kwargs(kwargs),
event_metadata={"key": key},
) )
) )
return response return response
@ -769,6 +772,7 @@ class RedisCache(BaseCache):
start_time=start_time, start_time=start_time,
end_time=end_time, end_time=end_time,
parent_otel_span=_get_parent_otel_span_from_kwargs(kwargs), parent_otel_span=_get_parent_otel_span_from_kwargs(kwargs),
event_metadata={"key": key},
) )
) )
# NON blocking - notify users Redis is throwing an exception # NON blocking - notify users Redis is throwing an exception

View file

@ -1,10 +1,17 @@
# What is this? # What is this?
## Helper utilities ## Helper utilities
import os import os
from typing import List, Literal, Optional, Tuple from typing import TYPE_CHECKING, Any, List, Literal, Optional, Tuple, Union
from litellm._logging import verbose_logger from litellm._logging import verbose_logger
if TYPE_CHECKING:
from opentelemetry.trace import Span as _Span
Span = _Span
else:
Span = Any
def map_finish_reason( def map_finish_reason(
finish_reason: str, finish_reason: str,
@ -68,10 +75,12 @@ def get_litellm_metadata_from_kwargs(kwargs: dict):
# Helper functions used for OTEL logging # Helper functions used for OTEL logging
def _get_parent_otel_span_from_kwargs(kwargs: Optional[dict] = None): def _get_parent_otel_span_from_kwargs(
kwargs: Optional[dict] = None,
) -> Union[Span, None]:
try: try:
if kwargs is None: if kwargs is None:
return None raise ValueError("kwargs is None")
litellm_params = kwargs.get("litellm_params") litellm_params = kwargs.get("litellm_params")
_metadata = kwargs.get("metadata") or {} _metadata = kwargs.get("metadata") or {}
if "litellm_parent_otel_span" in _metadata: if "litellm_parent_otel_span" in _metadata:
@ -84,5 +93,9 @@ def _get_parent_otel_span_from_kwargs(kwargs: Optional[dict] = None):
return litellm_params["metadata"]["litellm_parent_otel_span"] return litellm_params["metadata"]["litellm_parent_otel_span"]
elif "litellm_parent_otel_span" in kwargs: elif "litellm_parent_otel_span" in kwargs:
return kwargs["litellm_parent_otel_span"] return kwargs["litellm_parent_otel_span"]
except: return None
except Exception as e:
verbose_logger.exception(
"Error in _get_parent_otel_span_from_kwargs: " + str(e)
)
return None return None

View file

@ -440,7 +440,7 @@ async def _cache_management_object(
exclude_unset=True, exclude={"parent_otel_span": True} exclude_unset=True, exclude={"parent_otel_span": True}
) )
await proxy_logging_obj.internal_usage_cache.async_set_cache( await proxy_logging_obj.internal_usage_cache.async_set_cache(
key=key, value=_value key=key, value=_value, litellm_parent_otel_span=None
) )
@ -493,7 +493,9 @@ async def _delete_cache_key_object(
## UPDATE REDIS CACHE ## ## UPDATE REDIS CACHE ##
if proxy_logging_obj is not None: if proxy_logging_obj is not None:
await proxy_logging_obj.internal_usage_cache.async_delete_cache(key=key) await proxy_logging_obj.internal_usage_cache.dual_cache.async_delete_cache(
key=key
)
@log_to_opentelemetry @log_to_opentelemetry
@ -522,13 +524,11 @@ async def get_team_object(
## CHECK REDIS CACHE ## ## CHECK REDIS CACHE ##
if ( if (
proxy_logging_obj is not None proxy_logging_obj is not None
and proxy_logging_obj.internal_usage_cache.redis_cache is not None and proxy_logging_obj.internal_usage_cache.dual_cache.redis_cache is not None
): ):
cached_team_obj = ( cached_team_obj = await proxy_logging_obj.internal_usage_cache.dual_cache.redis_cache.async_get_cache(
await proxy_logging_obj.internal_usage_cache.redis_cache.async_get_cache(
key=key key=key
) )
)
if cached_team_obj is None: if cached_team_obj is None:
cached_team_obj = await user_api_key_cache.async_get_cache(key=key) cached_team_obj = await user_api_key_cache.async_get_cache(key=key)
@ -595,13 +595,11 @@ async def get_key_object(
## CHECK REDIS CACHE ## ## CHECK REDIS CACHE ##
if ( if (
proxy_logging_obj is not None proxy_logging_obj is not None
and proxy_logging_obj.internal_usage_cache.redis_cache is not None and proxy_logging_obj.internal_usage_cache.dual_cache.redis_cache is not None
): ):
cached_team_obj = ( cached_team_obj = await proxy_logging_obj.internal_usage_cache.dual_cache.redis_cache.async_get_cache(
await proxy_logging_obj.internal_usage_cache.redis_cache.async_get_cache(
key=key key=key
) )
)
if cached_team_obj is None: if cached_team_obj is None:
cached_team_obj = await user_api_key_cache.async_get_cache(key=key) cached_team_obj = await user_api_key_cache.async_get_cache(key=key)

View file

@ -1,7 +1,7 @@
import sys import sys
import traceback import traceback
from datetime import datetime, timedelta from datetime import datetime, timedelta
from typing import Literal, Optional from typing import TYPE_CHECKING, Any, Literal, Optional, Union
from fastapi import HTTPException from fastapi import HTTPException
@ -10,17 +10,28 @@ from litellm import ModelResponse
from litellm._logging import verbose_proxy_logger from litellm._logging import verbose_proxy_logger
from litellm.caching import DualCache from litellm.caching import DualCache
from litellm.integrations.custom_logger import CustomLogger from litellm.integrations.custom_logger import CustomLogger
from litellm.litellm_core_utils.core_helpers import _get_parent_otel_span_from_kwargs
from litellm.proxy._types import UserAPIKeyAuth from litellm.proxy._types import UserAPIKeyAuth
from litellm.proxy.auth.auth_utils import ( from litellm.proxy.auth.auth_utils import (
get_key_model_rpm_limit, get_key_model_rpm_limit,
get_key_model_tpm_limit, get_key_model_tpm_limit,
) )
if TYPE_CHECKING:
from opentelemetry.trace import Span as _Span
from litellm.proxy.utils import InternalUsageCache as _InternalUsageCache
Span = _Span
InternalUsageCache = _InternalUsageCache
else:
Span = Any
InternalUsageCache = Any
class _PROXY_MaxParallelRequestsHandler(CustomLogger): class _PROXY_MaxParallelRequestsHandler(CustomLogger):
# Class variables or attributes # Class variables or attributes
def __init__(self, internal_usage_cache: DualCache): def __init__(self, internal_usage_cache: InternalUsageCache):
self.internal_usage_cache = internal_usage_cache self.internal_usage_cache = internal_usage_cache
def print_verbose(self, print_statement): def print_verbose(self, print_statement):
@ -44,7 +55,8 @@ class _PROXY_MaxParallelRequestsHandler(CustomLogger):
rate_limit_type: Literal["user", "customer", "team"], rate_limit_type: Literal["user", "customer", "team"],
): ):
current = await self.internal_usage_cache.async_get_cache( current = await self.internal_usage_cache.async_get_cache(
key=request_count_api_key key=request_count_api_key,
litellm_parent_otel_span=user_api_key_dict.parent_otel_span,
) # {"current_requests": 1, "current_tpm": 1, "current_rpm": 10} ) # {"current_requests": 1, "current_tpm": 1, "current_rpm": 10}
if current is None: if current is None:
if max_parallel_requests == 0 or tpm_limit == 0 or rpm_limit == 0: if max_parallel_requests == 0 or tpm_limit == 0 or rpm_limit == 0:
@ -58,7 +70,9 @@ class _PROXY_MaxParallelRequestsHandler(CustomLogger):
"current_rpm": 0, "current_rpm": 0,
} }
await self.internal_usage_cache.async_set_cache( await self.internal_usage_cache.async_set_cache(
request_count_api_key, new_val key=request_count_api_key,
value=new_val,
litellm_parent_otel_span=user_api_key_dict.parent_otel_span,
) )
elif ( elif (
int(current["current_requests"]) < max_parallel_requests int(current["current_requests"]) < max_parallel_requests
@ -72,7 +86,9 @@ class _PROXY_MaxParallelRequestsHandler(CustomLogger):
"current_rpm": current["current_rpm"], "current_rpm": current["current_rpm"],
} }
await self.internal_usage_cache.async_set_cache( await self.internal_usage_cache.async_set_cache(
request_count_api_key, new_val key=request_count_api_key,
value=new_val,
litellm_parent_otel_span=user_api_key_dict.parent_otel_span,
) )
else: else:
raise HTTPException( raise HTTPException(
@ -135,12 +151,14 @@ class _PROXY_MaxParallelRequestsHandler(CustomLogger):
# ------------ # ------------
# Setup values # Setup values
# ------------ # ------------
new_val: Optional[dict] = None
if global_max_parallel_requests is not None: if global_max_parallel_requests is not None:
# get value from cache # get value from cache
_key = "global_max_parallel_requests" _key = "global_max_parallel_requests"
current_global_requests = await self.internal_usage_cache.async_get_cache( current_global_requests = await self.internal_usage_cache.async_get_cache(
key=_key, local_only=True key=_key,
local_only=True,
litellm_parent_otel_span=user_api_key_dict.parent_otel_span,
) )
# check if below limit # check if below limit
if current_global_requests is None: if current_global_requests is None:
@ -153,7 +171,10 @@ class _PROXY_MaxParallelRequestsHandler(CustomLogger):
# if below -> increment # if below -> increment
else: else:
await self.internal_usage_cache.async_increment_cache( await self.internal_usage_cache.async_increment_cache(
key=_key, value=1, local_only=True key=_key,
value=1,
local_only=True,
litellm_parent_otel_span=user_api_key_dict.parent_otel_span,
) )
current_date = datetime.now().strftime("%Y-%m-%d") current_date = datetime.now().strftime("%Y-%m-%d")
@ -167,7 +188,8 @@ class _PROXY_MaxParallelRequestsHandler(CustomLogger):
# CHECK IF REQUEST ALLOWED for key # CHECK IF REQUEST ALLOWED for key
current = await self.internal_usage_cache.async_get_cache( current = await self.internal_usage_cache.async_get_cache(
key=request_count_api_key key=request_count_api_key,
litellm_parent_otel_span=user_api_key_dict.parent_otel_span,
) # {"current_requests": 1, "current_tpm": 1, "current_rpm": 10} ) # {"current_requests": 1, "current_tpm": 1, "current_rpm": 10}
self.print_verbose(f"current: {current}") self.print_verbose(f"current: {current}")
if ( if (
@ -187,7 +209,9 @@ class _PROXY_MaxParallelRequestsHandler(CustomLogger):
"current_rpm": 0, "current_rpm": 0,
} }
await self.internal_usage_cache.async_set_cache( await self.internal_usage_cache.async_set_cache(
request_count_api_key, new_val key=request_count_api_key,
value=new_val,
litellm_parent_otel_span=user_api_key_dict.parent_otel_span,
) )
elif ( elif (
int(current["current_requests"]) < max_parallel_requests int(current["current_requests"]) < max_parallel_requests
@ -201,7 +225,9 @@ class _PROXY_MaxParallelRequestsHandler(CustomLogger):
"current_rpm": current["current_rpm"], "current_rpm": current["current_rpm"],
} }
await self.internal_usage_cache.async_set_cache( await self.internal_usage_cache.async_set_cache(
request_count_api_key, new_val key=request_count_api_key,
value=new_val,
litellm_parent_otel_span=user_api_key_dict.parent_otel_span,
) )
else: else:
return self.raise_rate_limit_error( return self.raise_rate_limit_error(
@ -219,7 +245,8 @@ class _PROXY_MaxParallelRequestsHandler(CustomLogger):
) )
current = await self.internal_usage_cache.async_get_cache( current = await self.internal_usage_cache.async_get_cache(
key=request_count_api_key key=request_count_api_key,
litellm_parent_otel_span=user_api_key_dict.parent_otel_span,
) # {"current_requests": 1, "current_tpm": 1, "current_rpm": 10} ) # {"current_requests": 1, "current_tpm": 1, "current_rpm": 10}
tpm_limit_for_model = None tpm_limit_for_model = None
@ -242,7 +269,9 @@ class _PROXY_MaxParallelRequestsHandler(CustomLogger):
"current_rpm": 0, "current_rpm": 0,
} }
await self.internal_usage_cache.async_set_cache( await self.internal_usage_cache.async_set_cache(
request_count_api_key, new_val key=request_count_api_key,
value=new_val,
litellm_parent_otel_span=user_api_key_dict.parent_otel_span,
) )
elif tpm_limit_for_model is not None or rpm_limit_for_model is not None: elif tpm_limit_for_model is not None or rpm_limit_for_model is not None:
# Increase count for this token # Increase count for this token
@ -267,12 +296,15 @@ class _PROXY_MaxParallelRequestsHandler(CustomLogger):
) )
else: else:
await self.internal_usage_cache.async_set_cache( await self.internal_usage_cache.async_set_cache(
request_count_api_key, new_val key=request_count_api_key,
value=new_val,
litellm_parent_otel_span=user_api_key_dict.parent_otel_span,
) )
_remaining_tokens = None _remaining_tokens = None
_remaining_requests = None _remaining_requests = None
# Add remaining tokens, requests to metadata # Add remaining tokens, requests to metadata
if new_val:
if tpm_limit_for_model is not None: if tpm_limit_for_model is not None:
_remaining_tokens = tpm_limit_for_model - new_val["current_tpm"] _remaining_tokens = tpm_limit_for_model - new_val["current_tpm"]
if rpm_limit_for_model is not None: if rpm_limit_for_model is not None:
@ -291,7 +323,8 @@ class _PROXY_MaxParallelRequestsHandler(CustomLogger):
user_id = user_api_key_dict.user_id user_id = user_api_key_dict.user_id
if user_id is not None: if user_id is not None:
_user_id_rate_limits = await self.internal_usage_cache.async_get_cache( _user_id_rate_limits = await self.internal_usage_cache.async_get_cache(
key=user_id key=user_id,
litellm_parent_otel_span=user_api_key_dict.parent_otel_span,
) )
# get user tpm/rpm limits # get user tpm/rpm limits
if _user_id_rate_limits is not None and isinstance( if _user_id_rate_limits is not None and isinstance(
@ -388,6 +421,9 @@ class _PROXY_MaxParallelRequestsHandler(CustomLogger):
get_model_group_from_litellm_kwargs, get_model_group_from_litellm_kwargs,
) )
litellm_parent_otel_span: Union[Span, None] = _get_parent_otel_span_from_kwargs(
kwargs=kwargs
)
try: try:
self.print_verbose("INSIDE parallel request limiter ASYNC SUCCESS LOGGING") self.print_verbose("INSIDE parallel request limiter ASYNC SUCCESS LOGGING")
global_max_parallel_requests = kwargs["litellm_params"]["metadata"].get( global_max_parallel_requests = kwargs["litellm_params"]["metadata"].get(
@ -416,7 +452,10 @@ class _PROXY_MaxParallelRequestsHandler(CustomLogger):
_key = "global_max_parallel_requests" _key = "global_max_parallel_requests"
# decrement # decrement
await self.internal_usage_cache.async_increment_cache( await self.internal_usage_cache.async_increment_cache(
key=_key, value=-1, local_only=True key=_key,
value=-1,
local_only=True,
litellm_parent_otel_span=litellm_parent_otel_span,
) )
current_date = datetime.now().strftime("%Y-%m-%d") current_date = datetime.now().strftime("%Y-%m-%d")
@ -427,7 +466,7 @@ class _PROXY_MaxParallelRequestsHandler(CustomLogger):
total_tokens = 0 total_tokens = 0
if isinstance(response_obj, ModelResponse): if isinstance(response_obj, ModelResponse):
total_tokens = response_obj.usage.total_tokens total_tokens = response_obj.usage.total_tokens # type: ignore
# ------------ # ------------
# Update usage - API Key # Update usage - API Key
@ -439,7 +478,8 @@ class _PROXY_MaxParallelRequestsHandler(CustomLogger):
) )
current = await self.internal_usage_cache.async_get_cache( current = await self.internal_usage_cache.async_get_cache(
key=request_count_api_key key=request_count_api_key,
litellm_parent_otel_span=litellm_parent_otel_span,
) or { ) or {
"current_requests": 1, "current_requests": 1,
"current_tpm": total_tokens, "current_tpm": total_tokens,
@ -456,7 +496,10 @@ class _PROXY_MaxParallelRequestsHandler(CustomLogger):
f"updated_value in success call: {new_val}, precise_minute: {precise_minute}" f"updated_value in success call: {new_val}, precise_minute: {precise_minute}"
) )
await self.internal_usage_cache.async_set_cache( await self.internal_usage_cache.async_set_cache(
request_count_api_key, new_val, ttl=60 request_count_api_key,
new_val,
ttl=60,
litellm_parent_otel_span=litellm_parent_otel_span,
) # store in cache for 1 min. ) # store in cache for 1 min.
# ------------ # ------------
@ -476,7 +519,8 @@ class _PROXY_MaxParallelRequestsHandler(CustomLogger):
) )
current = await self.internal_usage_cache.async_get_cache( current = await self.internal_usage_cache.async_get_cache(
key=request_count_api_key key=request_count_api_key,
litellm_parent_otel_span=litellm_parent_otel_span,
) or { ) or {
"current_requests": 1, "current_requests": 1,
"current_tpm": total_tokens, "current_tpm": total_tokens,
@ -493,7 +537,10 @@ class _PROXY_MaxParallelRequestsHandler(CustomLogger):
f"updated_value in success call: {new_val}, precise_minute: {precise_minute}" f"updated_value in success call: {new_val}, precise_minute: {precise_minute}"
) )
await self.internal_usage_cache.async_set_cache( await self.internal_usage_cache.async_set_cache(
request_count_api_key, new_val, ttl=60 request_count_api_key,
new_val,
ttl=60,
litellm_parent_otel_span=litellm_parent_otel_span,
) )
# ------------ # ------------
@ -503,14 +550,15 @@ class _PROXY_MaxParallelRequestsHandler(CustomLogger):
total_tokens = 0 total_tokens = 0
if isinstance(response_obj, ModelResponse): if isinstance(response_obj, ModelResponse):
total_tokens = response_obj.usage.total_tokens total_tokens = response_obj.usage.total_tokens # type: ignore
request_count_api_key = ( request_count_api_key = (
f"{user_api_key_user_id}::{precise_minute}::request_count" f"{user_api_key_user_id}::{precise_minute}::request_count"
) )
current = await self.internal_usage_cache.async_get_cache( current = await self.internal_usage_cache.async_get_cache(
key=request_count_api_key key=request_count_api_key,
litellm_parent_otel_span=litellm_parent_otel_span,
) or { ) or {
"current_requests": 1, "current_requests": 1,
"current_tpm": total_tokens, "current_tpm": total_tokens,
@ -527,7 +575,10 @@ class _PROXY_MaxParallelRequestsHandler(CustomLogger):
f"updated_value in success call: {new_val}, precise_minute: {precise_minute}" f"updated_value in success call: {new_val}, precise_minute: {precise_minute}"
) )
await self.internal_usage_cache.async_set_cache( await self.internal_usage_cache.async_set_cache(
request_count_api_key, new_val, ttl=60 request_count_api_key,
new_val,
ttl=60,
litellm_parent_otel_span=litellm_parent_otel_span,
) # store in cache for 1 min. ) # store in cache for 1 min.
# ------------ # ------------
@ -537,14 +588,15 @@ class _PROXY_MaxParallelRequestsHandler(CustomLogger):
total_tokens = 0 total_tokens = 0
if isinstance(response_obj, ModelResponse): if isinstance(response_obj, ModelResponse):
total_tokens = response_obj.usage.total_tokens total_tokens = response_obj.usage.total_tokens # type: ignore
request_count_api_key = ( request_count_api_key = (
f"{user_api_key_team_id}::{precise_minute}::request_count" f"{user_api_key_team_id}::{precise_minute}::request_count"
) )
current = await self.internal_usage_cache.async_get_cache( current = await self.internal_usage_cache.async_get_cache(
key=request_count_api_key key=request_count_api_key,
litellm_parent_otel_span=litellm_parent_otel_span,
) or { ) or {
"current_requests": 1, "current_requests": 1,
"current_tpm": total_tokens, "current_tpm": total_tokens,
@ -561,7 +613,10 @@ class _PROXY_MaxParallelRequestsHandler(CustomLogger):
f"updated_value in success call: {new_val}, precise_minute: {precise_minute}" f"updated_value in success call: {new_val}, precise_minute: {precise_minute}"
) )
await self.internal_usage_cache.async_set_cache( await self.internal_usage_cache.async_set_cache(
request_count_api_key, new_val, ttl=60 request_count_api_key,
new_val,
ttl=60,
litellm_parent_otel_span=litellm_parent_otel_span,
) # store in cache for 1 min. ) # store in cache for 1 min.
# ------------ # ------------
@ -571,14 +626,15 @@ class _PROXY_MaxParallelRequestsHandler(CustomLogger):
total_tokens = 0 total_tokens = 0
if isinstance(response_obj, ModelResponse): if isinstance(response_obj, ModelResponse):
total_tokens = response_obj.usage.total_tokens total_tokens = response_obj.usage.total_tokens # type: ignore
request_count_api_key = ( request_count_api_key = (
f"{user_api_key_end_user_id}::{precise_minute}::request_count" f"{user_api_key_end_user_id}::{precise_minute}::request_count"
) )
current = await self.internal_usage_cache.async_get_cache( current = await self.internal_usage_cache.async_get_cache(
key=request_count_api_key key=request_count_api_key,
litellm_parent_otel_span=litellm_parent_otel_span,
) or { ) or {
"current_requests": 1, "current_requests": 1,
"current_tpm": total_tokens, "current_tpm": total_tokens,
@ -595,7 +651,10 @@ class _PROXY_MaxParallelRequestsHandler(CustomLogger):
f"updated_value in success call: {new_val}, precise_minute: {precise_minute}" f"updated_value in success call: {new_val}, precise_minute: {precise_minute}"
) )
await self.internal_usage_cache.async_set_cache( await self.internal_usage_cache.async_set_cache(
request_count_api_key, new_val, ttl=60 request_count_api_key,
new_val,
ttl=60,
litellm_parent_otel_span=litellm_parent_otel_span,
) # store in cache for 1 min. ) # store in cache for 1 min.
except Exception as e: except Exception as e:
@ -604,6 +663,9 @@ class _PROXY_MaxParallelRequestsHandler(CustomLogger):
async def async_log_failure_event(self, kwargs, response_obj, start_time, end_time): async def async_log_failure_event(self, kwargs, response_obj, start_time, end_time):
try: try:
self.print_verbose("Inside Max Parallel Request Failure Hook") self.print_verbose("Inside Max Parallel Request Failure Hook")
litellm_parent_otel_span: Union[Span, None] = (
_get_parent_otel_span_from_kwargs(kwargs=kwargs)
)
_metadata = kwargs["litellm_params"].get("metadata", {}) or {} _metadata = kwargs["litellm_params"].get("metadata", {}) or {}
global_max_parallel_requests = _metadata.get( global_max_parallel_requests = _metadata.get(
"global_max_parallel_requests", None "global_max_parallel_requests", None
@ -626,12 +688,17 @@ class _PROXY_MaxParallelRequestsHandler(CustomLogger):
_key = "global_max_parallel_requests" _key = "global_max_parallel_requests"
current_global_requests = ( current_global_requests = (
await self.internal_usage_cache.async_get_cache( await self.internal_usage_cache.async_get_cache(
key=_key, local_only=True key=_key,
local_only=True,
litellm_parent_otel_span=litellm_parent_otel_span,
) )
) )
# decrement # decrement
await self.internal_usage_cache.async_increment_cache( await self.internal_usage_cache.async_increment_cache(
key=_key, value=-1, local_only=True key=_key,
value=-1,
local_only=True,
litellm_parent_otel_span=litellm_parent_otel_span,
) )
current_date = datetime.now().strftime("%Y-%m-%d") current_date = datetime.now().strftime("%Y-%m-%d")
@ -647,7 +714,8 @@ class _PROXY_MaxParallelRequestsHandler(CustomLogger):
# Update usage # Update usage
# ------------ # ------------
current = await self.internal_usage_cache.async_get_cache( current = await self.internal_usage_cache.async_get_cache(
key=request_count_api_key key=request_count_api_key,
litellm_parent_otel_span=litellm_parent_otel_span,
) or { ) or {
"current_requests": 1, "current_requests": 1,
"current_tpm": 0, "current_tpm": 0,
@ -662,7 +730,10 @@ class _PROXY_MaxParallelRequestsHandler(CustomLogger):
self.print_verbose(f"updated_value in failure call: {new_val}") self.print_verbose(f"updated_value in failure call: {new_val}")
await self.internal_usage_cache.async_set_cache( await self.internal_usage_cache.async_set_cache(
request_count_api_key, new_val, ttl=60 request_count_api_key,
new_val,
ttl=60,
litellm_parent_otel_span=litellm_parent_otel_span,
) # save in cache for up to 1 min. ) # save in cache for up to 1 min.
except Exception as e: except Exception as e:
verbose_proxy_logger.exception( verbose_proxy_logger.exception(

View file

@ -21,7 +21,7 @@ model_list:
litellm_settings: litellm_settings:
cache: true cache: true
# callbacks: ["otel"] callbacks: ["otel"]
general_settings: general_settings:

View file

@ -205,6 +205,83 @@ def log_to_opentelemetry(func):
return wrapper return wrapper
class InternalUsageCache:
def __init__(self, dual_cache: DualCache):
self.dual_cache: DualCache = dual_cache
async def async_get_cache(
self,
key,
litellm_parent_otel_span: Union[Span, None],
local_only: bool = False,
**kwargs,
) -> Any:
return await self.dual_cache.async_get_cache(
key=key,
local_only=local_only,
litellm_parent_otel_span=litellm_parent_otel_span,
**kwargs,
)
async def async_set_cache(
self,
key,
value,
litellm_parent_otel_span: Union[Span, None],
local_only: bool = False,
**kwargs,
) -> None:
return await self.dual_cache.async_set_cache(
key=key,
value=value,
local_only=local_only,
litellm_parent_otel_span=litellm_parent_otel_span,
**kwargs,
)
async def async_increment_cache(
self,
key,
value: float,
litellm_parent_otel_span: Union[Span, None],
local_only: bool = False,
**kwargs,
):
return await self.dual_cache.async_increment_cache(
key=key,
value=value,
local_only=local_only,
litellm_parent_otel_span=litellm_parent_otel_span,
**kwargs,
)
def set_cache(
self,
key,
value,
local_only: bool = False,
**kwargs,
) -> None:
return self.dual_cache.set_cache(
key=key,
value=value,
local_only=local_only,
**kwargs,
)
def get_cache(
self,
key,
local_only: bool = False,
**kwargs,
) -> Any:
return self.dual_cache.get_cache(
key=key,
local_only=local_only,
**kwargs,
)
### LOGGING ### ### LOGGING ###
class ProxyLogging: class ProxyLogging:
""" """
@ -222,9 +299,9 @@ class ProxyLogging:
## INITIALIZE LITELLM CALLBACKS ## ## INITIALIZE LITELLM CALLBACKS ##
self.call_details: dict = {} self.call_details: dict = {}
self.call_details["user_api_key_cache"] = user_api_key_cache self.call_details["user_api_key_cache"] = user_api_key_cache
self.internal_usage_cache = DualCache( self.internal_usage_cache: InternalUsageCache = InternalUsageCache(
default_in_memory_ttl=1 dual_cache=DualCache(default_in_memory_ttl=1) # ping redis cache every 1s
) # ping redis cache every 1s )
self.max_parallel_request_limiter = _PROXY_MaxParallelRequestsHandler( self.max_parallel_request_limiter = _PROXY_MaxParallelRequestsHandler(
self.internal_usage_cache self.internal_usage_cache
) )
@ -238,7 +315,7 @@ class ProxyLogging:
alerting_threshold=self.alerting_threshold, alerting_threshold=self.alerting_threshold,
alerting=self.alerting, alerting=self.alerting,
alert_types=self.alert_types, alert_types=self.alert_types,
internal_usage_cache=self.internal_usage_cache, internal_usage_cache=self.internal_usage_cache.dual_cache,
) )
def update_values( def update_values(
@ -283,7 +360,7 @@ class ProxyLogging:
litellm.callbacks.append(self.slack_alerting_instance) # type: ignore litellm.callbacks.append(self.slack_alerting_instance) # type: ignore
if redis_cache is not None: if redis_cache is not None:
self.internal_usage_cache.redis_cache = redis_cache self.internal_usage_cache.dual_cache.redis_cache = redis_cache
def _init_litellm_callbacks(self, llm_router: Optional[litellm.Router] = None): def _init_litellm_callbacks(self, llm_router: Optional[litellm.Router] = None):
self.service_logging_obj = ServiceLogging() self.service_logging_obj = ServiceLogging()
@ -298,7 +375,7 @@ class ProxyLogging:
if isinstance(callback, str): if isinstance(callback, str):
callback = litellm.litellm_core_utils.litellm_logging._init_custom_logger_compatible_class( # type: ignore callback = litellm.litellm_core_utils.litellm_logging._init_custom_logger_compatible_class( # type: ignore
callback, callback,
internal_usage_cache=self.internal_usage_cache, internal_usage_cache=self.internal_usage_cache.dual_cache,
llm_router=llm_router, llm_router=llm_router,
) )
if callback not in litellm.input_callback: if callback not in litellm.input_callback:
@ -347,6 +424,7 @@ class ProxyLogging:
value=status, value=status,
local_only=True, local_only=True,
ttl=alerting_threshold, ttl=alerting_threshold,
litellm_parent_otel_span=None,
) )
async def process_pre_call_hook_response(self, response, data, call_type): async def process_pre_call_hook_response(self, response, data, call_type):

View file

@ -2045,7 +2045,7 @@ async def test_proxy_logging_setup():
from litellm.proxy.utils import ProxyLogging from litellm.proxy.utils import ProxyLogging
pl_obj = ProxyLogging(user_api_key_cache=DualCache()) pl_obj = ProxyLogging(user_api_key_cache=DualCache())
assert pl_obj.internal_usage_cache.always_read_redis is True assert pl_obj.internal_usage_cache.dual_cache.always_read_redis is True
@pytest.mark.skip(reason="local test. Requires sentinel setup.") @pytest.mark.skip(reason="local test. Requires sentinel setup.")

View file

@ -28,7 +28,7 @@ from litellm.proxy._types import UserAPIKeyAuth
from litellm.proxy.hooks.parallel_request_limiter import ( from litellm.proxy.hooks.parallel_request_limiter import (
_PROXY_MaxParallelRequestsHandler as MaxParallelRequestsHandler, _PROXY_MaxParallelRequestsHandler as MaxParallelRequestsHandler,
) )
from litellm.proxy.utils import ProxyLogging, hash_token from litellm.proxy.utils import InternalUsageCache, ProxyLogging, hash_token
## On Request received ## On Request received
## On Request success ## On Request success
@ -48,7 +48,7 @@ async def test_global_max_parallel_requests():
user_api_key_dict = UserAPIKeyAuth(api_key=_api_key, max_parallel_requests=100) user_api_key_dict = UserAPIKeyAuth(api_key=_api_key, max_parallel_requests=100)
local_cache = DualCache() local_cache = DualCache()
parallel_request_handler = MaxParallelRequestsHandler( parallel_request_handler = MaxParallelRequestsHandler(
internal_usage_cache=local_cache internal_usage_cache=InternalUsageCache(dual_cache=local_cache)
) )
for _ in range(3): for _ in range(3):
@ -78,7 +78,7 @@ async def test_pre_call_hook():
user_api_key_dict = UserAPIKeyAuth(api_key=_api_key, max_parallel_requests=1) user_api_key_dict = UserAPIKeyAuth(api_key=_api_key, max_parallel_requests=1)
local_cache = DualCache() local_cache = DualCache()
parallel_request_handler = MaxParallelRequestsHandler( parallel_request_handler = MaxParallelRequestsHandler(
internal_usage_cache=local_cache internal_usage_cache=InternalUsageCache(dual_cache=local_cache)
) )
await parallel_request_handler.async_pre_call_hook( await parallel_request_handler.async_pre_call_hook(
@ -115,7 +115,7 @@ async def test_pre_call_hook_rpm_limits():
) )
local_cache = DualCache() local_cache = DualCache()
parallel_request_handler = MaxParallelRequestsHandler( parallel_request_handler = MaxParallelRequestsHandler(
internal_usage_cache=local_cache internal_usage_cache=InternalUsageCache(dual_cache=local_cache)
) )
await parallel_request_handler.async_pre_call_hook( await parallel_request_handler.async_pre_call_hook(
@ -157,7 +157,7 @@ async def test_pre_call_hook_rpm_limits_retry_after():
) )
local_cache = DualCache() local_cache = DualCache()
parallel_request_handler = MaxParallelRequestsHandler( parallel_request_handler = MaxParallelRequestsHandler(
internal_usage_cache=local_cache internal_usage_cache=InternalUsageCache(dual_cache=local_cache)
) )
await parallel_request_handler.async_pre_call_hook( await parallel_request_handler.async_pre_call_hook(
@ -208,7 +208,7 @@ async def test_pre_call_hook_team_rpm_limits():
) )
local_cache = DualCache() local_cache = DualCache()
parallel_request_handler = MaxParallelRequestsHandler( parallel_request_handler = MaxParallelRequestsHandler(
internal_usage_cache=local_cache internal_usage_cache=InternalUsageCache(dual_cache=local_cache)
) )
await parallel_request_handler.async_pre_call_hook( await parallel_request_handler.async_pre_call_hook(
@ -256,7 +256,7 @@ async def test_pre_call_hook_tpm_limits():
) )
local_cache = DualCache() local_cache = DualCache()
parallel_request_handler = MaxParallelRequestsHandler( parallel_request_handler = MaxParallelRequestsHandler(
internal_usage_cache=local_cache internal_usage_cache=InternalUsageCache(dual_cache=local_cache)
) )
await parallel_request_handler.async_pre_call_hook( await parallel_request_handler.async_pre_call_hook(
@ -308,7 +308,7 @@ async def test_pre_call_hook_user_tpm_limits():
print("dict user", res) print("dict user", res)
parallel_request_handler = MaxParallelRequestsHandler( parallel_request_handler = MaxParallelRequestsHandler(
internal_usage_cache=local_cache internal_usage_cache=InternalUsageCache(dual_cache=local_cache)
) )
await parallel_request_handler.async_pre_call_hook( await parallel_request_handler.async_pre_call_hook(
@ -353,7 +353,7 @@ async def test_success_call_hook():
user_api_key_dict = UserAPIKeyAuth(api_key=_api_key, max_parallel_requests=1) user_api_key_dict = UserAPIKeyAuth(api_key=_api_key, max_parallel_requests=1)
local_cache = DualCache() local_cache = DualCache()
parallel_request_handler = MaxParallelRequestsHandler( parallel_request_handler = MaxParallelRequestsHandler(
internal_usage_cache=local_cache internal_usage_cache=InternalUsageCache(dual_cache=local_cache)
) )
await parallel_request_handler.async_pre_call_hook( await parallel_request_handler.async_pre_call_hook(
@ -397,7 +397,7 @@ async def test_failure_call_hook():
user_api_key_dict = UserAPIKeyAuth(api_key=_api_key, max_parallel_requests=1) user_api_key_dict = UserAPIKeyAuth(api_key=_api_key, max_parallel_requests=1)
local_cache = DualCache() local_cache = DualCache()
parallel_request_handler = MaxParallelRequestsHandler( parallel_request_handler = MaxParallelRequestsHandler(
internal_usage_cache=local_cache internal_usage_cache=InternalUsageCache(dual_cache=local_cache)
) )
await parallel_request_handler.async_pre_call_hook( await parallel_request_handler.async_pre_call_hook(
@ -975,7 +975,7 @@ async def test_bad_router_tpm_limit_per_model():
print( print(
"internal usage cache: ", "internal usage cache: ",
parallel_request_handler.internal_usage_cache.in_memory_cache.cache_dict, parallel_request_handler.internal_usage_cache.dual_cache.in_memory_cache.cache_dict,
) )
assert ( assert (
@ -1161,7 +1161,7 @@ async def test_pre_call_hook_tpm_limits_per_model():
print( print(
"internal usage cache: ", "internal usage cache: ",
parallel_request_handler.internal_usage_cache.in_memory_cache.cache_dict, parallel_request_handler.internal_usage_cache.dual_cache.in_memory_cache.cache_dict,
) )
assert ( assert (

View file

@ -753,10 +753,10 @@ async def test_team_update_redis():
litellm.proxy.proxy_server, "proxy_logging_obj" litellm.proxy.proxy_server, "proxy_logging_obj"
) )
proxy_logging_obj.internal_usage_cache.redis_cache = RedisCache() proxy_logging_obj.internal_usage_cache.dual_cache.redis_cache = RedisCache()
with patch.object( with patch.object(
proxy_logging_obj.internal_usage_cache.redis_cache, proxy_logging_obj.internal_usage_cache.dual_cache.redis_cache,
"async_set_cache", "async_set_cache",
new=AsyncMock(), new=AsyncMock(),
) as mock_client: ) as mock_client:
@ -782,10 +782,10 @@ async def test_get_team_redis(client_no_auth):
litellm.proxy.proxy_server, "proxy_logging_obj" litellm.proxy.proxy_server, "proxy_logging_obj"
) )
proxy_logging_obj.internal_usage_cache.redis_cache = RedisCache() proxy_logging_obj.internal_usage_cache.dual_cache.redis_cache = RedisCache()
with patch.object( with patch.object(
proxy_logging_obj.internal_usage_cache.redis_cache, proxy_logging_obj.internal_usage_cache.dual_cache.redis_cache,
"async_get_cache", "async_get_cache",
new=AsyncMock(), new=AsyncMock(),
) as mock_client: ) as mock_client: