fix(dynamic_rate_limiter.py): support dynamic rate limiting on rpm

This commit is contained in:
Krrish Dholakia 2024-07-01 17:45:10 -07:00
parent b6f509a745
commit d528e263c2
2 changed files with 101 additions and 28 deletions

View file

@ -81,28 +81,36 @@ class _PROXY_DynamicRateLimitHandler(CustomLogger):
def update_variables(self, llm_router: Router):
self.llm_router = llm_router
async def check_available_tpm(
async def check_available_usage(
self, model: str
) -> Tuple[Optional[int], Optional[int], Optional[int]]:
) -> Tuple[
Optional[int], Optional[int], Optional[int], Optional[int], Optional[int]
]:
"""
For a given model, get its available tpm
Returns
- Tuple[available_tpm, model_tpm, active_projects]
- Tuple[available_tpm, available_tpm, model_tpm, model_rpm, active_projects]
- available_tpm: int or null - always 0 or positive.
- available_tpm: int or null - always 0 or positive.
- remaining_model_tpm: int or null. If available tpm is int, then this will be too.
- remaining_model_rpm: int or null. If available rpm is int, then this will be too.
- active_projects: int or null
"""
active_projects = await self.internal_usage_cache.async_get_cache(model=model)
current_model_tpm: Optional[int] = await self.llm_router.get_model_group_usage(
model_group=model
current_model_tpm, current_model_rpm = (
await self.llm_router.get_model_group_usage(model_group=model)
)
model_group_info: Optional[ModelGroupInfo] = (
self.llm_router.get_model_group_info(model_group=model)
)
total_model_tpm: Optional[int] = None
if model_group_info is not None and model_group_info.tpm is not None:
total_model_tpm = model_group_info.tpm
total_model_rpm: Optional[int] = None
if model_group_info is not None:
if model_group_info.tpm is not None:
total_model_tpm = model_group_info.tpm
if model_group_info.rpm is not None:
total_model_rpm = model_group_info.rpm
remaining_model_tpm: Optional[int] = None
if total_model_tpm is not None and current_model_tpm is not None:
@ -110,6 +118,12 @@ class _PROXY_DynamicRateLimitHandler(CustomLogger):
elif total_model_tpm is not None:
remaining_model_tpm = total_model_tpm
remaining_model_rpm: Optional[int] = None
if total_model_rpm is not None and current_model_rpm is not None:
remaining_model_rpm = total_model_rpm - current_model_rpm
elif total_model_rpm is not None:
remaining_model_rpm = total_model_rpm
available_tpm: Optional[int] = None
if remaining_model_tpm is not None:
@ -120,7 +134,24 @@ class _PROXY_DynamicRateLimitHandler(CustomLogger):
if available_tpm is not None and available_tpm < 0:
available_tpm = 0
return available_tpm, remaining_model_tpm, active_projects
available_rpm: Optional[int] = None
if remaining_model_rpm is not None:
if active_projects is not None:
available_rpm = int(remaining_model_rpm / active_projects)
else:
available_rpm = remaining_model_rpm
if available_rpm is not None and available_rpm < 0:
available_rpm = 0
return (
available_tpm,
available_rpm,
remaining_model_tpm,
remaining_model_rpm,
active_projects,
)
async def async_pre_call_hook(
self,
@ -140,13 +171,14 @@ class _PROXY_DynamicRateLimitHandler(CustomLogger):
]: # raise exception if invalid, return a str for the user to receive - if rejected, or return a modified dictionary for passing into litellm
"""
- For a model group
- Check if tpm available
- Raise RateLimitError if no tpm available
- Check if tpm/rpm available
- Raise RateLimitError if no tpm/rpm available
"""
if "model" in data:
available_tpm, model_tpm, active_projects = await self.check_available_tpm(
model=data["model"]
available_tpm, available_rpm, model_tpm, model_rpm, active_projects = (
await self.check_available_usage(model=data["model"])
)
### CHECK TPM ###
if available_tpm is not None and available_tpm == 0:
raise HTTPException(
status_code=429,
@ -159,7 +191,20 @@ class _PROXY_DynamicRateLimitHandler(CustomLogger):
)
},
)
elif available_tpm is not None:
### CHECK RPM ###
elif available_rpm is not None and available_rpm == 0:
raise HTTPException(
status_code=429,
detail={
"error": "Key={} over available RPM={}. Model RPM={}, Active keys={}".format(
user_api_key_dict.api_key,
available_rpm,
model_rpm,
active_projects,
)
},
)
elif available_rpm is not None or available_tpm is not None:
## UPDATE CACHE WITH ACTIVE PROJECT
asyncio.create_task(
self.internal_usage_cache.async_set_cache_sadd( # this is a set
@ -182,15 +227,19 @@ class _PROXY_DynamicRateLimitHandler(CustomLogger):
), "Model info for model with id={} is None".format(
response._hidden_params["model_id"]
)
available_tpm, remaining_model_tpm, active_projects = (
await self.check_available_tpm(model=model_info["model_name"])
available_tpm, available_rpm, model_tpm, model_rpm, active_projects = (
await self.check_available_usage(model=model_info["model_name"])
)
response._hidden_params["additional_headers"] = (
{ # Add additional response headers - easier debugging
"x-litellm-model_group": model_info["model_name"],
"x-ratelimit-remaining-litellm-project-tokens": available_tpm,
"x-ratelimit-remaining-litellm-project-requests": available_rpm,
"x-ratelimit-remaining-model-tokens": model_tpm,
"x-ratelimit-remaining-model-requests": model_tpm,
"x-ratelimit-current-active-projects": active_projects,
}
)
response._hidden_params["additional_headers"] = {
"x-litellm-model_group": model_info["model_name"],
"x-ratelimit-remaining-litellm-project-tokens": available_tpm,
"x-ratelimit-remaining-model-tokens": remaining_model_tpm,
"x-ratelimit-current-active-projects": active_projects,
}
return response
return await super().async_post_call_success_hook(

View file

@ -4191,25 +4191,42 @@ class Router:
return model_group_info
async def get_model_group_usage(self, model_group: str) -> Optional[int]:
async def get_model_group_usage(
self, model_group: str
) -> Tuple[Optional[int], Optional[int]]:
"""
Returns remaining tpm quota for model group
Returns remaining tpm/rpm quota for model group
Returns:
- usage: Tuple[tpm, rpm]
"""
dt = get_utc_datetime()
current_minute = dt.strftime(
"%H-%M"
) # use the same timezone regardless of system clock
tpm_keys: List[str] = []
rpm_keys: List[str] = []
for model in self.model_list:
if "model_name" in model and model["model_name"] == model_group:
tpm_keys.append(
f"global_router:{model['model_info']['id']}:tpm:{current_minute}"
)
rpm_keys.append(
f"global_router:{model['model_info']['id']}:rpm:{current_minute}"
)
combined_tpm_rpm_keys = tpm_keys + rpm_keys
combined_tpm_rpm_values = await self.cache.async_batch_get_cache(
keys=combined_tpm_rpm_keys
)
if combined_tpm_rpm_values is None:
return None, None
tpm_usage_list: Optional[List] = combined_tpm_rpm_values[: len(tpm_keys)]
rpm_usage_list: Optional[List] = combined_tpm_rpm_values[len(tpm_keys) :]
## TPM
tpm_usage_list: Optional[List] = await self.cache.async_batch_get_cache(
keys=tpm_keys
)
tpm_usage: Optional[int] = None
if tpm_usage_list is not None:
for t in tpm_usage_list:
@ -4217,8 +4234,15 @@ class Router:
if tpm_usage is None:
tpm_usage = 0
tpm_usage += t
return tpm_usage
## RPM
rpm_usage: Optional[int] = None
if rpm_usage_list is not None:
for t in rpm_usage_list:
if isinstance(t, int):
if rpm_usage is None:
rpm_usage = 0
rpm_usage += t
return tpm_usage, rpm_usage
def get_model_ids(self) -> List[str]:
"""