diff --git a/litellm/router.py b/litellm/router.py index 2a177a2a20..d00b60deeb 100644 --- a/litellm/router.py +++ b/litellm/router.py @@ -110,6 +110,7 @@ from litellm.types.router import ( updateDeployment, updateLiteLLMParams, ) +from litellm.types.utils import OPENAI_RESPONSE_HEADERS from litellm.types.utils import ModelInfo as ModelMapInfo from litellm.utils import ( CustomStreamWrapper, @@ -3083,7 +3084,8 @@ class Router: message=f"This is a mock exception for model={model_group}, to trigger a rate limit error.", ) # if the function call is successful, no exception will be raised and we'll break out of the loop - response = await original_function(*args, **kwargs) + response = await self.make_call(original_function, *args, **kwargs) + return response except Exception as e: current_attempt = None @@ -3136,7 +3138,7 @@ class Router: for current_attempt in range(num_retries): try: # if the function call is successful, no exception will be raised and we'll break out of the loop - response = await original_function(*args, **kwargs) + response = await self.make_call(original_function, *args, **kwargs) if inspect.iscoroutinefunction( response ): # async errors are often returned as coroutines @@ -3170,6 +3172,17 @@ class Router: raise original_exception + async def make_call(self, original_function: Any, *args, **kwargs): + """ + Handler for making a call to the .completion()/.embeddings() functions. + """ + model_group = kwargs.get("model") + response = await original_function(*args, **kwargs) + ## PROCESS RESPONSE HEADERS + await self.set_response_headers(response=response, model_group=model_group) + + return response + def should_retry_this_error( self, error: Exception, @@ -3828,7 +3841,15 @@ class Router: return healthy_deployments, _all_deployments - async def _async_get_healthy_deployments(self, model: str): + async def _async_get_healthy_deployments( + self, model: str + ) -> Tuple[List[Dict], List[Dict]]: + """ + Returns Tuple of: + - Tuple[List[Dict], List[Dict]]: + 1. healthy_deployments: list of healthy deployments + 2. all_deployments: list of all deployments + """ _all_deployments: list = [] try: _, _all_deployments = self._common_checks_available_deployment( # type: ignore @@ -3836,7 +3857,7 @@ class Router: ) if type(_all_deployments) == dict: return [], _all_deployments - except: + except Exception: pass unhealthy_deployments = await _async_get_cooldown_deployments( @@ -4637,6 +4658,66 @@ class Router: rpm_usage += t return tpm_usage, rpm_usage + async def set_response_headers( + self, response: Any, model_group: Optional[str] = None + ) -> Any: + """ + Add the most accurate rate limit headers for a given model response. + + - if healthy_deployments > 1, return model group rate limit headers + - else return the model's rate limit headers + """ + if model_group is None: + return response + + healthy_deployments, all_deployments = ( + await self._async_get_healthy_deployments(model=model_group) + ) + + hidden_params = getattr(response, "_hidden_params", {}) or {} + additional_headers = hidden_params.get("additional_headers", {}) or {} + + if len(healthy_deployments) <= 1: + return ( + response # setting response headers is handled in wrappers in utils.py + ) + else: + # return model group rate limit headers + model_group_info = self.get_model_group_info(model_group=model_group) + tpm_usage, rpm_usage = await self.get_model_group_usage( + model_group=model_group + ) + model_group_remaining_rpm_limit: Optional[int] = None + model_group_rpm_limit: Optional[int] = None + model_group_remaining_tpm_limit: Optional[int] = None + model_group_tpm_limit: Optional[int] = None + + if model_group_info is not None and model_group_info.rpm is not None: + model_group_rpm_limit = model_group_info.rpm + if rpm_usage is not None: + model_group_remaining_rpm_limit = model_group_info.rpm - rpm_usage + if model_group_info is not None and model_group_info.tpm is not None: + model_group_tpm_limit = model_group_info.tpm + if tpm_usage is not None: + model_group_remaining_tpm_limit = model_group_info.tpm - tpm_usage + + if model_group_remaining_rpm_limit is not None: + additional_headers["x-ratelimit-remaining-requests"] = ( + model_group_remaining_rpm_limit + ) + if model_group_rpm_limit is not None: + additional_headers["x-ratelimit-limit-requests"] = model_group_rpm_limit + if model_group_remaining_tpm_limit is not None: + additional_headers["x-ratelimit-remaining-tokens"] = ( + model_group_remaining_tpm_limit + ) + if model_group_tpm_limit is not None: + additional_headers["x-ratelimit-limit-tokens"] = model_group_tpm_limit + + hidden_params["additional_headers"] = additional_headers + setattr(response, "_hidden_params", hidden_params) + return response + def get_model_ids(self, model_name: Optional[str] = None) -> List[str]: """ if 'model_name' is none, returns all. diff --git a/litellm/utils.py b/litellm/utils.py index 524abe3635..b1f5a25ba2 100644 --- a/litellm/utils.py +++ b/litellm/utils.py @@ -9263,6 +9263,7 @@ def process_response_headers(response_headers: Union[httpx.Headers, dict]) -> di openai_headers = {} processed_headers = {} additional_headers = {} + for k, v in response_headers.items(): if k in OPENAI_RESPONSE_HEADERS: # return openai-compatible headers openai_headers[k] = v diff --git a/tests/local_testing/test_router.py b/tests/local_testing/test_router.py index 76277874fc..7649e88858 100644 --- a/tests/local_testing/test_router.py +++ b/tests/local_testing/test_router.py @@ -2566,3 +2566,47 @@ def test_model_group_alias(hidden): else: assert len(models) == len(_model_list) + 1 assert len(model_names) == len(_model_list) + 1 + + +@pytest.mark.parametrize("on_error", [True, False]) +@pytest.mark.asyncio +async def test_router_response_headers(on_error): + router = Router( + model_list=[ + { + "model_name": "gpt-3.5-turbo", + "litellm_params": { + "model": "azure/chatgpt-v-2", + "api_key": os.getenv("AZURE_API_KEY"), + "api_base": os.getenv("AZURE_API_BASE"), + "tpm": 100000, + "rpm": 100000, + }, + }, + { + "model_name": "gpt-3.5-turbo", + "litellm_params": { + "model": "azure/chatgpt-v-2", + "api_key": os.getenv("AZURE_API_KEY"), + "api_base": os.getenv("AZURE_API_BASE"), + "tpm": 500, + "rpm": 500, + }, + }, + ] + ) + + response = await router.acompletion( + model="gpt-3.5-turbo", + messages=[{"role": "user", "content": "Hello world!"}], + mock_testing_rate_limit_error=on_error, + ) + + response_headers = response._hidden_params["additional_headers"] + + print(response_headers) + + assert response_headers["x-ratelimit-limit-requests"] == 100500 + assert int(response_headers["x-ratelimit-remaining-requests"]) > 0 + assert response_headers["x-ratelimit-limit-tokens"] == 100500 + assert int(response_headers["x-ratelimit-remaining-tokens"]) > 0