feat(router.py): support content policy fallbacks

Closes https://github.com/BerriAI/litellm/issues/2632
This commit is contained in:
Krrish Dholakia 2024-06-14 17:15:44 -07:00
parent 0404d30a9c
commit 6f715b4782
6 changed files with 197 additions and 33 deletions

View file

@ -108,6 +108,7 @@ class Router:
] = None, # generic fallbacks, works across all deployments
fallbacks: List = [],
context_window_fallbacks: List = [],
content_policy_fallbacks: List = [],
model_group_alias: Optional[dict] = {},
enable_pre_call_checks: bool = False,
retry_after: int = 0, # min time to wait before retrying a failed request
@ -311,6 +312,12 @@ class Router:
self.context_window_fallbacks = (
context_window_fallbacks or litellm.context_window_fallbacks
)
_content_policy_fallbacks = (
content_policy_fallbacks or litellm.content_policy_fallbacks
)
self.validate_fallbacks(fallback_param=_content_policy_fallbacks)
self.content_policy_fallbacks = _content_policy_fallbacks
self.total_calls: defaultdict = defaultdict(
int
) # dict to store total calls made to each model
@ -1998,6 +2005,9 @@ class Router:
context_window_fallbacks = kwargs.get(
"context_window_fallbacks", self.context_window_fallbacks
)
content_policy_fallbacks = kwargs.get(
"content_policy_fallbacks", self.content_policy_fallbacks
)
try:
if mock_testing_fallbacks is not None and mock_testing_fallbacks == True:
raise Exception(
@ -2016,7 +2026,10 @@ class Router:
if (
hasattr(e, "status_code")
and e.status_code == 400 # type: ignore
and not isinstance(e, litellm.ContextWindowExceededError)
and not (
isinstance(e, litellm.ContextWindowExceededError)
or isinstance(e, litellm.ContentPolicyViolationError)
)
): # don't retry a malformed request
raise e
if (
@ -2034,6 +2047,39 @@ class Router:
if fallback_model_group is None:
raise original_exception
for mg in fallback_model_group:
"""
Iterate through the model groups and try calling that deployment
"""
try:
kwargs["model"] = mg
kwargs.setdefault("metadata", {}).update(
{"model_group": mg}
) # update model_group used, if fallbacks are done
response = await self.async_function_with_retries(
*args, **kwargs
)
verbose_router_logger.info(
"Successful fallback b/w models."
)
return response
except Exception as e:
pass
elif (
isinstance(e, litellm.ContentPolicyViolationError)
and content_policy_fallbacks is not None
):
fallback_model_group = None
for (
item
) in content_policy_fallbacks: # [{"gpt-3.5-turbo": ["gpt-4"]}]
if list(item.keys())[0] == model_group:
fallback_model_group = item[model_group]
break
if fallback_model_group is None:
raise original_exception
for mg in fallback_model_group:
"""
Iterate through the model groups and try calling that deployment
@ -2114,6 +2160,9 @@ class Router:
context_window_fallbacks = kwargs.pop(
"context_window_fallbacks", self.context_window_fallbacks
)
content_policy_fallbacks = kwargs.pop(
"content_policy_fallbacks", self.content_policy_fallbacks
)
num_retries = kwargs.pop("num_retries")
@ -2141,6 +2190,7 @@ class Router:
healthy_deployments=_healthy_deployments,
context_window_fallbacks=context_window_fallbacks,
regular_fallbacks=fallbacks,
content_policy_fallbacks=content_policy_fallbacks,
)
# decides how long to sleep before retry
@ -2206,10 +2256,12 @@ class Router:
error: Exception,
healthy_deployments: Optional[List] = None,
context_window_fallbacks: Optional[List] = None,
content_policy_fallbacks: Optional[List] = None,
regular_fallbacks: Optional[List] = None,
):
"""
1. raise an exception for ContextWindowExceededError if context_window_fallbacks is not None
2. raise an exception for ContentPolicyViolationError if content_policy_fallbacks is not None
2. raise an exception for RateLimitError if
- there are no fallbacks
@ -2219,13 +2271,19 @@ class Router:
if healthy_deployments is not None and isinstance(healthy_deployments, list):
_num_healthy_deployments = len(healthy_deployments)
### CHECK IF RATE LIMIT / CONTEXT WINDOW ERROR w/ fallbacks available / Bad Request Error
### CHECK IF RATE LIMIT / CONTEXT WINDOW ERROR / CONTENT POLICY VIOLATION ERROR w/ fallbacks available / Bad Request Error
if (
isinstance(error, litellm.ContextWindowExceededError)
and context_window_fallbacks is not None
):
raise error
if (
isinstance(error, litellm.ContentPolicyViolationError)
and content_policy_fallbacks is not None
):
raise error
# Error we should only retry if there are other deployments
if isinstance(error, openai.RateLimitError):
if (
@ -2256,6 +2314,9 @@ class Router:
context_window_fallbacks = kwargs.get(
"context_window_fallbacks", self.context_window_fallbacks
)
content_policy_fallbacks = kwargs.get(
"content_policy_fallbacks", self.content_policy_fallbacks
)
try:
if mock_testing_fallbacks is not None and mock_testing_fallbacks == True:
raise Exception(
@ -2271,7 +2332,10 @@ class Router:
if (
hasattr(e, "status_code")
and e.status_code == 400 # type: ignore
and not isinstance(e, litellm.ContextWindowExceededError)
and not (
isinstance(e, litellm.ContextWindowExceededError)
or isinstance(e, litellm.ContentPolicyViolationError)
)
): # don't retry a malformed request
raise e
@ -2294,6 +2358,37 @@ class Router:
if fallback_model_group is None:
raise original_exception
for mg in fallback_model_group:
"""
Iterate through the model groups and try calling that deployment
"""
try:
## LOGGING
kwargs = self.log_retry(kwargs=kwargs, e=original_exception)
kwargs["model"] = mg
kwargs.setdefault("metadata", {}).update(
{"model_group": mg}
) # update model_group used, if fallbacks are done
response = self.function_with_fallbacks(*args, **kwargs)
return response
except Exception as e:
pass
elif (
isinstance(e, litellm.ContentPolicyViolationError)
and content_policy_fallbacks is not None
):
fallback_model_group = None
for (
item
) in content_policy_fallbacks: # [{"gpt-3.5-turbo": ["gpt-4"]}]
if list(item.keys())[0] == model_group:
fallback_model_group = item[model_group]
break
if fallback_model_group is None:
raise original_exception
for mg in fallback_model_group:
"""
Iterate through the model groups and try calling that deployment
@ -2400,6 +2495,9 @@ class Router:
context_window_fallbacks = kwargs.pop(
"context_window_fallbacks", self.context_window_fallbacks
)
content_policy_fallbacks = kwargs.pop(
"content_policy_fallbacks", self.content_policy_fallbacks
)
try:
# if the function call is successful, no exception will be raised and we'll break out of the loop
@ -2419,6 +2517,7 @@ class Router:
healthy_deployments=_healthy_deployments,
context_window_fallbacks=context_window_fallbacks,
regular_fallbacks=fallbacks,
content_policy_fallbacks=content_policy_fallbacks,
)
# decides how long to sleep before retry