From 68d81f88f996a96684659097846efb5c146c5825 Mon Sep 17 00:00:00 2001 From: Krish Dholakia Date: Thu, 14 Nov 2024 19:15:13 +0530 Subject: [PATCH] Litellm router disable fallbacks (#6743) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * bump: version 1.52.6 → 1.52.7 * feat(router.py): enable dynamically disabling fallbacks Allows for enabling/disabling fallbacks per key * feat(litellm_pre_call_utils.py): support setting 'disable_fallbacks' on litellm key * test: fix test --- litellm/proxy/litellm_pre_call_utils.py | 81 ++++++++++++-------- litellm/router.py | 3 +- pyproject.toml | 4 +- tests/local_testing/test_router_fallbacks.py | 43 +++++++++++ tests/proxy_unit_tests/test_proxy_server.py | 25 ++++++ 5 files changed, 122 insertions(+), 34 deletions(-) diff --git a/litellm/proxy/litellm_pre_call_utils.py b/litellm/proxy/litellm_pre_call_utils.py index 789e79f37..3d1d3b491 100644 --- a/litellm/proxy/litellm_pre_call_utils.py +++ b/litellm/proxy/litellm_pre_call_utils.py @@ -274,6 +274,51 @@ class LiteLLMProxyRequestSetup: ) return user_api_key_logged_metadata + @staticmethod + def add_key_level_controls( + key_metadata: dict, data: dict, _metadata_variable_name: str + ): + data = data.copy() + if "cache" in key_metadata: + data["cache"] = {} + if isinstance(key_metadata["cache"], dict): + for k, v in key_metadata["cache"].items(): + if k in SupportedCacheControls: + data["cache"][k] = v + + ## KEY-LEVEL SPEND LOGS / TAGS + if "tags" in key_metadata and key_metadata["tags"] is not None: + if "tags" in data[_metadata_variable_name] and isinstance( + data[_metadata_variable_name]["tags"], list + ): + data[_metadata_variable_name]["tags"].extend(key_metadata["tags"]) + else: + data[_metadata_variable_name]["tags"] = key_metadata["tags"] + if "spend_logs_metadata" in key_metadata and isinstance( + key_metadata["spend_logs_metadata"], dict + ): + if "spend_logs_metadata" in data[_metadata_variable_name] and isinstance( + data[_metadata_variable_name]["spend_logs_metadata"], dict + ): + for key, value in key_metadata["spend_logs_metadata"].items(): + if ( + key not in data[_metadata_variable_name]["spend_logs_metadata"] + ): # don't override k-v pair sent by request (user request) + data[_metadata_variable_name]["spend_logs_metadata"][ + key + ] = value + else: + data[_metadata_variable_name]["spend_logs_metadata"] = key_metadata[ + "spend_logs_metadata" + ] + + ## KEY-LEVEL DISABLE FALLBACKS + if "disable_fallbacks" in key_metadata and isinstance( + key_metadata["disable_fallbacks"], bool + ): + data["disable_fallbacks"] = key_metadata["disable_fallbacks"] + return data + async def add_litellm_data_to_request( # noqa: PLR0915 data: dict, @@ -389,37 +434,11 @@ async def add_litellm_data_to_request( # noqa: PLR0915 ### KEY-LEVEL Controls key_metadata = user_api_key_dict.metadata - if "cache" in key_metadata: - data["cache"] = {} - if isinstance(key_metadata["cache"], dict): - for k, v in key_metadata["cache"].items(): - if k in SupportedCacheControls: - data["cache"][k] = v - - ## KEY-LEVEL SPEND LOGS / TAGS - if "tags" in key_metadata and key_metadata["tags"] is not None: - if "tags" in data[_metadata_variable_name] and isinstance( - data[_metadata_variable_name]["tags"], list - ): - data[_metadata_variable_name]["tags"].extend(key_metadata["tags"]) - else: - data[_metadata_variable_name]["tags"] = key_metadata["tags"] - if "spend_logs_metadata" in key_metadata and isinstance( - key_metadata["spend_logs_metadata"], dict - ): - if "spend_logs_metadata" in data[_metadata_variable_name] and isinstance( - data[_metadata_variable_name]["spend_logs_metadata"], dict - ): - for key, value in key_metadata["spend_logs_metadata"].items(): - if ( - key not in data[_metadata_variable_name]["spend_logs_metadata"] - ): # don't override k-v pair sent by request (user request) - data[_metadata_variable_name]["spend_logs_metadata"][key] = value - else: - data[_metadata_variable_name]["spend_logs_metadata"] = key_metadata[ - "spend_logs_metadata" - ] - + data = LiteLLMProxyRequestSetup.add_key_level_controls( + key_metadata=key_metadata, + data=data, + _metadata_variable_name=_metadata_variable_name, + ) ## TEAM-LEVEL SPEND LOGS/TAGS team_metadata = user_api_key_dict.team_metadata or {} if "tags" in team_metadata and team_metadata["tags"] is not None: diff --git a/litellm/router.py b/litellm/router.py index 400347ff2..97065bc85 100644 --- a/litellm/router.py +++ b/litellm/router.py @@ -2610,6 +2610,7 @@ class Router: If it fails after num_retries, fall back to another model group """ model_group: Optional[str] = kwargs.get("model") + disable_fallbacks: Optional[bool] = kwargs.pop("disable_fallbacks", False) fallbacks: Optional[List] = kwargs.get("fallbacks", self.fallbacks) context_window_fallbacks: Optional[List] = kwargs.get( "context_window_fallbacks", self.context_window_fallbacks @@ -2637,7 +2638,7 @@ class Router: original_model_group: Optional[str] = kwargs.get("model") # type: ignore fallback_failure_exception_str = "" - if original_model_group is None: + if disable_fallbacks is True or original_model_group is None: raise e input_kwargs = { diff --git a/pyproject.toml b/pyproject.toml index 17d37c0ce..aed832f24 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "litellm" -version = "1.52.6" +version = "1.52.7" description = "Library to easily interface with LLM API providers" authors = ["BerriAI"] license = "MIT" @@ -91,7 +91,7 @@ requires = ["poetry-core", "wheel"] build-backend = "poetry.core.masonry.api" [tool.commitizen] -version = "1.52.6" +version = "1.52.7" version_files = [ "pyproject.toml:^version" ] diff --git a/tests/local_testing/test_router_fallbacks.py b/tests/local_testing/test_router_fallbacks.py index cad640a54..1a745e716 100644 --- a/tests/local_testing/test_router_fallbacks.py +++ b/tests/local_testing/test_router_fallbacks.py @@ -1455,3 +1455,46 @@ async def test_router_fallbacks_default_and_model_specific_fallbacks(sync_mode): assert isinstance( exc_info.value, litellm.AuthenticationError ), f"Expected AuthenticationError, but got {type(exc_info.value).__name__}" + + +@pytest.mark.asyncio +async def test_router_disable_fallbacks_dynamically(): + from litellm.router import run_async_fallback + + router = Router( + model_list=[ + { + "model_name": "bad-model", + "litellm_params": { + "model": "openai/my-bad-model", + "api_key": "my-bad-api-key", + }, + }, + { + "model_name": "good-model", + "litellm_params": { + "model": "gpt-4o", + "api_key": os.getenv("OPENAI_API_KEY"), + }, + }, + ], + fallbacks=[{"bad-model": ["good-model"]}], + default_fallbacks=["good-model"], + ) + + with patch.object( + router, + "log_retry", + new=MagicMock(return_value=None), + ) as mock_client: + try: + resp = await router.acompletion( + model="bad-model", + messages=[{"role": "user", "content": "Hey, how's it going?"}], + disable_fallbacks=True, + ) + print(resp) + except Exception as e: + print(e) + + mock_client.assert_not_called() diff --git a/tests/proxy_unit_tests/test_proxy_server.py b/tests/proxy_unit_tests/test_proxy_server.py index 5588d0414..b1c00ce75 100644 --- a/tests/proxy_unit_tests/test_proxy_server.py +++ b/tests/proxy_unit_tests/test_proxy_server.py @@ -1500,6 +1500,31 @@ async def test_add_callback_via_key_litellm_pre_call_utils( assert new_data["failure_callback"] == expected_failure_callbacks +@pytest.mark.asyncio +@pytest.mark.parametrize( + "disable_fallbacks_set", + [ + True, + False, + ], +) +async def test_disable_fallbacks_by_key(disable_fallbacks_set): + from litellm.proxy.litellm_pre_call_utils import LiteLLMProxyRequestSetup + + key_metadata = {"disable_fallbacks": disable_fallbacks_set} + existing_data = { + "model": "azure/chatgpt-v-2", + "messages": [{"role": "user", "content": "write 1 sentence poem"}], + } + data = LiteLLMProxyRequestSetup.add_key_level_controls( + key_metadata=key_metadata, + data=existing_data, + _metadata_variable_name="metadata", + ) + + assert data["disable_fallbacks"] == disable_fallbacks_set + + @pytest.mark.asyncio @pytest.mark.parametrize( "callback_type, expected_success_callbacks, expected_failure_callbacks",