feat(router.py): support content policy fallbacks

Closes https://github.com/BerriAI/litellm/issues/2632
This commit is contained in:
Krrish Dholakia 2024-06-14 17:15:44 -07:00
parent 0404d30a9c
commit 6f715b4782
6 changed files with 197 additions and 33 deletions

View file

@ -240,6 +240,7 @@ num_retries: Optional[int] = None # per model endpoint
default_fallbacks: Optional[List] = None default_fallbacks: Optional[List] = None
fallbacks: Optional[List] = None fallbacks: Optional[List] = None
context_window_fallbacks: Optional[List] = None context_window_fallbacks: Optional[List] = None
content_policy_fallbacks: Optional[List] = None
allowed_fails: int = 0 allowed_fails: int = 0
num_retries_per_request: Optional[int] = ( num_retries_per_request: Optional[int] = (
None # for the request overall (incl. fallbacks + model retries) None # for the request overall (incl. fallbacks + model retries)

View file

@ -324,7 +324,7 @@ class ContextWindowExceededError(BadRequestError): # type: ignore
message, message,
model, model,
llm_provider, llm_provider,
response: httpx.Response, response: Optional[httpx.Response] = None,
litellm_debug_info: Optional[str] = None, litellm_debug_info: Optional[str] = None,
): ):
self.status_code = 400 self.status_code = 400
@ -332,11 +332,13 @@ class ContextWindowExceededError(BadRequestError): # type: ignore
self.model = model self.model = model
self.llm_provider = llm_provider self.llm_provider = llm_provider
self.litellm_debug_info = litellm_debug_info self.litellm_debug_info = litellm_debug_info
request = httpx.Request(method="POST", url="https://api.openai.com/v1")
self.response = response or httpx.Response(status_code=400, request=request)
super().__init__( super().__init__(
message=self.message, message=self.message,
model=self.model, # type: ignore model=self.model, # type: ignore
llm_provider=self.llm_provider, # type: ignore llm_provider=self.llm_provider, # type: ignore
response=response, response=self.response,
litellm_debug_info=self.litellm_debug_info, litellm_debug_info=self.litellm_debug_info,
) # Call the base class constructor with the parameters it needs ) # Call the base class constructor with the parameters it needs
@ -407,7 +409,7 @@ class ContentPolicyViolationError(BadRequestError): # type: ignore
message, message,
model, model,
llm_provider, llm_provider,
response: httpx.Response, response: Optional[httpx.Response] = None,
litellm_debug_info: Optional[str] = None, litellm_debug_info: Optional[str] = None,
): ):
self.status_code = 400 self.status_code = 400
@ -415,11 +417,13 @@ class ContentPolicyViolationError(BadRequestError): # type: ignore
self.model = model self.model = model
self.llm_provider = llm_provider self.llm_provider = llm_provider
self.litellm_debug_info = litellm_debug_info self.litellm_debug_info = litellm_debug_info
request = httpx.Request(method="POST", url="https://api.openai.com/v1")
self.response = response or httpx.Response(status_code=500, request=request)
super().__init__( super().__init__(
message=self.message, message=self.message,
model=self.model, # type: ignore model=self.model, # type: ignore
llm_provider=self.llm_provider, # type: ignore llm_provider=self.llm_provider, # type: ignore
response=response, response=self.response,
litellm_debug_info=self.litellm_debug_info, litellm_debug_info=self.litellm_debug_info,
) # Call the base class constructor with the parameters it needs ) # Call the base class constructor with the parameters it needs

View file

@ -401,6 +401,7 @@ def mock_completion(
stream: Optional[bool] = False, stream: Optional[bool] = False,
mock_response: Union[str, Exception] = "This is a mock request", mock_response: Union[str, Exception] = "This is a mock request",
logging=None, logging=None,
custom_llm_provider=None,
**kwargs, **kwargs,
): ):
""" """
@ -438,7 +439,7 @@ def mock_completion(
raise litellm.APIError( raise litellm.APIError(
status_code=getattr(mock_response, "status_code", 500), # type: ignore status_code=getattr(mock_response, "status_code", 500), # type: ignore
message=getattr(mock_response, "text", str(mock_response)), message=getattr(mock_response, "text", str(mock_response)),
llm_provider=getattr(mock_response, "llm_provider", "openai"), # type: ignore llm_provider=getattr(mock_response, "llm_provider", custom_llm_provider or "openai"), # type: ignore
model=model, # type: ignore model=model, # type: ignore
request=httpx.Request(method="POST", url="https://api.openai.com/v1/"), request=httpx.Request(method="POST", url="https://api.openai.com/v1/"),
) )
@ -907,6 +908,7 @@ def completion(
logging=logging, logging=logging,
acompletion=acompletion, acompletion=acompletion,
mock_delay=kwargs.get("mock_delay", None), mock_delay=kwargs.get("mock_delay", None),
custom_llm_provider=custom_llm_provider,
) )
if custom_llm_provider == "azure": if custom_llm_provider == "azure":
# azure configs # azure configs

View file

@ -108,6 +108,7 @@ class Router:
] = None, # generic fallbacks, works across all deployments ] = None, # generic fallbacks, works across all deployments
fallbacks: List = [], fallbacks: List = [],
context_window_fallbacks: List = [], context_window_fallbacks: List = [],
content_policy_fallbacks: List = [],
model_group_alias: Optional[dict] = {}, model_group_alias: Optional[dict] = {},
enable_pre_call_checks: bool = False, enable_pre_call_checks: bool = False,
retry_after: int = 0, # min time to wait before retrying a failed request retry_after: int = 0, # min time to wait before retrying a failed request
@ -311,6 +312,12 @@ class Router:
self.context_window_fallbacks = ( self.context_window_fallbacks = (
context_window_fallbacks or litellm.context_window_fallbacks context_window_fallbacks or litellm.context_window_fallbacks
) )
_content_policy_fallbacks = (
content_policy_fallbacks or litellm.content_policy_fallbacks
)
self.validate_fallbacks(fallback_param=_content_policy_fallbacks)
self.content_policy_fallbacks = _content_policy_fallbacks
self.total_calls: defaultdict = defaultdict( self.total_calls: defaultdict = defaultdict(
int int
) # dict to store total calls made to each model ) # dict to store total calls made to each model
@ -1998,6 +2005,9 @@ class Router:
context_window_fallbacks = kwargs.get( context_window_fallbacks = kwargs.get(
"context_window_fallbacks", self.context_window_fallbacks "context_window_fallbacks", self.context_window_fallbacks
) )
content_policy_fallbacks = kwargs.get(
"content_policy_fallbacks", self.content_policy_fallbacks
)
try: try:
if mock_testing_fallbacks is not None and mock_testing_fallbacks == True: if mock_testing_fallbacks is not None and mock_testing_fallbacks == True:
raise Exception( raise Exception(
@ -2016,7 +2026,10 @@ class Router:
if ( if (
hasattr(e, "status_code") hasattr(e, "status_code")
and e.status_code == 400 # type: ignore and e.status_code == 400 # type: ignore
and not isinstance(e, litellm.ContextWindowExceededError) and not (
isinstance(e, litellm.ContextWindowExceededError)
or isinstance(e, litellm.ContentPolicyViolationError)
)
): # don't retry a malformed request ): # don't retry a malformed request
raise e raise e
if ( if (
@ -2034,6 +2047,39 @@ class Router:
if fallback_model_group is None: if fallback_model_group is None:
raise original_exception raise original_exception
for mg in fallback_model_group:
"""
Iterate through the model groups and try calling that deployment
"""
try:
kwargs["model"] = mg
kwargs.setdefault("metadata", {}).update(
{"model_group": mg}
) # update model_group used, if fallbacks are done
response = await self.async_function_with_retries(
*args, **kwargs
)
verbose_router_logger.info(
"Successful fallback b/w models."
)
return response
except Exception as e:
pass
elif (
isinstance(e, litellm.ContentPolicyViolationError)
and content_policy_fallbacks is not None
):
fallback_model_group = None
for (
item
) in content_policy_fallbacks: # [{"gpt-3.5-turbo": ["gpt-4"]}]
if list(item.keys())[0] == model_group:
fallback_model_group = item[model_group]
break
if fallback_model_group is None:
raise original_exception
for mg in fallback_model_group: for mg in fallback_model_group:
""" """
Iterate through the model groups and try calling that deployment Iterate through the model groups and try calling that deployment
@ -2114,6 +2160,9 @@ class Router:
context_window_fallbacks = kwargs.pop( context_window_fallbacks = kwargs.pop(
"context_window_fallbacks", self.context_window_fallbacks "context_window_fallbacks", self.context_window_fallbacks
) )
content_policy_fallbacks = kwargs.pop(
"content_policy_fallbacks", self.content_policy_fallbacks
)
num_retries = kwargs.pop("num_retries") num_retries = kwargs.pop("num_retries")
@ -2141,6 +2190,7 @@ class Router:
healthy_deployments=_healthy_deployments, healthy_deployments=_healthy_deployments,
context_window_fallbacks=context_window_fallbacks, context_window_fallbacks=context_window_fallbacks,
regular_fallbacks=fallbacks, regular_fallbacks=fallbacks,
content_policy_fallbacks=content_policy_fallbacks,
) )
# decides how long to sleep before retry # decides how long to sleep before retry
@ -2206,10 +2256,12 @@ class Router:
error: Exception, error: Exception,
healthy_deployments: Optional[List] = None, healthy_deployments: Optional[List] = None,
context_window_fallbacks: Optional[List] = None, context_window_fallbacks: Optional[List] = None,
content_policy_fallbacks: Optional[List] = None,
regular_fallbacks: Optional[List] = None, regular_fallbacks: Optional[List] = None,
): ):
""" """
1. raise an exception for ContextWindowExceededError if context_window_fallbacks is not None 1. raise an exception for ContextWindowExceededError if context_window_fallbacks is not None
2. raise an exception for ContentPolicyViolationError if content_policy_fallbacks is not None
2. raise an exception for RateLimitError if 2. raise an exception for RateLimitError if
- there are no fallbacks - there are no fallbacks
@ -2219,13 +2271,19 @@ class Router:
if healthy_deployments is not None and isinstance(healthy_deployments, list): if healthy_deployments is not None and isinstance(healthy_deployments, list):
_num_healthy_deployments = len(healthy_deployments) _num_healthy_deployments = len(healthy_deployments)
### CHECK IF RATE LIMIT / CONTEXT WINDOW ERROR w/ fallbacks available / Bad Request Error ### CHECK IF RATE LIMIT / CONTEXT WINDOW ERROR / CONTENT POLICY VIOLATION ERROR w/ fallbacks available / Bad Request Error
if ( if (
isinstance(error, litellm.ContextWindowExceededError) isinstance(error, litellm.ContextWindowExceededError)
and context_window_fallbacks is not None and context_window_fallbacks is not None
): ):
raise error raise error
if (
isinstance(error, litellm.ContentPolicyViolationError)
and content_policy_fallbacks is not None
):
raise error
# Error we should only retry if there are other deployments # Error we should only retry if there are other deployments
if isinstance(error, openai.RateLimitError): if isinstance(error, openai.RateLimitError):
if ( if (
@ -2256,6 +2314,9 @@ class Router:
context_window_fallbacks = kwargs.get( context_window_fallbacks = kwargs.get(
"context_window_fallbacks", self.context_window_fallbacks "context_window_fallbacks", self.context_window_fallbacks
) )
content_policy_fallbacks = kwargs.get(
"content_policy_fallbacks", self.content_policy_fallbacks
)
try: try:
if mock_testing_fallbacks is not None and mock_testing_fallbacks == True: if mock_testing_fallbacks is not None and mock_testing_fallbacks == True:
raise Exception( raise Exception(
@ -2271,7 +2332,10 @@ class Router:
if ( if (
hasattr(e, "status_code") hasattr(e, "status_code")
and e.status_code == 400 # type: ignore and e.status_code == 400 # type: ignore
and not isinstance(e, litellm.ContextWindowExceededError) and not (
isinstance(e, litellm.ContextWindowExceededError)
or isinstance(e, litellm.ContentPolicyViolationError)
)
): # don't retry a malformed request ): # don't retry a malformed request
raise e raise e
@ -2294,6 +2358,37 @@ class Router:
if fallback_model_group is None: if fallback_model_group is None:
raise original_exception raise original_exception
for mg in fallback_model_group:
"""
Iterate through the model groups and try calling that deployment
"""
try:
## LOGGING
kwargs = self.log_retry(kwargs=kwargs, e=original_exception)
kwargs["model"] = mg
kwargs.setdefault("metadata", {}).update(
{"model_group": mg}
) # update model_group used, if fallbacks are done
response = self.function_with_fallbacks(*args, **kwargs)
return response
except Exception as e:
pass
elif (
isinstance(e, litellm.ContentPolicyViolationError)
and content_policy_fallbacks is not None
):
fallback_model_group = None
for (
item
) in content_policy_fallbacks: # [{"gpt-3.5-turbo": ["gpt-4"]}]
if list(item.keys())[0] == model_group:
fallback_model_group = item[model_group]
break
if fallback_model_group is None:
raise original_exception
for mg in fallback_model_group: for mg in fallback_model_group:
""" """
Iterate through the model groups and try calling that deployment Iterate through the model groups and try calling that deployment
@ -2400,6 +2495,9 @@ class Router:
context_window_fallbacks = kwargs.pop( context_window_fallbacks = kwargs.pop(
"context_window_fallbacks", self.context_window_fallbacks "context_window_fallbacks", self.context_window_fallbacks
) )
content_policy_fallbacks = kwargs.pop(
"content_policy_fallbacks", self.content_policy_fallbacks
)
try: try:
# if the function call is successful, no exception will be raised and we'll break out of the loop # if the function call is successful, no exception will be raised and we'll break out of the loop
@ -2419,6 +2517,7 @@ class Router:
healthy_deployments=_healthy_deployments, healthy_deployments=_healthy_deployments,
context_window_fallbacks=context_window_fallbacks, context_window_fallbacks=context_window_fallbacks,
regular_fallbacks=fallbacks, regular_fallbacks=fallbacks,
content_policy_fallbacks=content_policy_fallbacks,
) )
# decides how long to sleep before retry # decides how long to sleep before retry

View file

@ -1109,3 +1109,59 @@ async def test_client_side_fallbacks_list(sync_mode):
assert isinstance(response, litellm.ModelResponse) assert isinstance(response, litellm.ModelResponse)
assert response.model is not None and response.model == "gpt-4o" assert response.model is not None and response.model == "gpt-4o"
@pytest.mark.parametrize("sync_mode", [True, False])
@pytest.mark.asyncio
async def test_router_content_policy_fallbacks(sync_mode):
os.environ["LITELLM_LOG"] = "DEBUG"
router = Router(
model_list=[
{
"model_name": "claude-2",
"litellm_params": {
"model": "claude-2",
"api_key": "",
"mock_response": Exception("content filtering policy"),
},
},
{
"model_name": "my-fallback-model",
"litellm_params": {
"model": "claude-2",
"api_key": "",
"mock_response": "This works!",
},
},
{
"model_name": "my-general-model",
"litellm_params": {
"model": "claude-2",
"api_key": "",
"mock_response": Exception("Should not have called this."),
},
},
{
"model_name": "my-context-window-model",
"litellm_params": {
"model": "claude-2",
"api_key": "",
"mock_response": Exception("Should not have called this."),
},
},
],
content_policy_fallbacks=[{"claude-2": ["my-fallback-model"]}],
fallbacks=[{"claude-2": ["my-general-model"]}],
context_window_fallbacks=[{"claude-2": ["my-context-window-model"]}],
)
if sync_mode is True:
response = router.completion(
model="claude-2",
messages=[{"role": "user", "content": "Hey, how's it going?"}],
)
else:
response = await router.acompletion(
model="claude-2",
messages=[{"role": "user", "content": "Hey, how's it going?"}],
)

View file

@ -3436,6 +3436,7 @@ def client(original_function):
isinstance(e, litellm.exceptions.ContextWindowExceededError) isinstance(e, litellm.exceptions.ContextWindowExceededError)
and context_window_fallback_dict and context_window_fallback_dict
and model in context_window_fallback_dict and model in context_window_fallback_dict
and not _is_litellm_router_call
): ):
if len(args) > 0: if len(args) > 0:
args[0] = context_window_fallback_dict[model] args[0] = context_window_fallback_dict[model]
@ -8637,32 +8638,33 @@ def exception_type(
), ),
) )
elif custom_llm_provider == "anthropic": # one of the anthropics elif custom_llm_provider == "anthropic": # one of the anthropics
if hasattr(original_exception, "message"): if "prompt is too long" in error_str or "prompt: length" in error_str:
if ( exception_mapping_worked = True
"prompt is too long" in original_exception.message raise ContextWindowExceededError(
or "prompt: length" in original_exception.message message=error_str,
): model=model,
exception_mapping_worked = True llm_provider="anthropic",
raise ContextWindowExceededError( )
message=original_exception.message, if "Invalid API Key" in error_str:
model=model, exception_mapping_worked = True
llm_provider="anthropic", raise AuthenticationError(
response=original_exception.response, message=error_str,
) model=model,
if "Invalid API Key" in original_exception.message: llm_provider="anthropic",
exception_mapping_worked = True )
raise AuthenticationError( if "content filtering policy" in error_str:
message=original_exception.message, exception_mapping_worked = True
model=model, raise ContentPolicyViolationError(
llm_provider="anthropic", message=error_str,
response=original_exception.response, model=model,
) llm_provider="anthropic",
)
if hasattr(original_exception, "status_code"): if hasattr(original_exception, "status_code"):
print_verbose(f"status_code: {original_exception.status_code}") print_verbose(f"status_code: {original_exception.status_code}")
if original_exception.status_code == 401: if original_exception.status_code == 401:
exception_mapping_worked = True exception_mapping_worked = True
raise AuthenticationError( raise AuthenticationError(
message=f"AnthropicException - {original_exception.message}", message=f"AnthropicException - {error_str}",
llm_provider="anthropic", llm_provider="anthropic",
model=model, model=model,
response=original_exception.response, response=original_exception.response,
@ -8673,7 +8675,7 @@ def exception_type(
): ):
exception_mapping_worked = True exception_mapping_worked = True
raise BadRequestError( raise BadRequestError(
message=f"AnthropicException - {original_exception.message}", message=f"AnthropicException - {error_str}",
model=model, model=model,
llm_provider="anthropic", llm_provider="anthropic",
response=original_exception.response, response=original_exception.response,
@ -8681,14 +8683,14 @@ def exception_type(
elif original_exception.status_code == 408: elif original_exception.status_code == 408:
exception_mapping_worked = True exception_mapping_worked = True
raise Timeout( raise Timeout(
message=f"AnthropicException - {original_exception.message}", message=f"AnthropicException - {error_str}",
model=model, model=model,
llm_provider="anthropic", llm_provider="anthropic",
) )
elif original_exception.status_code == 429: elif original_exception.status_code == 429:
exception_mapping_worked = True exception_mapping_worked = True
raise RateLimitError( raise RateLimitError(
message=f"AnthropicException - {original_exception.message}", message=f"AnthropicException - {error_str}",
llm_provider="anthropic", llm_provider="anthropic",
model=model, model=model,
response=original_exception.response, response=original_exception.response,
@ -8697,7 +8699,7 @@ def exception_type(
exception_mapping_worked = True exception_mapping_worked = True
raise APIError( raise APIError(
status_code=500, status_code=500,
message=f"AnthropicException - {original_exception.message}. Handle with `litellm.APIError`.", message=f"AnthropicException - {error_str}. Handle with `litellm.APIError`.",
llm_provider="anthropic", llm_provider="anthropic",
model=model, model=model,
request=original_exception.request, request=original_exception.request,