Merge pull request #4207 from BerriAI/litellm_content_policy_fallbacks

feat(router.py): support content policy fallbacks
This commit is contained in:
Krish Dholakia 2024-06-14 18:55:11 -07:00 committed by GitHub
commit 28a52fe5fb
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
6 changed files with 197 additions and 33 deletions

View file

@ -109,6 +109,7 @@ class Router:
] = None, # generic fallbacks, works across all deployments
fallbacks: List = [],
context_window_fallbacks: List = [],
content_policy_fallbacks: List = [],
model_group_alias: Optional[dict] = {},
enable_pre_call_checks: bool = False,
retry_after: int = 0, # min time to wait before retrying a failed request
@ -312,6 +313,12 @@ class Router:
self.context_window_fallbacks = (
context_window_fallbacks or litellm.context_window_fallbacks
)
_content_policy_fallbacks = (
content_policy_fallbacks or litellm.content_policy_fallbacks
)
self.validate_fallbacks(fallback_param=_content_policy_fallbacks)
self.content_policy_fallbacks = _content_policy_fallbacks
self.total_calls: defaultdict = defaultdict(
int
) # dict to store total calls made to each model
@ -2055,6 +2062,9 @@ class Router:
context_window_fallbacks = kwargs.get(
"context_window_fallbacks", self.context_window_fallbacks
)
content_policy_fallbacks = kwargs.get(
"content_policy_fallbacks", self.content_policy_fallbacks
)
try:
if mock_testing_fallbacks is not None and mock_testing_fallbacks == True:
raise Exception(
@ -2073,7 +2083,10 @@ class Router:
if (
hasattr(e, "status_code")
and e.status_code == 400 # type: ignore
and not isinstance(e, litellm.ContextWindowExceededError)
and not (
isinstance(e, litellm.ContextWindowExceededError)
or isinstance(e, litellm.ContentPolicyViolationError)
)
): # don't retry a malformed request
raise e
if (
@ -2091,6 +2104,39 @@ class Router:
if fallback_model_group is None:
raise original_exception
for mg in fallback_model_group:
"""
Iterate through the model groups and try calling that deployment
"""
try:
kwargs["model"] = mg
kwargs.setdefault("metadata", {}).update(
{"model_group": mg}
) # update model_group used, if fallbacks are done
response = await self.async_function_with_retries(
*args, **kwargs
)
verbose_router_logger.info(
"Successful fallback b/w models."
)
return response
except Exception as e:
pass
elif (
isinstance(e, litellm.ContentPolicyViolationError)
and content_policy_fallbacks is not None
):
fallback_model_group = None
for (
item
) in content_policy_fallbacks: # [{"gpt-3.5-turbo": ["gpt-4"]}]
if list(item.keys())[0] == model_group:
fallback_model_group = item[model_group]
break
if fallback_model_group is None:
raise original_exception
for mg in fallback_model_group:
"""
Iterate through the model groups and try calling that deployment
@ -2171,6 +2217,9 @@ class Router:
context_window_fallbacks = kwargs.pop(
"context_window_fallbacks", self.context_window_fallbacks
)
content_policy_fallbacks = kwargs.pop(
"content_policy_fallbacks", self.content_policy_fallbacks
)
num_retries = kwargs.pop("num_retries")
@ -2198,6 +2247,7 @@ class Router:
healthy_deployments=_healthy_deployments,
context_window_fallbacks=context_window_fallbacks,
regular_fallbacks=fallbacks,
content_policy_fallbacks=content_policy_fallbacks,
)
# decides how long to sleep before retry
@ -2263,10 +2313,12 @@ class Router:
error: Exception,
healthy_deployments: Optional[List] = None,
context_window_fallbacks: Optional[List] = None,
content_policy_fallbacks: Optional[List] = None,
regular_fallbacks: Optional[List] = None,
):
"""
1. raise an exception for ContextWindowExceededError if context_window_fallbacks is not None
2. raise an exception for ContentPolicyViolationError if content_policy_fallbacks is not None
2. raise an exception for RateLimitError if
- there are no fallbacks
@ -2276,13 +2328,19 @@ class Router:
if healthy_deployments is not None and isinstance(healthy_deployments, list):
_num_healthy_deployments = len(healthy_deployments)
### CHECK IF RATE LIMIT / CONTEXT WINDOW ERROR w/ fallbacks available / Bad Request Error
### CHECK IF RATE LIMIT / CONTEXT WINDOW ERROR / CONTENT POLICY VIOLATION ERROR w/ fallbacks available / Bad Request Error
if (
isinstance(error, litellm.ContextWindowExceededError)
and context_window_fallbacks is not None
):
raise error
if (
isinstance(error, litellm.ContentPolicyViolationError)
and content_policy_fallbacks is not None
):
raise error
# Error we should only retry if there are other deployments
if isinstance(error, openai.RateLimitError):
if (
@ -2313,6 +2371,9 @@ class Router:
context_window_fallbacks = kwargs.get(
"context_window_fallbacks", self.context_window_fallbacks
)
content_policy_fallbacks = kwargs.get(
"content_policy_fallbacks", self.content_policy_fallbacks
)
try:
if mock_testing_fallbacks is not None and mock_testing_fallbacks == True:
raise Exception(
@ -2328,7 +2389,10 @@ class Router:
if (
hasattr(e, "status_code")
and e.status_code == 400 # type: ignore
and not isinstance(e, litellm.ContextWindowExceededError)
and not (
isinstance(e, litellm.ContextWindowExceededError)
or isinstance(e, litellm.ContentPolicyViolationError)
)
): # don't retry a malformed request
raise e
@ -2351,6 +2415,37 @@ class Router:
if fallback_model_group is None:
raise original_exception
for mg in fallback_model_group:
"""
Iterate through the model groups and try calling that deployment
"""
try:
## LOGGING
kwargs = self.log_retry(kwargs=kwargs, e=original_exception)
kwargs["model"] = mg
kwargs.setdefault("metadata", {}).update(
{"model_group": mg}
) # update model_group used, if fallbacks are done
response = self.function_with_fallbacks(*args, **kwargs)
return response
except Exception as e:
pass
elif (
isinstance(e, litellm.ContentPolicyViolationError)
and content_policy_fallbacks is not None
):
fallback_model_group = None
for (
item
) in content_policy_fallbacks: # [{"gpt-3.5-turbo": ["gpt-4"]}]
if list(item.keys())[0] == model_group:
fallback_model_group = item[model_group]
break
if fallback_model_group is None:
raise original_exception
for mg in fallback_model_group:
"""
Iterate through the model groups and try calling that deployment
@ -2457,6 +2552,9 @@ class Router:
context_window_fallbacks = kwargs.pop(
"context_window_fallbacks", self.context_window_fallbacks
)
content_policy_fallbacks = kwargs.pop(
"content_policy_fallbacks", self.content_policy_fallbacks
)
try:
# if the function call is successful, no exception will be raised and we'll break out of the loop
@ -2476,6 +2574,7 @@ class Router:
healthy_deployments=_healthy_deployments,
context_window_fallbacks=context_window_fallbacks,
regular_fallbacks=fallbacks,
content_policy_fallbacks=content_policy_fallbacks,
)
# decides how long to sleep before retry