Litellm dev 12 14 2024 p1 (#7231)

* fix(router.py): fix reading + using deployment-specific num retries on router

Fixes https://github.com/BerriAI/litellm/issues/7001

* fix(router.py): ensure 'timeout' in litellm_params overrides any value in router settings

Refactors all routes to use common '_update_kwargs_with_deployment' which has the timeout handling

* fix(router.py): fix timeout check
This commit is contained in:
Krish Dholakia 2024-12-14 22:22:29 -08:00 committed by GitHub
parent 2459f9735d
commit 194acfa95c
5 changed files with 117 additions and 165 deletions

View file

@ -813,7 +813,6 @@ class Router:
kwargs["messages"] = messages
kwargs["stream"] = stream
kwargs["original_function"] = self._acompletion
kwargs["num_retries"] = kwargs.get("num_retries", self.num_retries)
self._update_kwargs_before_fallbacks(model=model, kwargs=kwargs)
request_priority = kwargs.get("priority") or self.default_priority
@ -899,17 +898,12 @@ class Router:
)
self.total_calls[model_name] += 1
timeout: Optional[Union[float, int]] = self._get_timeout(
kwargs=kwargs, data=data
)
_response = litellm.acompletion(
**{
**data,
"messages": messages,
"caching": self.cache_responses,
"client": model_client,
"timeout": timeout,
**kwargs,
}
)
@ -1015,6 +1009,10 @@ class Router:
}
)
kwargs["model_info"] = deployment.get("model_info", {})
kwargs["timeout"] = self._get_timeout(
kwargs=kwargs, data=deployment["litellm_params"]
)
self._update_kwargs_with_default_litellm_params(kwargs=kwargs)
def _get_async_openai_model_client(self, deployment: dict, kwargs: dict):
@ -1046,16 +1044,16 @@ class Router:
def _get_timeout(self, kwargs: dict, data: dict) -> Optional[Union[float, int]]:
"""Helper to get timeout from kwargs or deployment params"""
timeout = (
data.get(
kwargs.get("timeout", None) # the params dynamically set by user
or kwargs.get("request_timeout", None) # the params dynamically set by user
or data.get(
"timeout", None
) # timeout set on litellm_params for this deployment
or data.get(
"request_timeout", None
) # timeout set on litellm_params for this deployment
or self.timeout # timeout set on router
or kwargs.get(
"timeout", None
) # this uses default_litellm_params when nothing is set
or self.default_litellm_params.get("timeout", None)
)
return timeout
@ -1378,7 +1376,6 @@ class Router:
kwargs["prompt"] = prompt
kwargs["original_function"] = self._image_generation
kwargs["num_retries"] = kwargs.get("num_retries", self.num_retries)
kwargs.get("request_timeout", self.timeout)
kwargs.setdefault("metadata", {}).update({"model_group": model})
response = self.function_with_fallbacks(**kwargs)
@ -1438,7 +1435,6 @@ class Router:
kwargs["prompt"] = prompt
kwargs["original_function"] = self._aimage_generation
kwargs["num_retries"] = kwargs.get("num_retries", self.num_retries)
kwargs.get("request_timeout", self.timeout)
kwargs.setdefault("metadata", {}).update({"model_group": model})
response = await self.async_function_with_fallbacks(**kwargs)
@ -1768,17 +1764,11 @@ class Router:
)
self.total_calls[model_name] += 1
timeout: Optional[Union[float, int]] = self._get_timeout(
kwargs=kwargs,
data=data,
)
response = await litellm.arerank(
**{
**data,
"caching": self.cache_responses,
"client": model_client,
"timeout": timeout,
**kwargs,
}
)
@ -1800,7 +1790,6 @@ class Router:
messages = [{"role": "user", "content": "dummy-text"}]
try:
kwargs["num_retries"] = kwargs.get("num_retries", self.num_retries)
kwargs.get("request_timeout", self.timeout)
kwargs.setdefault("metadata", {}).update({"model_group": model})
# pick the one that is available (lowest TPM/RPM)
@ -1826,7 +1815,7 @@ class Router:
kwargs["model"] = model
kwargs["messages"] = messages
kwargs["original_function"] = self._arealtime
return self.function_with_retries(**kwargs)
return await self.async_function_with_retries(**kwargs)
else:
raise e
@ -1844,7 +1833,6 @@ class Router:
kwargs["model"] = model
kwargs["prompt"] = prompt
kwargs["num_retries"] = kwargs.get("num_retries", self.num_retries)
kwargs.get("request_timeout", self.timeout)
kwargs.setdefault("metadata", {}).update({"model_group": model})
# pick the one that is available (lowest TPM/RPM)
@ -1924,7 +1912,6 @@ class Router:
"prompt": prompt,
"caching": self.cache_responses,
"client": model_client,
"timeout": self.timeout,
**kwargs,
}
)
@ -1980,7 +1967,6 @@ class Router:
kwargs["adapter_id"] = adapter_id
kwargs["original_function"] = self._aadapter_completion
kwargs["num_retries"] = kwargs.get("num_retries", self.num_retries)
kwargs.get("request_timeout", self.timeout)
kwargs.setdefault("metadata", {}).update({"model_group": model})
response = await self.async_function_with_fallbacks(**kwargs)
@ -2024,7 +2010,6 @@ class Router:
"adapter_id": adapter_id,
"caching": self.cache_responses,
"client": model_client,
"timeout": self.timeout,
**kwargs,
}
)
@ -2078,7 +2063,6 @@ class Router:
kwargs["input"] = input
kwargs["original_function"] = self._embedding
kwargs["num_retries"] = kwargs.get("num_retries", self.num_retries)
kwargs.get("request_timeout", self.timeout)
kwargs.setdefault("metadata", {}).update({"model_group": model})
response = self.function_with_fallbacks(**kwargs)
return response
@ -2245,7 +2229,6 @@ class Router:
kwargs["model"] = model
kwargs["original_function"] = self._acreate_file
kwargs["num_retries"] = kwargs.get("num_retries", self.num_retries)
kwargs.get("request_timeout", self.timeout)
kwargs.setdefault("metadata", {}).update({"model_group": model})
response = await self.async_function_with_fallbacks(**kwargs)
@ -2301,7 +2284,6 @@ class Router:
"custom_llm_provider": custom_llm_provider,
"caching": self.cache_responses,
"client": model_client,
"timeout": self.timeout,
**kwargs,
}
)
@ -2352,7 +2334,6 @@ class Router:
kwargs["model"] = model
kwargs["original_function"] = self._acreate_batch
kwargs["num_retries"] = kwargs.get("num_retries", self.num_retries)
kwargs.get("request_timeout", self.timeout)
kwargs.setdefault("metadata", {}).update({"model_group": model})
response = await self.async_function_with_fallbacks(**kwargs)
@ -2397,13 +2378,7 @@ class Router:
kwargs["model_info"] = deployment.get("model_info", {})
data = deployment["litellm_params"].copy()
model_name = data["model"]
for k, v in self.default_litellm_params.items():
if (
k not in kwargs
): # prioritize model-specific params > default router params
kwargs[k] = v
elif k == metadata_variable_name:
kwargs[k].update(v)
self._update_kwargs_with_deployment(deployment=deployment, kwargs=kwargs)
model_client = self._get_async_openai_model_client(
deployment=deployment,
@ -2420,7 +2395,6 @@ class Router:
"custom_llm_provider": custom_llm_provider,
"caching": self.cache_responses,
"client": model_client,
"timeout": self.timeout,
**kwargs,
}
)
@ -2913,7 +2887,11 @@ class Router:
except Exception as e:
current_attempt = None
original_exception = e
deployment_num_retries = getattr(e, "num_retries", None)
if deployment_num_retries is not None and isinstance(
deployment_num_retries, int
):
num_retries = deployment_num_retries
"""
Retry Logic
"""
@ -3210,106 +3188,6 @@ class Router:
return timeout
def function_with_retries(self, *args, **kwargs):
"""
Try calling the model 3 times. Shuffle-between available deployments.
"""
verbose_router_logger.debug(
f"Inside function with retries: args - {args}; kwargs - {kwargs}"
)
original_function = kwargs.pop("original_function")
num_retries = kwargs.pop("num_retries")
fallbacks = kwargs.pop("fallbacks", self.fallbacks)
context_window_fallbacks = kwargs.pop(
"context_window_fallbacks", self.context_window_fallbacks
)
content_policy_fallbacks = kwargs.pop(
"content_policy_fallbacks", self.content_policy_fallbacks
)
model_group = kwargs.get("model")
try:
# if the function call is successful, no exception will be raised and we'll break out of the loop
self._handle_mock_testing_rate_limit_error(
kwargs=kwargs, model_group=model_group
)
response = original_function(*args, **kwargs)
return response
except Exception as e:
current_attempt = None
original_exception = e
_model: Optional[str] = kwargs.get("model") # type: ignore
if _model is None:
raise e # re-raise error, if model can't be determined for loadbalancing
### CHECK IF RATE LIMIT / CONTEXT WINDOW ERROR
parent_otel_span = _get_parent_otel_span_from_kwargs(kwargs)
_healthy_deployments, _all_deployments = self._get_healthy_deployments(
model=_model,
parent_otel_span=parent_otel_span,
)
# raises an exception if this error should not be retries
self.should_retry_this_error(
error=e,
healthy_deployments=_healthy_deployments,
all_deployments=_all_deployments,
context_window_fallbacks=context_window_fallbacks,
regular_fallbacks=fallbacks,
content_policy_fallbacks=content_policy_fallbacks,
)
# decides how long to sleep before retry
_timeout = self._time_to_sleep_before_retry(
e=original_exception,
remaining_retries=num_retries,
num_retries=num_retries,
healthy_deployments=_healthy_deployments,
all_deployments=_all_deployments,
)
## LOGGING
if num_retries > 0:
kwargs = self.log_retry(kwargs=kwargs, e=original_exception)
time.sleep(_timeout)
for current_attempt in range(num_retries):
verbose_router_logger.debug(
f"retrying request. Current attempt - {current_attempt}; retries left: {num_retries}"
)
try:
# if the function call is successful, no exception will be raised and we'll break out of the loop
response = original_function(*args, **kwargs)
return response
except Exception as e:
## LOGGING
kwargs = self.log_retry(kwargs=kwargs, e=e)
_model: Optional[str] = kwargs.get("model") # type: ignore
if _model is None:
raise e # re-raise error, if model can't be determined for loadbalancing
parent_otel_span = _get_parent_otel_span_from_kwargs(kwargs)
_healthy_deployments, _ = self._get_healthy_deployments(
model=_model,
parent_otel_span=parent_otel_span,
)
remaining_retries = num_retries - current_attempt
_timeout = self._time_to_sleep_before_retry(
e=e,
remaining_retries=remaining_retries,
num_retries=num_retries,
healthy_deployments=_healthy_deployments,
all_deployments=_all_deployments,
)
time.sleep(_timeout)
if type(original_exception) in litellm.LITELLM_EXCEPTION_TYPES:
setattr(original_exception, "max_retries", num_retries)
setattr(original_exception, "num_retries", current_attempt)
raise original_exception
### HELPER FUNCTIONS
async def deployment_callback_on_success(