mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-26 19:24:27 +00:00
Litellm dev 12 14 2024 p1 (#7231)
* fix(router.py): fix reading + using deployment-specific num retries on router Fixes https://github.com/BerriAI/litellm/issues/7001 * fix(router.py): ensure 'timeout' in litellm_params overrides any value in router settings Refactors all routes to use common '_update_kwargs_with_deployment' which has the timeout handling * fix(router.py): fix timeout check
This commit is contained in:
parent
2459f9735d
commit
194acfa95c
5 changed files with 117 additions and 165 deletions
|
@ -5,11 +5,12 @@ model_list:
|
||||||
api_key: os.environ/AZURE_API_KEY
|
api_key: os.environ/AZURE_API_KEY
|
||||||
api_base: os.environ/AZURE_API_BASE
|
api_base: os.environ/AZURE_API_BASE
|
||||||
temperature: 0.2
|
temperature: 0.2
|
||||||
- model_name: "*"
|
- model_name: gpt-4o
|
||||||
litellm_params:
|
litellm_params:
|
||||||
model: "*"
|
model: openai/gpt-4o
|
||||||
model_info:
|
api_key: os.environ/OPENAI_API_KEY
|
||||||
access_groups: ["default"]
|
num_retries: 3
|
||||||
|
|
||||||
litellm_settings:
|
litellm_settings:
|
||||||
success_callback: ["langsmith"]
|
success_callback: ["langsmith"]
|
||||||
|
num_retries: 0
|
|
@ -813,7 +813,6 @@ class Router:
|
||||||
kwargs["messages"] = messages
|
kwargs["messages"] = messages
|
||||||
kwargs["stream"] = stream
|
kwargs["stream"] = stream
|
||||||
kwargs["original_function"] = self._acompletion
|
kwargs["original_function"] = self._acompletion
|
||||||
kwargs["num_retries"] = kwargs.get("num_retries", self.num_retries)
|
|
||||||
self._update_kwargs_before_fallbacks(model=model, kwargs=kwargs)
|
self._update_kwargs_before_fallbacks(model=model, kwargs=kwargs)
|
||||||
|
|
||||||
request_priority = kwargs.get("priority") or self.default_priority
|
request_priority = kwargs.get("priority") or self.default_priority
|
||||||
|
@ -899,17 +898,12 @@ class Router:
|
||||||
)
|
)
|
||||||
self.total_calls[model_name] += 1
|
self.total_calls[model_name] += 1
|
||||||
|
|
||||||
timeout: Optional[Union[float, int]] = self._get_timeout(
|
|
||||||
kwargs=kwargs, data=data
|
|
||||||
)
|
|
||||||
|
|
||||||
_response = litellm.acompletion(
|
_response = litellm.acompletion(
|
||||||
**{
|
**{
|
||||||
**data,
|
**data,
|
||||||
"messages": messages,
|
"messages": messages,
|
||||||
"caching": self.cache_responses,
|
"caching": self.cache_responses,
|
||||||
"client": model_client,
|
"client": model_client,
|
||||||
"timeout": timeout,
|
|
||||||
**kwargs,
|
**kwargs,
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
@ -1015,6 +1009,10 @@ class Router:
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
kwargs["model_info"] = deployment.get("model_info", {})
|
kwargs["model_info"] = deployment.get("model_info", {})
|
||||||
|
kwargs["timeout"] = self._get_timeout(
|
||||||
|
kwargs=kwargs, data=deployment["litellm_params"]
|
||||||
|
)
|
||||||
|
|
||||||
self._update_kwargs_with_default_litellm_params(kwargs=kwargs)
|
self._update_kwargs_with_default_litellm_params(kwargs=kwargs)
|
||||||
|
|
||||||
def _get_async_openai_model_client(self, deployment: dict, kwargs: dict):
|
def _get_async_openai_model_client(self, deployment: dict, kwargs: dict):
|
||||||
|
@ -1046,16 +1044,16 @@ class Router:
|
||||||
def _get_timeout(self, kwargs: dict, data: dict) -> Optional[Union[float, int]]:
|
def _get_timeout(self, kwargs: dict, data: dict) -> Optional[Union[float, int]]:
|
||||||
"""Helper to get timeout from kwargs or deployment params"""
|
"""Helper to get timeout from kwargs or deployment params"""
|
||||||
timeout = (
|
timeout = (
|
||||||
data.get(
|
kwargs.get("timeout", None) # the params dynamically set by user
|
||||||
|
or kwargs.get("request_timeout", None) # the params dynamically set by user
|
||||||
|
or data.get(
|
||||||
"timeout", None
|
"timeout", None
|
||||||
) # timeout set on litellm_params for this deployment
|
) # timeout set on litellm_params for this deployment
|
||||||
or data.get(
|
or data.get(
|
||||||
"request_timeout", None
|
"request_timeout", None
|
||||||
) # timeout set on litellm_params for this deployment
|
) # timeout set on litellm_params for this deployment
|
||||||
or self.timeout # timeout set on router
|
or self.timeout # timeout set on router
|
||||||
or kwargs.get(
|
or self.default_litellm_params.get("timeout", None)
|
||||||
"timeout", None
|
|
||||||
) # this uses default_litellm_params when nothing is set
|
|
||||||
)
|
)
|
||||||
|
|
||||||
return timeout
|
return timeout
|
||||||
|
@ -1378,7 +1376,6 @@ class Router:
|
||||||
kwargs["prompt"] = prompt
|
kwargs["prompt"] = prompt
|
||||||
kwargs["original_function"] = self._image_generation
|
kwargs["original_function"] = self._image_generation
|
||||||
kwargs["num_retries"] = kwargs.get("num_retries", self.num_retries)
|
kwargs["num_retries"] = kwargs.get("num_retries", self.num_retries)
|
||||||
kwargs.get("request_timeout", self.timeout)
|
|
||||||
kwargs.setdefault("metadata", {}).update({"model_group": model})
|
kwargs.setdefault("metadata", {}).update({"model_group": model})
|
||||||
response = self.function_with_fallbacks(**kwargs)
|
response = self.function_with_fallbacks(**kwargs)
|
||||||
|
|
||||||
|
@ -1438,7 +1435,6 @@ class Router:
|
||||||
kwargs["prompt"] = prompt
|
kwargs["prompt"] = prompt
|
||||||
kwargs["original_function"] = self._aimage_generation
|
kwargs["original_function"] = self._aimage_generation
|
||||||
kwargs["num_retries"] = kwargs.get("num_retries", self.num_retries)
|
kwargs["num_retries"] = kwargs.get("num_retries", self.num_retries)
|
||||||
kwargs.get("request_timeout", self.timeout)
|
|
||||||
kwargs.setdefault("metadata", {}).update({"model_group": model})
|
kwargs.setdefault("metadata", {}).update({"model_group": model})
|
||||||
response = await self.async_function_with_fallbacks(**kwargs)
|
response = await self.async_function_with_fallbacks(**kwargs)
|
||||||
|
|
||||||
|
@ -1768,17 +1764,11 @@ class Router:
|
||||||
)
|
)
|
||||||
self.total_calls[model_name] += 1
|
self.total_calls[model_name] += 1
|
||||||
|
|
||||||
timeout: Optional[Union[float, int]] = self._get_timeout(
|
|
||||||
kwargs=kwargs,
|
|
||||||
data=data,
|
|
||||||
)
|
|
||||||
|
|
||||||
response = await litellm.arerank(
|
response = await litellm.arerank(
|
||||||
**{
|
**{
|
||||||
**data,
|
**data,
|
||||||
"caching": self.cache_responses,
|
"caching": self.cache_responses,
|
||||||
"client": model_client,
|
"client": model_client,
|
||||||
"timeout": timeout,
|
|
||||||
**kwargs,
|
**kwargs,
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
@ -1800,7 +1790,6 @@ class Router:
|
||||||
messages = [{"role": "user", "content": "dummy-text"}]
|
messages = [{"role": "user", "content": "dummy-text"}]
|
||||||
try:
|
try:
|
||||||
kwargs["num_retries"] = kwargs.get("num_retries", self.num_retries)
|
kwargs["num_retries"] = kwargs.get("num_retries", self.num_retries)
|
||||||
kwargs.get("request_timeout", self.timeout)
|
|
||||||
kwargs.setdefault("metadata", {}).update({"model_group": model})
|
kwargs.setdefault("metadata", {}).update({"model_group": model})
|
||||||
|
|
||||||
# pick the one that is available (lowest TPM/RPM)
|
# pick the one that is available (lowest TPM/RPM)
|
||||||
|
@ -1826,7 +1815,7 @@ class Router:
|
||||||
kwargs["model"] = model
|
kwargs["model"] = model
|
||||||
kwargs["messages"] = messages
|
kwargs["messages"] = messages
|
||||||
kwargs["original_function"] = self._arealtime
|
kwargs["original_function"] = self._arealtime
|
||||||
return self.function_with_retries(**kwargs)
|
return await self.async_function_with_retries(**kwargs)
|
||||||
else:
|
else:
|
||||||
raise e
|
raise e
|
||||||
|
|
||||||
|
@ -1844,7 +1833,6 @@ class Router:
|
||||||
kwargs["model"] = model
|
kwargs["model"] = model
|
||||||
kwargs["prompt"] = prompt
|
kwargs["prompt"] = prompt
|
||||||
kwargs["num_retries"] = kwargs.get("num_retries", self.num_retries)
|
kwargs["num_retries"] = kwargs.get("num_retries", self.num_retries)
|
||||||
kwargs.get("request_timeout", self.timeout)
|
|
||||||
kwargs.setdefault("metadata", {}).update({"model_group": model})
|
kwargs.setdefault("metadata", {}).update({"model_group": model})
|
||||||
|
|
||||||
# pick the one that is available (lowest TPM/RPM)
|
# pick the one that is available (lowest TPM/RPM)
|
||||||
|
@ -1924,7 +1912,6 @@ class Router:
|
||||||
"prompt": prompt,
|
"prompt": prompt,
|
||||||
"caching": self.cache_responses,
|
"caching": self.cache_responses,
|
||||||
"client": model_client,
|
"client": model_client,
|
||||||
"timeout": self.timeout,
|
|
||||||
**kwargs,
|
**kwargs,
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
@ -1980,7 +1967,6 @@ class Router:
|
||||||
kwargs["adapter_id"] = adapter_id
|
kwargs["adapter_id"] = adapter_id
|
||||||
kwargs["original_function"] = self._aadapter_completion
|
kwargs["original_function"] = self._aadapter_completion
|
||||||
kwargs["num_retries"] = kwargs.get("num_retries", self.num_retries)
|
kwargs["num_retries"] = kwargs.get("num_retries", self.num_retries)
|
||||||
kwargs.get("request_timeout", self.timeout)
|
|
||||||
kwargs.setdefault("metadata", {}).update({"model_group": model})
|
kwargs.setdefault("metadata", {}).update({"model_group": model})
|
||||||
response = await self.async_function_with_fallbacks(**kwargs)
|
response = await self.async_function_with_fallbacks(**kwargs)
|
||||||
|
|
||||||
|
@ -2024,7 +2010,6 @@ class Router:
|
||||||
"adapter_id": adapter_id,
|
"adapter_id": adapter_id,
|
||||||
"caching": self.cache_responses,
|
"caching": self.cache_responses,
|
||||||
"client": model_client,
|
"client": model_client,
|
||||||
"timeout": self.timeout,
|
|
||||||
**kwargs,
|
**kwargs,
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
@ -2078,7 +2063,6 @@ class Router:
|
||||||
kwargs["input"] = input
|
kwargs["input"] = input
|
||||||
kwargs["original_function"] = self._embedding
|
kwargs["original_function"] = self._embedding
|
||||||
kwargs["num_retries"] = kwargs.get("num_retries", self.num_retries)
|
kwargs["num_retries"] = kwargs.get("num_retries", self.num_retries)
|
||||||
kwargs.get("request_timeout", self.timeout)
|
|
||||||
kwargs.setdefault("metadata", {}).update({"model_group": model})
|
kwargs.setdefault("metadata", {}).update({"model_group": model})
|
||||||
response = self.function_with_fallbacks(**kwargs)
|
response = self.function_with_fallbacks(**kwargs)
|
||||||
return response
|
return response
|
||||||
|
@ -2245,7 +2229,6 @@ class Router:
|
||||||
kwargs["model"] = model
|
kwargs["model"] = model
|
||||||
kwargs["original_function"] = self._acreate_file
|
kwargs["original_function"] = self._acreate_file
|
||||||
kwargs["num_retries"] = kwargs.get("num_retries", self.num_retries)
|
kwargs["num_retries"] = kwargs.get("num_retries", self.num_retries)
|
||||||
kwargs.get("request_timeout", self.timeout)
|
|
||||||
kwargs.setdefault("metadata", {}).update({"model_group": model})
|
kwargs.setdefault("metadata", {}).update({"model_group": model})
|
||||||
response = await self.async_function_with_fallbacks(**kwargs)
|
response = await self.async_function_with_fallbacks(**kwargs)
|
||||||
|
|
||||||
|
@ -2301,7 +2284,6 @@ class Router:
|
||||||
"custom_llm_provider": custom_llm_provider,
|
"custom_llm_provider": custom_llm_provider,
|
||||||
"caching": self.cache_responses,
|
"caching": self.cache_responses,
|
||||||
"client": model_client,
|
"client": model_client,
|
||||||
"timeout": self.timeout,
|
|
||||||
**kwargs,
|
**kwargs,
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
@ -2352,7 +2334,6 @@ class Router:
|
||||||
kwargs["model"] = model
|
kwargs["model"] = model
|
||||||
kwargs["original_function"] = self._acreate_batch
|
kwargs["original_function"] = self._acreate_batch
|
||||||
kwargs["num_retries"] = kwargs.get("num_retries", self.num_retries)
|
kwargs["num_retries"] = kwargs.get("num_retries", self.num_retries)
|
||||||
kwargs.get("request_timeout", self.timeout)
|
|
||||||
kwargs.setdefault("metadata", {}).update({"model_group": model})
|
kwargs.setdefault("metadata", {}).update({"model_group": model})
|
||||||
response = await self.async_function_with_fallbacks(**kwargs)
|
response = await self.async_function_with_fallbacks(**kwargs)
|
||||||
|
|
||||||
|
@ -2397,13 +2378,7 @@ class Router:
|
||||||
kwargs["model_info"] = deployment.get("model_info", {})
|
kwargs["model_info"] = deployment.get("model_info", {})
|
||||||
data = deployment["litellm_params"].copy()
|
data = deployment["litellm_params"].copy()
|
||||||
model_name = data["model"]
|
model_name = data["model"]
|
||||||
for k, v in self.default_litellm_params.items():
|
self._update_kwargs_with_deployment(deployment=deployment, kwargs=kwargs)
|
||||||
if (
|
|
||||||
k not in kwargs
|
|
||||||
): # prioritize model-specific params > default router params
|
|
||||||
kwargs[k] = v
|
|
||||||
elif k == metadata_variable_name:
|
|
||||||
kwargs[k].update(v)
|
|
||||||
|
|
||||||
model_client = self._get_async_openai_model_client(
|
model_client = self._get_async_openai_model_client(
|
||||||
deployment=deployment,
|
deployment=deployment,
|
||||||
|
@ -2420,7 +2395,6 @@ class Router:
|
||||||
"custom_llm_provider": custom_llm_provider,
|
"custom_llm_provider": custom_llm_provider,
|
||||||
"caching": self.cache_responses,
|
"caching": self.cache_responses,
|
||||||
"client": model_client,
|
"client": model_client,
|
||||||
"timeout": self.timeout,
|
|
||||||
**kwargs,
|
**kwargs,
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
@ -2913,7 +2887,11 @@ class Router:
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
current_attempt = None
|
current_attempt = None
|
||||||
original_exception = e
|
original_exception = e
|
||||||
|
deployment_num_retries = getattr(e, "num_retries", None)
|
||||||
|
if deployment_num_retries is not None and isinstance(
|
||||||
|
deployment_num_retries, int
|
||||||
|
):
|
||||||
|
num_retries = deployment_num_retries
|
||||||
"""
|
"""
|
||||||
Retry Logic
|
Retry Logic
|
||||||
"""
|
"""
|
||||||
|
@ -3210,106 +3188,6 @@ class Router:
|
||||||
|
|
||||||
return timeout
|
return timeout
|
||||||
|
|
||||||
def function_with_retries(self, *args, **kwargs):
|
|
||||||
"""
|
|
||||||
Try calling the model 3 times. Shuffle-between available deployments.
|
|
||||||
"""
|
|
||||||
verbose_router_logger.debug(
|
|
||||||
f"Inside function with retries: args - {args}; kwargs - {kwargs}"
|
|
||||||
)
|
|
||||||
original_function = kwargs.pop("original_function")
|
|
||||||
num_retries = kwargs.pop("num_retries")
|
|
||||||
fallbacks = kwargs.pop("fallbacks", self.fallbacks)
|
|
||||||
context_window_fallbacks = kwargs.pop(
|
|
||||||
"context_window_fallbacks", self.context_window_fallbacks
|
|
||||||
)
|
|
||||||
content_policy_fallbacks = kwargs.pop(
|
|
||||||
"content_policy_fallbacks", self.content_policy_fallbacks
|
|
||||||
)
|
|
||||||
model_group = kwargs.get("model")
|
|
||||||
|
|
||||||
try:
|
|
||||||
# if the function call is successful, no exception will be raised and we'll break out of the loop
|
|
||||||
self._handle_mock_testing_rate_limit_error(
|
|
||||||
kwargs=kwargs, model_group=model_group
|
|
||||||
)
|
|
||||||
response = original_function(*args, **kwargs)
|
|
||||||
return response
|
|
||||||
except Exception as e:
|
|
||||||
current_attempt = None
|
|
||||||
original_exception = e
|
|
||||||
_model: Optional[str] = kwargs.get("model") # type: ignore
|
|
||||||
|
|
||||||
if _model is None:
|
|
||||||
raise e # re-raise error, if model can't be determined for loadbalancing
|
|
||||||
### CHECK IF RATE LIMIT / CONTEXT WINDOW ERROR
|
|
||||||
parent_otel_span = _get_parent_otel_span_from_kwargs(kwargs)
|
|
||||||
_healthy_deployments, _all_deployments = self._get_healthy_deployments(
|
|
||||||
model=_model,
|
|
||||||
parent_otel_span=parent_otel_span,
|
|
||||||
)
|
|
||||||
|
|
||||||
# raises an exception if this error should not be retries
|
|
||||||
self.should_retry_this_error(
|
|
||||||
error=e,
|
|
||||||
healthy_deployments=_healthy_deployments,
|
|
||||||
all_deployments=_all_deployments,
|
|
||||||
context_window_fallbacks=context_window_fallbacks,
|
|
||||||
regular_fallbacks=fallbacks,
|
|
||||||
content_policy_fallbacks=content_policy_fallbacks,
|
|
||||||
)
|
|
||||||
|
|
||||||
# decides how long to sleep before retry
|
|
||||||
_timeout = self._time_to_sleep_before_retry(
|
|
||||||
e=original_exception,
|
|
||||||
remaining_retries=num_retries,
|
|
||||||
num_retries=num_retries,
|
|
||||||
healthy_deployments=_healthy_deployments,
|
|
||||||
all_deployments=_all_deployments,
|
|
||||||
)
|
|
||||||
|
|
||||||
## LOGGING
|
|
||||||
if num_retries > 0:
|
|
||||||
kwargs = self.log_retry(kwargs=kwargs, e=original_exception)
|
|
||||||
|
|
||||||
time.sleep(_timeout)
|
|
||||||
for current_attempt in range(num_retries):
|
|
||||||
verbose_router_logger.debug(
|
|
||||||
f"retrying request. Current attempt - {current_attempt}; retries left: {num_retries}"
|
|
||||||
)
|
|
||||||
try:
|
|
||||||
# if the function call is successful, no exception will be raised and we'll break out of the loop
|
|
||||||
response = original_function(*args, **kwargs)
|
|
||||||
return response
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
## LOGGING
|
|
||||||
kwargs = self.log_retry(kwargs=kwargs, e=e)
|
|
||||||
_model: Optional[str] = kwargs.get("model") # type: ignore
|
|
||||||
|
|
||||||
if _model is None:
|
|
||||||
raise e # re-raise error, if model can't be determined for loadbalancing
|
|
||||||
parent_otel_span = _get_parent_otel_span_from_kwargs(kwargs)
|
|
||||||
_healthy_deployments, _ = self._get_healthy_deployments(
|
|
||||||
model=_model,
|
|
||||||
parent_otel_span=parent_otel_span,
|
|
||||||
)
|
|
||||||
remaining_retries = num_retries - current_attempt
|
|
||||||
_timeout = self._time_to_sleep_before_retry(
|
|
||||||
e=e,
|
|
||||||
remaining_retries=remaining_retries,
|
|
||||||
num_retries=num_retries,
|
|
||||||
healthy_deployments=_healthy_deployments,
|
|
||||||
all_deployments=_all_deployments,
|
|
||||||
)
|
|
||||||
time.sleep(_timeout)
|
|
||||||
|
|
||||||
if type(original_exception) in litellm.LITELLM_EXCEPTION_TYPES:
|
|
||||||
setattr(original_exception, "max_retries", num_retries)
|
|
||||||
setattr(original_exception, "num_retries", current_attempt)
|
|
||||||
|
|
||||||
raise original_exception
|
|
||||||
|
|
||||||
### HELPER FUNCTIONS
|
### HELPER FUNCTIONS
|
||||||
|
|
||||||
async def deployment_callback_on_success(
|
async def deployment_callback_on_success(
|
||||||
|
|
|
@ -689,6 +689,19 @@ def client(original_function): # noqa: PLR0915
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise e
|
raise e
|
||||||
|
|
||||||
|
def _get_num_retries(
|
||||||
|
kwargs: Dict[str, Any], exception: Exception
|
||||||
|
) -> Tuple[Optional[int], Dict[str, Any]]:
|
||||||
|
num_retries = kwargs.get("num_retries", None) or litellm.num_retries or None
|
||||||
|
if kwargs.get("retry_policy", None):
|
||||||
|
num_retries = get_num_retries_from_retry_policy(
|
||||||
|
exception=exception,
|
||||||
|
retry_policy=kwargs.get("retry_policy"),
|
||||||
|
)
|
||||||
|
kwargs["retry_policy"] = reset_retry_policy()
|
||||||
|
|
||||||
|
return num_retries, kwargs
|
||||||
|
|
||||||
@wraps(original_function)
|
@wraps(original_function)
|
||||||
def wrapper(*args, **kwargs): # noqa: PLR0915
|
def wrapper(*args, **kwargs): # noqa: PLR0915
|
||||||
# DO NOT MOVE THIS. It always needs to run first
|
# DO NOT MOVE THIS. It always needs to run first
|
||||||
|
@ -1159,20 +1172,8 @@ def client(original_function): # noqa: PLR0915
|
||||||
raise e
|
raise e
|
||||||
|
|
||||||
call_type = original_function.__name__
|
call_type = original_function.__name__
|
||||||
|
num_retries, kwargs = _get_num_retries(kwargs=kwargs, exception=e)
|
||||||
if call_type == CallTypes.acompletion.value:
|
if call_type == CallTypes.acompletion.value:
|
||||||
num_retries = (
|
|
||||||
kwargs.get("num_retries", None) or litellm.num_retries or None
|
|
||||||
)
|
|
||||||
if kwargs.get("retry_policy", None):
|
|
||||||
num_retries = get_num_retries_from_retry_policy(
|
|
||||||
exception=e,
|
|
||||||
retry_policy=kwargs.get("retry_policy"),
|
|
||||||
)
|
|
||||||
kwargs["retry_policy"] = reset_retry_policy()
|
|
||||||
|
|
||||||
litellm.num_retries = (
|
|
||||||
None # set retries to None to prevent infinite loops
|
|
||||||
)
|
|
||||||
context_window_fallback_dict = kwargs.get(
|
context_window_fallback_dict = kwargs.get(
|
||||||
"context_window_fallback_dict", {}
|
"context_window_fallback_dict", {}
|
||||||
)
|
)
|
||||||
|
@ -1184,6 +1185,9 @@ def client(original_function): # noqa: PLR0915
|
||||||
num_retries and not _is_litellm_router_call
|
num_retries and not _is_litellm_router_call
|
||||||
): # only enter this if call is not from litellm router/proxy. router has it's own logic for retrying
|
): # only enter this if call is not from litellm router/proxy. router has it's own logic for retrying
|
||||||
try:
|
try:
|
||||||
|
litellm.num_retries = (
|
||||||
|
None # set retries to None to prevent infinite loops
|
||||||
|
)
|
||||||
kwargs["num_retries"] = num_retries
|
kwargs["num_retries"] = num_retries
|
||||||
kwargs["original_function"] = original_function
|
kwargs["original_function"] = original_function
|
||||||
if isinstance(
|
if isinstance(
|
||||||
|
@ -1205,6 +1209,10 @@ def client(original_function): # noqa: PLR0915
|
||||||
else:
|
else:
|
||||||
kwargs["model"] = context_window_fallback_dict[model]
|
kwargs["model"] = context_window_fallback_dict[model]
|
||||||
return await original_function(*args, **kwargs)
|
return await original_function(*args, **kwargs)
|
||||||
|
|
||||||
|
setattr(
|
||||||
|
e, "num_retries", num_retries
|
||||||
|
) ## IMPORTANT: returns the deployment's num_retries to the router
|
||||||
raise e
|
raise e
|
||||||
|
|
||||||
is_coroutine = inspect.iscoroutinefunction(original_function)
|
is_coroutine = inspect.iscoroutinefunction(original_function)
|
||||||
|
|
|
@ -733,3 +733,73 @@ def test_no_retry_when_no_healthy_deployments():
|
||||||
)
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print("got exception", e)
|
print("got exception", e)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_router_retries_model_specific_and_global():
|
||||||
|
from unittest.mock import patch, MagicMock
|
||||||
|
|
||||||
|
litellm.num_retries = 0
|
||||||
|
router = Router(
|
||||||
|
model_list=[
|
||||||
|
{
|
||||||
|
"model_name": "gpt-3.5-turbo",
|
||||||
|
"litellm_params": {
|
||||||
|
"model": "gpt-3.5-turbo",
|
||||||
|
"api_key": os.getenv("OPENAI_API_KEY"),
|
||||||
|
"num_retries": 1,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
with patch.object(
|
||||||
|
router, "_time_to_sleep_before_retry"
|
||||||
|
) as mock_async_function_with_retries:
|
||||||
|
try:
|
||||||
|
await router.acompletion(
|
||||||
|
model="gpt-3.5-turbo",
|
||||||
|
messages=[{"role": "user", "content": "Hello, how are you?"}],
|
||||||
|
mock_response="litellm.RateLimitError",
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
print("got exception", e)
|
||||||
|
|
||||||
|
mock_async_function_with_retries.assert_called_once()
|
||||||
|
|
||||||
|
assert mock_async_function_with_retries.call_args.kwargs["num_retries"] == 1
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_router_timeout_model_specific_and_global():
|
||||||
|
from unittest.mock import patch, MagicMock
|
||||||
|
from litellm.llms.custom_httpx.http_handler import HTTPHandler
|
||||||
|
|
||||||
|
router = Router(
|
||||||
|
model_list=[
|
||||||
|
{
|
||||||
|
"model_name": "anthropic-claude",
|
||||||
|
"litellm_params": {
|
||||||
|
"model": "anthropic/claude-3-5-sonnet-20240620",
|
||||||
|
"timeout": 1,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
],
|
||||||
|
timeout=10,
|
||||||
|
)
|
||||||
|
|
||||||
|
client = HTTPHandler()
|
||||||
|
|
||||||
|
with patch.object(client, "post") as mock_client:
|
||||||
|
try:
|
||||||
|
await router.acompletion(
|
||||||
|
model="anthropic-claude",
|
||||||
|
messages=[{"role": "user", "content": "Hello, how are you?"}],
|
||||||
|
client=client,
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
print("got exception", e)
|
||||||
|
|
||||||
|
mock_client.assert_called()
|
||||||
|
|
||||||
|
assert mock_client.call_args.kwargs["timeout"] == 1
|
||||||
|
|
|
@ -217,16 +217,11 @@ async def test_router_function_with_retries(model_list, sync_mode):
|
||||||
"mock_response": "I'm fine, thank you!",
|
"mock_response": "I'm fine, thank you!",
|
||||||
"num_retries": 0,
|
"num_retries": 0,
|
||||||
}
|
}
|
||||||
if sync_mode:
|
|
||||||
response = router.function_with_retries(
|
|
||||||
original_function=router._completion,
|
|
||||||
**data,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
response = await router.async_function_with_retries(
|
response = await router.async_function_with_retries(
|
||||||
original_function=router._acompletion,
|
original_function=router._acompletion,
|
||||||
**data,
|
**data,
|
||||||
)
|
)
|
||||||
|
|
||||||
assert response.choices[0].message.content == "I'm fine, thank you!"
|
assert response.choices[0].message.content == "I'm fine, thank you!"
|
||||||
|
|
||||||
|
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue