LiteLLM Minor Fixes and Improvements (11/09/2024) (#5634)

* fix(caching.py): set ttl for async_increment cache

fixes issue where ttl for redis client was not being set on increment_cache

Fixes https://github.com/BerriAI/litellm/issues/5609

* fix(caching.py): fix increment cache w/ ttl for sync increment cache on redis

Fixes https://github.com/BerriAI/litellm/issues/5609

* fix(router.py): support adding retry policy + allowed fails policy via config.yaml

* fix(router.py): don't cooldown single deployments

No point, as there's no other deployment to loadbalance with.

* fix(user_api_key_auth.py): support setting allowed email domains on jwt tokens

Closes https://github.com/BerriAI/litellm/issues/5605

* docs(token_auth.md): add user upsert + allowed email domain to jwt auth docs

* fix(litellm_pre_call_utils.py): fix dynamic key logging when team id is set

Fixes issue where key logging would not be set if team metadata was not none

* fix(secret_managers/main.py): load environment variables correctly

Fixes issue where os.environ/ was not being loaded correctly

* test(test_router.py): fix test

* feat(spend_tracking_utils.py): support logging additional usage params - e.g. prompt caching values for deepseek

* test: fix tests

* test: fix test

* test: fix test

* test: fix test

* test: fix test
This commit is contained in:
Krish Dholakia 2024-09-11 22:36:06 -07:00 committed by GitHub
parent c7e299d213
commit dec53961f7
25 changed files with 745 additions and 114 deletions

View file

@ -161,10 +161,10 @@ class Router:
enable_tag_filtering: bool = False,
retry_after: int = 0, # min time to wait before retrying a failed request
retry_policy: Optional[
RetryPolicy
Union[RetryPolicy, dict]
] = None, # set custom retries for different exceptions
model_group_retry_policy: Optional[
Dict[str, RetryPolicy]
model_group_retry_policy: Dict[
str, RetryPolicy
] = {}, # set custom retry policies based on model group
allowed_fails: Optional[
int
@ -263,7 +263,7 @@ class Router:
self.debug_level = debug_level
self.enable_pre_call_checks = enable_pre_call_checks
self.enable_tag_filtering = enable_tag_filtering
if self.set_verbose == True:
if self.set_verbose is True:
if debug_level == "INFO":
verbose_router_logger.setLevel(logging.INFO)
elif debug_level == "DEBUG":
@ -454,11 +454,35 @@ class Router:
)
self.routing_strategy_args = routing_strategy_args
self.retry_policy: Optional[RetryPolicy] = retry_policy
self.retry_policy: Optional[RetryPolicy] = None
if retry_policy is not None:
if isinstance(retry_policy, dict):
self.retry_policy = RetryPolicy(**retry_policy)
elif isinstance(retry_policy, RetryPolicy):
self.retry_policy = retry_policy
verbose_router_logger.info(
"\033[32mRouter Custom Retry Policy Set:\n{}\033[0m".format(
self.retry_policy.model_dump(exclude_none=True)
)
)
self.model_group_retry_policy: Optional[Dict[str, RetryPolicy]] = (
model_group_retry_policy
)
self.allowed_fails_policy: Optional[AllowedFailsPolicy] = allowed_fails_policy
self.allowed_fails_policy: Optional[AllowedFailsPolicy] = None
if allowed_fails_policy is not None:
if isinstance(allowed_fails_policy, dict):
self.allowed_fails_policy = AllowedFailsPolicy(**allowed_fails_policy)
elif isinstance(allowed_fails_policy, AllowedFailsPolicy):
self.allowed_fails_policy = allowed_fails_policy
verbose_router_logger.info(
"\033[32mRouter Custom Allowed Fails Policy Set:\n{}\033[0m".format(
self.allowed_fails_policy.model_dump(exclude_none=True)
)
)
self.alerting_config: Optional[AlertingConfig] = alerting_config
if self.alerting_config is not None:
self._initialize_alerting()
@ -3003,6 +3027,13 @@ class Router:
model_group = kwargs.get("model")
num_retries = kwargs.pop("num_retries")
## ADD MODEL GROUP SIZE TO METADATA - used for model_group_rate_limit_error tracking
_metadata: dict = kwargs.get("metadata") or {}
if "model_group" in _metadata and isinstance(_metadata["model_group"], str):
model_list = self.get_model_list(model_name=_metadata["model_group"])
if model_list is not None:
_metadata.update({"model_group_size": len(model_list)})
verbose_router_logger.debug(
f"async function w/ retries: original_function - {original_function}, num_retries - {num_retries}"
)
@ -3165,6 +3196,7 @@ class Router:
If it fails after num_retries, fall back to another model group
"""
mock_testing_fallbacks = kwargs.pop("mock_testing_fallbacks", None)
model_group = kwargs.get("model")
fallbacks = kwargs.get("fallbacks", self.fallbacks)
context_window_fallbacks = kwargs.get(
@ -3173,6 +3205,7 @@ class Router:
content_policy_fallbacks = kwargs.get(
"content_policy_fallbacks", self.content_policy_fallbacks
)
try:
if mock_testing_fallbacks is not None and mock_testing_fallbacks == True:
raise Exception(
@ -3324,6 +3357,9 @@ class Router:
f"Inside function with retries: args - {args}; kwargs - {kwargs}"
)
original_function = kwargs.pop("original_function")
mock_testing_rate_limit_error = kwargs.pop(
"mock_testing_rate_limit_error", None
)
num_retries = kwargs.pop("num_retries")
fallbacks = kwargs.pop("fallbacks", self.fallbacks)
context_window_fallbacks = kwargs.pop(
@ -3332,9 +3368,22 @@ class Router:
content_policy_fallbacks = kwargs.pop(
"content_policy_fallbacks", self.content_policy_fallbacks
)
model_group = kwargs.get("model")
try:
# if the function call is successful, no exception will be raised and we'll break out of the loop
if (
mock_testing_rate_limit_error is not None
and mock_testing_rate_limit_error is True
):
verbose_router_logger.info(
"litellm.router.py::async_function_with_retries() - mock_testing_rate_limit_error=True. Raising litellm.RateLimitError."
)
raise litellm.RateLimitError(
model=model_group,
llm_provider="",
message=f"This is a mock exception for model={model_group}, to trigger a rate limit error.",
)
response = original_function(*args, **kwargs)
return response
except Exception as e:
@ -3571,17 +3620,26 @@ class Router:
) # don't change existing ttl
def _is_cooldown_required(
self, exception_status: Union[str, int], exception_str: Optional[str] = None
):
self,
model_id: str,
exception_status: Union[str, int],
exception_str: Optional[str] = None,
) -> bool:
"""
A function to determine if a cooldown is required based on the exception status.
Parameters:
model_id (str) The id of the model in the model list
exception_status (Union[str, int]): The status of the exception.
Returns:
bool: True if a cooldown is required, False otherwise.
"""
## BASE CASE - single deployment
model_group = self.get_model_group(id=model_id)
if model_group is not None and len(model_group) == 1:
return False
try:
ignored_strings = ["APIConnectionError"]
if (
@ -3677,7 +3735,9 @@ class Router:
if (
self._is_cooldown_required(
exception_status=exception_status, exception_str=str(original_exception)
model_id=deployment,
exception_status=exception_status,
exception_str=str(original_exception),
)
is False
):
@ -3690,7 +3750,9 @@ class Router:
exception=original_exception,
)
allowed_fails = _allowed_fails if _allowed_fails is not None else self.allowed_fails
allowed_fails = (
_allowed_fails if _allowed_fails is not None else self.allowed_fails
)
dt = get_utc_datetime()
current_minute = dt.strftime("%H-%M")
@ -4298,6 +4360,18 @@ class Router:
return model
return None
def get_model_group(self, id: str) -> Optional[List]:
"""
Return list of all models in the same model group as that model id
"""
model_info = self.get_model_info(id=id)
if model_info is None:
return None
model_name = model_info["model_name"]
return self.get_model_list(model_name=model_name)
def _set_model_group_info(
self, model_group: str, user_facing_model_group_name: str
) -> Optional[ModelGroupInfo]:
@ -5528,18 +5602,19 @@ class Router:
ContentPolicyViolationErrorRetries: Optional[int] = None
"""
# if we can find the exception then in the retry policy -> return the number of retries
retry_policy = self.retry_policy
retry_policy: Optional[RetryPolicy] = self.retry_policy
if (
self.model_group_retry_policy is not None
and model_group is not None
and model_group in self.model_group_retry_policy
):
retry_policy = self.model_group_retry_policy.get(model_group, None)
retry_policy = self.model_group_retry_policy.get(model_group, None) # type: ignore
if retry_policy is None:
return None
if isinstance(retry_policy, dict):
retry_policy = RetryPolicy(**retry_policy)
if (
isinstance(exception, litellm.BadRequestError)
and retry_policy.BadRequestErrorRetries is not None