forked from phoenix/litellm-mirror
LiteLLM Minor Fixes and Improvements (11/09/2024) (#5634)
* fix(caching.py): set ttl for async_increment cache fixes issue where ttl for redis client was not being set on increment_cache Fixes https://github.com/BerriAI/litellm/issues/5609 * fix(caching.py): fix increment cache w/ ttl for sync increment cache on redis Fixes https://github.com/BerriAI/litellm/issues/5609 * fix(router.py): support adding retry policy + allowed fails policy via config.yaml * fix(router.py): don't cooldown single deployments No point, as there's no other deployment to loadbalance with. * fix(user_api_key_auth.py): support setting allowed email domains on jwt tokens Closes https://github.com/BerriAI/litellm/issues/5605 * docs(token_auth.md): add user upsert + allowed email domain to jwt auth docs * fix(litellm_pre_call_utils.py): fix dynamic key logging when team id is set Fixes issue where key logging would not be set if team metadata was not none * fix(secret_managers/main.py): load environment variables correctly Fixes issue where os.environ/ was not being loaded correctly * test(test_router.py): fix test * feat(spend_tracking_utils.py): support logging additional usage params - e.g. prompt caching values for deepseek * test: fix tests * test: fix test * test: fix test * test: fix test * test: fix test
This commit is contained in:
parent
70100d716b
commit
98c34a7e27
25 changed files with 745 additions and 114 deletions
|
@ -161,10 +161,10 @@ class Router:
|
|||
enable_tag_filtering: bool = False,
|
||||
retry_after: int = 0, # min time to wait before retrying a failed request
|
||||
retry_policy: Optional[
|
||||
RetryPolicy
|
||||
Union[RetryPolicy, dict]
|
||||
] = None, # set custom retries for different exceptions
|
||||
model_group_retry_policy: Optional[
|
||||
Dict[str, RetryPolicy]
|
||||
model_group_retry_policy: Dict[
|
||||
str, RetryPolicy
|
||||
] = {}, # set custom retry policies based on model group
|
||||
allowed_fails: Optional[
|
||||
int
|
||||
|
@ -263,7 +263,7 @@ class Router:
|
|||
self.debug_level = debug_level
|
||||
self.enable_pre_call_checks = enable_pre_call_checks
|
||||
self.enable_tag_filtering = enable_tag_filtering
|
||||
if self.set_verbose == True:
|
||||
if self.set_verbose is True:
|
||||
if debug_level == "INFO":
|
||||
verbose_router_logger.setLevel(logging.INFO)
|
||||
elif debug_level == "DEBUG":
|
||||
|
@ -454,11 +454,35 @@ class Router:
|
|||
)
|
||||
|
||||
self.routing_strategy_args = routing_strategy_args
|
||||
self.retry_policy: Optional[RetryPolicy] = retry_policy
|
||||
self.retry_policy: Optional[RetryPolicy] = None
|
||||
if retry_policy is not None:
|
||||
if isinstance(retry_policy, dict):
|
||||
self.retry_policy = RetryPolicy(**retry_policy)
|
||||
elif isinstance(retry_policy, RetryPolicy):
|
||||
self.retry_policy = retry_policy
|
||||
verbose_router_logger.info(
|
||||
"\033[32mRouter Custom Retry Policy Set:\n{}\033[0m".format(
|
||||
self.retry_policy.model_dump(exclude_none=True)
|
||||
)
|
||||
)
|
||||
|
||||
self.model_group_retry_policy: Optional[Dict[str, RetryPolicy]] = (
|
||||
model_group_retry_policy
|
||||
)
|
||||
self.allowed_fails_policy: Optional[AllowedFailsPolicy] = allowed_fails_policy
|
||||
|
||||
self.allowed_fails_policy: Optional[AllowedFailsPolicy] = None
|
||||
if allowed_fails_policy is not None:
|
||||
if isinstance(allowed_fails_policy, dict):
|
||||
self.allowed_fails_policy = AllowedFailsPolicy(**allowed_fails_policy)
|
||||
elif isinstance(allowed_fails_policy, AllowedFailsPolicy):
|
||||
self.allowed_fails_policy = allowed_fails_policy
|
||||
|
||||
verbose_router_logger.info(
|
||||
"\033[32mRouter Custom Allowed Fails Policy Set:\n{}\033[0m".format(
|
||||
self.allowed_fails_policy.model_dump(exclude_none=True)
|
||||
)
|
||||
)
|
||||
|
||||
self.alerting_config: Optional[AlertingConfig] = alerting_config
|
||||
if self.alerting_config is not None:
|
||||
self._initialize_alerting()
|
||||
|
@ -3003,6 +3027,13 @@ class Router:
|
|||
model_group = kwargs.get("model")
|
||||
num_retries = kwargs.pop("num_retries")
|
||||
|
||||
## ADD MODEL GROUP SIZE TO METADATA - used for model_group_rate_limit_error tracking
|
||||
_metadata: dict = kwargs.get("metadata") or {}
|
||||
if "model_group" in _metadata and isinstance(_metadata["model_group"], str):
|
||||
model_list = self.get_model_list(model_name=_metadata["model_group"])
|
||||
if model_list is not None:
|
||||
_metadata.update({"model_group_size": len(model_list)})
|
||||
|
||||
verbose_router_logger.debug(
|
||||
f"async function w/ retries: original_function - {original_function}, num_retries - {num_retries}"
|
||||
)
|
||||
|
@ -3165,6 +3196,7 @@ class Router:
|
|||
If it fails after num_retries, fall back to another model group
|
||||
"""
|
||||
mock_testing_fallbacks = kwargs.pop("mock_testing_fallbacks", None)
|
||||
|
||||
model_group = kwargs.get("model")
|
||||
fallbacks = kwargs.get("fallbacks", self.fallbacks)
|
||||
context_window_fallbacks = kwargs.get(
|
||||
|
@ -3173,6 +3205,7 @@ class Router:
|
|||
content_policy_fallbacks = kwargs.get(
|
||||
"content_policy_fallbacks", self.content_policy_fallbacks
|
||||
)
|
||||
|
||||
try:
|
||||
if mock_testing_fallbacks is not None and mock_testing_fallbacks == True:
|
||||
raise Exception(
|
||||
|
@ -3324,6 +3357,9 @@ class Router:
|
|||
f"Inside function with retries: args - {args}; kwargs - {kwargs}"
|
||||
)
|
||||
original_function = kwargs.pop("original_function")
|
||||
mock_testing_rate_limit_error = kwargs.pop(
|
||||
"mock_testing_rate_limit_error", None
|
||||
)
|
||||
num_retries = kwargs.pop("num_retries")
|
||||
fallbacks = kwargs.pop("fallbacks", self.fallbacks)
|
||||
context_window_fallbacks = kwargs.pop(
|
||||
|
@ -3332,9 +3368,22 @@ class Router:
|
|||
content_policy_fallbacks = kwargs.pop(
|
||||
"content_policy_fallbacks", self.content_policy_fallbacks
|
||||
)
|
||||
model_group = kwargs.get("model")
|
||||
|
||||
try:
|
||||
# if the function call is successful, no exception will be raised and we'll break out of the loop
|
||||
if (
|
||||
mock_testing_rate_limit_error is not None
|
||||
and mock_testing_rate_limit_error is True
|
||||
):
|
||||
verbose_router_logger.info(
|
||||
"litellm.router.py::async_function_with_retries() - mock_testing_rate_limit_error=True. Raising litellm.RateLimitError."
|
||||
)
|
||||
raise litellm.RateLimitError(
|
||||
model=model_group,
|
||||
llm_provider="",
|
||||
message=f"This is a mock exception for model={model_group}, to trigger a rate limit error.",
|
||||
)
|
||||
response = original_function(*args, **kwargs)
|
||||
return response
|
||||
except Exception as e:
|
||||
|
@ -3571,17 +3620,26 @@ class Router:
|
|||
) # don't change existing ttl
|
||||
|
||||
def _is_cooldown_required(
|
||||
self, exception_status: Union[str, int], exception_str: Optional[str] = None
|
||||
):
|
||||
self,
|
||||
model_id: str,
|
||||
exception_status: Union[str, int],
|
||||
exception_str: Optional[str] = None,
|
||||
) -> bool:
|
||||
"""
|
||||
A function to determine if a cooldown is required based on the exception status.
|
||||
|
||||
Parameters:
|
||||
model_id (str) The id of the model in the model list
|
||||
exception_status (Union[str, int]): The status of the exception.
|
||||
|
||||
Returns:
|
||||
bool: True if a cooldown is required, False otherwise.
|
||||
"""
|
||||
## BASE CASE - single deployment
|
||||
model_group = self.get_model_group(id=model_id)
|
||||
if model_group is not None and len(model_group) == 1:
|
||||
return False
|
||||
|
||||
try:
|
||||
ignored_strings = ["APIConnectionError"]
|
||||
if (
|
||||
|
@ -3677,7 +3735,9 @@ class Router:
|
|||
|
||||
if (
|
||||
self._is_cooldown_required(
|
||||
exception_status=exception_status, exception_str=str(original_exception)
|
||||
model_id=deployment,
|
||||
exception_status=exception_status,
|
||||
exception_str=str(original_exception),
|
||||
)
|
||||
is False
|
||||
):
|
||||
|
@ -3690,7 +3750,9 @@ class Router:
|
|||
exception=original_exception,
|
||||
)
|
||||
|
||||
allowed_fails = _allowed_fails if _allowed_fails is not None else self.allowed_fails
|
||||
allowed_fails = (
|
||||
_allowed_fails if _allowed_fails is not None else self.allowed_fails
|
||||
)
|
||||
|
||||
dt = get_utc_datetime()
|
||||
current_minute = dt.strftime("%H-%M")
|
||||
|
@ -4298,6 +4360,18 @@ class Router:
|
|||
return model
|
||||
return None
|
||||
|
||||
def get_model_group(self, id: str) -> Optional[List]:
|
||||
"""
|
||||
Return list of all models in the same model group as that model id
|
||||
"""
|
||||
|
||||
model_info = self.get_model_info(id=id)
|
||||
if model_info is None:
|
||||
return None
|
||||
|
||||
model_name = model_info["model_name"]
|
||||
return self.get_model_list(model_name=model_name)
|
||||
|
||||
def _set_model_group_info(
|
||||
self, model_group: str, user_facing_model_group_name: str
|
||||
) -> Optional[ModelGroupInfo]:
|
||||
|
@ -5528,18 +5602,19 @@ class Router:
|
|||
ContentPolicyViolationErrorRetries: Optional[int] = None
|
||||
"""
|
||||
# if we can find the exception then in the retry policy -> return the number of retries
|
||||
retry_policy = self.retry_policy
|
||||
retry_policy: Optional[RetryPolicy] = self.retry_policy
|
||||
if (
|
||||
self.model_group_retry_policy is not None
|
||||
and model_group is not None
|
||||
and model_group in self.model_group_retry_policy
|
||||
):
|
||||
retry_policy = self.model_group_retry_policy.get(model_group, None)
|
||||
retry_policy = self.model_group_retry_policy.get(model_group, None) # type: ignore
|
||||
|
||||
if retry_policy is None:
|
||||
return None
|
||||
if isinstance(retry_policy, dict):
|
||||
retry_policy = RetryPolicy(**retry_policy)
|
||||
|
||||
if (
|
||||
isinstance(exception, litellm.BadRequestError)
|
||||
and retry_policy.BadRequestErrorRetries is not None
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue