mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-25 10:44:24 +00:00
feat(router.py): support content policy fallbacks
Closes https://github.com/BerriAI/litellm/issues/2632
This commit is contained in:
parent
0404d30a9c
commit
6f715b4782
6 changed files with 197 additions and 33 deletions
|
@ -108,6 +108,7 @@ class Router:
|
|||
] = None, # generic fallbacks, works across all deployments
|
||||
fallbacks: List = [],
|
||||
context_window_fallbacks: List = [],
|
||||
content_policy_fallbacks: List = [],
|
||||
model_group_alias: Optional[dict] = {},
|
||||
enable_pre_call_checks: bool = False,
|
||||
retry_after: int = 0, # min time to wait before retrying a failed request
|
||||
|
@ -311,6 +312,12 @@ class Router:
|
|||
self.context_window_fallbacks = (
|
||||
context_window_fallbacks or litellm.context_window_fallbacks
|
||||
)
|
||||
|
||||
_content_policy_fallbacks = (
|
||||
content_policy_fallbacks or litellm.content_policy_fallbacks
|
||||
)
|
||||
self.validate_fallbacks(fallback_param=_content_policy_fallbacks)
|
||||
self.content_policy_fallbacks = _content_policy_fallbacks
|
||||
self.total_calls: defaultdict = defaultdict(
|
||||
int
|
||||
) # dict to store total calls made to each model
|
||||
|
@ -1998,6 +2005,9 @@ class Router:
|
|||
context_window_fallbacks = kwargs.get(
|
||||
"context_window_fallbacks", self.context_window_fallbacks
|
||||
)
|
||||
content_policy_fallbacks = kwargs.get(
|
||||
"content_policy_fallbacks", self.content_policy_fallbacks
|
||||
)
|
||||
try:
|
||||
if mock_testing_fallbacks is not None and mock_testing_fallbacks == True:
|
||||
raise Exception(
|
||||
|
@ -2016,7 +2026,10 @@ class Router:
|
|||
if (
|
||||
hasattr(e, "status_code")
|
||||
and e.status_code == 400 # type: ignore
|
||||
and not isinstance(e, litellm.ContextWindowExceededError)
|
||||
and not (
|
||||
isinstance(e, litellm.ContextWindowExceededError)
|
||||
or isinstance(e, litellm.ContentPolicyViolationError)
|
||||
)
|
||||
): # don't retry a malformed request
|
||||
raise e
|
||||
if (
|
||||
|
@ -2034,6 +2047,39 @@ class Router:
|
|||
if fallback_model_group is None:
|
||||
raise original_exception
|
||||
|
||||
for mg in fallback_model_group:
|
||||
"""
|
||||
Iterate through the model groups and try calling that deployment
|
||||
"""
|
||||
try:
|
||||
kwargs["model"] = mg
|
||||
kwargs.setdefault("metadata", {}).update(
|
||||
{"model_group": mg}
|
||||
) # update model_group used, if fallbacks are done
|
||||
response = await self.async_function_with_retries(
|
||||
*args, **kwargs
|
||||
)
|
||||
verbose_router_logger.info(
|
||||
"Successful fallback b/w models."
|
||||
)
|
||||
return response
|
||||
except Exception as e:
|
||||
pass
|
||||
elif (
|
||||
isinstance(e, litellm.ContentPolicyViolationError)
|
||||
and content_policy_fallbacks is not None
|
||||
):
|
||||
fallback_model_group = None
|
||||
for (
|
||||
item
|
||||
) in content_policy_fallbacks: # [{"gpt-3.5-turbo": ["gpt-4"]}]
|
||||
if list(item.keys())[0] == model_group:
|
||||
fallback_model_group = item[model_group]
|
||||
break
|
||||
|
||||
if fallback_model_group is None:
|
||||
raise original_exception
|
||||
|
||||
for mg in fallback_model_group:
|
||||
"""
|
||||
Iterate through the model groups and try calling that deployment
|
||||
|
@ -2114,6 +2160,9 @@ class Router:
|
|||
context_window_fallbacks = kwargs.pop(
|
||||
"context_window_fallbacks", self.context_window_fallbacks
|
||||
)
|
||||
content_policy_fallbacks = kwargs.pop(
|
||||
"content_policy_fallbacks", self.content_policy_fallbacks
|
||||
)
|
||||
|
||||
num_retries = kwargs.pop("num_retries")
|
||||
|
||||
|
@ -2141,6 +2190,7 @@ class Router:
|
|||
healthy_deployments=_healthy_deployments,
|
||||
context_window_fallbacks=context_window_fallbacks,
|
||||
regular_fallbacks=fallbacks,
|
||||
content_policy_fallbacks=content_policy_fallbacks,
|
||||
)
|
||||
|
||||
# decides how long to sleep before retry
|
||||
|
@ -2206,10 +2256,12 @@ class Router:
|
|||
error: Exception,
|
||||
healthy_deployments: Optional[List] = None,
|
||||
context_window_fallbacks: Optional[List] = None,
|
||||
content_policy_fallbacks: Optional[List] = None,
|
||||
regular_fallbacks: Optional[List] = None,
|
||||
):
|
||||
"""
|
||||
1. raise an exception for ContextWindowExceededError if context_window_fallbacks is not None
|
||||
2. raise an exception for ContentPolicyViolationError if content_policy_fallbacks is not None
|
||||
|
||||
2. raise an exception for RateLimitError if
|
||||
- there are no fallbacks
|
||||
|
@ -2219,13 +2271,19 @@ class Router:
|
|||
if healthy_deployments is not None and isinstance(healthy_deployments, list):
|
||||
_num_healthy_deployments = len(healthy_deployments)
|
||||
|
||||
### CHECK IF RATE LIMIT / CONTEXT WINDOW ERROR w/ fallbacks available / Bad Request Error
|
||||
### CHECK IF RATE LIMIT / CONTEXT WINDOW ERROR / CONTENT POLICY VIOLATION ERROR w/ fallbacks available / Bad Request Error
|
||||
if (
|
||||
isinstance(error, litellm.ContextWindowExceededError)
|
||||
and context_window_fallbacks is not None
|
||||
):
|
||||
raise error
|
||||
|
||||
if (
|
||||
isinstance(error, litellm.ContentPolicyViolationError)
|
||||
and content_policy_fallbacks is not None
|
||||
):
|
||||
raise error
|
||||
|
||||
# Error we should only retry if there are other deployments
|
||||
if isinstance(error, openai.RateLimitError):
|
||||
if (
|
||||
|
@ -2256,6 +2314,9 @@ class Router:
|
|||
context_window_fallbacks = kwargs.get(
|
||||
"context_window_fallbacks", self.context_window_fallbacks
|
||||
)
|
||||
content_policy_fallbacks = kwargs.get(
|
||||
"content_policy_fallbacks", self.content_policy_fallbacks
|
||||
)
|
||||
try:
|
||||
if mock_testing_fallbacks is not None and mock_testing_fallbacks == True:
|
||||
raise Exception(
|
||||
|
@ -2271,7 +2332,10 @@ class Router:
|
|||
if (
|
||||
hasattr(e, "status_code")
|
||||
and e.status_code == 400 # type: ignore
|
||||
and not isinstance(e, litellm.ContextWindowExceededError)
|
||||
and not (
|
||||
isinstance(e, litellm.ContextWindowExceededError)
|
||||
or isinstance(e, litellm.ContentPolicyViolationError)
|
||||
)
|
||||
): # don't retry a malformed request
|
||||
raise e
|
||||
|
||||
|
@ -2294,6 +2358,37 @@ class Router:
|
|||
if fallback_model_group is None:
|
||||
raise original_exception
|
||||
|
||||
for mg in fallback_model_group:
|
||||
"""
|
||||
Iterate through the model groups and try calling that deployment
|
||||
"""
|
||||
try:
|
||||
## LOGGING
|
||||
kwargs = self.log_retry(kwargs=kwargs, e=original_exception)
|
||||
kwargs["model"] = mg
|
||||
kwargs.setdefault("metadata", {}).update(
|
||||
{"model_group": mg}
|
||||
) # update model_group used, if fallbacks are done
|
||||
response = self.function_with_fallbacks(*args, **kwargs)
|
||||
return response
|
||||
except Exception as e:
|
||||
pass
|
||||
elif (
|
||||
isinstance(e, litellm.ContentPolicyViolationError)
|
||||
and content_policy_fallbacks is not None
|
||||
):
|
||||
fallback_model_group = None
|
||||
|
||||
for (
|
||||
item
|
||||
) in content_policy_fallbacks: # [{"gpt-3.5-turbo": ["gpt-4"]}]
|
||||
if list(item.keys())[0] == model_group:
|
||||
fallback_model_group = item[model_group]
|
||||
break
|
||||
|
||||
if fallback_model_group is None:
|
||||
raise original_exception
|
||||
|
||||
for mg in fallback_model_group:
|
||||
"""
|
||||
Iterate through the model groups and try calling that deployment
|
||||
|
@ -2400,6 +2495,9 @@ class Router:
|
|||
context_window_fallbacks = kwargs.pop(
|
||||
"context_window_fallbacks", self.context_window_fallbacks
|
||||
)
|
||||
content_policy_fallbacks = kwargs.pop(
|
||||
"content_policy_fallbacks", self.content_policy_fallbacks
|
||||
)
|
||||
|
||||
try:
|
||||
# if the function call is successful, no exception will be raised and we'll break out of the loop
|
||||
|
@ -2419,6 +2517,7 @@ class Router:
|
|||
healthy_deployments=_healthy_deployments,
|
||||
context_window_fallbacks=context_window_fallbacks,
|
||||
regular_fallbacks=fallbacks,
|
||||
content_policy_fallbacks=content_policy_fallbacks,
|
||||
)
|
||||
|
||||
# decides how long to sleep before retry
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue