forked from phoenix/litellm-mirror
fix(router.py): handle setting response headers during retries
This commit is contained in:
parent
d64e971d8c
commit
b0eff0b84f
3 changed files with 130 additions and 4 deletions
|
@ -110,6 +110,7 @@ from litellm.types.router import (
|
||||||
updateDeployment,
|
updateDeployment,
|
||||||
updateLiteLLMParams,
|
updateLiteLLMParams,
|
||||||
)
|
)
|
||||||
|
from litellm.types.utils import OPENAI_RESPONSE_HEADERS
|
||||||
from litellm.types.utils import ModelInfo as ModelMapInfo
|
from litellm.types.utils import ModelInfo as ModelMapInfo
|
||||||
from litellm.utils import (
|
from litellm.utils import (
|
||||||
CustomStreamWrapper,
|
CustomStreamWrapper,
|
||||||
|
@ -3083,7 +3084,8 @@ class Router:
|
||||||
message=f"This is a mock exception for model={model_group}, to trigger a rate limit error.",
|
message=f"This is a mock exception for model={model_group}, to trigger a rate limit error.",
|
||||||
)
|
)
|
||||||
# if the function call is successful, no exception will be raised and we'll break out of the loop
|
# if the function call is successful, no exception will be raised and we'll break out of the loop
|
||||||
response = await original_function(*args, **kwargs)
|
response = await self.make_call(original_function, *args, **kwargs)
|
||||||
|
|
||||||
return response
|
return response
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
current_attempt = None
|
current_attempt = None
|
||||||
|
@ -3136,7 +3138,7 @@ class Router:
|
||||||
for current_attempt in range(num_retries):
|
for current_attempt in range(num_retries):
|
||||||
try:
|
try:
|
||||||
# if the function call is successful, no exception will be raised and we'll break out of the loop
|
# if the function call is successful, no exception will be raised and we'll break out of the loop
|
||||||
response = await original_function(*args, **kwargs)
|
response = await self.make_call(original_function, *args, **kwargs)
|
||||||
if inspect.iscoroutinefunction(
|
if inspect.iscoroutinefunction(
|
||||||
response
|
response
|
||||||
): # async errors are often returned as coroutines
|
): # async errors are often returned as coroutines
|
||||||
|
@ -3170,6 +3172,17 @@ class Router:
|
||||||
|
|
||||||
raise original_exception
|
raise original_exception
|
||||||
|
|
||||||
|
async def make_call(self, original_function: Any, *args, **kwargs):
|
||||||
|
"""
|
||||||
|
Handler for making a call to the .completion()/.embeddings() functions.
|
||||||
|
"""
|
||||||
|
model_group = kwargs.get("model")
|
||||||
|
response = await original_function(*args, **kwargs)
|
||||||
|
## PROCESS RESPONSE HEADERS
|
||||||
|
await self.set_response_headers(response=response, model_group=model_group)
|
||||||
|
|
||||||
|
return response
|
||||||
|
|
||||||
def should_retry_this_error(
|
def should_retry_this_error(
|
||||||
self,
|
self,
|
||||||
error: Exception,
|
error: Exception,
|
||||||
|
@ -3828,7 +3841,15 @@ class Router:
|
||||||
|
|
||||||
return healthy_deployments, _all_deployments
|
return healthy_deployments, _all_deployments
|
||||||
|
|
||||||
async def _async_get_healthy_deployments(self, model: str):
|
async def _async_get_healthy_deployments(
|
||||||
|
self, model: str
|
||||||
|
) -> Tuple[List[Dict], List[Dict]]:
|
||||||
|
"""
|
||||||
|
Returns Tuple of:
|
||||||
|
- Tuple[List[Dict], List[Dict]]:
|
||||||
|
1. healthy_deployments: list of healthy deployments
|
||||||
|
2. all_deployments: list of all deployments
|
||||||
|
"""
|
||||||
_all_deployments: list = []
|
_all_deployments: list = []
|
||||||
try:
|
try:
|
||||||
_, _all_deployments = self._common_checks_available_deployment( # type: ignore
|
_, _all_deployments = self._common_checks_available_deployment( # type: ignore
|
||||||
|
@ -3836,7 +3857,7 @@ class Router:
|
||||||
)
|
)
|
||||||
if type(_all_deployments) == dict:
|
if type(_all_deployments) == dict:
|
||||||
return [], _all_deployments
|
return [], _all_deployments
|
||||||
except:
|
except Exception:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
unhealthy_deployments = await _async_get_cooldown_deployments(
|
unhealthy_deployments = await _async_get_cooldown_deployments(
|
||||||
|
@ -4637,6 +4658,66 @@ class Router:
|
||||||
rpm_usage += t
|
rpm_usage += t
|
||||||
return tpm_usage, rpm_usage
|
return tpm_usage, rpm_usage
|
||||||
|
|
||||||
|
async def set_response_headers(
|
||||||
|
self, response: Any, model_group: Optional[str] = None
|
||||||
|
) -> Any:
|
||||||
|
"""
|
||||||
|
Add the most accurate rate limit headers for a given model response.
|
||||||
|
|
||||||
|
- if healthy_deployments > 1, return model group rate limit headers
|
||||||
|
- else return the model's rate limit headers
|
||||||
|
"""
|
||||||
|
if model_group is None:
|
||||||
|
return response
|
||||||
|
|
||||||
|
healthy_deployments, all_deployments = (
|
||||||
|
await self._async_get_healthy_deployments(model=model_group)
|
||||||
|
)
|
||||||
|
|
||||||
|
hidden_params = getattr(response, "_hidden_params", {}) or {}
|
||||||
|
additional_headers = hidden_params.get("additional_headers", {}) or {}
|
||||||
|
|
||||||
|
if len(healthy_deployments) <= 1:
|
||||||
|
return (
|
||||||
|
response # setting response headers is handled in wrappers in utils.py
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# return model group rate limit headers
|
||||||
|
model_group_info = self.get_model_group_info(model_group=model_group)
|
||||||
|
tpm_usage, rpm_usage = await self.get_model_group_usage(
|
||||||
|
model_group=model_group
|
||||||
|
)
|
||||||
|
model_group_remaining_rpm_limit: Optional[int] = None
|
||||||
|
model_group_rpm_limit: Optional[int] = None
|
||||||
|
model_group_remaining_tpm_limit: Optional[int] = None
|
||||||
|
model_group_tpm_limit: Optional[int] = None
|
||||||
|
|
||||||
|
if model_group_info is not None and model_group_info.rpm is not None:
|
||||||
|
model_group_rpm_limit = model_group_info.rpm
|
||||||
|
if rpm_usage is not None:
|
||||||
|
model_group_remaining_rpm_limit = model_group_info.rpm - rpm_usage
|
||||||
|
if model_group_info is not None and model_group_info.tpm is not None:
|
||||||
|
model_group_tpm_limit = model_group_info.tpm
|
||||||
|
if tpm_usage is not None:
|
||||||
|
model_group_remaining_tpm_limit = model_group_info.tpm - tpm_usage
|
||||||
|
|
||||||
|
if model_group_remaining_rpm_limit is not None:
|
||||||
|
additional_headers["x-ratelimit-remaining-requests"] = (
|
||||||
|
model_group_remaining_rpm_limit
|
||||||
|
)
|
||||||
|
if model_group_rpm_limit is not None:
|
||||||
|
additional_headers["x-ratelimit-limit-requests"] = model_group_rpm_limit
|
||||||
|
if model_group_remaining_tpm_limit is not None:
|
||||||
|
additional_headers["x-ratelimit-remaining-tokens"] = (
|
||||||
|
model_group_remaining_tpm_limit
|
||||||
|
)
|
||||||
|
if model_group_tpm_limit is not None:
|
||||||
|
additional_headers["x-ratelimit-limit-tokens"] = model_group_tpm_limit
|
||||||
|
|
||||||
|
hidden_params["additional_headers"] = additional_headers
|
||||||
|
setattr(response, "_hidden_params", hidden_params)
|
||||||
|
return response
|
||||||
|
|
||||||
def get_model_ids(self, model_name: Optional[str] = None) -> List[str]:
|
def get_model_ids(self, model_name: Optional[str] = None) -> List[str]:
|
||||||
"""
|
"""
|
||||||
if 'model_name' is none, returns all.
|
if 'model_name' is none, returns all.
|
||||||
|
|
|
@ -9263,6 +9263,7 @@ def process_response_headers(response_headers: Union[httpx.Headers, dict]) -> di
|
||||||
openai_headers = {}
|
openai_headers = {}
|
||||||
processed_headers = {}
|
processed_headers = {}
|
||||||
additional_headers = {}
|
additional_headers = {}
|
||||||
|
|
||||||
for k, v in response_headers.items():
|
for k, v in response_headers.items():
|
||||||
if k in OPENAI_RESPONSE_HEADERS: # return openai-compatible headers
|
if k in OPENAI_RESPONSE_HEADERS: # return openai-compatible headers
|
||||||
openai_headers[k] = v
|
openai_headers[k] = v
|
||||||
|
|
|
@ -2566,3 +2566,47 @@ def test_model_group_alias(hidden):
|
||||||
else:
|
else:
|
||||||
assert len(models) == len(_model_list) + 1
|
assert len(models) == len(_model_list) + 1
|
||||||
assert len(model_names) == len(_model_list) + 1
|
assert len(model_names) == len(_model_list) + 1
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("on_error", [True, False])
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_router_response_headers(on_error):
|
||||||
|
router = Router(
|
||||||
|
model_list=[
|
||||||
|
{
|
||||||
|
"model_name": "gpt-3.5-turbo",
|
||||||
|
"litellm_params": {
|
||||||
|
"model": "azure/chatgpt-v-2",
|
||||||
|
"api_key": os.getenv("AZURE_API_KEY"),
|
||||||
|
"api_base": os.getenv("AZURE_API_BASE"),
|
||||||
|
"tpm": 100000,
|
||||||
|
"rpm": 100000,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"model_name": "gpt-3.5-turbo",
|
||||||
|
"litellm_params": {
|
||||||
|
"model": "azure/chatgpt-v-2",
|
||||||
|
"api_key": os.getenv("AZURE_API_KEY"),
|
||||||
|
"api_base": os.getenv("AZURE_API_BASE"),
|
||||||
|
"tpm": 500,
|
||||||
|
"rpm": 500,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
response = await router.acompletion(
|
||||||
|
model="gpt-3.5-turbo",
|
||||||
|
messages=[{"role": "user", "content": "Hello world!"}],
|
||||||
|
mock_testing_rate_limit_error=on_error,
|
||||||
|
)
|
||||||
|
|
||||||
|
response_headers = response._hidden_params["additional_headers"]
|
||||||
|
|
||||||
|
print(response_headers)
|
||||||
|
|
||||||
|
assert response_headers["x-ratelimit-limit-requests"] == 100500
|
||||||
|
assert int(response_headers["x-ratelimit-remaining-requests"]) > 0
|
||||||
|
assert response_headers["x-ratelimit-limit-tokens"] == 100500
|
||||||
|
assert int(response_headers["x-ratelimit-remaining-tokens"]) > 0
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue