refactor: instrument 'dynamic_rate_limiting' callback on proxy

This commit is contained in:
Krrish Dholakia 2024-06-22 00:32:29 -07:00
parent 6a7982fa40
commit 8f95381276
8 changed files with 136 additions and 28 deletions

View file

@ -152,3 +152,40 @@ litellm_remaining_team_budget_metric{team_alias="QA Prod Bot",team_id="de35b29e-
```
### Dynamic TPM Allocation
Prevent teams from gobbling too much quota.
1. Setup config.yaml
```yaml
model_list:
- model_name: my-fake-model
litellm_params:
model: gpt-3.5-turbo
api_key: my-fake-key
mock_response: hello-world
tpm: 60
general_settings:
callbacks: ["dynamic_rate_limiting"]
```
2. Start proxy
```bash
litellm --config /path/to/config.yaml
```
3. Test it!
```python
"""
- Run 2 concurrent teams calling same model
- model has 60 TPM
- Mock response returns 30 total tokens / request
- Each team will only be able to make 1 request per minute
"""
```

View file

@ -37,7 +37,9 @@ input_callback: List[Union[str, Callable]] = []
success_callback: List[Union[str, Callable]] = []
failure_callback: List[Union[str, Callable]] = []
service_callback: List[Union[str, Callable]] = []
_custom_logger_compatible_callbacks_literal = Literal["lago", "openmeter", "logfire"]
_custom_logger_compatible_callbacks_literal = Literal[
"lago", "openmeter", "logfire", "dynamic_rate_limiter"
]
callbacks: List[Union[Callable, _custom_logger_compatible_callbacks_literal]] = []
_langfuse_default_tags: Optional[
List[

View file

@ -19,7 +19,7 @@ from litellm import (
turn_off_message_logging,
verbose_logger,
)
from litellm.caching import S3Cache
from litellm.caching import DualCache, S3Cache
from litellm.integrations.custom_logger import CustomLogger
from litellm.litellm_core_utils.redact_messages import (
redact_message_input_output_from_logging,
@ -1840,7 +1840,11 @@ def set_callbacks(callback_list, function_id=None):
def _init_custom_logger_compatible_class(
logging_integration: litellm._custom_logger_compatible_callbacks_literal,
) -> Callable:
internal_usage_cache: Optional[DualCache],
llm_router: Optional[
Any
], # expect litellm.Router, but typing errors due to circular import
) -> CustomLogger:
if logging_integration == "lago":
for callback in _in_memory_loggers:
if isinstance(callback, LagoLogger):
@ -1876,3 +1880,58 @@ def _init_custom_logger_compatible_class(
_otel_logger = OpenTelemetry(config=otel_config)
_in_memory_loggers.append(_otel_logger)
return _otel_logger # type: ignore
elif logging_integration == "dynamic_rate_limiter":
from litellm.proxy.hooks.dynamic_rate_limiter import (
_PROXY_DynamicRateLimitHandler,
)
for callback in _in_memory_loggers:
if isinstance(callback, _PROXY_DynamicRateLimitHandler):
return callback # type: ignore
if internal_usage_cache is None:
raise Exception(
"Internal Error: Cache cannot be empty - internal_usage_cache={}".format(
internal_usage_cache
)
)
dynamic_rate_limiter_obj = _PROXY_DynamicRateLimitHandler(
internal_usage_cache=internal_usage_cache
)
if llm_router is not None and isinstance(llm_router, litellm.Router):
dynamic_rate_limiter_obj.update_variables(llm_router=llm_router)
_in_memory_loggers.append(dynamic_rate_limiter_obj)
return dynamic_rate_limiter_obj # type: ignore
def get_custom_logger_compatible_class(
logging_integration: litellm._custom_logger_compatible_callbacks_literal,
) -> Optional[CustomLogger]:
if logging_integration == "lago":
for callback in _in_memory_loggers:
if isinstance(callback, LagoLogger):
return callback
elif logging_integration == "openmeter":
for callback in _in_memory_loggers:
if isinstance(callback, OpenMeterLogger):
return callback
elif logging_integration == "logfire":
if "LOGFIRE_TOKEN" not in os.environ:
raise ValueError("LOGFIRE_TOKEN not found in environment variables")
from litellm.integrations.opentelemetry import OpenTelemetry
for callback in _in_memory_loggers:
if isinstance(callback, OpenTelemetry):
return callback # type: ignore
elif logging_integration == "dynamic_rate_limiter":
from litellm.proxy.hooks.dynamic_rate_limiter import (
_PROXY_DynamicRateLimitHandler,
)
for callback in _in_memory_loggers:
if isinstance(callback, _PROXY_DynamicRateLimitHandler):
return callback # type: ignore
return None

View file

@ -67,8 +67,7 @@ model_list:
max_input_tokens: 80920
litellm_settings:
success_callback: ["langfuse"]
failure_callback: ["langfuse"]
callbacks: ["dynamic_rate_limiter"]
# default_team_settings:
# - team_id: proj1
# success_callback: ["langfuse"]

View file

@ -131,7 +131,6 @@ class _PROXY_DynamicRateLimitHandler(CustomLogger):
- Check if tpm available
- Raise RateLimitError if no tpm available
"""
if "model" in data:
available_tpm, model_tpm, active_projects = await self.check_available_tpm(
model=data["model"]

View file

@ -2644,7 +2644,9 @@ async def startup_event():
redis_cache=redis_usage_cache
) # used by parallel request limiter for rate limiting keys across instances
proxy_logging_obj._init_litellm_callbacks() # INITIALIZE LITELLM CALLBACKS ON SERVER STARTUP <- do this to catch any logging errors on startup, not when calls are being made
proxy_logging_obj._init_litellm_callbacks(
llm_router=llm_router
) # INITIALIZE LITELLM CALLBACKS ON SERVER STARTUP <- do this to catch any logging errors on startup, not when calls are being made
if "daily_reports" in proxy_logging_obj.slack_alerting_instance.alert_types:
asyncio.create_task(
@ -3116,11 +3118,10 @@ async def chat_completion(
except Exception as e:
data["litellm_status"] = "fail" # used for alerting
verbose_proxy_logger.error(
"litellm.proxy.proxy_server.chat_completion(): Exception occured - {}".format(
get_error_message_str(e=e)
"litellm.proxy.proxy_server.chat_completion(): Exception occured - {}\n{}".format(
get_error_message_str(e=e), traceback.format_exc()
)
)
verbose_proxy_logger.debug(traceback.format_exc())
await proxy_logging_obj.post_call_failure_hook(
user_api_key_dict=user_api_key_dict, original_exception=e, request_data=data
)

View file

@ -229,31 +229,32 @@ class ProxyLogging:
if redis_cache is not None:
self.internal_usage_cache.redis_cache = redis_cache
def _init_litellm_callbacks(self):
print_verbose("INITIALIZING LITELLM CALLBACKS!")
def _init_litellm_callbacks(self, llm_router: Optional[litellm.Router] = None):
self.service_logging_obj = ServiceLogging()
litellm.callbacks.append(self.max_parallel_request_limiter)
litellm.callbacks.append(self.max_budget_limiter)
litellm.callbacks.append(self.cache_control_check)
litellm.callbacks.append(self.service_logging_obj)
litellm.callbacks.append(self.max_parallel_request_limiter) # type: ignore
litellm.callbacks.append(self.max_budget_limiter) # type: ignore
litellm.callbacks.append(self.cache_control_check) # type: ignore
litellm.callbacks.append(self.service_logging_obj) # type: ignore
litellm.success_callback.append(
self.slack_alerting_instance.response_taking_too_long_callback
)
for callback in litellm.callbacks:
if isinstance(callback, str):
callback = litellm.litellm_core_utils.litellm_logging._init_custom_logger_compatible_class(
callback
callback = litellm.litellm_core_utils.litellm_logging._init_custom_logger_compatible_class( # type: ignore
callback,
internal_usage_cache=self.internal_usage_cache,
llm_router=llm_router,
)
if callback not in litellm.input_callback:
litellm.input_callback.append(callback)
litellm.input_callback.append(callback) # type: ignore
if callback not in litellm.success_callback:
litellm.success_callback.append(callback)
litellm.success_callback.append(callback) # type: ignore
if callback not in litellm.failure_callback:
litellm.failure_callback.append(callback)
litellm.failure_callback.append(callback) # type: ignore
if callback not in litellm._async_success_callback:
litellm._async_success_callback.append(callback)
litellm._async_success_callback.append(callback) # type: ignore
if callback not in litellm._async_failure_callback:
litellm._async_failure_callback.append(callback)
litellm._async_failure_callback.append(callback) # type: ignore
if (
len(litellm.input_callback) > 0
@ -301,10 +302,19 @@ class ProxyLogging:
try:
for callback in litellm.callbacks:
if isinstance(callback, CustomLogger) and "async_pre_call_hook" in vars(
callback.__class__
_callback: Optional[CustomLogger] = None
if isinstance(callback, str):
_callback = litellm.litellm_core_utils.litellm_logging.get_custom_logger_compatible_class(
callback
)
else:
_callback = callback # type: ignore
if (
_callback is not None
and isinstance(_callback, CustomLogger)
and "async_pre_call_hook" in vars(_callback.__class__)
):
response = await callback.async_pre_call_hook(
response = await _callback.async_pre_call_hook(
user_api_key_dict=user_api_key_dict,
cache=self.call_details["user_api_key_cache"],
data=data,

View file

@ -340,14 +340,15 @@ def function_setup(
)
try:
global callback_list, add_breadcrumb, user_logger_fn, Logging
function_id = kwargs["id"] if "id" in kwargs else None
if len(litellm.callbacks) > 0:
for callback in litellm.callbacks:
# check if callback is a string - e.g. "lago", "openmeter"
if isinstance(callback, str):
callback = litellm.litellm_core_utils.litellm_logging._init_custom_logger_compatible_class(
callback
callback = litellm.litellm_core_utils.litellm_logging._init_custom_logger_compatible_class( # type: ignore
callback, internal_usage_cache=None, llm_router=None
)
if any(
isinstance(cb, type(callback))