Litellm dev 12 14 2024 p1 (#7231)

* fix(router.py): fix reading + using deployment-specific num retries on router

Fixes https://github.com/BerriAI/litellm/issues/7001

* fix(router.py): ensure 'timeout' in litellm_params overrides any value in router settings

Refactors all routes to use common '_update_kwargs_with_deployment' which has the timeout handling

* fix(router.py): fix timeout check
This commit is contained in:
Krish Dholakia 2024-12-14 22:22:29 -08:00 committed by GitHub
parent 2459f9735d
commit 194acfa95c
5 changed files with 117 additions and 165 deletions

View file

@ -5,11 +5,12 @@ model_list:
api_key: os.environ/AZURE_API_KEY
api_base: os.environ/AZURE_API_BASE
temperature: 0.2
- model_name: "*"
- model_name: gpt-4o
litellm_params:
model: "*"
model_info:
access_groups: ["default"]
model: openai/gpt-4o
api_key: os.environ/OPENAI_API_KEY
num_retries: 3
litellm_settings:
success_callback: ["langsmith"]
num_retries: 0

View file

@ -813,7 +813,6 @@ class Router:
kwargs["messages"] = messages
kwargs["stream"] = stream
kwargs["original_function"] = self._acompletion
kwargs["num_retries"] = kwargs.get("num_retries", self.num_retries)
self._update_kwargs_before_fallbacks(model=model, kwargs=kwargs)
request_priority = kwargs.get("priority") or self.default_priority
@ -899,17 +898,12 @@ class Router:
)
self.total_calls[model_name] += 1
timeout: Optional[Union[float, int]] = self._get_timeout(
kwargs=kwargs, data=data
)
_response = litellm.acompletion(
**{
**data,
"messages": messages,
"caching": self.cache_responses,
"client": model_client,
"timeout": timeout,
**kwargs,
}
)
@ -1015,6 +1009,10 @@ class Router:
}
)
kwargs["model_info"] = deployment.get("model_info", {})
kwargs["timeout"] = self._get_timeout(
kwargs=kwargs, data=deployment["litellm_params"]
)
self._update_kwargs_with_default_litellm_params(kwargs=kwargs)
def _get_async_openai_model_client(self, deployment: dict, kwargs: dict):
@ -1046,16 +1044,16 @@ class Router:
def _get_timeout(self, kwargs: dict, data: dict) -> Optional[Union[float, int]]:
"""Helper to get timeout from kwargs or deployment params"""
timeout = (
data.get(
kwargs.get("timeout", None) # the params dynamically set by user
or kwargs.get("request_timeout", None) # the params dynamically set by user
or data.get(
"timeout", None
) # timeout set on litellm_params for this deployment
or data.get(
"request_timeout", None
) # timeout set on litellm_params for this deployment
or self.timeout # timeout set on router
or kwargs.get(
"timeout", None
) # this uses default_litellm_params when nothing is set
or self.default_litellm_params.get("timeout", None)
)
return timeout
@ -1378,7 +1376,6 @@ class Router:
kwargs["prompt"] = prompt
kwargs["original_function"] = self._image_generation
kwargs["num_retries"] = kwargs.get("num_retries", self.num_retries)
kwargs.get("request_timeout", self.timeout)
kwargs.setdefault("metadata", {}).update({"model_group": model})
response = self.function_with_fallbacks(**kwargs)
@ -1438,7 +1435,6 @@ class Router:
kwargs["prompt"] = prompt
kwargs["original_function"] = self._aimage_generation
kwargs["num_retries"] = kwargs.get("num_retries", self.num_retries)
kwargs.get("request_timeout", self.timeout)
kwargs.setdefault("metadata", {}).update({"model_group": model})
response = await self.async_function_with_fallbacks(**kwargs)
@ -1768,17 +1764,11 @@ class Router:
)
self.total_calls[model_name] += 1
timeout: Optional[Union[float, int]] = self._get_timeout(
kwargs=kwargs,
data=data,
)
response = await litellm.arerank(
**{
**data,
"caching": self.cache_responses,
"client": model_client,
"timeout": timeout,
**kwargs,
}
)
@ -1800,7 +1790,6 @@ class Router:
messages = [{"role": "user", "content": "dummy-text"}]
try:
kwargs["num_retries"] = kwargs.get("num_retries", self.num_retries)
kwargs.get("request_timeout", self.timeout)
kwargs.setdefault("metadata", {}).update({"model_group": model})
# pick the one that is available (lowest TPM/RPM)
@ -1826,7 +1815,7 @@ class Router:
kwargs["model"] = model
kwargs["messages"] = messages
kwargs["original_function"] = self._arealtime
return self.function_with_retries(**kwargs)
return await self.async_function_with_retries(**kwargs)
else:
raise e
@ -1844,7 +1833,6 @@ class Router:
kwargs["model"] = model
kwargs["prompt"] = prompt
kwargs["num_retries"] = kwargs.get("num_retries", self.num_retries)
kwargs.get("request_timeout", self.timeout)
kwargs.setdefault("metadata", {}).update({"model_group": model})
# pick the one that is available (lowest TPM/RPM)
@ -1924,7 +1912,6 @@ class Router:
"prompt": prompt,
"caching": self.cache_responses,
"client": model_client,
"timeout": self.timeout,
**kwargs,
}
)
@ -1980,7 +1967,6 @@ class Router:
kwargs["adapter_id"] = adapter_id
kwargs["original_function"] = self._aadapter_completion
kwargs["num_retries"] = kwargs.get("num_retries", self.num_retries)
kwargs.get("request_timeout", self.timeout)
kwargs.setdefault("metadata", {}).update({"model_group": model})
response = await self.async_function_with_fallbacks(**kwargs)
@ -2024,7 +2010,6 @@ class Router:
"adapter_id": adapter_id,
"caching": self.cache_responses,
"client": model_client,
"timeout": self.timeout,
**kwargs,
}
)
@ -2078,7 +2063,6 @@ class Router:
kwargs["input"] = input
kwargs["original_function"] = self._embedding
kwargs["num_retries"] = kwargs.get("num_retries", self.num_retries)
kwargs.get("request_timeout", self.timeout)
kwargs.setdefault("metadata", {}).update({"model_group": model})
response = self.function_with_fallbacks(**kwargs)
return response
@ -2245,7 +2229,6 @@ class Router:
kwargs["model"] = model
kwargs["original_function"] = self._acreate_file
kwargs["num_retries"] = kwargs.get("num_retries", self.num_retries)
kwargs.get("request_timeout", self.timeout)
kwargs.setdefault("metadata", {}).update({"model_group": model})
response = await self.async_function_with_fallbacks(**kwargs)
@ -2301,7 +2284,6 @@ class Router:
"custom_llm_provider": custom_llm_provider,
"caching": self.cache_responses,
"client": model_client,
"timeout": self.timeout,
**kwargs,
}
)
@ -2352,7 +2334,6 @@ class Router:
kwargs["model"] = model
kwargs["original_function"] = self._acreate_batch
kwargs["num_retries"] = kwargs.get("num_retries", self.num_retries)
kwargs.get("request_timeout", self.timeout)
kwargs.setdefault("metadata", {}).update({"model_group": model})
response = await self.async_function_with_fallbacks(**kwargs)
@ -2397,13 +2378,7 @@ class Router:
kwargs["model_info"] = deployment.get("model_info", {})
data = deployment["litellm_params"].copy()
model_name = data["model"]
for k, v in self.default_litellm_params.items():
if (
k not in kwargs
): # prioritize model-specific params > default router params
kwargs[k] = v
elif k == metadata_variable_name:
kwargs[k].update(v)
self._update_kwargs_with_deployment(deployment=deployment, kwargs=kwargs)
model_client = self._get_async_openai_model_client(
deployment=deployment,
@ -2420,7 +2395,6 @@ class Router:
"custom_llm_provider": custom_llm_provider,
"caching": self.cache_responses,
"client": model_client,
"timeout": self.timeout,
**kwargs,
}
)
@ -2913,7 +2887,11 @@ class Router:
except Exception as e:
current_attempt = None
original_exception = e
deployment_num_retries = getattr(e, "num_retries", None)
if deployment_num_retries is not None and isinstance(
deployment_num_retries, int
):
num_retries = deployment_num_retries
"""
Retry Logic
"""
@ -3210,106 +3188,6 @@ class Router:
return timeout
def function_with_retries(self, *args, **kwargs):
"""
Try calling the model 3 times. Shuffle-between available deployments.
"""
verbose_router_logger.debug(
f"Inside function with retries: args - {args}; kwargs - {kwargs}"
)
original_function = kwargs.pop("original_function")
num_retries = kwargs.pop("num_retries")
fallbacks = kwargs.pop("fallbacks", self.fallbacks)
context_window_fallbacks = kwargs.pop(
"context_window_fallbacks", self.context_window_fallbacks
)
content_policy_fallbacks = kwargs.pop(
"content_policy_fallbacks", self.content_policy_fallbacks
)
model_group = kwargs.get("model")
try:
# if the function call is successful, no exception will be raised and we'll break out of the loop
self._handle_mock_testing_rate_limit_error(
kwargs=kwargs, model_group=model_group
)
response = original_function(*args, **kwargs)
return response
except Exception as e:
current_attempt = None
original_exception = e
_model: Optional[str] = kwargs.get("model") # type: ignore
if _model is None:
raise e # re-raise error, if model can't be determined for loadbalancing
### CHECK IF RATE LIMIT / CONTEXT WINDOW ERROR
parent_otel_span = _get_parent_otel_span_from_kwargs(kwargs)
_healthy_deployments, _all_deployments = self._get_healthy_deployments(
model=_model,
parent_otel_span=parent_otel_span,
)
# raises an exception if this error should not be retries
self.should_retry_this_error(
error=e,
healthy_deployments=_healthy_deployments,
all_deployments=_all_deployments,
context_window_fallbacks=context_window_fallbacks,
regular_fallbacks=fallbacks,
content_policy_fallbacks=content_policy_fallbacks,
)
# decides how long to sleep before retry
_timeout = self._time_to_sleep_before_retry(
e=original_exception,
remaining_retries=num_retries,
num_retries=num_retries,
healthy_deployments=_healthy_deployments,
all_deployments=_all_deployments,
)
## LOGGING
if num_retries > 0:
kwargs = self.log_retry(kwargs=kwargs, e=original_exception)
time.sleep(_timeout)
for current_attempt in range(num_retries):
verbose_router_logger.debug(
f"retrying request. Current attempt - {current_attempt}; retries left: {num_retries}"
)
try:
# if the function call is successful, no exception will be raised and we'll break out of the loop
response = original_function(*args, **kwargs)
return response
except Exception as e:
## LOGGING
kwargs = self.log_retry(kwargs=kwargs, e=e)
_model: Optional[str] = kwargs.get("model") # type: ignore
if _model is None:
raise e # re-raise error, if model can't be determined for loadbalancing
parent_otel_span = _get_parent_otel_span_from_kwargs(kwargs)
_healthy_deployments, _ = self._get_healthy_deployments(
model=_model,
parent_otel_span=parent_otel_span,
)
remaining_retries = num_retries - current_attempt
_timeout = self._time_to_sleep_before_retry(
e=e,
remaining_retries=remaining_retries,
num_retries=num_retries,
healthy_deployments=_healthy_deployments,
all_deployments=_all_deployments,
)
time.sleep(_timeout)
if type(original_exception) in litellm.LITELLM_EXCEPTION_TYPES:
setattr(original_exception, "max_retries", num_retries)
setattr(original_exception, "num_retries", current_attempt)
raise original_exception
### HELPER FUNCTIONS
async def deployment_callback_on_success(

View file

@ -689,6 +689,19 @@ def client(original_function): # noqa: PLR0915
except Exception as e:
raise e
def _get_num_retries(
kwargs: Dict[str, Any], exception: Exception
) -> Tuple[Optional[int], Dict[str, Any]]:
num_retries = kwargs.get("num_retries", None) or litellm.num_retries or None
if kwargs.get("retry_policy", None):
num_retries = get_num_retries_from_retry_policy(
exception=exception,
retry_policy=kwargs.get("retry_policy"),
)
kwargs["retry_policy"] = reset_retry_policy()
return num_retries, kwargs
@wraps(original_function)
def wrapper(*args, **kwargs): # noqa: PLR0915
# DO NOT MOVE THIS. It always needs to run first
@ -1159,20 +1172,8 @@ def client(original_function): # noqa: PLR0915
raise e
call_type = original_function.__name__
num_retries, kwargs = _get_num_retries(kwargs=kwargs, exception=e)
if call_type == CallTypes.acompletion.value:
num_retries = (
kwargs.get("num_retries", None) or litellm.num_retries or None
)
if kwargs.get("retry_policy", None):
num_retries = get_num_retries_from_retry_policy(
exception=e,
retry_policy=kwargs.get("retry_policy"),
)
kwargs["retry_policy"] = reset_retry_policy()
litellm.num_retries = (
None # set retries to None to prevent infinite loops
)
context_window_fallback_dict = kwargs.get(
"context_window_fallback_dict", {}
)
@ -1184,6 +1185,9 @@ def client(original_function): # noqa: PLR0915
num_retries and not _is_litellm_router_call
): # only enter this if call is not from litellm router/proxy. router has it's own logic for retrying
try:
litellm.num_retries = (
None # set retries to None to prevent infinite loops
)
kwargs["num_retries"] = num_retries
kwargs["original_function"] = original_function
if isinstance(
@ -1205,6 +1209,10 @@ def client(original_function): # noqa: PLR0915
else:
kwargs["model"] = context_window_fallback_dict[model]
return await original_function(*args, **kwargs)
setattr(
e, "num_retries", num_retries
) ## IMPORTANT: returns the deployment's num_retries to the router
raise e
is_coroutine = inspect.iscoroutinefunction(original_function)

View file

@ -733,3 +733,73 @@ def test_no_retry_when_no_healthy_deployments():
)
except Exception as e:
print("got exception", e)
@pytest.mark.asyncio
async def test_router_retries_model_specific_and_global():
from unittest.mock import patch, MagicMock
litellm.num_retries = 0
router = Router(
model_list=[
{
"model_name": "gpt-3.5-turbo",
"litellm_params": {
"model": "gpt-3.5-turbo",
"api_key": os.getenv("OPENAI_API_KEY"),
"num_retries": 1,
},
}
]
)
with patch.object(
router, "_time_to_sleep_before_retry"
) as mock_async_function_with_retries:
try:
await router.acompletion(
model="gpt-3.5-turbo",
messages=[{"role": "user", "content": "Hello, how are you?"}],
mock_response="litellm.RateLimitError",
)
except Exception as e:
print("got exception", e)
mock_async_function_with_retries.assert_called_once()
assert mock_async_function_with_retries.call_args.kwargs["num_retries"] == 1
@pytest.mark.asyncio
async def test_router_timeout_model_specific_and_global():
from unittest.mock import patch, MagicMock
from litellm.llms.custom_httpx.http_handler import HTTPHandler
router = Router(
model_list=[
{
"model_name": "anthropic-claude",
"litellm_params": {
"model": "anthropic/claude-3-5-sonnet-20240620",
"timeout": 1,
},
}
],
timeout=10,
)
client = HTTPHandler()
with patch.object(client, "post") as mock_client:
try:
await router.acompletion(
model="anthropic-claude",
messages=[{"role": "user", "content": "Hello, how are you?"}],
client=client,
)
except Exception as e:
print("got exception", e)
mock_client.assert_called()
assert mock_client.call_args.kwargs["timeout"] == 1

View file

@ -217,16 +217,11 @@ async def test_router_function_with_retries(model_list, sync_mode):
"mock_response": "I'm fine, thank you!",
"num_retries": 0,
}
if sync_mode:
response = router.function_with_retries(
original_function=router._completion,
**data,
)
else:
response = await router.async_function_with_retries(
original_function=router._acompletion,
**data,
)
assert response.choices[0].message.content == "I'm fine, thank you!"