fix(dynamic_rate_limiter.py): support dynamic rate limiting on rpm

This commit is contained in:
Krrish Dholakia 2024-07-01 17:45:10 -07:00
parent b6f509a745
commit d528e263c2
2 changed files with 101 additions and 28 deletions

View file

@ -81,28 +81,36 @@ class _PROXY_DynamicRateLimitHandler(CustomLogger):
def update_variables(self, llm_router: Router): def update_variables(self, llm_router: Router):
self.llm_router = llm_router self.llm_router = llm_router
async def check_available_tpm( async def check_available_usage(
self, model: str self, model: str
) -> Tuple[Optional[int], Optional[int], Optional[int]]: ) -> Tuple[
Optional[int], Optional[int], Optional[int], Optional[int], Optional[int]
]:
""" """
For a given model, get its available tpm For a given model, get its available tpm
Returns Returns
- Tuple[available_tpm, model_tpm, active_projects] - Tuple[available_tpm, available_tpm, model_tpm, model_rpm, active_projects]
- available_tpm: int or null - always 0 or positive.
- available_tpm: int or null - always 0 or positive. - available_tpm: int or null - always 0 or positive.
- remaining_model_tpm: int or null. If available tpm is int, then this will be too. - remaining_model_tpm: int or null. If available tpm is int, then this will be too.
- remaining_model_rpm: int or null. If available rpm is int, then this will be too.
- active_projects: int or null - active_projects: int or null
""" """
active_projects = await self.internal_usage_cache.async_get_cache(model=model) active_projects = await self.internal_usage_cache.async_get_cache(model=model)
current_model_tpm: Optional[int] = await self.llm_router.get_model_group_usage( current_model_tpm, current_model_rpm = (
model_group=model await self.llm_router.get_model_group_usage(model_group=model)
) )
model_group_info: Optional[ModelGroupInfo] = ( model_group_info: Optional[ModelGroupInfo] = (
self.llm_router.get_model_group_info(model_group=model) self.llm_router.get_model_group_info(model_group=model)
) )
total_model_tpm: Optional[int] = None total_model_tpm: Optional[int] = None
if model_group_info is not None and model_group_info.tpm is not None: total_model_rpm: Optional[int] = None
total_model_tpm = model_group_info.tpm if model_group_info is not None:
if model_group_info.tpm is not None:
total_model_tpm = model_group_info.tpm
if model_group_info.rpm is not None:
total_model_rpm = model_group_info.rpm
remaining_model_tpm: Optional[int] = None remaining_model_tpm: Optional[int] = None
if total_model_tpm is not None and current_model_tpm is not None: if total_model_tpm is not None and current_model_tpm is not None:
@ -110,6 +118,12 @@ class _PROXY_DynamicRateLimitHandler(CustomLogger):
elif total_model_tpm is not None: elif total_model_tpm is not None:
remaining_model_tpm = total_model_tpm remaining_model_tpm = total_model_tpm
remaining_model_rpm: Optional[int] = None
if total_model_rpm is not None and current_model_rpm is not None:
remaining_model_rpm = total_model_rpm - current_model_rpm
elif total_model_rpm is not None:
remaining_model_rpm = total_model_rpm
available_tpm: Optional[int] = None available_tpm: Optional[int] = None
if remaining_model_tpm is not None: if remaining_model_tpm is not None:
@ -120,7 +134,24 @@ class _PROXY_DynamicRateLimitHandler(CustomLogger):
if available_tpm is not None and available_tpm < 0: if available_tpm is not None and available_tpm < 0:
available_tpm = 0 available_tpm = 0
return available_tpm, remaining_model_tpm, active_projects
available_rpm: Optional[int] = None
if remaining_model_rpm is not None:
if active_projects is not None:
available_rpm = int(remaining_model_rpm / active_projects)
else:
available_rpm = remaining_model_rpm
if available_rpm is not None and available_rpm < 0:
available_rpm = 0
return (
available_tpm,
available_rpm,
remaining_model_tpm,
remaining_model_rpm,
active_projects,
)
async def async_pre_call_hook( async def async_pre_call_hook(
self, self,
@ -140,13 +171,14 @@ class _PROXY_DynamicRateLimitHandler(CustomLogger):
]: # raise exception if invalid, return a str for the user to receive - if rejected, or return a modified dictionary for passing into litellm ]: # raise exception if invalid, return a str for the user to receive - if rejected, or return a modified dictionary for passing into litellm
""" """
- For a model group - For a model group
- Check if tpm available - Check if tpm/rpm available
- Raise RateLimitError if no tpm available - Raise RateLimitError if no tpm/rpm available
""" """
if "model" in data: if "model" in data:
available_tpm, model_tpm, active_projects = await self.check_available_tpm( available_tpm, available_rpm, model_tpm, model_rpm, active_projects = (
model=data["model"] await self.check_available_usage(model=data["model"])
) )
### CHECK TPM ###
if available_tpm is not None and available_tpm == 0: if available_tpm is not None and available_tpm == 0:
raise HTTPException( raise HTTPException(
status_code=429, status_code=429,
@ -159,7 +191,20 @@ class _PROXY_DynamicRateLimitHandler(CustomLogger):
) )
}, },
) )
elif available_tpm is not None: ### CHECK RPM ###
elif available_rpm is not None and available_rpm == 0:
raise HTTPException(
status_code=429,
detail={
"error": "Key={} over available RPM={}. Model RPM={}, Active keys={}".format(
user_api_key_dict.api_key,
available_rpm,
model_rpm,
active_projects,
)
},
)
elif available_rpm is not None or available_tpm is not None:
## UPDATE CACHE WITH ACTIVE PROJECT ## UPDATE CACHE WITH ACTIVE PROJECT
asyncio.create_task( asyncio.create_task(
self.internal_usage_cache.async_set_cache_sadd( # this is a set self.internal_usage_cache.async_set_cache_sadd( # this is a set
@ -182,15 +227,19 @@ class _PROXY_DynamicRateLimitHandler(CustomLogger):
), "Model info for model with id={} is None".format( ), "Model info for model with id={} is None".format(
response._hidden_params["model_id"] response._hidden_params["model_id"]
) )
available_tpm, remaining_model_tpm, active_projects = ( available_tpm, available_rpm, model_tpm, model_rpm, active_projects = (
await self.check_available_tpm(model=model_info["model_name"]) await self.check_available_usage(model=model_info["model_name"])
)
response._hidden_params["additional_headers"] = (
{ # Add additional response headers - easier debugging
"x-litellm-model_group": model_info["model_name"],
"x-ratelimit-remaining-litellm-project-tokens": available_tpm,
"x-ratelimit-remaining-litellm-project-requests": available_rpm,
"x-ratelimit-remaining-model-tokens": model_tpm,
"x-ratelimit-remaining-model-requests": model_tpm,
"x-ratelimit-current-active-projects": active_projects,
}
) )
response._hidden_params["additional_headers"] = {
"x-litellm-model_group": model_info["model_name"],
"x-ratelimit-remaining-litellm-project-tokens": available_tpm,
"x-ratelimit-remaining-model-tokens": remaining_model_tpm,
"x-ratelimit-current-active-projects": active_projects,
}
return response return response
return await super().async_post_call_success_hook( return await super().async_post_call_success_hook(

View file

@ -4191,25 +4191,42 @@ class Router:
return model_group_info return model_group_info
async def get_model_group_usage(self, model_group: str) -> Optional[int]: async def get_model_group_usage(
self, model_group: str
) -> Tuple[Optional[int], Optional[int]]:
""" """
Returns remaining tpm quota for model group Returns remaining tpm/rpm quota for model group
Returns:
- usage: Tuple[tpm, rpm]
""" """
dt = get_utc_datetime() dt = get_utc_datetime()
current_minute = dt.strftime( current_minute = dt.strftime(
"%H-%M" "%H-%M"
) # use the same timezone regardless of system clock ) # use the same timezone regardless of system clock
tpm_keys: List[str] = [] tpm_keys: List[str] = []
rpm_keys: List[str] = []
for model in self.model_list: for model in self.model_list:
if "model_name" in model and model["model_name"] == model_group: if "model_name" in model and model["model_name"] == model_group:
tpm_keys.append( tpm_keys.append(
f"global_router:{model['model_info']['id']}:tpm:{current_minute}" f"global_router:{model['model_info']['id']}:tpm:{current_minute}"
) )
rpm_keys.append(
f"global_router:{model['model_info']['id']}:rpm:{current_minute}"
)
combined_tpm_rpm_keys = tpm_keys + rpm_keys
combined_tpm_rpm_values = await self.cache.async_batch_get_cache(
keys=combined_tpm_rpm_keys
)
if combined_tpm_rpm_values is None:
return None, None
tpm_usage_list: Optional[List] = combined_tpm_rpm_values[: len(tpm_keys)]
rpm_usage_list: Optional[List] = combined_tpm_rpm_values[len(tpm_keys) :]
## TPM ## TPM
tpm_usage_list: Optional[List] = await self.cache.async_batch_get_cache(
keys=tpm_keys
)
tpm_usage: Optional[int] = None tpm_usage: Optional[int] = None
if tpm_usage_list is not None: if tpm_usage_list is not None:
for t in tpm_usage_list: for t in tpm_usage_list:
@ -4217,8 +4234,15 @@ class Router:
if tpm_usage is None: if tpm_usage is None:
tpm_usage = 0 tpm_usage = 0
tpm_usage += t tpm_usage += t
## RPM
return tpm_usage rpm_usage: Optional[int] = None
if rpm_usage_list is not None:
for t in rpm_usage_list:
if isinstance(t, int):
if rpm_usage is None:
rpm_usage = 0
rpm_usage += t
return tpm_usage, rpm_usage
def get_model_ids(self) -> List[str]: def get_model_ids(self) -> List[str]:
""" """