fix(router.py): handle setting response headers during retries

This commit is contained in:
Krrish Dholakia 2024-09-28 18:10:54 -07:00
parent d64e971d8c
commit b0eff0b84f
3 changed files with 130 additions and 4 deletions

View file

@ -110,6 +110,7 @@ from litellm.types.router import (
updateDeployment,
updateLiteLLMParams,
)
from litellm.types.utils import OPENAI_RESPONSE_HEADERS
from litellm.types.utils import ModelInfo as ModelMapInfo
from litellm.utils import (
CustomStreamWrapper,
@ -3083,7 +3084,8 @@ class Router:
message=f"This is a mock exception for model={model_group}, to trigger a rate limit error.",
)
# if the function call is successful, no exception will be raised and we'll break out of the loop
response = await original_function(*args, **kwargs)
response = await self.make_call(original_function, *args, **kwargs)
return response
except Exception as e:
current_attempt = None
@ -3136,7 +3138,7 @@ class Router:
for current_attempt in range(num_retries):
try:
# if the function call is successful, no exception will be raised and we'll break out of the loop
response = await original_function(*args, **kwargs)
response = await self.make_call(original_function, *args, **kwargs)
if inspect.iscoroutinefunction(
response
): # async errors are often returned as coroutines
@ -3170,6 +3172,17 @@ class Router:
raise original_exception
async def make_call(self, original_function: Any, *args, **kwargs):
"""
Handler for making a call to the .completion()/.embeddings() functions.
"""
model_group = kwargs.get("model")
response = await original_function(*args, **kwargs)
## PROCESS RESPONSE HEADERS
await self.set_response_headers(response=response, model_group=model_group)
return response
def should_retry_this_error(
self,
error: Exception,
@ -3828,7 +3841,15 @@ class Router:
return healthy_deployments, _all_deployments
async def _async_get_healthy_deployments(self, model: str):
async def _async_get_healthy_deployments(
self, model: str
) -> Tuple[List[Dict], List[Dict]]:
"""
Returns Tuple of:
- Tuple[List[Dict], List[Dict]]:
1. healthy_deployments: list of healthy deployments
2. all_deployments: list of all deployments
"""
_all_deployments: list = []
try:
_, _all_deployments = self._common_checks_available_deployment( # type: ignore
@ -3836,7 +3857,7 @@ class Router:
)
if type(_all_deployments) == dict:
return [], _all_deployments
except:
except Exception:
pass
unhealthy_deployments = await _async_get_cooldown_deployments(
@ -4637,6 +4658,66 @@ class Router:
rpm_usage += t
return tpm_usage, rpm_usage
async def set_response_headers(
self, response: Any, model_group: Optional[str] = None
) -> Any:
"""
Add the most accurate rate limit headers for a given model response.
- if healthy_deployments > 1, return model group rate limit headers
- else return the model's rate limit headers
"""
if model_group is None:
return response
healthy_deployments, all_deployments = (
await self._async_get_healthy_deployments(model=model_group)
)
hidden_params = getattr(response, "_hidden_params", {}) or {}
additional_headers = hidden_params.get("additional_headers", {}) or {}
if len(healthy_deployments) <= 1:
return (
response # setting response headers is handled in wrappers in utils.py
)
else:
# return model group rate limit headers
model_group_info = self.get_model_group_info(model_group=model_group)
tpm_usage, rpm_usage = await self.get_model_group_usage(
model_group=model_group
)
model_group_remaining_rpm_limit: Optional[int] = None
model_group_rpm_limit: Optional[int] = None
model_group_remaining_tpm_limit: Optional[int] = None
model_group_tpm_limit: Optional[int] = None
if model_group_info is not None and model_group_info.rpm is not None:
model_group_rpm_limit = model_group_info.rpm
if rpm_usage is not None:
model_group_remaining_rpm_limit = model_group_info.rpm - rpm_usage
if model_group_info is not None and model_group_info.tpm is not None:
model_group_tpm_limit = model_group_info.tpm
if tpm_usage is not None:
model_group_remaining_tpm_limit = model_group_info.tpm - tpm_usage
if model_group_remaining_rpm_limit is not None:
additional_headers["x-ratelimit-remaining-requests"] = (
model_group_remaining_rpm_limit
)
if model_group_rpm_limit is not None:
additional_headers["x-ratelimit-limit-requests"] = model_group_rpm_limit
if model_group_remaining_tpm_limit is not None:
additional_headers["x-ratelimit-remaining-tokens"] = (
model_group_remaining_tpm_limit
)
if model_group_tpm_limit is not None:
additional_headers["x-ratelimit-limit-tokens"] = model_group_tpm_limit
hidden_params["additional_headers"] = additional_headers
setattr(response, "_hidden_params", hidden_params)
return response
def get_model_ids(self, model_name: Optional[str] = None) -> List[str]:
"""
if 'model_name' is none, returns all.