fix(router.py): fallback on 400-status code requests

This commit is contained in:
Krrish Dholakia 2024-08-09 12:15:45 -07:00
parent f8b0118ecd
commit 7b6db63d30
5 changed files with 31 additions and 30 deletions

View file

@ -1,8 +1,24 @@
model_list: general_settings:
- model_name: "*" store_model_in_db: true
litellm_params: database_connection_pool_limit: 20
model: "*"
model_list:
- model_name: fake-openai-endpoint
litellm_params:
model: openai/my-fake-model
api_key: my-fake-key
api_base: https://exampleopenaiendpoint-production.up.railway.app/
litellm_settings: litellm_settings:
max_internal_user_budget: 0.001 drop_params: True
internal_user_budget_duration: "5m" success_callback: ["prometheus"]
failure_callback: ["prometheus"]
service_callback: ["prometheus_system"]
_langfuse_default_tags: ["user_api_key_alias", "user_api_key_user_id", "user_api_key_user_email", "user_api_key_team_alias", "semantic-similarity", "proxy_base_url"]
router_settings:
routing_strategy: "latency-based-routing"
routing_strategy_args: {"ttl": 86400} # Average the last 10 calls to compute avg latency per model
allowed_fails: 1
num_retries: 3
retry_after: 5 # seconds to wait before retrying a failed request
cooldown_time: 30 # seconds to cooldown a deployment after failure

View file

@ -420,6 +420,7 @@ async def update_team(
@management_endpoint_wrapper @management_endpoint_wrapper
async def team_member_add( async def team_member_add(
data: TeamMemberAddRequest, data: TeamMemberAddRequest,
http_request: Request,
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth), user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
): ):
""" """

View file

@ -2364,18 +2364,6 @@ class Router:
fallback_failure_exception_str = "" fallback_failure_exception_str = ""
try: try:
verbose_router_logger.debug("Trying to fallback b/w models") verbose_router_logger.debug("Trying to fallback b/w models")
if (
hasattr(e, "status_code")
and e.status_code == 400 # type: ignore
and not (
isinstance(e, litellm.ContextWindowExceededError)
or isinstance(e, litellm.ContentPolicyViolationError)
)
): # don't retry a malformed request
verbose_router_logger.debug(
"Not retrying request as it's malformed. Status code=400."
)
raise e
if isinstance(e, litellm.ContextWindowExceededError): if isinstance(e, litellm.ContextWindowExceededError):
if context_window_fallbacks is not None: if context_window_fallbacks is not None:
fallback_model_group = None fallback_model_group = None
@ -2730,16 +2718,6 @@ class Router:
original_exception = e original_exception = e
verbose_router_logger.debug(f"An exception occurs {original_exception}") verbose_router_logger.debug(f"An exception occurs {original_exception}")
try: try:
if (
hasattr(e, "status_code")
and e.status_code == 400 # type: ignore
and not (
isinstance(e, litellm.ContextWindowExceededError)
or isinstance(e, litellm.ContentPolicyViolationError)
)
): # don't retry a malformed request
raise e
verbose_router_logger.debug( verbose_router_logger.debug(
f"Trying to fallback b/w models. Initial model group: {model_group}" f"Trying to fallback b/w models. Initial model group: {model_group}"
) )

View file

@ -865,6 +865,8 @@ async def test_create_user_default_budget(prisma_client, user_role):
async def test_create_team_member_add(prisma_client, new_member_method): async def test_create_team_member_add(prisma_client, new_member_method):
import time import time
from fastapi import Request
from litellm.proxy._types import LiteLLM_TeamTableCachedObj from litellm.proxy._types import LiteLLM_TeamTableCachedObj
from litellm.proxy.proxy_server import hash_token, user_api_key_cache from litellm.proxy.proxy_server import hash_token, user_api_key_cache
@ -906,7 +908,11 @@ async def test_create_team_member_add(prisma_client, new_member_method):
mock_litellm_usertable.find_many = AsyncMock(return_value=None) mock_litellm_usertable.find_many = AsyncMock(return_value=None)
await team_member_add( await team_member_add(
data=team_member_add_request, user_api_key_dict=UserAPIKeyAuth() data=team_member_add_request,
user_api_key_dict=UserAPIKeyAuth(),
http_request=Request(
scope={"type": "http", "path": "/user/new"},
),
) )
mock_client.assert_called() mock_client.assert_called()

View file

@ -143,7 +143,7 @@ class GenericLiteLLMParams(BaseModel):
## VERTEX AI ## ## VERTEX AI ##
vertex_project: Optional[str] = None vertex_project: Optional[str] = None
vertex_location: Optional[str] = None vertex_location: Optional[str] = None
vertex_credentials: Optional[str] = None vertex_credentials: Optional[Union[str, dict]] = None
## AWS BEDROCK / SAGEMAKER ## ## AWS BEDROCK / SAGEMAKER ##
aws_access_key_id: Optional[str] = None aws_access_key_id: Optional[str] = None
aws_secret_access_key: Optional[str] = None aws_secret_access_key: Optional[str] = None