mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-26 11:14:04 +00:00
Litellm dev 12 14 2024 p1 (#7231)
* fix(router.py): fix reading + using deployment-specific num retries on router Fixes https://github.com/BerriAI/litellm/issues/7001 * fix(router.py): ensure 'timeout' in litellm_params overrides any value in router settings Refactors all routes to use common '_update_kwargs_with_deployment' which has the timeout handling * fix(router.py): fix timeout check
This commit is contained in:
parent
2459f9735d
commit
194acfa95c
5 changed files with 117 additions and 165 deletions
|
@ -5,11 +5,12 @@ model_list:
|
|||
api_key: os.environ/AZURE_API_KEY
|
||||
api_base: os.environ/AZURE_API_BASE
|
||||
temperature: 0.2
|
||||
- model_name: "*"
|
||||
- model_name: gpt-4o
|
||||
litellm_params:
|
||||
model: "*"
|
||||
model_info:
|
||||
access_groups: ["default"]
|
||||
model: openai/gpt-4o
|
||||
api_key: os.environ/OPENAI_API_KEY
|
||||
num_retries: 3
|
||||
|
||||
litellm_settings:
|
||||
success_callback: ["langsmith"]
|
||||
num_retries: 0
|
|
@ -813,7 +813,6 @@ class Router:
|
|||
kwargs["messages"] = messages
|
||||
kwargs["stream"] = stream
|
||||
kwargs["original_function"] = self._acompletion
|
||||
kwargs["num_retries"] = kwargs.get("num_retries", self.num_retries)
|
||||
self._update_kwargs_before_fallbacks(model=model, kwargs=kwargs)
|
||||
|
||||
request_priority = kwargs.get("priority") or self.default_priority
|
||||
|
@ -899,17 +898,12 @@ class Router:
|
|||
)
|
||||
self.total_calls[model_name] += 1
|
||||
|
||||
timeout: Optional[Union[float, int]] = self._get_timeout(
|
||||
kwargs=kwargs, data=data
|
||||
)
|
||||
|
||||
_response = litellm.acompletion(
|
||||
**{
|
||||
**data,
|
||||
"messages": messages,
|
||||
"caching": self.cache_responses,
|
||||
"client": model_client,
|
||||
"timeout": timeout,
|
||||
**kwargs,
|
||||
}
|
||||
)
|
||||
|
@ -1015,6 +1009,10 @@ class Router:
|
|||
}
|
||||
)
|
||||
kwargs["model_info"] = deployment.get("model_info", {})
|
||||
kwargs["timeout"] = self._get_timeout(
|
||||
kwargs=kwargs, data=deployment["litellm_params"]
|
||||
)
|
||||
|
||||
self._update_kwargs_with_default_litellm_params(kwargs=kwargs)
|
||||
|
||||
def _get_async_openai_model_client(self, deployment: dict, kwargs: dict):
|
||||
|
@ -1046,16 +1044,16 @@ class Router:
|
|||
def _get_timeout(self, kwargs: dict, data: dict) -> Optional[Union[float, int]]:
|
||||
"""Helper to get timeout from kwargs or deployment params"""
|
||||
timeout = (
|
||||
data.get(
|
||||
kwargs.get("timeout", None) # the params dynamically set by user
|
||||
or kwargs.get("request_timeout", None) # the params dynamically set by user
|
||||
or data.get(
|
||||
"timeout", None
|
||||
) # timeout set on litellm_params for this deployment
|
||||
or data.get(
|
||||
"request_timeout", None
|
||||
) # timeout set on litellm_params for this deployment
|
||||
or self.timeout # timeout set on router
|
||||
or kwargs.get(
|
||||
"timeout", None
|
||||
) # this uses default_litellm_params when nothing is set
|
||||
or self.default_litellm_params.get("timeout", None)
|
||||
)
|
||||
|
||||
return timeout
|
||||
|
@ -1378,7 +1376,6 @@ class Router:
|
|||
kwargs["prompt"] = prompt
|
||||
kwargs["original_function"] = self._image_generation
|
||||
kwargs["num_retries"] = kwargs.get("num_retries", self.num_retries)
|
||||
kwargs.get("request_timeout", self.timeout)
|
||||
kwargs.setdefault("metadata", {}).update({"model_group": model})
|
||||
response = self.function_with_fallbacks(**kwargs)
|
||||
|
||||
|
@ -1438,7 +1435,6 @@ class Router:
|
|||
kwargs["prompt"] = prompt
|
||||
kwargs["original_function"] = self._aimage_generation
|
||||
kwargs["num_retries"] = kwargs.get("num_retries", self.num_retries)
|
||||
kwargs.get("request_timeout", self.timeout)
|
||||
kwargs.setdefault("metadata", {}).update({"model_group": model})
|
||||
response = await self.async_function_with_fallbacks(**kwargs)
|
||||
|
||||
|
@ -1768,17 +1764,11 @@ class Router:
|
|||
)
|
||||
self.total_calls[model_name] += 1
|
||||
|
||||
timeout: Optional[Union[float, int]] = self._get_timeout(
|
||||
kwargs=kwargs,
|
||||
data=data,
|
||||
)
|
||||
|
||||
response = await litellm.arerank(
|
||||
**{
|
||||
**data,
|
||||
"caching": self.cache_responses,
|
||||
"client": model_client,
|
||||
"timeout": timeout,
|
||||
**kwargs,
|
||||
}
|
||||
)
|
||||
|
@ -1800,7 +1790,6 @@ class Router:
|
|||
messages = [{"role": "user", "content": "dummy-text"}]
|
||||
try:
|
||||
kwargs["num_retries"] = kwargs.get("num_retries", self.num_retries)
|
||||
kwargs.get("request_timeout", self.timeout)
|
||||
kwargs.setdefault("metadata", {}).update({"model_group": model})
|
||||
|
||||
# pick the one that is available (lowest TPM/RPM)
|
||||
|
@ -1826,7 +1815,7 @@ class Router:
|
|||
kwargs["model"] = model
|
||||
kwargs["messages"] = messages
|
||||
kwargs["original_function"] = self._arealtime
|
||||
return self.function_with_retries(**kwargs)
|
||||
return await self.async_function_with_retries(**kwargs)
|
||||
else:
|
||||
raise e
|
||||
|
||||
|
@ -1844,7 +1833,6 @@ class Router:
|
|||
kwargs["model"] = model
|
||||
kwargs["prompt"] = prompt
|
||||
kwargs["num_retries"] = kwargs.get("num_retries", self.num_retries)
|
||||
kwargs.get("request_timeout", self.timeout)
|
||||
kwargs.setdefault("metadata", {}).update({"model_group": model})
|
||||
|
||||
# pick the one that is available (lowest TPM/RPM)
|
||||
|
@ -1924,7 +1912,6 @@ class Router:
|
|||
"prompt": prompt,
|
||||
"caching": self.cache_responses,
|
||||
"client": model_client,
|
||||
"timeout": self.timeout,
|
||||
**kwargs,
|
||||
}
|
||||
)
|
||||
|
@ -1980,7 +1967,6 @@ class Router:
|
|||
kwargs["adapter_id"] = adapter_id
|
||||
kwargs["original_function"] = self._aadapter_completion
|
||||
kwargs["num_retries"] = kwargs.get("num_retries", self.num_retries)
|
||||
kwargs.get("request_timeout", self.timeout)
|
||||
kwargs.setdefault("metadata", {}).update({"model_group": model})
|
||||
response = await self.async_function_with_fallbacks(**kwargs)
|
||||
|
||||
|
@ -2024,7 +2010,6 @@ class Router:
|
|||
"adapter_id": adapter_id,
|
||||
"caching": self.cache_responses,
|
||||
"client": model_client,
|
||||
"timeout": self.timeout,
|
||||
**kwargs,
|
||||
}
|
||||
)
|
||||
|
@ -2078,7 +2063,6 @@ class Router:
|
|||
kwargs["input"] = input
|
||||
kwargs["original_function"] = self._embedding
|
||||
kwargs["num_retries"] = kwargs.get("num_retries", self.num_retries)
|
||||
kwargs.get("request_timeout", self.timeout)
|
||||
kwargs.setdefault("metadata", {}).update({"model_group": model})
|
||||
response = self.function_with_fallbacks(**kwargs)
|
||||
return response
|
||||
|
@ -2245,7 +2229,6 @@ class Router:
|
|||
kwargs["model"] = model
|
||||
kwargs["original_function"] = self._acreate_file
|
||||
kwargs["num_retries"] = kwargs.get("num_retries", self.num_retries)
|
||||
kwargs.get("request_timeout", self.timeout)
|
||||
kwargs.setdefault("metadata", {}).update({"model_group": model})
|
||||
response = await self.async_function_with_fallbacks(**kwargs)
|
||||
|
||||
|
@ -2301,7 +2284,6 @@ class Router:
|
|||
"custom_llm_provider": custom_llm_provider,
|
||||
"caching": self.cache_responses,
|
||||
"client": model_client,
|
||||
"timeout": self.timeout,
|
||||
**kwargs,
|
||||
}
|
||||
)
|
||||
|
@ -2352,7 +2334,6 @@ class Router:
|
|||
kwargs["model"] = model
|
||||
kwargs["original_function"] = self._acreate_batch
|
||||
kwargs["num_retries"] = kwargs.get("num_retries", self.num_retries)
|
||||
kwargs.get("request_timeout", self.timeout)
|
||||
kwargs.setdefault("metadata", {}).update({"model_group": model})
|
||||
response = await self.async_function_with_fallbacks(**kwargs)
|
||||
|
||||
|
@ -2397,13 +2378,7 @@ class Router:
|
|||
kwargs["model_info"] = deployment.get("model_info", {})
|
||||
data = deployment["litellm_params"].copy()
|
||||
model_name = data["model"]
|
||||
for k, v in self.default_litellm_params.items():
|
||||
if (
|
||||
k not in kwargs
|
||||
): # prioritize model-specific params > default router params
|
||||
kwargs[k] = v
|
||||
elif k == metadata_variable_name:
|
||||
kwargs[k].update(v)
|
||||
self._update_kwargs_with_deployment(deployment=deployment, kwargs=kwargs)
|
||||
|
||||
model_client = self._get_async_openai_model_client(
|
||||
deployment=deployment,
|
||||
|
@ -2420,7 +2395,6 @@ class Router:
|
|||
"custom_llm_provider": custom_llm_provider,
|
||||
"caching": self.cache_responses,
|
||||
"client": model_client,
|
||||
"timeout": self.timeout,
|
||||
**kwargs,
|
||||
}
|
||||
)
|
||||
|
@ -2913,7 +2887,11 @@ class Router:
|
|||
except Exception as e:
|
||||
current_attempt = None
|
||||
original_exception = e
|
||||
|
||||
deployment_num_retries = getattr(e, "num_retries", None)
|
||||
if deployment_num_retries is not None and isinstance(
|
||||
deployment_num_retries, int
|
||||
):
|
||||
num_retries = deployment_num_retries
|
||||
"""
|
||||
Retry Logic
|
||||
"""
|
||||
|
@ -3210,106 +3188,6 @@ class Router:
|
|||
|
||||
return timeout
|
||||
|
||||
def function_with_retries(self, *args, **kwargs):
|
||||
"""
|
||||
Try calling the model 3 times. Shuffle-between available deployments.
|
||||
"""
|
||||
verbose_router_logger.debug(
|
||||
f"Inside function with retries: args - {args}; kwargs - {kwargs}"
|
||||
)
|
||||
original_function = kwargs.pop("original_function")
|
||||
num_retries = kwargs.pop("num_retries")
|
||||
fallbacks = kwargs.pop("fallbacks", self.fallbacks)
|
||||
context_window_fallbacks = kwargs.pop(
|
||||
"context_window_fallbacks", self.context_window_fallbacks
|
||||
)
|
||||
content_policy_fallbacks = kwargs.pop(
|
||||
"content_policy_fallbacks", self.content_policy_fallbacks
|
||||
)
|
||||
model_group = kwargs.get("model")
|
||||
|
||||
try:
|
||||
# if the function call is successful, no exception will be raised and we'll break out of the loop
|
||||
self._handle_mock_testing_rate_limit_error(
|
||||
kwargs=kwargs, model_group=model_group
|
||||
)
|
||||
response = original_function(*args, **kwargs)
|
||||
return response
|
||||
except Exception as e:
|
||||
current_attempt = None
|
||||
original_exception = e
|
||||
_model: Optional[str] = kwargs.get("model") # type: ignore
|
||||
|
||||
if _model is None:
|
||||
raise e # re-raise error, if model can't be determined for loadbalancing
|
||||
### CHECK IF RATE LIMIT / CONTEXT WINDOW ERROR
|
||||
parent_otel_span = _get_parent_otel_span_from_kwargs(kwargs)
|
||||
_healthy_deployments, _all_deployments = self._get_healthy_deployments(
|
||||
model=_model,
|
||||
parent_otel_span=parent_otel_span,
|
||||
)
|
||||
|
||||
# raises an exception if this error should not be retries
|
||||
self.should_retry_this_error(
|
||||
error=e,
|
||||
healthy_deployments=_healthy_deployments,
|
||||
all_deployments=_all_deployments,
|
||||
context_window_fallbacks=context_window_fallbacks,
|
||||
regular_fallbacks=fallbacks,
|
||||
content_policy_fallbacks=content_policy_fallbacks,
|
||||
)
|
||||
|
||||
# decides how long to sleep before retry
|
||||
_timeout = self._time_to_sleep_before_retry(
|
||||
e=original_exception,
|
||||
remaining_retries=num_retries,
|
||||
num_retries=num_retries,
|
||||
healthy_deployments=_healthy_deployments,
|
||||
all_deployments=_all_deployments,
|
||||
)
|
||||
|
||||
## LOGGING
|
||||
if num_retries > 0:
|
||||
kwargs = self.log_retry(kwargs=kwargs, e=original_exception)
|
||||
|
||||
time.sleep(_timeout)
|
||||
for current_attempt in range(num_retries):
|
||||
verbose_router_logger.debug(
|
||||
f"retrying request. Current attempt - {current_attempt}; retries left: {num_retries}"
|
||||
)
|
||||
try:
|
||||
# if the function call is successful, no exception will be raised and we'll break out of the loop
|
||||
response = original_function(*args, **kwargs)
|
||||
return response
|
||||
|
||||
except Exception as e:
|
||||
## LOGGING
|
||||
kwargs = self.log_retry(kwargs=kwargs, e=e)
|
||||
_model: Optional[str] = kwargs.get("model") # type: ignore
|
||||
|
||||
if _model is None:
|
||||
raise e # re-raise error, if model can't be determined for loadbalancing
|
||||
parent_otel_span = _get_parent_otel_span_from_kwargs(kwargs)
|
||||
_healthy_deployments, _ = self._get_healthy_deployments(
|
||||
model=_model,
|
||||
parent_otel_span=parent_otel_span,
|
||||
)
|
||||
remaining_retries = num_retries - current_attempt
|
||||
_timeout = self._time_to_sleep_before_retry(
|
||||
e=e,
|
||||
remaining_retries=remaining_retries,
|
||||
num_retries=num_retries,
|
||||
healthy_deployments=_healthy_deployments,
|
||||
all_deployments=_all_deployments,
|
||||
)
|
||||
time.sleep(_timeout)
|
||||
|
||||
if type(original_exception) in litellm.LITELLM_EXCEPTION_TYPES:
|
||||
setattr(original_exception, "max_retries", num_retries)
|
||||
setattr(original_exception, "num_retries", current_attempt)
|
||||
|
||||
raise original_exception
|
||||
|
||||
### HELPER FUNCTIONS
|
||||
|
||||
async def deployment_callback_on_success(
|
||||
|
|
|
@ -689,6 +689,19 @@ def client(original_function): # noqa: PLR0915
|
|||
except Exception as e:
|
||||
raise e
|
||||
|
||||
def _get_num_retries(
|
||||
kwargs: Dict[str, Any], exception: Exception
|
||||
) -> Tuple[Optional[int], Dict[str, Any]]:
|
||||
num_retries = kwargs.get("num_retries", None) or litellm.num_retries or None
|
||||
if kwargs.get("retry_policy", None):
|
||||
num_retries = get_num_retries_from_retry_policy(
|
||||
exception=exception,
|
||||
retry_policy=kwargs.get("retry_policy"),
|
||||
)
|
||||
kwargs["retry_policy"] = reset_retry_policy()
|
||||
|
||||
return num_retries, kwargs
|
||||
|
||||
@wraps(original_function)
|
||||
def wrapper(*args, **kwargs): # noqa: PLR0915
|
||||
# DO NOT MOVE THIS. It always needs to run first
|
||||
|
@ -1159,20 +1172,8 @@ def client(original_function): # noqa: PLR0915
|
|||
raise e
|
||||
|
||||
call_type = original_function.__name__
|
||||
num_retries, kwargs = _get_num_retries(kwargs=kwargs, exception=e)
|
||||
if call_type == CallTypes.acompletion.value:
|
||||
num_retries = (
|
||||
kwargs.get("num_retries", None) or litellm.num_retries or None
|
||||
)
|
||||
if kwargs.get("retry_policy", None):
|
||||
num_retries = get_num_retries_from_retry_policy(
|
||||
exception=e,
|
||||
retry_policy=kwargs.get("retry_policy"),
|
||||
)
|
||||
kwargs["retry_policy"] = reset_retry_policy()
|
||||
|
||||
litellm.num_retries = (
|
||||
None # set retries to None to prevent infinite loops
|
||||
)
|
||||
context_window_fallback_dict = kwargs.get(
|
||||
"context_window_fallback_dict", {}
|
||||
)
|
||||
|
@ -1184,6 +1185,9 @@ def client(original_function): # noqa: PLR0915
|
|||
num_retries and not _is_litellm_router_call
|
||||
): # only enter this if call is not from litellm router/proxy. router has it's own logic for retrying
|
||||
try:
|
||||
litellm.num_retries = (
|
||||
None # set retries to None to prevent infinite loops
|
||||
)
|
||||
kwargs["num_retries"] = num_retries
|
||||
kwargs["original_function"] = original_function
|
||||
if isinstance(
|
||||
|
@ -1205,6 +1209,10 @@ def client(original_function): # noqa: PLR0915
|
|||
else:
|
||||
kwargs["model"] = context_window_fallback_dict[model]
|
||||
return await original_function(*args, **kwargs)
|
||||
|
||||
setattr(
|
||||
e, "num_retries", num_retries
|
||||
) ## IMPORTANT: returns the deployment's num_retries to the router
|
||||
raise e
|
||||
|
||||
is_coroutine = inspect.iscoroutinefunction(original_function)
|
||||
|
|
|
@ -733,3 +733,73 @@ def test_no_retry_when_no_healthy_deployments():
|
|||
)
|
||||
except Exception as e:
|
||||
print("got exception", e)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_router_retries_model_specific_and_global():
|
||||
from unittest.mock import patch, MagicMock
|
||||
|
||||
litellm.num_retries = 0
|
||||
router = Router(
|
||||
model_list=[
|
||||
{
|
||||
"model_name": "gpt-3.5-turbo",
|
||||
"litellm_params": {
|
||||
"model": "gpt-3.5-turbo",
|
||||
"api_key": os.getenv("OPENAI_API_KEY"),
|
||||
"num_retries": 1,
|
||||
},
|
||||
}
|
||||
]
|
||||
)
|
||||
|
||||
with patch.object(
|
||||
router, "_time_to_sleep_before_retry"
|
||||
) as mock_async_function_with_retries:
|
||||
try:
|
||||
await router.acompletion(
|
||||
model="gpt-3.5-turbo",
|
||||
messages=[{"role": "user", "content": "Hello, how are you?"}],
|
||||
mock_response="litellm.RateLimitError",
|
||||
)
|
||||
except Exception as e:
|
||||
print("got exception", e)
|
||||
|
||||
mock_async_function_with_retries.assert_called_once()
|
||||
|
||||
assert mock_async_function_with_retries.call_args.kwargs["num_retries"] == 1
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_router_timeout_model_specific_and_global():
|
||||
from unittest.mock import patch, MagicMock
|
||||
from litellm.llms.custom_httpx.http_handler import HTTPHandler
|
||||
|
||||
router = Router(
|
||||
model_list=[
|
||||
{
|
||||
"model_name": "anthropic-claude",
|
||||
"litellm_params": {
|
||||
"model": "anthropic/claude-3-5-sonnet-20240620",
|
||||
"timeout": 1,
|
||||
},
|
||||
}
|
||||
],
|
||||
timeout=10,
|
||||
)
|
||||
|
||||
client = HTTPHandler()
|
||||
|
||||
with patch.object(client, "post") as mock_client:
|
||||
try:
|
||||
await router.acompletion(
|
||||
model="anthropic-claude",
|
||||
messages=[{"role": "user", "content": "Hello, how are you?"}],
|
||||
client=client,
|
||||
)
|
||||
except Exception as e:
|
||||
print("got exception", e)
|
||||
|
||||
mock_client.assert_called()
|
||||
|
||||
assert mock_client.call_args.kwargs["timeout"] == 1
|
||||
|
|
|
@ -217,16 +217,11 @@ async def test_router_function_with_retries(model_list, sync_mode):
|
|||
"mock_response": "I'm fine, thank you!",
|
||||
"num_retries": 0,
|
||||
}
|
||||
if sync_mode:
|
||||
response = router.function_with_retries(
|
||||
original_function=router._completion,
|
||||
**data,
|
||||
)
|
||||
else:
|
||||
response = await router.async_function_with_retries(
|
||||
original_function=router._acompletion,
|
||||
**data,
|
||||
)
|
||||
|
||||
assert response.choices[0].message.content == "I'm fine, thank you!"
|
||||
|
||||
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue