From c70d4ffafb483323fb45a186bcba43d0bc270eb0 Mon Sep 17 00:00:00 2001 From: Krrish Dholakia Date: Tue, 25 Jun 2024 16:51:55 -0700 Subject: [PATCH] fix(router.py): set `cooldown_time:` per model --- litellm/integrations/custom_logger.py | 12 ++-- litellm/litellm_core_utils/litellm_logging.py | 3 +- litellm/main.py | 6 ++ litellm/router.py | 4 +- litellm/tests/test_router_cooldowns.py | 56 ++++++++++++++++++- litellm/utils.py | 2 + 6 files changed, 72 insertions(+), 11 deletions(-) diff --git a/litellm/integrations/custom_logger.py b/litellm/integrations/custom_logger.py index 5a6282994c..da9826b9b5 100644 --- a/litellm/integrations/custom_logger.py +++ b/litellm/integrations/custom_logger.py @@ -1,11 +1,13 @@ #### What this does #### # On success, logs events to Promptlayer -import dotenv, os - -from litellm.proxy._types import UserAPIKeyAuth -from litellm.caching import DualCache -from typing import Literal, Union, Optional +import os import traceback +from typing import Literal, Optional, Union + +import dotenv + +from litellm.caching import DualCache +from litellm.proxy._types import UserAPIKeyAuth class CustomLogger: # https://docs.litellm.ai/docs/observability/custom_callback#callback-class diff --git a/litellm/litellm_core_utils/litellm_logging.py b/litellm/litellm_core_utils/litellm_logging.py index aa22b51534..add281e43f 100644 --- a/litellm/litellm_core_utils/litellm_logging.py +++ b/litellm/litellm_core_utils/litellm_logging.py @@ -19,8 +19,7 @@ from litellm import ( turn_off_message_logging, verbose_logger, ) - -from litellm.caching import InMemoryCache, S3Cache, DualCache +from litellm.caching import DualCache, InMemoryCache, S3Cache from litellm.integrations.custom_logger import CustomLogger from litellm.litellm_core_utils.redact_messages import ( redact_message_input_output_from_logging, diff --git a/litellm/main.py b/litellm/main.py index 573b2c19fe..b7aa47ab74 100644 --- a/litellm/main.py +++ b/litellm/main.py @@ -650,6 +650,7 @@ def completion( headers = kwargs.get("headers", None) or extra_headers num_retries = kwargs.get("num_retries", None) ## deprecated max_retries = kwargs.get("max_retries", None) + cooldown_time = kwargs.get("cooldown_time", None) context_window_fallback_dict = kwargs.get("context_window_fallback_dict", None) organization = kwargs.get("organization", None) ### CUSTOM MODEL COST ### @@ -763,6 +764,7 @@ def completion( "allowed_model_region", "model_config", "fastest_response", + "cooldown_time", ] default_params = openai_params + litellm_params @@ -947,6 +949,7 @@ def completion( input_cost_per_token=input_cost_per_token, output_cost_per_second=output_cost_per_second, output_cost_per_token=output_cost_per_token, + cooldown_time=cooldown_time, ) logging.update_environment_variables( model=model, @@ -3030,6 +3033,7 @@ def embedding( client = kwargs.pop("client", None) rpm = kwargs.pop("rpm", None) tpm = kwargs.pop("tpm", None) + cooldown_time = kwargs.get("cooldown_time", None) max_parallel_requests = kwargs.pop("max_parallel_requests", None) model_info = kwargs.get("model_info", None) metadata = kwargs.get("metadata", None) @@ -3105,6 +3109,7 @@ def embedding( "region_name", "allowed_model_region", "model_config", + "cooldown_time", ] default_params = openai_params + litellm_params non_default_params = { @@ -3165,6 +3170,7 @@ def embedding( "aembedding": aembedding, "preset_cache_key": None, "stream_response": {}, + "cooldown_time": cooldown_time, }, ) if azure == True or custom_llm_provider == "azure": diff --git a/litellm/router.py b/litellm/router.py index 840df5b54e..e2f7ce8b21 100644 --- a/litellm/router.py +++ b/litellm/router.py @@ -2816,7 +2816,9 @@ class Router: exception_response = getattr(exception, "response", {}) exception_headers = getattr(exception_response, "headers", None) - _time_to_cooldown = self.cooldown_time + _time_to_cooldown = kwargs.get("litellm_params", {}).get( + "cooldown_time", self.cooldown_time + ) if exception_headers is not None: diff --git a/litellm/tests/test_router_cooldowns.py b/litellm/tests/test_router_cooldowns.py index 35095bb2cf..3eef6e5423 100644 --- a/litellm/tests/test_router_cooldowns.py +++ b/litellm/tests/test_router_cooldowns.py @@ -1,18 +1,26 @@ #### What this tests #### # This tests calling router with fallback models -import sys, os, time -import traceback, asyncio +import asyncio +import os +import sys +import time +import traceback + import pytest sys.path.insert( 0, os.path.abspath("../..") ) # Adds the parent directory to the system path +from unittest.mock import AsyncMock, MagicMock, patch + +import httpx +import openai + import litellm from litellm import Router from litellm.integrations.custom_logger import CustomLogger -import openai, httpx @pytest.mark.asyncio @@ -62,3 +70,45 @@ async def test_cooldown_badrequest_error(): assert response is not None print(response) + + +@pytest.mark.asyncio +async def test_dynamic_cooldowns(): + """ + Assert kwargs for completion/embedding have 'cooldown_time' as a litellm_param + """ + # litellm.set_verbose = True + tmp_mock = MagicMock() + + litellm.failure_callback = [tmp_mock] + + router = Router( + model_list=[ + { + "model_name": "my-fake-model", + "litellm_params": { + "model": "openai/gpt-1", + "api_key": "my-key", + "mock_response": Exception("this is an error"), + }, + } + ], + cooldown_time=60, + ) + + try: + _ = router.completion( + model="my-fake-model", + messages=[{"role": "user", "content": "Hey, how's it going?"}], + cooldown_time=0, + num_retries=0, + ) + except Exception: + pass + + tmp_mock.assert_called_once() + + print(tmp_mock.call_count) + + assert "cooldown_time" in tmp_mock.call_args[0][0]["litellm_params"] + assert tmp_mock.call_args[0][0]["litellm_params"]["cooldown_time"] == 0 diff --git a/litellm/utils.py b/litellm/utils.py index 4465c5b0a4..beae7ba4ab 100644 --- a/litellm/utils.py +++ b/litellm/utils.py @@ -2017,6 +2017,7 @@ def get_litellm_params( input_cost_per_token=None, output_cost_per_token=None, output_cost_per_second=None, + cooldown_time=None, ): litellm_params = { "acompletion": acompletion, @@ -2039,6 +2040,7 @@ def get_litellm_params( "input_cost_per_second": input_cost_per_second, "output_cost_per_token": output_cost_per_token, "output_cost_per_second": output_cost_per_second, + "cooldown_time": cooldown_time, } return litellm_params