mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-26 03:04:13 +00:00
Litellm dev 12 14 2024 p1 (#7231)
* fix(router.py): fix reading + using deployment-specific num retries on router Fixes https://github.com/BerriAI/litellm/issues/7001 * fix(router.py): ensure 'timeout' in litellm_params overrides any value in router settings Refactors all routes to use common '_update_kwargs_with_deployment' which has the timeout handling * fix(router.py): fix timeout check
This commit is contained in:
parent
2459f9735d
commit
194acfa95c
5 changed files with 117 additions and 165 deletions
|
@ -813,7 +813,6 @@ class Router:
|
|||
kwargs["messages"] = messages
|
||||
kwargs["stream"] = stream
|
||||
kwargs["original_function"] = self._acompletion
|
||||
kwargs["num_retries"] = kwargs.get("num_retries", self.num_retries)
|
||||
self._update_kwargs_before_fallbacks(model=model, kwargs=kwargs)
|
||||
|
||||
request_priority = kwargs.get("priority") or self.default_priority
|
||||
|
@ -899,17 +898,12 @@ class Router:
|
|||
)
|
||||
self.total_calls[model_name] += 1
|
||||
|
||||
timeout: Optional[Union[float, int]] = self._get_timeout(
|
||||
kwargs=kwargs, data=data
|
||||
)
|
||||
|
||||
_response = litellm.acompletion(
|
||||
**{
|
||||
**data,
|
||||
"messages": messages,
|
||||
"caching": self.cache_responses,
|
||||
"client": model_client,
|
||||
"timeout": timeout,
|
||||
**kwargs,
|
||||
}
|
||||
)
|
||||
|
@ -1015,6 +1009,10 @@ class Router:
|
|||
}
|
||||
)
|
||||
kwargs["model_info"] = deployment.get("model_info", {})
|
||||
kwargs["timeout"] = self._get_timeout(
|
||||
kwargs=kwargs, data=deployment["litellm_params"]
|
||||
)
|
||||
|
||||
self._update_kwargs_with_default_litellm_params(kwargs=kwargs)
|
||||
|
||||
def _get_async_openai_model_client(self, deployment: dict, kwargs: dict):
|
||||
|
@ -1046,16 +1044,16 @@ class Router:
|
|||
def _get_timeout(self, kwargs: dict, data: dict) -> Optional[Union[float, int]]:
|
||||
"""Helper to get timeout from kwargs or deployment params"""
|
||||
timeout = (
|
||||
data.get(
|
||||
kwargs.get("timeout", None) # the params dynamically set by user
|
||||
or kwargs.get("request_timeout", None) # the params dynamically set by user
|
||||
or data.get(
|
||||
"timeout", None
|
||||
) # timeout set on litellm_params for this deployment
|
||||
or data.get(
|
||||
"request_timeout", None
|
||||
) # timeout set on litellm_params for this deployment
|
||||
or self.timeout # timeout set on router
|
||||
or kwargs.get(
|
||||
"timeout", None
|
||||
) # this uses default_litellm_params when nothing is set
|
||||
or self.default_litellm_params.get("timeout", None)
|
||||
)
|
||||
|
||||
return timeout
|
||||
|
@ -1378,7 +1376,6 @@ class Router:
|
|||
kwargs["prompt"] = prompt
|
||||
kwargs["original_function"] = self._image_generation
|
||||
kwargs["num_retries"] = kwargs.get("num_retries", self.num_retries)
|
||||
kwargs.get("request_timeout", self.timeout)
|
||||
kwargs.setdefault("metadata", {}).update({"model_group": model})
|
||||
response = self.function_with_fallbacks(**kwargs)
|
||||
|
||||
|
@ -1438,7 +1435,6 @@ class Router:
|
|||
kwargs["prompt"] = prompt
|
||||
kwargs["original_function"] = self._aimage_generation
|
||||
kwargs["num_retries"] = kwargs.get("num_retries", self.num_retries)
|
||||
kwargs.get("request_timeout", self.timeout)
|
||||
kwargs.setdefault("metadata", {}).update({"model_group": model})
|
||||
response = await self.async_function_with_fallbacks(**kwargs)
|
||||
|
||||
|
@ -1768,17 +1764,11 @@ class Router:
|
|||
)
|
||||
self.total_calls[model_name] += 1
|
||||
|
||||
timeout: Optional[Union[float, int]] = self._get_timeout(
|
||||
kwargs=kwargs,
|
||||
data=data,
|
||||
)
|
||||
|
||||
response = await litellm.arerank(
|
||||
**{
|
||||
**data,
|
||||
"caching": self.cache_responses,
|
||||
"client": model_client,
|
||||
"timeout": timeout,
|
||||
**kwargs,
|
||||
}
|
||||
)
|
||||
|
@ -1800,7 +1790,6 @@ class Router:
|
|||
messages = [{"role": "user", "content": "dummy-text"}]
|
||||
try:
|
||||
kwargs["num_retries"] = kwargs.get("num_retries", self.num_retries)
|
||||
kwargs.get("request_timeout", self.timeout)
|
||||
kwargs.setdefault("metadata", {}).update({"model_group": model})
|
||||
|
||||
# pick the one that is available (lowest TPM/RPM)
|
||||
|
@ -1826,7 +1815,7 @@ class Router:
|
|||
kwargs["model"] = model
|
||||
kwargs["messages"] = messages
|
||||
kwargs["original_function"] = self._arealtime
|
||||
return self.function_with_retries(**kwargs)
|
||||
return await self.async_function_with_retries(**kwargs)
|
||||
else:
|
||||
raise e
|
||||
|
||||
|
@ -1844,7 +1833,6 @@ class Router:
|
|||
kwargs["model"] = model
|
||||
kwargs["prompt"] = prompt
|
||||
kwargs["num_retries"] = kwargs.get("num_retries", self.num_retries)
|
||||
kwargs.get("request_timeout", self.timeout)
|
||||
kwargs.setdefault("metadata", {}).update({"model_group": model})
|
||||
|
||||
# pick the one that is available (lowest TPM/RPM)
|
||||
|
@ -1924,7 +1912,6 @@ class Router:
|
|||
"prompt": prompt,
|
||||
"caching": self.cache_responses,
|
||||
"client": model_client,
|
||||
"timeout": self.timeout,
|
||||
**kwargs,
|
||||
}
|
||||
)
|
||||
|
@ -1980,7 +1967,6 @@ class Router:
|
|||
kwargs["adapter_id"] = adapter_id
|
||||
kwargs["original_function"] = self._aadapter_completion
|
||||
kwargs["num_retries"] = kwargs.get("num_retries", self.num_retries)
|
||||
kwargs.get("request_timeout", self.timeout)
|
||||
kwargs.setdefault("metadata", {}).update({"model_group": model})
|
||||
response = await self.async_function_with_fallbacks(**kwargs)
|
||||
|
||||
|
@ -2024,7 +2010,6 @@ class Router:
|
|||
"adapter_id": adapter_id,
|
||||
"caching": self.cache_responses,
|
||||
"client": model_client,
|
||||
"timeout": self.timeout,
|
||||
**kwargs,
|
||||
}
|
||||
)
|
||||
|
@ -2078,7 +2063,6 @@ class Router:
|
|||
kwargs["input"] = input
|
||||
kwargs["original_function"] = self._embedding
|
||||
kwargs["num_retries"] = kwargs.get("num_retries", self.num_retries)
|
||||
kwargs.get("request_timeout", self.timeout)
|
||||
kwargs.setdefault("metadata", {}).update({"model_group": model})
|
||||
response = self.function_with_fallbacks(**kwargs)
|
||||
return response
|
||||
|
@ -2245,7 +2229,6 @@ class Router:
|
|||
kwargs["model"] = model
|
||||
kwargs["original_function"] = self._acreate_file
|
||||
kwargs["num_retries"] = kwargs.get("num_retries", self.num_retries)
|
||||
kwargs.get("request_timeout", self.timeout)
|
||||
kwargs.setdefault("metadata", {}).update({"model_group": model})
|
||||
response = await self.async_function_with_fallbacks(**kwargs)
|
||||
|
||||
|
@ -2301,7 +2284,6 @@ class Router:
|
|||
"custom_llm_provider": custom_llm_provider,
|
||||
"caching": self.cache_responses,
|
||||
"client": model_client,
|
||||
"timeout": self.timeout,
|
||||
**kwargs,
|
||||
}
|
||||
)
|
||||
|
@ -2352,7 +2334,6 @@ class Router:
|
|||
kwargs["model"] = model
|
||||
kwargs["original_function"] = self._acreate_batch
|
||||
kwargs["num_retries"] = kwargs.get("num_retries", self.num_retries)
|
||||
kwargs.get("request_timeout", self.timeout)
|
||||
kwargs.setdefault("metadata", {}).update({"model_group": model})
|
||||
response = await self.async_function_with_fallbacks(**kwargs)
|
||||
|
||||
|
@ -2397,13 +2378,7 @@ class Router:
|
|||
kwargs["model_info"] = deployment.get("model_info", {})
|
||||
data = deployment["litellm_params"].copy()
|
||||
model_name = data["model"]
|
||||
for k, v in self.default_litellm_params.items():
|
||||
if (
|
||||
k not in kwargs
|
||||
): # prioritize model-specific params > default router params
|
||||
kwargs[k] = v
|
||||
elif k == metadata_variable_name:
|
||||
kwargs[k].update(v)
|
||||
self._update_kwargs_with_deployment(deployment=deployment, kwargs=kwargs)
|
||||
|
||||
model_client = self._get_async_openai_model_client(
|
||||
deployment=deployment,
|
||||
|
@ -2420,7 +2395,6 @@ class Router:
|
|||
"custom_llm_provider": custom_llm_provider,
|
||||
"caching": self.cache_responses,
|
||||
"client": model_client,
|
||||
"timeout": self.timeout,
|
||||
**kwargs,
|
||||
}
|
||||
)
|
||||
|
@ -2913,7 +2887,11 @@ class Router:
|
|||
except Exception as e:
|
||||
current_attempt = None
|
||||
original_exception = e
|
||||
|
||||
deployment_num_retries = getattr(e, "num_retries", None)
|
||||
if deployment_num_retries is not None and isinstance(
|
||||
deployment_num_retries, int
|
||||
):
|
||||
num_retries = deployment_num_retries
|
||||
"""
|
||||
Retry Logic
|
||||
"""
|
||||
|
@ -3210,106 +3188,6 @@ class Router:
|
|||
|
||||
return timeout
|
||||
|
||||
def function_with_retries(self, *args, **kwargs):
|
||||
"""
|
||||
Try calling the model 3 times. Shuffle-between available deployments.
|
||||
"""
|
||||
verbose_router_logger.debug(
|
||||
f"Inside function with retries: args - {args}; kwargs - {kwargs}"
|
||||
)
|
||||
original_function = kwargs.pop("original_function")
|
||||
num_retries = kwargs.pop("num_retries")
|
||||
fallbacks = kwargs.pop("fallbacks", self.fallbacks)
|
||||
context_window_fallbacks = kwargs.pop(
|
||||
"context_window_fallbacks", self.context_window_fallbacks
|
||||
)
|
||||
content_policy_fallbacks = kwargs.pop(
|
||||
"content_policy_fallbacks", self.content_policy_fallbacks
|
||||
)
|
||||
model_group = kwargs.get("model")
|
||||
|
||||
try:
|
||||
# if the function call is successful, no exception will be raised and we'll break out of the loop
|
||||
self._handle_mock_testing_rate_limit_error(
|
||||
kwargs=kwargs, model_group=model_group
|
||||
)
|
||||
response = original_function(*args, **kwargs)
|
||||
return response
|
||||
except Exception as e:
|
||||
current_attempt = None
|
||||
original_exception = e
|
||||
_model: Optional[str] = kwargs.get("model") # type: ignore
|
||||
|
||||
if _model is None:
|
||||
raise e # re-raise error, if model can't be determined for loadbalancing
|
||||
### CHECK IF RATE LIMIT / CONTEXT WINDOW ERROR
|
||||
parent_otel_span = _get_parent_otel_span_from_kwargs(kwargs)
|
||||
_healthy_deployments, _all_deployments = self._get_healthy_deployments(
|
||||
model=_model,
|
||||
parent_otel_span=parent_otel_span,
|
||||
)
|
||||
|
||||
# raises an exception if this error should not be retries
|
||||
self.should_retry_this_error(
|
||||
error=e,
|
||||
healthy_deployments=_healthy_deployments,
|
||||
all_deployments=_all_deployments,
|
||||
context_window_fallbacks=context_window_fallbacks,
|
||||
regular_fallbacks=fallbacks,
|
||||
content_policy_fallbacks=content_policy_fallbacks,
|
||||
)
|
||||
|
||||
# decides how long to sleep before retry
|
||||
_timeout = self._time_to_sleep_before_retry(
|
||||
e=original_exception,
|
||||
remaining_retries=num_retries,
|
||||
num_retries=num_retries,
|
||||
healthy_deployments=_healthy_deployments,
|
||||
all_deployments=_all_deployments,
|
||||
)
|
||||
|
||||
## LOGGING
|
||||
if num_retries > 0:
|
||||
kwargs = self.log_retry(kwargs=kwargs, e=original_exception)
|
||||
|
||||
time.sleep(_timeout)
|
||||
for current_attempt in range(num_retries):
|
||||
verbose_router_logger.debug(
|
||||
f"retrying request. Current attempt - {current_attempt}; retries left: {num_retries}"
|
||||
)
|
||||
try:
|
||||
# if the function call is successful, no exception will be raised and we'll break out of the loop
|
||||
response = original_function(*args, **kwargs)
|
||||
return response
|
||||
|
||||
except Exception as e:
|
||||
## LOGGING
|
||||
kwargs = self.log_retry(kwargs=kwargs, e=e)
|
||||
_model: Optional[str] = kwargs.get("model") # type: ignore
|
||||
|
||||
if _model is None:
|
||||
raise e # re-raise error, if model can't be determined for loadbalancing
|
||||
parent_otel_span = _get_parent_otel_span_from_kwargs(kwargs)
|
||||
_healthy_deployments, _ = self._get_healthy_deployments(
|
||||
model=_model,
|
||||
parent_otel_span=parent_otel_span,
|
||||
)
|
||||
remaining_retries = num_retries - current_attempt
|
||||
_timeout = self._time_to_sleep_before_retry(
|
||||
e=e,
|
||||
remaining_retries=remaining_retries,
|
||||
num_retries=num_retries,
|
||||
healthy_deployments=_healthy_deployments,
|
||||
all_deployments=_all_deployments,
|
||||
)
|
||||
time.sleep(_timeout)
|
||||
|
||||
if type(original_exception) in litellm.LITELLM_EXCEPTION_TYPES:
|
||||
setattr(original_exception, "max_retries", num_retries)
|
||||
setattr(original_exception, "num_retries", current_attempt)
|
||||
|
||||
raise original_exception
|
||||
|
||||
### HELPER FUNCTIONS
|
||||
|
||||
async def deployment_callback_on_success(
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue