mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-25 18:54:30 +00:00
feat(dynamic_rate_limiter.py): passing base case
This commit is contained in:
parent
a028600932
commit
068e8dff5b
5 changed files with 310 additions and 12 deletions
|
@ -11,6 +11,7 @@ import asyncio
|
|||
import concurrent
|
||||
import copy
|
||||
import datetime as datetime_og
|
||||
import enum
|
||||
import hashlib
|
||||
import inspect
|
||||
import json
|
||||
|
@ -90,6 +91,10 @@ from litellm.utils import (
|
|||
)
|
||||
|
||||
|
||||
class RoutingArgs(enum.Enum):
|
||||
ttl = 60 # 1min (RPM/TPM expire key)
|
||||
|
||||
|
||||
class Router:
|
||||
model_names: List = []
|
||||
cache_responses: Optional[bool] = False
|
||||
|
@ -387,6 +392,11 @@ class Router:
|
|||
routing_strategy=routing_strategy,
|
||||
routing_strategy_args=routing_strategy_args,
|
||||
)
|
||||
## USAGE TRACKING ##
|
||||
if isinstance(litellm._async_success_callback, list):
|
||||
litellm._async_success_callback.append(self.deployment_callback_on_success)
|
||||
else:
|
||||
litellm._async_success_callback.append(self.deployment_callback_on_success)
|
||||
## COOLDOWNS ##
|
||||
if isinstance(litellm.failure_callback, list):
|
||||
litellm.failure_callback.append(self.deployment_callback_on_failure)
|
||||
|
@ -2636,13 +2646,69 @@ class Router:
|
|||
time.sleep(_timeout)
|
||||
|
||||
if type(original_exception) in litellm.LITELLM_EXCEPTION_TYPES:
|
||||
original_exception.max_retries = num_retries
|
||||
original_exception.num_retries = current_attempt
|
||||
setattr(original_exception, "max_retries", num_retries)
|
||||
setattr(original_exception, "num_retries", current_attempt)
|
||||
|
||||
raise original_exception
|
||||
|
||||
### HELPER FUNCTIONS
|
||||
|
||||
async def deployment_callback_on_success(
|
||||
self,
|
||||
kwargs, # kwargs to completion
|
||||
completion_response, # response from completion
|
||||
start_time,
|
||||
end_time, # start/end time
|
||||
):
|
||||
"""
|
||||
Track remaining tpm/rpm quota for model in model_list
|
||||
"""
|
||||
try:
|
||||
"""
|
||||
Update TPM usage on success
|
||||
"""
|
||||
if kwargs["litellm_params"].get("metadata") is None:
|
||||
pass
|
||||
else:
|
||||
model_group = kwargs["litellm_params"]["metadata"].get(
|
||||
"model_group", None
|
||||
)
|
||||
|
||||
id = kwargs["litellm_params"].get("model_info", {}).get("id", None)
|
||||
if model_group is None or id is None:
|
||||
return
|
||||
elif isinstance(id, int):
|
||||
id = str(id)
|
||||
|
||||
total_tokens = completion_response["usage"]["total_tokens"]
|
||||
|
||||
# ------------
|
||||
# Setup values
|
||||
# ------------
|
||||
dt = get_utc_datetime()
|
||||
current_minute = dt.strftime(
|
||||
"%H-%M"
|
||||
) # use the same timezone regardless of system clock
|
||||
|
||||
tpm_key = f"global_router:{id}:tpm:{current_minute}"
|
||||
# ------------
|
||||
# Update usage
|
||||
# ------------
|
||||
# update cache
|
||||
|
||||
## TPM
|
||||
await self.cache.async_increment_cache(
|
||||
key=tpm_key, value=total_tokens, ttl=RoutingArgs.ttl.value
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
verbose_router_logger.error(
|
||||
"litellm.proxy.hooks.prompt_injection_detection.py::async_pre_call_hook(): Exception occured - {}\n{}".format(
|
||||
str(e), traceback.format_exc()
|
||||
)
|
||||
)
|
||||
pass
|
||||
|
||||
def deployment_callback_on_failure(
|
||||
self,
|
||||
kwargs, # kwargs to completion
|
||||
|
@ -3963,6 +4029,35 @@ class Router:
|
|||
|
||||
return model_group_info
|
||||
|
||||
async def get_model_group_usage(self, model_group: str) -> Optional[int]:
|
||||
"""
|
||||
Returns remaining tpm quota for model group
|
||||
"""
|
||||
dt = get_utc_datetime()
|
||||
current_minute = dt.strftime(
|
||||
"%H-%M"
|
||||
) # use the same timezone regardless of system clock
|
||||
tpm_keys: List[str] = []
|
||||
for model in self.model_list:
|
||||
if "model_name" in model and model["model_name"] == model_group:
|
||||
tpm_keys.append(
|
||||
f"global_router:{model['model_info']['id']}:tpm:{current_minute}"
|
||||
)
|
||||
|
||||
## TPM
|
||||
tpm_usage_list: Optional[List] = await self.cache.async_batch_get_cache(
|
||||
keys=tpm_keys
|
||||
)
|
||||
tpm_usage: Optional[int] = None
|
||||
if tpm_usage_list is not None:
|
||||
for t in tpm_usage_list:
|
||||
if isinstance(t, int):
|
||||
if tpm_usage is None:
|
||||
tpm_usage = 0
|
||||
tpm_usage += t
|
||||
|
||||
return tpm_usage
|
||||
|
||||
def get_model_ids(self) -> List[str]:
|
||||
"""
|
||||
Returns list of model id's.
|
||||
|
@ -4890,7 +4985,7 @@ class Router:
|
|||
def reset(self):
|
||||
## clean up on close
|
||||
litellm.success_callback = []
|
||||
litellm.__async_success_callback = []
|
||||
litellm._async_success_callback = []
|
||||
litellm.failure_callback = []
|
||||
litellm._async_failure_callback = []
|
||||
self.retry_policy = None
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue