diff --git a/litellm/__init__.py b/litellm/__init__.py index 15f562d159..6ecf70d0d7 100644 --- a/litellm/__init__.py +++ b/litellm/__init__.py @@ -240,6 +240,7 @@ num_retries: Optional[int] = None # per model endpoint default_fallbacks: Optional[List] = None fallbacks: Optional[List] = None context_window_fallbacks: Optional[List] = None +content_policy_fallbacks: Optional[List] = None allowed_fails: int = 0 num_retries_per_request: Optional[int] = ( None # for the request overall (incl. fallbacks + model retries) diff --git a/litellm/exceptions.py b/litellm/exceptions.py index 8b102d791b..9674d48b12 100644 --- a/litellm/exceptions.py +++ b/litellm/exceptions.py @@ -324,7 +324,7 @@ class ContextWindowExceededError(BadRequestError): # type: ignore message, model, llm_provider, - response: httpx.Response, + response: Optional[httpx.Response] = None, litellm_debug_info: Optional[str] = None, ): self.status_code = 400 @@ -332,11 +332,13 @@ class ContextWindowExceededError(BadRequestError): # type: ignore self.model = model self.llm_provider = llm_provider self.litellm_debug_info = litellm_debug_info + request = httpx.Request(method="POST", url="https://api.openai.com/v1") + self.response = response or httpx.Response(status_code=400, request=request) super().__init__( message=self.message, model=self.model, # type: ignore llm_provider=self.llm_provider, # type: ignore - response=response, + response=self.response, litellm_debug_info=self.litellm_debug_info, ) # Call the base class constructor with the parameters it needs @@ -407,7 +409,7 @@ class ContentPolicyViolationError(BadRequestError): # type: ignore message, model, llm_provider, - response: httpx.Response, + response: Optional[httpx.Response] = None, litellm_debug_info: Optional[str] = None, ): self.status_code = 400 @@ -415,11 +417,13 @@ class ContentPolicyViolationError(BadRequestError): # type: ignore self.model = model self.llm_provider = llm_provider self.litellm_debug_info = litellm_debug_info + request = httpx.Request(method="POST", url="https://api.openai.com/v1") + self.response = response or httpx.Response(status_code=500, request=request) super().__init__( message=self.message, model=self.model, # type: ignore llm_provider=self.llm_provider, # type: ignore - response=response, + response=self.response, litellm_debug_info=self.litellm_debug_info, ) # Call the base class constructor with the parameters it needs diff --git a/litellm/main.py b/litellm/main.py index 9d369a4ce9..77fe38fd2d 100644 --- a/litellm/main.py +++ b/litellm/main.py @@ -401,6 +401,7 @@ def mock_completion( stream: Optional[bool] = False, mock_response: Union[str, Exception] = "This is a mock request", logging=None, + custom_llm_provider=None, **kwargs, ): """ @@ -438,7 +439,7 @@ def mock_completion( raise litellm.APIError( status_code=getattr(mock_response, "status_code", 500), # type: ignore message=getattr(mock_response, "text", str(mock_response)), - llm_provider=getattr(mock_response, "llm_provider", "openai"), # type: ignore + llm_provider=getattr(mock_response, "llm_provider", custom_llm_provider or "openai"), # type: ignore model=model, # type: ignore request=httpx.Request(method="POST", url="https://api.openai.com/v1/"), ) @@ -907,6 +908,7 @@ def completion( logging=logging, acompletion=acompletion, mock_delay=kwargs.get("mock_delay", None), + custom_llm_provider=custom_llm_provider, ) if custom_llm_provider == "azure": # azure configs diff --git a/litellm/router.py b/litellm/router.py index 4d7a36a386..b8844fd336 100644 --- a/litellm/router.py +++ b/litellm/router.py @@ -108,6 +108,7 @@ class Router: ] = None, # generic fallbacks, works across all deployments fallbacks: List = [], context_window_fallbacks: List = [], + content_policy_fallbacks: List = [], model_group_alias: Optional[dict] = {}, enable_pre_call_checks: bool = False, retry_after: int = 0, # min time to wait before retrying a failed request @@ -311,6 +312,12 @@ class Router: self.context_window_fallbacks = ( context_window_fallbacks or litellm.context_window_fallbacks ) + + _content_policy_fallbacks = ( + content_policy_fallbacks or litellm.content_policy_fallbacks + ) + self.validate_fallbacks(fallback_param=_content_policy_fallbacks) + self.content_policy_fallbacks = _content_policy_fallbacks self.total_calls: defaultdict = defaultdict( int ) # dict to store total calls made to each model @@ -1998,6 +2005,9 @@ class Router: context_window_fallbacks = kwargs.get( "context_window_fallbacks", self.context_window_fallbacks ) + content_policy_fallbacks = kwargs.get( + "content_policy_fallbacks", self.content_policy_fallbacks + ) try: if mock_testing_fallbacks is not None and mock_testing_fallbacks == True: raise Exception( @@ -2016,7 +2026,10 @@ class Router: if ( hasattr(e, "status_code") and e.status_code == 400 # type: ignore - and not isinstance(e, litellm.ContextWindowExceededError) + and not ( + isinstance(e, litellm.ContextWindowExceededError) + or isinstance(e, litellm.ContentPolicyViolationError) + ) ): # don't retry a malformed request raise e if ( @@ -2034,6 +2047,39 @@ class Router: if fallback_model_group is None: raise original_exception + for mg in fallback_model_group: + """ + Iterate through the model groups and try calling that deployment + """ + try: + kwargs["model"] = mg + kwargs.setdefault("metadata", {}).update( + {"model_group": mg} + ) # update model_group used, if fallbacks are done + response = await self.async_function_with_retries( + *args, **kwargs + ) + verbose_router_logger.info( + "Successful fallback b/w models." + ) + return response + except Exception as e: + pass + elif ( + isinstance(e, litellm.ContentPolicyViolationError) + and content_policy_fallbacks is not None + ): + fallback_model_group = None + for ( + item + ) in content_policy_fallbacks: # [{"gpt-3.5-turbo": ["gpt-4"]}] + if list(item.keys())[0] == model_group: + fallback_model_group = item[model_group] + break + + if fallback_model_group is None: + raise original_exception + for mg in fallback_model_group: """ Iterate through the model groups and try calling that deployment @@ -2114,6 +2160,9 @@ class Router: context_window_fallbacks = kwargs.pop( "context_window_fallbacks", self.context_window_fallbacks ) + content_policy_fallbacks = kwargs.pop( + "content_policy_fallbacks", self.content_policy_fallbacks + ) num_retries = kwargs.pop("num_retries") @@ -2141,6 +2190,7 @@ class Router: healthy_deployments=_healthy_deployments, context_window_fallbacks=context_window_fallbacks, regular_fallbacks=fallbacks, + content_policy_fallbacks=content_policy_fallbacks, ) # decides how long to sleep before retry @@ -2206,10 +2256,12 @@ class Router: error: Exception, healthy_deployments: Optional[List] = None, context_window_fallbacks: Optional[List] = None, + content_policy_fallbacks: Optional[List] = None, regular_fallbacks: Optional[List] = None, ): """ 1. raise an exception for ContextWindowExceededError if context_window_fallbacks is not None + 2. raise an exception for ContentPolicyViolationError if content_policy_fallbacks is not None 2. raise an exception for RateLimitError if - there are no fallbacks @@ -2219,13 +2271,19 @@ class Router: if healthy_deployments is not None and isinstance(healthy_deployments, list): _num_healthy_deployments = len(healthy_deployments) - ### CHECK IF RATE LIMIT / CONTEXT WINDOW ERROR w/ fallbacks available / Bad Request Error + ### CHECK IF RATE LIMIT / CONTEXT WINDOW ERROR / CONTENT POLICY VIOLATION ERROR w/ fallbacks available / Bad Request Error if ( isinstance(error, litellm.ContextWindowExceededError) and context_window_fallbacks is not None ): raise error + if ( + isinstance(error, litellm.ContentPolicyViolationError) + and content_policy_fallbacks is not None + ): + raise error + # Error we should only retry if there are other deployments if isinstance(error, openai.RateLimitError): if ( @@ -2256,6 +2314,9 @@ class Router: context_window_fallbacks = kwargs.get( "context_window_fallbacks", self.context_window_fallbacks ) + content_policy_fallbacks = kwargs.get( + "content_policy_fallbacks", self.content_policy_fallbacks + ) try: if mock_testing_fallbacks is not None and mock_testing_fallbacks == True: raise Exception( @@ -2271,7 +2332,10 @@ class Router: if ( hasattr(e, "status_code") and e.status_code == 400 # type: ignore - and not isinstance(e, litellm.ContextWindowExceededError) + and not ( + isinstance(e, litellm.ContextWindowExceededError) + or isinstance(e, litellm.ContentPolicyViolationError) + ) ): # don't retry a malformed request raise e @@ -2294,6 +2358,37 @@ class Router: if fallback_model_group is None: raise original_exception + for mg in fallback_model_group: + """ + Iterate through the model groups and try calling that deployment + """ + try: + ## LOGGING + kwargs = self.log_retry(kwargs=kwargs, e=original_exception) + kwargs["model"] = mg + kwargs.setdefault("metadata", {}).update( + {"model_group": mg} + ) # update model_group used, if fallbacks are done + response = self.function_with_fallbacks(*args, **kwargs) + return response + except Exception as e: + pass + elif ( + isinstance(e, litellm.ContentPolicyViolationError) + and content_policy_fallbacks is not None + ): + fallback_model_group = None + + for ( + item + ) in content_policy_fallbacks: # [{"gpt-3.5-turbo": ["gpt-4"]}] + if list(item.keys())[0] == model_group: + fallback_model_group = item[model_group] + break + + if fallback_model_group is None: + raise original_exception + for mg in fallback_model_group: """ Iterate through the model groups and try calling that deployment @@ -2400,6 +2495,9 @@ class Router: context_window_fallbacks = kwargs.pop( "context_window_fallbacks", self.context_window_fallbacks ) + content_policy_fallbacks = kwargs.pop( + "content_policy_fallbacks", self.content_policy_fallbacks + ) try: # if the function call is successful, no exception will be raised and we'll break out of the loop @@ -2419,6 +2517,7 @@ class Router: healthy_deployments=_healthy_deployments, context_window_fallbacks=context_window_fallbacks, regular_fallbacks=fallbacks, + content_policy_fallbacks=content_policy_fallbacks, ) # decides how long to sleep before retry diff --git a/litellm/tests/test_router_fallbacks.py b/litellm/tests/test_router_fallbacks.py index c6e0e54111..545eb23db3 100644 --- a/litellm/tests/test_router_fallbacks.py +++ b/litellm/tests/test_router_fallbacks.py @@ -1109,3 +1109,59 @@ async def test_client_side_fallbacks_list(sync_mode): assert isinstance(response, litellm.ModelResponse) assert response.model is not None and response.model == "gpt-4o" + + +@pytest.mark.parametrize("sync_mode", [True, False]) +@pytest.mark.asyncio +async def test_router_content_policy_fallbacks(sync_mode): + os.environ["LITELLM_LOG"] = "DEBUG" + router = Router( + model_list=[ + { + "model_name": "claude-2", + "litellm_params": { + "model": "claude-2", + "api_key": "", + "mock_response": Exception("content filtering policy"), + }, + }, + { + "model_name": "my-fallback-model", + "litellm_params": { + "model": "claude-2", + "api_key": "", + "mock_response": "This works!", + }, + }, + { + "model_name": "my-general-model", + "litellm_params": { + "model": "claude-2", + "api_key": "", + "mock_response": Exception("Should not have called this."), + }, + }, + { + "model_name": "my-context-window-model", + "litellm_params": { + "model": "claude-2", + "api_key": "", + "mock_response": Exception("Should not have called this."), + }, + }, + ], + content_policy_fallbacks=[{"claude-2": ["my-fallback-model"]}], + fallbacks=[{"claude-2": ["my-general-model"]}], + context_window_fallbacks=[{"claude-2": ["my-context-window-model"]}], + ) + + if sync_mode is True: + response = router.completion( + model="claude-2", + messages=[{"role": "user", "content": "Hey, how's it going?"}], + ) + else: + response = await router.acompletion( + model="claude-2", + messages=[{"role": "user", "content": "Hey, how's it going?"}], + ) diff --git a/litellm/utils.py b/litellm/utils.py index 830c45610b..7f37bcf7c5 100644 --- a/litellm/utils.py +++ b/litellm/utils.py @@ -3436,6 +3436,7 @@ def client(original_function): isinstance(e, litellm.exceptions.ContextWindowExceededError) and context_window_fallback_dict and model in context_window_fallback_dict + and not _is_litellm_router_call ): if len(args) > 0: args[0] = context_window_fallback_dict[model] @@ -8637,32 +8638,33 @@ def exception_type( ), ) elif custom_llm_provider == "anthropic": # one of the anthropics - if hasattr(original_exception, "message"): - if ( - "prompt is too long" in original_exception.message - or "prompt: length" in original_exception.message - ): - exception_mapping_worked = True - raise ContextWindowExceededError( - message=original_exception.message, - model=model, - llm_provider="anthropic", - response=original_exception.response, - ) - if "Invalid API Key" in original_exception.message: - exception_mapping_worked = True - raise AuthenticationError( - message=original_exception.message, - model=model, - llm_provider="anthropic", - response=original_exception.response, - ) + if "prompt is too long" in error_str or "prompt: length" in error_str: + exception_mapping_worked = True + raise ContextWindowExceededError( + message=error_str, + model=model, + llm_provider="anthropic", + ) + if "Invalid API Key" in error_str: + exception_mapping_worked = True + raise AuthenticationError( + message=error_str, + model=model, + llm_provider="anthropic", + ) + if "content filtering policy" in error_str: + exception_mapping_worked = True + raise ContentPolicyViolationError( + message=error_str, + model=model, + llm_provider="anthropic", + ) if hasattr(original_exception, "status_code"): print_verbose(f"status_code: {original_exception.status_code}") if original_exception.status_code == 401: exception_mapping_worked = True raise AuthenticationError( - message=f"AnthropicException - {original_exception.message}", + message=f"AnthropicException - {error_str}", llm_provider="anthropic", model=model, response=original_exception.response, @@ -8673,7 +8675,7 @@ def exception_type( ): exception_mapping_worked = True raise BadRequestError( - message=f"AnthropicException - {original_exception.message}", + message=f"AnthropicException - {error_str}", model=model, llm_provider="anthropic", response=original_exception.response, @@ -8681,14 +8683,14 @@ def exception_type( elif original_exception.status_code == 408: exception_mapping_worked = True raise Timeout( - message=f"AnthropicException - {original_exception.message}", + message=f"AnthropicException - {error_str}", model=model, llm_provider="anthropic", ) elif original_exception.status_code == 429: exception_mapping_worked = True raise RateLimitError( - message=f"AnthropicException - {original_exception.message}", + message=f"AnthropicException - {error_str}", llm_provider="anthropic", model=model, response=original_exception.response, @@ -8697,7 +8699,7 @@ def exception_type( exception_mapping_worked = True raise APIError( status_code=500, - message=f"AnthropicException - {original_exception.message}. Handle with `litellm.APIError`.", + message=f"AnthropicException - {error_str}. Handle with `litellm.APIError`.", llm_provider="anthropic", model=model, request=original_exception.request,