From 8f95381276b5a89a6719c4635afe1cd9eb10c191 Mon Sep 17 00:00:00 2001 From: Krrish Dholakia Date: Sat, 22 Jun 2024 00:32:29 -0700 Subject: [PATCH] refactor: instrument 'dynamic_rate_limiting' callback on proxy --- docs/my-website/docs/proxy/team_budgets.md | 37 +++++++++++ litellm/__init__.py | 4 +- litellm/litellm_core_utils/litellm_logging.py | 63 ++++++++++++++++++- litellm/proxy/_super_secret_config.yaml | 3 +- litellm/proxy/hooks/dynamic_rate_limiter.py | 1 - litellm/proxy/proxy_server.py | 9 +-- litellm/proxy/utils.py | 42 ++++++++----- litellm/utils.py | 5 +- 8 files changed, 136 insertions(+), 28 deletions(-) diff --git a/docs/my-website/docs/proxy/team_budgets.md b/docs/my-website/docs/proxy/team_budgets.md index 9bfcb35d4c..d7f6d7fdc9 100644 --- a/docs/my-website/docs/proxy/team_budgets.md +++ b/docs/my-website/docs/proxy/team_budgets.md @@ -152,3 +152,40 @@ litellm_remaining_team_budget_metric{team_alias="QA Prod Bot",team_id="de35b29e- ``` +### Dynamic TPM Allocation + +Prevent teams from gobbling too much quota. + +1. Setup config.yaml + +```yaml +model_list: + - model_name: my-fake-model + litellm_params: + model: gpt-3.5-turbo + api_key: my-fake-key + mock_response: hello-world + tpm: 60 + +general_settings: + callbacks: ["dynamic_rate_limiting"] +``` + +2. Start proxy + +```bash +litellm --config /path/to/config.yaml +``` + +3. Test it! + +```python +""" +- Run 2 concurrent teams calling same model +- model has 60 TPM +- Mock response returns 30 total tokens / request +- Each team will only be able to make 1 request per minute +""" + + +``` \ No newline at end of file diff --git a/litellm/__init__.py b/litellm/__init__.py index a191d46bfd..09bc0ab5ff 100644 --- a/litellm/__init__.py +++ b/litellm/__init__.py @@ -37,7 +37,9 @@ input_callback: List[Union[str, Callable]] = [] success_callback: List[Union[str, Callable]] = [] failure_callback: List[Union[str, Callable]] = [] service_callback: List[Union[str, Callable]] = [] -_custom_logger_compatible_callbacks_literal = Literal["lago", "openmeter", "logfire"] +_custom_logger_compatible_callbacks_literal = Literal[ + "lago", "openmeter", "logfire", "dynamic_rate_limiter" +] callbacks: List[Union[Callable, _custom_logger_compatible_callbacks_literal]] = [] _langfuse_default_tags: Optional[ List[ diff --git a/litellm/litellm_core_utils/litellm_logging.py b/litellm/litellm_core_utils/litellm_logging.py index 64e8dbfc9c..ad92140ee3 100644 --- a/litellm/litellm_core_utils/litellm_logging.py +++ b/litellm/litellm_core_utils/litellm_logging.py @@ -19,7 +19,7 @@ from litellm import ( turn_off_message_logging, verbose_logger, ) -from litellm.caching import S3Cache +from litellm.caching import DualCache, S3Cache from litellm.integrations.custom_logger import CustomLogger from litellm.litellm_core_utils.redact_messages import ( redact_message_input_output_from_logging, @@ -1840,7 +1840,11 @@ def set_callbacks(callback_list, function_id=None): def _init_custom_logger_compatible_class( logging_integration: litellm._custom_logger_compatible_callbacks_literal, -) -> Callable: + internal_usage_cache: Optional[DualCache], + llm_router: Optional[ + Any + ], # expect litellm.Router, but typing errors due to circular import +) -> CustomLogger: if logging_integration == "lago": for callback in _in_memory_loggers: if isinstance(callback, LagoLogger): @@ -1876,3 +1880,58 @@ def _init_custom_logger_compatible_class( _otel_logger = OpenTelemetry(config=otel_config) _in_memory_loggers.append(_otel_logger) return _otel_logger # type: ignore + elif logging_integration == "dynamic_rate_limiter": + from litellm.proxy.hooks.dynamic_rate_limiter import ( + _PROXY_DynamicRateLimitHandler, + ) + + for callback in _in_memory_loggers: + if isinstance(callback, _PROXY_DynamicRateLimitHandler): + return callback # type: ignore + + if internal_usage_cache is None: + raise Exception( + "Internal Error: Cache cannot be empty - internal_usage_cache={}".format( + internal_usage_cache + ) + ) + + dynamic_rate_limiter_obj = _PROXY_DynamicRateLimitHandler( + internal_usage_cache=internal_usage_cache + ) + + if llm_router is not None and isinstance(llm_router, litellm.Router): + dynamic_rate_limiter_obj.update_variables(llm_router=llm_router) + _in_memory_loggers.append(dynamic_rate_limiter_obj) + return dynamic_rate_limiter_obj # type: ignore + + +def get_custom_logger_compatible_class( + logging_integration: litellm._custom_logger_compatible_callbacks_literal, +) -> Optional[CustomLogger]: + if logging_integration == "lago": + for callback in _in_memory_loggers: + if isinstance(callback, LagoLogger): + return callback + elif logging_integration == "openmeter": + for callback in _in_memory_loggers: + if isinstance(callback, OpenMeterLogger): + return callback + elif logging_integration == "logfire": + if "LOGFIRE_TOKEN" not in os.environ: + raise ValueError("LOGFIRE_TOKEN not found in environment variables") + from litellm.integrations.opentelemetry import OpenTelemetry + + for callback in _in_memory_loggers: + if isinstance(callback, OpenTelemetry): + return callback # type: ignore + + elif logging_integration == "dynamic_rate_limiter": + from litellm.proxy.hooks.dynamic_rate_limiter import ( + _PROXY_DynamicRateLimitHandler, + ) + + for callback in _in_memory_loggers: + if isinstance(callback, _PROXY_DynamicRateLimitHandler): + return callback # type: ignore + return None diff --git a/litellm/proxy/_super_secret_config.yaml b/litellm/proxy/_super_secret_config.yaml index ac0aaca5c7..ef88f75f61 100644 --- a/litellm/proxy/_super_secret_config.yaml +++ b/litellm/proxy/_super_secret_config.yaml @@ -67,8 +67,7 @@ model_list: max_input_tokens: 80920 litellm_settings: - success_callback: ["langfuse"] - failure_callback: ["langfuse"] + callbacks: ["dynamic_rate_limiter"] # default_team_settings: # - team_id: proj1 # success_callback: ["langfuse"] diff --git a/litellm/proxy/hooks/dynamic_rate_limiter.py b/litellm/proxy/hooks/dynamic_rate_limiter.py index dfe9632155..c0511e30c8 100644 --- a/litellm/proxy/hooks/dynamic_rate_limiter.py +++ b/litellm/proxy/hooks/dynamic_rate_limiter.py @@ -131,7 +131,6 @@ class _PROXY_DynamicRateLimitHandler(CustomLogger): - Check if tpm available - Raise RateLimitError if no tpm available """ - if "model" in data: available_tpm, model_tpm, active_projects = await self.check_available_tpm( model=data["model"] diff --git a/litellm/proxy/proxy_server.py b/litellm/proxy/proxy_server.py index 8eac72629a..b2eaea8f36 100644 --- a/litellm/proxy/proxy_server.py +++ b/litellm/proxy/proxy_server.py @@ -2644,7 +2644,9 @@ async def startup_event(): redis_cache=redis_usage_cache ) # used by parallel request limiter for rate limiting keys across instances - proxy_logging_obj._init_litellm_callbacks() # INITIALIZE LITELLM CALLBACKS ON SERVER STARTUP <- do this to catch any logging errors on startup, not when calls are being made + proxy_logging_obj._init_litellm_callbacks( + llm_router=llm_router + ) # INITIALIZE LITELLM CALLBACKS ON SERVER STARTUP <- do this to catch any logging errors on startup, not when calls are being made if "daily_reports" in proxy_logging_obj.slack_alerting_instance.alert_types: asyncio.create_task( @@ -3116,11 +3118,10 @@ async def chat_completion( except Exception as e: data["litellm_status"] = "fail" # used for alerting verbose_proxy_logger.error( - "litellm.proxy.proxy_server.chat_completion(): Exception occured - {}".format( - get_error_message_str(e=e) + "litellm.proxy.proxy_server.chat_completion(): Exception occured - {}\n{}".format( + get_error_message_str(e=e), traceback.format_exc() ) ) - verbose_proxy_logger.debug(traceback.format_exc()) await proxy_logging_obj.post_call_failure_hook( user_api_key_dict=user_api_key_dict, original_exception=e, request_data=data ) diff --git a/litellm/proxy/utils.py b/litellm/proxy/utils.py index d1e1d3576f..5acba95c2e 100644 --- a/litellm/proxy/utils.py +++ b/litellm/proxy/utils.py @@ -229,31 +229,32 @@ class ProxyLogging: if redis_cache is not None: self.internal_usage_cache.redis_cache = redis_cache - def _init_litellm_callbacks(self): - print_verbose("INITIALIZING LITELLM CALLBACKS!") + def _init_litellm_callbacks(self, llm_router: Optional[litellm.Router] = None): self.service_logging_obj = ServiceLogging() - litellm.callbacks.append(self.max_parallel_request_limiter) - litellm.callbacks.append(self.max_budget_limiter) - litellm.callbacks.append(self.cache_control_check) - litellm.callbacks.append(self.service_logging_obj) + litellm.callbacks.append(self.max_parallel_request_limiter) # type: ignore + litellm.callbacks.append(self.max_budget_limiter) # type: ignore + litellm.callbacks.append(self.cache_control_check) # type: ignore + litellm.callbacks.append(self.service_logging_obj) # type: ignore litellm.success_callback.append( self.slack_alerting_instance.response_taking_too_long_callback ) for callback in litellm.callbacks: if isinstance(callback, str): - callback = litellm.litellm_core_utils.litellm_logging._init_custom_logger_compatible_class( - callback + callback = litellm.litellm_core_utils.litellm_logging._init_custom_logger_compatible_class( # type: ignore + callback, + internal_usage_cache=self.internal_usage_cache, + llm_router=llm_router, ) if callback not in litellm.input_callback: - litellm.input_callback.append(callback) + litellm.input_callback.append(callback) # type: ignore if callback not in litellm.success_callback: - litellm.success_callback.append(callback) + litellm.success_callback.append(callback) # type: ignore if callback not in litellm.failure_callback: - litellm.failure_callback.append(callback) + litellm.failure_callback.append(callback) # type: ignore if callback not in litellm._async_success_callback: - litellm._async_success_callback.append(callback) + litellm._async_success_callback.append(callback) # type: ignore if callback not in litellm._async_failure_callback: - litellm._async_failure_callback.append(callback) + litellm._async_failure_callback.append(callback) # type: ignore if ( len(litellm.input_callback) > 0 @@ -301,10 +302,19 @@ class ProxyLogging: try: for callback in litellm.callbacks: - if isinstance(callback, CustomLogger) and "async_pre_call_hook" in vars( - callback.__class__ + _callback: Optional[CustomLogger] = None + if isinstance(callback, str): + _callback = litellm.litellm_core_utils.litellm_logging.get_custom_logger_compatible_class( + callback + ) + else: + _callback = callback # type: ignore + if ( + _callback is not None + and isinstance(_callback, CustomLogger) + and "async_pre_call_hook" in vars(_callback.__class__) ): - response = await callback.async_pre_call_hook( + response = await _callback.async_pre_call_hook( user_api_key_dict=user_api_key_dict, cache=self.call_details["user_api_key_cache"], data=data, diff --git a/litellm/utils.py b/litellm/utils.py index 7a64200760..b734ae465a 100644 --- a/litellm/utils.py +++ b/litellm/utils.py @@ -340,14 +340,15 @@ def function_setup( ) try: global callback_list, add_breadcrumb, user_logger_fn, Logging + function_id = kwargs["id"] if "id" in kwargs else None if len(litellm.callbacks) > 0: for callback in litellm.callbacks: # check if callback is a string - e.g. "lago", "openmeter" if isinstance(callback, str): - callback = litellm.litellm_core_utils.litellm_logging._init_custom_logger_compatible_class( - callback + callback = litellm.litellm_core_utils.litellm_logging._init_custom_logger_compatible_class( # type: ignore + callback, internal_usage_cache=None, llm_router=None ) if any( isinstance(cb, type(callback))