diff --git a/litellm/router.py b/litellm/router.py index e524937ae..fe8709294 100644 --- a/litellm/router.py +++ b/litellm/router.py @@ -9,7 +9,7 @@ import copy, httpx from datetime import datetime -from typing import Dict, List, Optional, Union, Literal, Any, BinaryIO, Tuple +from typing import Dict, List, Optional, Union, Literal, Any, BinaryIO, Tuple, TypedDict from typing_extensions import overload import random, threading, time, traceback, uuid import litellm, openai, hashlib, json @@ -47,6 +47,7 @@ from litellm.types.router import ( updateLiteLLMParams, RetryPolicy, AlertingConfig, + DeploymentTypedDict, ) from litellm.integrations.custom_logger import CustomLogger from litellm.llms.azure import get_azure_ad_token_from_oidc @@ -62,7 +63,7 @@ class Router: def __init__( self, - model_list: Optional[list] = None, + model_list: Optional[List[DeploymentTypedDict]] = None, ## CACHING ## redis_url: Optional[str] = None, redis_host: Optional[str] = None, @@ -83,6 +84,9 @@ class Router: default_max_parallel_requests: Optional[int] = None, set_verbose: bool = False, debug_level: Literal["DEBUG", "INFO"] = "INFO", + default_fallbacks: Optional[ + List[str] + ] = None, # generic fallbacks, works across all deployments fallbacks: List = [], context_window_fallbacks: List = [], model_group_alias: Optional[dict] = {}, @@ -259,6 +263,11 @@ class Router: self.retry_after = retry_after self.routing_strategy = routing_strategy self.fallbacks = fallbacks or litellm.fallbacks + if default_fallbacks is not None: + if self.fallbacks is not None: + self.fallbacks.append({"*": default_fallbacks}) + else: + self.fallbacks = [{"*": default_fallbacks}] self.context_window_fallbacks = ( context_window_fallbacks or litellm.context_window_fallbacks ) @@ -1471,13 +1480,21 @@ class Router: pass elif fallbacks is not None: verbose_router_logger.debug(f"inside model fallbacks: {fallbacks}") - for item in fallbacks: - key_list = list(item.keys()) - if len(key_list) == 0: - continue - if key_list[0] == model_group: + generic_fallback_idx: Optional[int] = None + ## check for specific model group-specific fallbacks + for idx, item in enumerate(fallbacks): + if list(item.keys())[0] == model_group: fallback_model_group = item[model_group] break + elif list(item.keys())[0] == "*": + generic_fallback_idx = idx + ## if none, check for generic fallback + if ( + fallback_model_group is None + and generic_fallback_idx is not None + ): + fallback_model_group = fallbacks[generic_fallback_idx]["*"] + if fallback_model_group is None: verbose_router_logger.info( f"No fallback model group found for original model_group={model_group}. Fallbacks={fallbacks}" @@ -1537,7 +1554,7 @@ class Router: """ _healthy_deployments = await self._async_get_healthy_deployments( - model=kwargs.get("model"), + model=kwargs.get("model") or "", ) # raises an exception if this error should not be retries @@ -1644,12 +1661,18 @@ class Router: Try calling the function_with_retries If it fails after num_retries, fall back to another model group """ + mock_testing_fallbacks = kwargs.pop("mock_testing_fallbacks", None) model_group = kwargs.get("model") fallbacks = kwargs.get("fallbacks", self.fallbacks) context_window_fallbacks = kwargs.get( "context_window_fallbacks", self.context_window_fallbacks ) try: + if mock_testing_fallbacks is not None and mock_testing_fallbacks == True: + raise Exception( + f"This is a mock exception for model={model_group}, to trigger a fallback. Fallbacks={fallbacks}" + ) + response = self.function_with_retries(*args, **kwargs) return response except Exception as e: @@ -1658,7 +1681,7 @@ class Router: try: if ( hasattr(e, "status_code") - and e.status_code == 400 + and e.status_code == 400 # type: ignore and not isinstance(e, litellm.ContextWindowExceededError) ): # don't retry a malformed request raise e @@ -1700,10 +1723,20 @@ class Router: elif fallbacks is not None: verbose_router_logger.debug(f"inside model fallbacks: {fallbacks}") fallback_model_group = None - for item in fallbacks: + generic_fallback_idx: Optional[int] = None + ## check for specific model group-specific fallbacks + for idx, item in enumerate(fallbacks): if list(item.keys())[0] == model_group: fallback_model_group = item[model_group] break + elif list(item.keys())[0] == "*": + generic_fallback_idx = idx + ## if none, check for generic fallback + if ( + fallback_model_group is None + and generic_fallback_idx is not None + ): + fallback_model_group = fallbacks[generic_fallback_idx]["*"] if fallback_model_group is None: raise original_exception diff --git a/litellm/tests/test_router_fallbacks.py b/litellm/tests/test_router_fallbacks.py index ce2b014e9..4ab97b274 100644 --- a/litellm/tests/test_router_fallbacks.py +++ b/litellm/tests/test_router_fallbacks.py @@ -1007,3 +1007,50 @@ async def test_service_unavailable_fallbacks(sync_mode): ) assert response.model == "gpt-35-turbo" + + +@pytest.mark.parametrize("sync_mode", [True, False]) +@pytest.mark.asyncio +async def test_default_model_fallbacks(sync_mode): + """ + Related issue - https://github.com/BerriAI/litellm/issues/3623 + + If model misconfigured, setup a default model for generic fallback + """ + router = Router( + model_list=[ + { + "model_name": "bad-model", + "litellm_params": { + "model": "openai/my-bad-model", + "api_key": "my-bad-api-key", + }, + }, + { + "model_name": "my-good-model", + "litellm_params": { + "model": "gpt-4o", + "api_key": os.getenv("OPENAI_API_KEY"), + }, + }, + ], + default_fallbacks=["my-good-model"], + ) + + if sync_mode: + response = router.completion( + model="bad-model", + messages=[{"role": "user", "content": "Hey, how's it going?"}], + mock_testing_fallbacks=True, + mock_response="Hey! nice day", + ) + else: + response = await router.acompletion( + model="bad-model", + messages=[{"role": "user", "content": "Hey, how's it going?"}], + mock_testing_fallbacks=True, + mock_response="Hey! nice day", + ) + + assert isinstance(response, litellm.ModelResponse) + assert response.model is not None and response.model == "gpt-4o" diff --git a/litellm/types/router.py b/litellm/types/router.py index e8f3ff641..68ee387fe 100644 --- a/litellm/types/router.py +++ b/litellm/types/router.py @@ -1,4 +1,4 @@ -from typing import List, Optional, Union, Dict, Tuple, Literal +from typing import List, Optional, Union, Dict, Tuple, Literal, TypedDict import httpx from pydantic import BaseModel, validator, Field from .completion import CompletionRequest @@ -277,6 +277,47 @@ class updateDeployment(BaseModel): protected_namespaces = () +class LiteLLMParamsTypedDict(TypedDict, total=False): + """ + [TODO] + - allow additional params (not in list) + - set value to none if not set -> don't raise error if value not set + """ + + model: str + custom_llm_provider: Optional[str] + tpm: Optional[int] + rpm: Optional[int] + api_key: Optional[str] + api_base: Optional[str] + api_version: Optional[str] + timeout: Optional[Union[float, str, httpx.Timeout]] + stream_timeout: Optional[Union[float, str]] + max_retries: Optional[int] + organization: Optional[str] # for openai orgs + ## UNIFIED PROJECT/REGION ## + region_name: Optional[str] + ## VERTEX AI ## + vertex_project: Optional[str] + vertex_location: Optional[str] + ## AWS BEDROCK / SAGEMAKER ## + aws_access_key_id: Optional[str] + aws_secret_access_key: Optional[str] + aws_region_name: Optional[str] + ## IBM WATSONX ## + watsonx_region_name: Optional[str] + ## CUSTOM PRICING ## + input_cost_per_token: Optional[float] + output_cost_per_token: Optional[float] + input_cost_per_second: Optional[float] + output_cost_per_second: Optional[float] + + +class DeploymentTypedDict(TypedDict): + model_name: str + litellm_params: LiteLLMParamsTypedDict + + class Deployment(BaseModel): model_name: str litellm_params: LiteLLM_Params