fix(router.py): set cooldown_time: per model

This commit is contained in:
Krrish Dholakia 2024-06-25 16:51:55 -07:00
parent 91bbef4bcd
commit c70d4ffafb
6 changed files with 72 additions and 11 deletions

View file

@ -1,11 +1,13 @@
#### What this does #### #### What this does ####
# On success, logs events to Promptlayer # On success, logs events to Promptlayer
import dotenv, os import os
from litellm.proxy._types import UserAPIKeyAuth
from litellm.caching import DualCache
from typing import Literal, Union, Optional
import traceback import traceback
from typing import Literal, Optional, Union
import dotenv
from litellm.caching import DualCache
from litellm.proxy._types import UserAPIKeyAuth
class CustomLogger: # https://docs.litellm.ai/docs/observability/custom_callback#callback-class class CustomLogger: # https://docs.litellm.ai/docs/observability/custom_callback#callback-class

View file

@ -19,8 +19,7 @@ from litellm import (
turn_off_message_logging, turn_off_message_logging,
verbose_logger, verbose_logger,
) )
from litellm.caching import DualCache, InMemoryCache, S3Cache
from litellm.caching import InMemoryCache, S3Cache, DualCache
from litellm.integrations.custom_logger import CustomLogger from litellm.integrations.custom_logger import CustomLogger
from litellm.litellm_core_utils.redact_messages import ( from litellm.litellm_core_utils.redact_messages import (
redact_message_input_output_from_logging, redact_message_input_output_from_logging,

View file

@ -650,6 +650,7 @@ def completion(
headers = kwargs.get("headers", None) or extra_headers headers = kwargs.get("headers", None) or extra_headers
num_retries = kwargs.get("num_retries", None) ## deprecated num_retries = kwargs.get("num_retries", None) ## deprecated
max_retries = kwargs.get("max_retries", None) max_retries = kwargs.get("max_retries", None)
cooldown_time = kwargs.get("cooldown_time", None)
context_window_fallback_dict = kwargs.get("context_window_fallback_dict", None) context_window_fallback_dict = kwargs.get("context_window_fallback_dict", None)
organization = kwargs.get("organization", None) organization = kwargs.get("organization", None)
### CUSTOM MODEL COST ### ### CUSTOM MODEL COST ###
@ -763,6 +764,7 @@ def completion(
"allowed_model_region", "allowed_model_region",
"model_config", "model_config",
"fastest_response", "fastest_response",
"cooldown_time",
] ]
default_params = openai_params + litellm_params default_params = openai_params + litellm_params
@ -947,6 +949,7 @@ def completion(
input_cost_per_token=input_cost_per_token, input_cost_per_token=input_cost_per_token,
output_cost_per_second=output_cost_per_second, output_cost_per_second=output_cost_per_second,
output_cost_per_token=output_cost_per_token, output_cost_per_token=output_cost_per_token,
cooldown_time=cooldown_time,
) )
logging.update_environment_variables( logging.update_environment_variables(
model=model, model=model,
@ -3030,6 +3033,7 @@ def embedding(
client = kwargs.pop("client", None) client = kwargs.pop("client", None)
rpm = kwargs.pop("rpm", None) rpm = kwargs.pop("rpm", None)
tpm = kwargs.pop("tpm", None) tpm = kwargs.pop("tpm", None)
cooldown_time = kwargs.get("cooldown_time", None)
max_parallel_requests = kwargs.pop("max_parallel_requests", None) max_parallel_requests = kwargs.pop("max_parallel_requests", None)
model_info = kwargs.get("model_info", None) model_info = kwargs.get("model_info", None)
metadata = kwargs.get("metadata", None) metadata = kwargs.get("metadata", None)
@ -3105,6 +3109,7 @@ def embedding(
"region_name", "region_name",
"allowed_model_region", "allowed_model_region",
"model_config", "model_config",
"cooldown_time",
] ]
default_params = openai_params + litellm_params default_params = openai_params + litellm_params
non_default_params = { non_default_params = {
@ -3165,6 +3170,7 @@ def embedding(
"aembedding": aembedding, "aembedding": aembedding,
"preset_cache_key": None, "preset_cache_key": None,
"stream_response": {}, "stream_response": {},
"cooldown_time": cooldown_time,
}, },
) )
if azure == True or custom_llm_provider == "azure": if azure == True or custom_llm_provider == "azure":

View file

@ -2816,7 +2816,9 @@ class Router:
exception_response = getattr(exception, "response", {}) exception_response = getattr(exception, "response", {})
exception_headers = getattr(exception_response, "headers", None) exception_headers = getattr(exception_response, "headers", None)
_time_to_cooldown = self.cooldown_time _time_to_cooldown = kwargs.get("litellm_params", {}).get(
"cooldown_time", self.cooldown_time
)
if exception_headers is not None: if exception_headers is not None:

View file

@ -1,18 +1,26 @@
#### What this tests #### #### What this tests ####
# This tests calling router with fallback models # This tests calling router with fallback models
import sys, os, time import asyncio
import traceback, asyncio import os
import sys
import time
import traceback
import pytest import pytest
sys.path.insert( sys.path.insert(
0, os.path.abspath("../..") 0, os.path.abspath("../..")
) # Adds the parent directory to the system path ) # Adds the parent directory to the system path
from unittest.mock import AsyncMock, MagicMock, patch
import httpx
import openai
import litellm import litellm
from litellm import Router from litellm import Router
from litellm.integrations.custom_logger import CustomLogger from litellm.integrations.custom_logger import CustomLogger
import openai, httpx
@pytest.mark.asyncio @pytest.mark.asyncio
@ -62,3 +70,45 @@ async def test_cooldown_badrequest_error():
assert response is not None assert response is not None
print(response) print(response)
@pytest.mark.asyncio
async def test_dynamic_cooldowns():
"""
Assert kwargs for completion/embedding have 'cooldown_time' as a litellm_param
"""
# litellm.set_verbose = True
tmp_mock = MagicMock()
litellm.failure_callback = [tmp_mock]
router = Router(
model_list=[
{
"model_name": "my-fake-model",
"litellm_params": {
"model": "openai/gpt-1",
"api_key": "my-key",
"mock_response": Exception("this is an error"),
},
}
],
cooldown_time=60,
)
try:
_ = router.completion(
model="my-fake-model",
messages=[{"role": "user", "content": "Hey, how's it going?"}],
cooldown_time=0,
num_retries=0,
)
except Exception:
pass
tmp_mock.assert_called_once()
print(tmp_mock.call_count)
assert "cooldown_time" in tmp_mock.call_args[0][0]["litellm_params"]
assert tmp_mock.call_args[0][0]["litellm_params"]["cooldown_time"] == 0

View file

@ -2017,6 +2017,7 @@ def get_litellm_params(
input_cost_per_token=None, input_cost_per_token=None,
output_cost_per_token=None, output_cost_per_token=None,
output_cost_per_second=None, output_cost_per_second=None,
cooldown_time=None,
): ):
litellm_params = { litellm_params = {
"acompletion": acompletion, "acompletion": acompletion,
@ -2039,6 +2040,7 @@ def get_litellm_params(
"input_cost_per_second": input_cost_per_second, "input_cost_per_second": input_cost_per_second,
"output_cost_per_token": output_cost_per_token, "output_cost_per_token": output_cost_per_token,
"output_cost_per_second": output_cost_per_second, "output_cost_per_second": output_cost_per_second,
"cooldown_time": cooldown_time,
} }
return litellm_params return litellm_params