mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-26 19:24:27 +00:00
refactor: instrument 'dynamic_rate_limiting' callback on proxy
This commit is contained in:
parent
6a7982fa40
commit
8f95381276
8 changed files with 136 additions and 28 deletions
|
@ -152,3 +152,40 @@ litellm_remaining_team_budget_metric{team_alias="QA Prod Bot",team_id="de35b29e-
|
|||
```
|
||||
|
||||
|
||||
### Dynamic TPM Allocation
|
||||
|
||||
Prevent teams from gobbling too much quota.
|
||||
|
||||
1. Setup config.yaml
|
||||
|
||||
```yaml
|
||||
model_list:
|
||||
- model_name: my-fake-model
|
||||
litellm_params:
|
||||
model: gpt-3.5-turbo
|
||||
api_key: my-fake-key
|
||||
mock_response: hello-world
|
||||
tpm: 60
|
||||
|
||||
general_settings:
|
||||
callbacks: ["dynamic_rate_limiting"]
|
||||
```
|
||||
|
||||
2. Start proxy
|
||||
|
||||
```bash
|
||||
litellm --config /path/to/config.yaml
|
||||
```
|
||||
|
||||
3. Test it!
|
||||
|
||||
```python
|
||||
"""
|
||||
- Run 2 concurrent teams calling same model
|
||||
- model has 60 TPM
|
||||
- Mock response returns 30 total tokens / request
|
||||
- Each team will only be able to make 1 request per minute
|
||||
"""
|
||||
|
||||
|
||||
```
|
|
@ -37,7 +37,9 @@ input_callback: List[Union[str, Callable]] = []
|
|||
success_callback: List[Union[str, Callable]] = []
|
||||
failure_callback: List[Union[str, Callable]] = []
|
||||
service_callback: List[Union[str, Callable]] = []
|
||||
_custom_logger_compatible_callbacks_literal = Literal["lago", "openmeter", "logfire"]
|
||||
_custom_logger_compatible_callbacks_literal = Literal[
|
||||
"lago", "openmeter", "logfire", "dynamic_rate_limiter"
|
||||
]
|
||||
callbacks: List[Union[Callable, _custom_logger_compatible_callbacks_literal]] = []
|
||||
_langfuse_default_tags: Optional[
|
||||
List[
|
||||
|
|
|
@ -19,7 +19,7 @@ from litellm import (
|
|||
turn_off_message_logging,
|
||||
verbose_logger,
|
||||
)
|
||||
from litellm.caching import S3Cache
|
||||
from litellm.caching import DualCache, S3Cache
|
||||
from litellm.integrations.custom_logger import CustomLogger
|
||||
from litellm.litellm_core_utils.redact_messages import (
|
||||
redact_message_input_output_from_logging,
|
||||
|
@ -1840,7 +1840,11 @@ def set_callbacks(callback_list, function_id=None):
|
|||
|
||||
def _init_custom_logger_compatible_class(
|
||||
logging_integration: litellm._custom_logger_compatible_callbacks_literal,
|
||||
) -> Callable:
|
||||
internal_usage_cache: Optional[DualCache],
|
||||
llm_router: Optional[
|
||||
Any
|
||||
], # expect litellm.Router, but typing errors due to circular import
|
||||
) -> CustomLogger:
|
||||
if logging_integration == "lago":
|
||||
for callback in _in_memory_loggers:
|
||||
if isinstance(callback, LagoLogger):
|
||||
|
@ -1876,3 +1880,58 @@ def _init_custom_logger_compatible_class(
|
|||
_otel_logger = OpenTelemetry(config=otel_config)
|
||||
_in_memory_loggers.append(_otel_logger)
|
||||
return _otel_logger # type: ignore
|
||||
elif logging_integration == "dynamic_rate_limiter":
|
||||
from litellm.proxy.hooks.dynamic_rate_limiter import (
|
||||
_PROXY_DynamicRateLimitHandler,
|
||||
)
|
||||
|
||||
for callback in _in_memory_loggers:
|
||||
if isinstance(callback, _PROXY_DynamicRateLimitHandler):
|
||||
return callback # type: ignore
|
||||
|
||||
if internal_usage_cache is None:
|
||||
raise Exception(
|
||||
"Internal Error: Cache cannot be empty - internal_usage_cache={}".format(
|
||||
internal_usage_cache
|
||||
)
|
||||
)
|
||||
|
||||
dynamic_rate_limiter_obj = _PROXY_DynamicRateLimitHandler(
|
||||
internal_usage_cache=internal_usage_cache
|
||||
)
|
||||
|
||||
if llm_router is not None and isinstance(llm_router, litellm.Router):
|
||||
dynamic_rate_limiter_obj.update_variables(llm_router=llm_router)
|
||||
_in_memory_loggers.append(dynamic_rate_limiter_obj)
|
||||
return dynamic_rate_limiter_obj # type: ignore
|
||||
|
||||
|
||||
def get_custom_logger_compatible_class(
|
||||
logging_integration: litellm._custom_logger_compatible_callbacks_literal,
|
||||
) -> Optional[CustomLogger]:
|
||||
if logging_integration == "lago":
|
||||
for callback in _in_memory_loggers:
|
||||
if isinstance(callback, LagoLogger):
|
||||
return callback
|
||||
elif logging_integration == "openmeter":
|
||||
for callback in _in_memory_loggers:
|
||||
if isinstance(callback, OpenMeterLogger):
|
||||
return callback
|
||||
elif logging_integration == "logfire":
|
||||
if "LOGFIRE_TOKEN" not in os.environ:
|
||||
raise ValueError("LOGFIRE_TOKEN not found in environment variables")
|
||||
from litellm.integrations.opentelemetry import OpenTelemetry
|
||||
|
||||
for callback in _in_memory_loggers:
|
||||
if isinstance(callback, OpenTelemetry):
|
||||
return callback # type: ignore
|
||||
|
||||
elif logging_integration == "dynamic_rate_limiter":
|
||||
from litellm.proxy.hooks.dynamic_rate_limiter import (
|
||||
_PROXY_DynamicRateLimitHandler,
|
||||
)
|
||||
|
||||
for callback in _in_memory_loggers:
|
||||
if isinstance(callback, _PROXY_DynamicRateLimitHandler):
|
||||
return callback # type: ignore
|
||||
return None
|
||||
|
|
|
@ -67,8 +67,7 @@ model_list:
|
|||
max_input_tokens: 80920
|
||||
|
||||
litellm_settings:
|
||||
success_callback: ["langfuse"]
|
||||
failure_callback: ["langfuse"]
|
||||
callbacks: ["dynamic_rate_limiter"]
|
||||
# default_team_settings:
|
||||
# - team_id: proj1
|
||||
# success_callback: ["langfuse"]
|
||||
|
|
|
@ -131,7 +131,6 @@ class _PROXY_DynamicRateLimitHandler(CustomLogger):
|
|||
- Check if tpm available
|
||||
- Raise RateLimitError if no tpm available
|
||||
"""
|
||||
|
||||
if "model" in data:
|
||||
available_tpm, model_tpm, active_projects = await self.check_available_tpm(
|
||||
model=data["model"]
|
||||
|
|
|
@ -2644,7 +2644,9 @@ async def startup_event():
|
|||
redis_cache=redis_usage_cache
|
||||
) # used by parallel request limiter for rate limiting keys across instances
|
||||
|
||||
proxy_logging_obj._init_litellm_callbacks() # INITIALIZE LITELLM CALLBACKS ON SERVER STARTUP <- do this to catch any logging errors on startup, not when calls are being made
|
||||
proxy_logging_obj._init_litellm_callbacks(
|
||||
llm_router=llm_router
|
||||
) # INITIALIZE LITELLM CALLBACKS ON SERVER STARTUP <- do this to catch any logging errors on startup, not when calls are being made
|
||||
|
||||
if "daily_reports" in proxy_logging_obj.slack_alerting_instance.alert_types:
|
||||
asyncio.create_task(
|
||||
|
@ -3116,11 +3118,10 @@ async def chat_completion(
|
|||
except Exception as e:
|
||||
data["litellm_status"] = "fail" # used for alerting
|
||||
verbose_proxy_logger.error(
|
||||
"litellm.proxy.proxy_server.chat_completion(): Exception occured - {}".format(
|
||||
get_error_message_str(e=e)
|
||||
"litellm.proxy.proxy_server.chat_completion(): Exception occured - {}\n{}".format(
|
||||
get_error_message_str(e=e), traceback.format_exc()
|
||||
)
|
||||
)
|
||||
verbose_proxy_logger.debug(traceback.format_exc())
|
||||
await proxy_logging_obj.post_call_failure_hook(
|
||||
user_api_key_dict=user_api_key_dict, original_exception=e, request_data=data
|
||||
)
|
||||
|
|
|
@ -229,31 +229,32 @@ class ProxyLogging:
|
|||
if redis_cache is not None:
|
||||
self.internal_usage_cache.redis_cache = redis_cache
|
||||
|
||||
def _init_litellm_callbacks(self):
|
||||
print_verbose("INITIALIZING LITELLM CALLBACKS!")
|
||||
def _init_litellm_callbacks(self, llm_router: Optional[litellm.Router] = None):
|
||||
self.service_logging_obj = ServiceLogging()
|
||||
litellm.callbacks.append(self.max_parallel_request_limiter)
|
||||
litellm.callbacks.append(self.max_budget_limiter)
|
||||
litellm.callbacks.append(self.cache_control_check)
|
||||
litellm.callbacks.append(self.service_logging_obj)
|
||||
litellm.callbacks.append(self.max_parallel_request_limiter) # type: ignore
|
||||
litellm.callbacks.append(self.max_budget_limiter) # type: ignore
|
||||
litellm.callbacks.append(self.cache_control_check) # type: ignore
|
||||
litellm.callbacks.append(self.service_logging_obj) # type: ignore
|
||||
litellm.success_callback.append(
|
||||
self.slack_alerting_instance.response_taking_too_long_callback
|
||||
)
|
||||
for callback in litellm.callbacks:
|
||||
if isinstance(callback, str):
|
||||
callback = litellm.litellm_core_utils.litellm_logging._init_custom_logger_compatible_class(
|
||||
callback
|
||||
callback = litellm.litellm_core_utils.litellm_logging._init_custom_logger_compatible_class( # type: ignore
|
||||
callback,
|
||||
internal_usage_cache=self.internal_usage_cache,
|
||||
llm_router=llm_router,
|
||||
)
|
||||
if callback not in litellm.input_callback:
|
||||
litellm.input_callback.append(callback)
|
||||
litellm.input_callback.append(callback) # type: ignore
|
||||
if callback not in litellm.success_callback:
|
||||
litellm.success_callback.append(callback)
|
||||
litellm.success_callback.append(callback) # type: ignore
|
||||
if callback not in litellm.failure_callback:
|
||||
litellm.failure_callback.append(callback)
|
||||
litellm.failure_callback.append(callback) # type: ignore
|
||||
if callback not in litellm._async_success_callback:
|
||||
litellm._async_success_callback.append(callback)
|
||||
litellm._async_success_callback.append(callback) # type: ignore
|
||||
if callback not in litellm._async_failure_callback:
|
||||
litellm._async_failure_callback.append(callback)
|
||||
litellm._async_failure_callback.append(callback) # type: ignore
|
||||
|
||||
if (
|
||||
len(litellm.input_callback) > 0
|
||||
|
@ -301,10 +302,19 @@ class ProxyLogging:
|
|||
|
||||
try:
|
||||
for callback in litellm.callbacks:
|
||||
if isinstance(callback, CustomLogger) and "async_pre_call_hook" in vars(
|
||||
callback.__class__
|
||||
_callback: Optional[CustomLogger] = None
|
||||
if isinstance(callback, str):
|
||||
_callback = litellm.litellm_core_utils.litellm_logging.get_custom_logger_compatible_class(
|
||||
callback
|
||||
)
|
||||
else:
|
||||
_callback = callback # type: ignore
|
||||
if (
|
||||
_callback is not None
|
||||
and isinstance(_callback, CustomLogger)
|
||||
and "async_pre_call_hook" in vars(_callback.__class__)
|
||||
):
|
||||
response = await callback.async_pre_call_hook(
|
||||
response = await _callback.async_pre_call_hook(
|
||||
user_api_key_dict=user_api_key_dict,
|
||||
cache=self.call_details["user_api_key_cache"],
|
||||
data=data,
|
||||
|
|
|
@ -340,14 +340,15 @@ def function_setup(
|
|||
)
|
||||
try:
|
||||
global callback_list, add_breadcrumb, user_logger_fn, Logging
|
||||
|
||||
function_id = kwargs["id"] if "id" in kwargs else None
|
||||
|
||||
if len(litellm.callbacks) > 0:
|
||||
for callback in litellm.callbacks:
|
||||
# check if callback is a string - e.g. "lago", "openmeter"
|
||||
if isinstance(callback, str):
|
||||
callback = litellm.litellm_core_utils.litellm_logging._init_custom_logger_compatible_class(
|
||||
callback
|
||||
callback = litellm.litellm_core_utils.litellm_logging._init_custom_logger_compatible_class( # type: ignore
|
||||
callback, internal_usage_cache=None, llm_router=None
|
||||
)
|
||||
if any(
|
||||
isinstance(cb, type(callback))
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue