fix(router.py): check if azure returns 'content_filter' response + fallback available -> fallback

Exception maps azure content filter response exceptions
This commit is contained in:
Krrish Dholakia 2024-06-22 19:10:15 -07:00 committed by Ishaan Jaff
parent e899359427
commit f479cd549f
7 changed files with 85 additions and 11 deletions

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

View file

@ -572,6 +572,18 @@ class Router:
f"litellm.completion(model={model_name})\033[32m 200 OK\033[0m"
)
## CHECK CONTENT FILTER ERROR ##
if isinstance(response, ModelResponse):
_should_raise = self._should_raise_content_policy_error(
model=model, response=response, kwargs=kwargs
)
if _should_raise:
raise litellm.ContentPolicyViolationError(
message="Response output was blocked.",
model=model,
llm_provider="",
)
return response
except Exception as e:
verbose_router_logger.info(
@ -731,6 +743,18 @@ class Router:
await self.async_routing_strategy_pre_call_checks(deployment=deployment)
response = await _response
## CHECK CONTENT FILTER ERROR ##
if isinstance(response, ModelResponse):
_should_raise = self._should_raise_content_policy_error(
model=model, response=response, kwargs=kwargs
)
if _should_raise:
raise litellm.ContentPolicyViolationError(
message="Response output was blocked.",
model=model,
llm_provider="",
)
self.success_calls[model_name] += 1
verbose_router_logger.info(
f"litellm.acompletion(model={model_name})\033[32m 200 OK\033[0m"
@ -2867,6 +2891,40 @@ class Router:
# Catch all - if any exceptions default to cooling down
return True
def _should_raise_content_policy_error(
self, model: str, response: ModelResponse, kwargs: dict
) -> bool:
"""
Determines if a content policy error should be raised.
Only raised if a fallback is available.
Else, original response is returned.
"""
if response.choices[0].finish_reason != "content_filter":
return False
content_policy_fallbacks = kwargs.get(
"content_policy_fallbacks", self.content_policy_fallbacks
)
### ONLY RAISE ERROR IF CP FALLBACK AVAILABLE ###
if content_policy_fallbacks is not None:
fallback_model_group = None
for item in content_policy_fallbacks: # [{"gpt-3.5-turbo": ["gpt-4"]}]
if list(item.keys())[0] == model:
fallback_model_group = item[model]
break
if fallback_model_group is not None:
return True
verbose_router_logger.info(
"Content Policy Error occurred. No available fallbacks. Returning original response. model={}, content_policy_fallbacks={}".format(
model, content_policy_fallbacks
)
)
return False
def _set_cooldown_deployments(
self,
original_exception: Any,

View file

@ -1,8 +1,12 @@
#### What this tests ####
# This tests calling router with fallback models
import sys, os, time
import traceback, asyncio
import asyncio
import os
import sys
import time
import traceback
import pytest
sys.path.insert(
@ -762,9 +766,11 @@ def test_ausage_based_routing_fallbacks():
# The Request should fail azure/gpt-4-fast. Then fallback -> "azure/gpt-4-basic" -> "openai-gpt-4"
# It should work with "openai-gpt-4"
import os
from dotenv import load_dotenv
import litellm
from litellm import Router
from dotenv import load_dotenv
load_dotenv()
@ -1112,9 +1118,19 @@ async def test_client_side_fallbacks_list(sync_mode):
@pytest.mark.parametrize("sync_mode", [True, False])
@pytest.mark.parametrize("content_filter_response_exception", [True, False])
@pytest.mark.asyncio
async def test_router_content_policy_fallbacks(sync_mode):
async def test_router_content_policy_fallbacks(
sync_mode, content_filter_response_exception
):
os.environ["LITELLM_LOG"] = "DEBUG"
if content_filter_response_exception:
mock_response = Exception("content filtering policy")
else:
mock_response = litellm.ModelResponse(
choices=[litellm.Choices(finish_reason="content_filter")]
)
router = Router(
model_list=[
{
@ -1122,13 +1138,13 @@ async def test_router_content_policy_fallbacks(sync_mode):
"litellm_params": {
"model": "claude-2",
"api_key": "",
"mock_response": Exception("content filtering policy"),
"mock_response": mock_response,
},
},
{
"model_name": "my-fallback-model",
"litellm_params": {
"model": "claude-2",
"model": "openai/my-fake-model",
"api_key": "",
"mock_response": "This works!",
},
@ -1165,3 +1181,5 @@ async def test_router_content_policy_fallbacks(sync_mode):
model="claude-2",
messages=[{"role": "user", "content": "Hey, how's it going?"}],
)
assert response.model == "my-fake-model"

View file

@ -12,6 +12,7 @@ from pydantic import BaseModel, ConfigDict, Field
from .completion import CompletionRequest
from .embedding import EmbeddingRequest
from .utils import ModelResponse
class ModelConfig(BaseModel):
@ -315,7 +316,7 @@ class LiteLLMParamsTypedDict(TypedDict, total=False):
input_cost_per_second: Optional[float]
output_cost_per_second: Optional[float]
## MOCK RESPONSES ##
mock_response: Optional[str]
mock_response: Optional[Union[str, ModelResponse, Exception]]
class DeploymentTypedDict(TypedDict):