mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-26 03:04:13 +00:00
feat(router.py): support content policy fallbacks
Closes https://github.com/BerriAI/litellm/issues/2632
This commit is contained in:
parent
0404d30a9c
commit
6f715b4782
6 changed files with 197 additions and 33 deletions
|
@ -240,6 +240,7 @@ num_retries: Optional[int] = None # per model endpoint
|
|||
default_fallbacks: Optional[List] = None
|
||||
fallbacks: Optional[List] = None
|
||||
context_window_fallbacks: Optional[List] = None
|
||||
content_policy_fallbacks: Optional[List] = None
|
||||
allowed_fails: int = 0
|
||||
num_retries_per_request: Optional[int] = (
|
||||
None # for the request overall (incl. fallbacks + model retries)
|
||||
|
|
|
@ -324,7 +324,7 @@ class ContextWindowExceededError(BadRequestError): # type: ignore
|
|||
message,
|
||||
model,
|
||||
llm_provider,
|
||||
response: httpx.Response,
|
||||
response: Optional[httpx.Response] = None,
|
||||
litellm_debug_info: Optional[str] = None,
|
||||
):
|
||||
self.status_code = 400
|
||||
|
@ -332,11 +332,13 @@ class ContextWindowExceededError(BadRequestError): # type: ignore
|
|||
self.model = model
|
||||
self.llm_provider = llm_provider
|
||||
self.litellm_debug_info = litellm_debug_info
|
||||
request = httpx.Request(method="POST", url="https://api.openai.com/v1")
|
||||
self.response = response or httpx.Response(status_code=400, request=request)
|
||||
super().__init__(
|
||||
message=self.message,
|
||||
model=self.model, # type: ignore
|
||||
llm_provider=self.llm_provider, # type: ignore
|
||||
response=response,
|
||||
response=self.response,
|
||||
litellm_debug_info=self.litellm_debug_info,
|
||||
) # Call the base class constructor with the parameters it needs
|
||||
|
||||
|
@ -407,7 +409,7 @@ class ContentPolicyViolationError(BadRequestError): # type: ignore
|
|||
message,
|
||||
model,
|
||||
llm_provider,
|
||||
response: httpx.Response,
|
||||
response: Optional[httpx.Response] = None,
|
||||
litellm_debug_info: Optional[str] = None,
|
||||
):
|
||||
self.status_code = 400
|
||||
|
@ -415,11 +417,13 @@ class ContentPolicyViolationError(BadRequestError): # type: ignore
|
|||
self.model = model
|
||||
self.llm_provider = llm_provider
|
||||
self.litellm_debug_info = litellm_debug_info
|
||||
request = httpx.Request(method="POST", url="https://api.openai.com/v1")
|
||||
self.response = response or httpx.Response(status_code=500, request=request)
|
||||
super().__init__(
|
||||
message=self.message,
|
||||
model=self.model, # type: ignore
|
||||
llm_provider=self.llm_provider, # type: ignore
|
||||
response=response,
|
||||
response=self.response,
|
||||
litellm_debug_info=self.litellm_debug_info,
|
||||
) # Call the base class constructor with the parameters it needs
|
||||
|
||||
|
|
|
@ -401,6 +401,7 @@ def mock_completion(
|
|||
stream: Optional[bool] = False,
|
||||
mock_response: Union[str, Exception] = "This is a mock request",
|
||||
logging=None,
|
||||
custom_llm_provider=None,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
|
@ -438,7 +439,7 @@ def mock_completion(
|
|||
raise litellm.APIError(
|
||||
status_code=getattr(mock_response, "status_code", 500), # type: ignore
|
||||
message=getattr(mock_response, "text", str(mock_response)),
|
||||
llm_provider=getattr(mock_response, "llm_provider", "openai"), # type: ignore
|
||||
llm_provider=getattr(mock_response, "llm_provider", custom_llm_provider or "openai"), # type: ignore
|
||||
model=model, # type: ignore
|
||||
request=httpx.Request(method="POST", url="https://api.openai.com/v1/"),
|
||||
)
|
||||
|
@ -907,6 +908,7 @@ def completion(
|
|||
logging=logging,
|
||||
acompletion=acompletion,
|
||||
mock_delay=kwargs.get("mock_delay", None),
|
||||
custom_llm_provider=custom_llm_provider,
|
||||
)
|
||||
if custom_llm_provider == "azure":
|
||||
# azure configs
|
||||
|
|
|
@ -108,6 +108,7 @@ class Router:
|
|||
] = None, # generic fallbacks, works across all deployments
|
||||
fallbacks: List = [],
|
||||
context_window_fallbacks: List = [],
|
||||
content_policy_fallbacks: List = [],
|
||||
model_group_alias: Optional[dict] = {},
|
||||
enable_pre_call_checks: bool = False,
|
||||
retry_after: int = 0, # min time to wait before retrying a failed request
|
||||
|
@ -311,6 +312,12 @@ class Router:
|
|||
self.context_window_fallbacks = (
|
||||
context_window_fallbacks or litellm.context_window_fallbacks
|
||||
)
|
||||
|
||||
_content_policy_fallbacks = (
|
||||
content_policy_fallbacks or litellm.content_policy_fallbacks
|
||||
)
|
||||
self.validate_fallbacks(fallback_param=_content_policy_fallbacks)
|
||||
self.content_policy_fallbacks = _content_policy_fallbacks
|
||||
self.total_calls: defaultdict = defaultdict(
|
||||
int
|
||||
) # dict to store total calls made to each model
|
||||
|
@ -1998,6 +2005,9 @@ class Router:
|
|||
context_window_fallbacks = kwargs.get(
|
||||
"context_window_fallbacks", self.context_window_fallbacks
|
||||
)
|
||||
content_policy_fallbacks = kwargs.get(
|
||||
"content_policy_fallbacks", self.content_policy_fallbacks
|
||||
)
|
||||
try:
|
||||
if mock_testing_fallbacks is not None and mock_testing_fallbacks == True:
|
||||
raise Exception(
|
||||
|
@ -2016,7 +2026,10 @@ class Router:
|
|||
if (
|
||||
hasattr(e, "status_code")
|
||||
and e.status_code == 400 # type: ignore
|
||||
and not isinstance(e, litellm.ContextWindowExceededError)
|
||||
and not (
|
||||
isinstance(e, litellm.ContextWindowExceededError)
|
||||
or isinstance(e, litellm.ContentPolicyViolationError)
|
||||
)
|
||||
): # don't retry a malformed request
|
||||
raise e
|
||||
if (
|
||||
|
@ -2034,6 +2047,39 @@ class Router:
|
|||
if fallback_model_group is None:
|
||||
raise original_exception
|
||||
|
||||
for mg in fallback_model_group:
|
||||
"""
|
||||
Iterate through the model groups and try calling that deployment
|
||||
"""
|
||||
try:
|
||||
kwargs["model"] = mg
|
||||
kwargs.setdefault("metadata", {}).update(
|
||||
{"model_group": mg}
|
||||
) # update model_group used, if fallbacks are done
|
||||
response = await self.async_function_with_retries(
|
||||
*args, **kwargs
|
||||
)
|
||||
verbose_router_logger.info(
|
||||
"Successful fallback b/w models."
|
||||
)
|
||||
return response
|
||||
except Exception as e:
|
||||
pass
|
||||
elif (
|
||||
isinstance(e, litellm.ContentPolicyViolationError)
|
||||
and content_policy_fallbacks is not None
|
||||
):
|
||||
fallback_model_group = None
|
||||
for (
|
||||
item
|
||||
) in content_policy_fallbacks: # [{"gpt-3.5-turbo": ["gpt-4"]}]
|
||||
if list(item.keys())[0] == model_group:
|
||||
fallback_model_group = item[model_group]
|
||||
break
|
||||
|
||||
if fallback_model_group is None:
|
||||
raise original_exception
|
||||
|
||||
for mg in fallback_model_group:
|
||||
"""
|
||||
Iterate through the model groups and try calling that deployment
|
||||
|
@ -2114,6 +2160,9 @@ class Router:
|
|||
context_window_fallbacks = kwargs.pop(
|
||||
"context_window_fallbacks", self.context_window_fallbacks
|
||||
)
|
||||
content_policy_fallbacks = kwargs.pop(
|
||||
"content_policy_fallbacks", self.content_policy_fallbacks
|
||||
)
|
||||
|
||||
num_retries = kwargs.pop("num_retries")
|
||||
|
||||
|
@ -2141,6 +2190,7 @@ class Router:
|
|||
healthy_deployments=_healthy_deployments,
|
||||
context_window_fallbacks=context_window_fallbacks,
|
||||
regular_fallbacks=fallbacks,
|
||||
content_policy_fallbacks=content_policy_fallbacks,
|
||||
)
|
||||
|
||||
# decides how long to sleep before retry
|
||||
|
@ -2206,10 +2256,12 @@ class Router:
|
|||
error: Exception,
|
||||
healthy_deployments: Optional[List] = None,
|
||||
context_window_fallbacks: Optional[List] = None,
|
||||
content_policy_fallbacks: Optional[List] = None,
|
||||
regular_fallbacks: Optional[List] = None,
|
||||
):
|
||||
"""
|
||||
1. raise an exception for ContextWindowExceededError if context_window_fallbacks is not None
|
||||
2. raise an exception for ContentPolicyViolationError if content_policy_fallbacks is not None
|
||||
|
||||
2. raise an exception for RateLimitError if
|
||||
- there are no fallbacks
|
||||
|
@ -2219,13 +2271,19 @@ class Router:
|
|||
if healthy_deployments is not None and isinstance(healthy_deployments, list):
|
||||
_num_healthy_deployments = len(healthy_deployments)
|
||||
|
||||
### CHECK IF RATE LIMIT / CONTEXT WINDOW ERROR w/ fallbacks available / Bad Request Error
|
||||
### CHECK IF RATE LIMIT / CONTEXT WINDOW ERROR / CONTENT POLICY VIOLATION ERROR w/ fallbacks available / Bad Request Error
|
||||
if (
|
||||
isinstance(error, litellm.ContextWindowExceededError)
|
||||
and context_window_fallbacks is not None
|
||||
):
|
||||
raise error
|
||||
|
||||
if (
|
||||
isinstance(error, litellm.ContentPolicyViolationError)
|
||||
and content_policy_fallbacks is not None
|
||||
):
|
||||
raise error
|
||||
|
||||
# Error we should only retry if there are other deployments
|
||||
if isinstance(error, openai.RateLimitError):
|
||||
if (
|
||||
|
@ -2256,6 +2314,9 @@ class Router:
|
|||
context_window_fallbacks = kwargs.get(
|
||||
"context_window_fallbacks", self.context_window_fallbacks
|
||||
)
|
||||
content_policy_fallbacks = kwargs.get(
|
||||
"content_policy_fallbacks", self.content_policy_fallbacks
|
||||
)
|
||||
try:
|
||||
if mock_testing_fallbacks is not None and mock_testing_fallbacks == True:
|
||||
raise Exception(
|
||||
|
@ -2271,7 +2332,10 @@ class Router:
|
|||
if (
|
||||
hasattr(e, "status_code")
|
||||
and e.status_code == 400 # type: ignore
|
||||
and not isinstance(e, litellm.ContextWindowExceededError)
|
||||
and not (
|
||||
isinstance(e, litellm.ContextWindowExceededError)
|
||||
or isinstance(e, litellm.ContentPolicyViolationError)
|
||||
)
|
||||
): # don't retry a malformed request
|
||||
raise e
|
||||
|
||||
|
@ -2294,6 +2358,37 @@ class Router:
|
|||
if fallback_model_group is None:
|
||||
raise original_exception
|
||||
|
||||
for mg in fallback_model_group:
|
||||
"""
|
||||
Iterate through the model groups and try calling that deployment
|
||||
"""
|
||||
try:
|
||||
## LOGGING
|
||||
kwargs = self.log_retry(kwargs=kwargs, e=original_exception)
|
||||
kwargs["model"] = mg
|
||||
kwargs.setdefault("metadata", {}).update(
|
||||
{"model_group": mg}
|
||||
) # update model_group used, if fallbacks are done
|
||||
response = self.function_with_fallbacks(*args, **kwargs)
|
||||
return response
|
||||
except Exception as e:
|
||||
pass
|
||||
elif (
|
||||
isinstance(e, litellm.ContentPolicyViolationError)
|
||||
and content_policy_fallbacks is not None
|
||||
):
|
||||
fallback_model_group = None
|
||||
|
||||
for (
|
||||
item
|
||||
) in content_policy_fallbacks: # [{"gpt-3.5-turbo": ["gpt-4"]}]
|
||||
if list(item.keys())[0] == model_group:
|
||||
fallback_model_group = item[model_group]
|
||||
break
|
||||
|
||||
if fallback_model_group is None:
|
||||
raise original_exception
|
||||
|
||||
for mg in fallback_model_group:
|
||||
"""
|
||||
Iterate through the model groups and try calling that deployment
|
||||
|
@ -2400,6 +2495,9 @@ class Router:
|
|||
context_window_fallbacks = kwargs.pop(
|
||||
"context_window_fallbacks", self.context_window_fallbacks
|
||||
)
|
||||
content_policy_fallbacks = kwargs.pop(
|
||||
"content_policy_fallbacks", self.content_policy_fallbacks
|
||||
)
|
||||
|
||||
try:
|
||||
# if the function call is successful, no exception will be raised and we'll break out of the loop
|
||||
|
@ -2419,6 +2517,7 @@ class Router:
|
|||
healthy_deployments=_healthy_deployments,
|
||||
context_window_fallbacks=context_window_fallbacks,
|
||||
regular_fallbacks=fallbacks,
|
||||
content_policy_fallbacks=content_policy_fallbacks,
|
||||
)
|
||||
|
||||
# decides how long to sleep before retry
|
||||
|
|
|
@ -1109,3 +1109,59 @@ async def test_client_side_fallbacks_list(sync_mode):
|
|||
|
||||
assert isinstance(response, litellm.ModelResponse)
|
||||
assert response.model is not None and response.model == "gpt-4o"
|
||||
|
||||
|
||||
@pytest.mark.parametrize("sync_mode", [True, False])
|
||||
@pytest.mark.asyncio
|
||||
async def test_router_content_policy_fallbacks(sync_mode):
|
||||
os.environ["LITELLM_LOG"] = "DEBUG"
|
||||
router = Router(
|
||||
model_list=[
|
||||
{
|
||||
"model_name": "claude-2",
|
||||
"litellm_params": {
|
||||
"model": "claude-2",
|
||||
"api_key": "",
|
||||
"mock_response": Exception("content filtering policy"),
|
||||
},
|
||||
},
|
||||
{
|
||||
"model_name": "my-fallback-model",
|
||||
"litellm_params": {
|
||||
"model": "claude-2",
|
||||
"api_key": "",
|
||||
"mock_response": "This works!",
|
||||
},
|
||||
},
|
||||
{
|
||||
"model_name": "my-general-model",
|
||||
"litellm_params": {
|
||||
"model": "claude-2",
|
||||
"api_key": "",
|
||||
"mock_response": Exception("Should not have called this."),
|
||||
},
|
||||
},
|
||||
{
|
||||
"model_name": "my-context-window-model",
|
||||
"litellm_params": {
|
||||
"model": "claude-2",
|
||||
"api_key": "",
|
||||
"mock_response": Exception("Should not have called this."),
|
||||
},
|
||||
},
|
||||
],
|
||||
content_policy_fallbacks=[{"claude-2": ["my-fallback-model"]}],
|
||||
fallbacks=[{"claude-2": ["my-general-model"]}],
|
||||
context_window_fallbacks=[{"claude-2": ["my-context-window-model"]}],
|
||||
)
|
||||
|
||||
if sync_mode is True:
|
||||
response = router.completion(
|
||||
model="claude-2",
|
||||
messages=[{"role": "user", "content": "Hey, how's it going?"}],
|
||||
)
|
||||
else:
|
||||
response = await router.acompletion(
|
||||
model="claude-2",
|
||||
messages=[{"role": "user", "content": "Hey, how's it going?"}],
|
||||
)
|
||||
|
|
|
@ -3436,6 +3436,7 @@ def client(original_function):
|
|||
isinstance(e, litellm.exceptions.ContextWindowExceededError)
|
||||
and context_window_fallback_dict
|
||||
and model in context_window_fallback_dict
|
||||
and not _is_litellm_router_call
|
||||
):
|
||||
if len(args) > 0:
|
||||
args[0] = context_window_fallback_dict[model]
|
||||
|
@ -8637,32 +8638,33 @@ def exception_type(
|
|||
),
|
||||
)
|
||||
elif custom_llm_provider == "anthropic": # one of the anthropics
|
||||
if hasattr(original_exception, "message"):
|
||||
if (
|
||||
"prompt is too long" in original_exception.message
|
||||
or "prompt: length" in original_exception.message
|
||||
):
|
||||
exception_mapping_worked = True
|
||||
raise ContextWindowExceededError(
|
||||
message=original_exception.message,
|
||||
model=model,
|
||||
llm_provider="anthropic",
|
||||
response=original_exception.response,
|
||||
)
|
||||
if "Invalid API Key" in original_exception.message:
|
||||
exception_mapping_worked = True
|
||||
raise AuthenticationError(
|
||||
message=original_exception.message,
|
||||
model=model,
|
||||
llm_provider="anthropic",
|
||||
response=original_exception.response,
|
||||
)
|
||||
if "prompt is too long" in error_str or "prompt: length" in error_str:
|
||||
exception_mapping_worked = True
|
||||
raise ContextWindowExceededError(
|
||||
message=error_str,
|
||||
model=model,
|
||||
llm_provider="anthropic",
|
||||
)
|
||||
if "Invalid API Key" in error_str:
|
||||
exception_mapping_worked = True
|
||||
raise AuthenticationError(
|
||||
message=error_str,
|
||||
model=model,
|
||||
llm_provider="anthropic",
|
||||
)
|
||||
if "content filtering policy" in error_str:
|
||||
exception_mapping_worked = True
|
||||
raise ContentPolicyViolationError(
|
||||
message=error_str,
|
||||
model=model,
|
||||
llm_provider="anthropic",
|
||||
)
|
||||
if hasattr(original_exception, "status_code"):
|
||||
print_verbose(f"status_code: {original_exception.status_code}")
|
||||
if original_exception.status_code == 401:
|
||||
exception_mapping_worked = True
|
||||
raise AuthenticationError(
|
||||
message=f"AnthropicException - {original_exception.message}",
|
||||
message=f"AnthropicException - {error_str}",
|
||||
llm_provider="anthropic",
|
||||
model=model,
|
||||
response=original_exception.response,
|
||||
|
@ -8673,7 +8675,7 @@ def exception_type(
|
|||
):
|
||||
exception_mapping_worked = True
|
||||
raise BadRequestError(
|
||||
message=f"AnthropicException - {original_exception.message}",
|
||||
message=f"AnthropicException - {error_str}",
|
||||
model=model,
|
||||
llm_provider="anthropic",
|
||||
response=original_exception.response,
|
||||
|
@ -8681,14 +8683,14 @@ def exception_type(
|
|||
elif original_exception.status_code == 408:
|
||||
exception_mapping_worked = True
|
||||
raise Timeout(
|
||||
message=f"AnthropicException - {original_exception.message}",
|
||||
message=f"AnthropicException - {error_str}",
|
||||
model=model,
|
||||
llm_provider="anthropic",
|
||||
)
|
||||
elif original_exception.status_code == 429:
|
||||
exception_mapping_worked = True
|
||||
raise RateLimitError(
|
||||
message=f"AnthropicException - {original_exception.message}",
|
||||
message=f"AnthropicException - {error_str}",
|
||||
llm_provider="anthropic",
|
||||
model=model,
|
||||
response=original_exception.response,
|
||||
|
@ -8697,7 +8699,7 @@ def exception_type(
|
|||
exception_mapping_worked = True
|
||||
raise APIError(
|
||||
status_code=500,
|
||||
message=f"AnthropicException - {original_exception.message}. Handle with `litellm.APIError`.",
|
||||
message=f"AnthropicException - {error_str}. Handle with `litellm.APIError`.",
|
||||
llm_provider="anthropic",
|
||||
model=model,
|
||||
request=original_exception.request,
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue