fix(router.py): set cooldown_time: per model

This commit is contained in:
Krrish Dholakia 2024-06-25 16:51:55 -07:00
parent e813e984f7
commit d98e00d1e0
6 changed files with 72 additions and 11 deletions

View file

@ -1,11 +1,13 @@
#### What this does ####
# On success, logs events to Promptlayer
import dotenv, os
from litellm.proxy._types import UserAPIKeyAuth
from litellm.caching import DualCache
from typing import Literal, Union, Optional
import os
import traceback
from typing import Literal, Optional, Union
import dotenv
from litellm.caching import DualCache
from litellm.proxy._types import UserAPIKeyAuth
class CustomLogger: # https://docs.litellm.ai/docs/observability/custom_callback#callback-class

View file

@ -19,8 +19,7 @@ from litellm import (
turn_off_message_logging,
verbose_logger,
)
from litellm.caching import InMemoryCache, S3Cache, DualCache
from litellm.caching import DualCache, InMemoryCache, S3Cache
from litellm.integrations.custom_logger import CustomLogger
from litellm.litellm_core_utils.redact_messages import (
redact_message_input_output_from_logging,

View file

@ -650,6 +650,7 @@ def completion(
headers = kwargs.get("headers", None) or extra_headers
num_retries = kwargs.get("num_retries", None) ## deprecated
max_retries = kwargs.get("max_retries", None)
cooldown_time = kwargs.get("cooldown_time", None)
context_window_fallback_dict = kwargs.get("context_window_fallback_dict", None)
organization = kwargs.get("organization", None)
### CUSTOM MODEL COST ###
@ -763,6 +764,7 @@ def completion(
"allowed_model_region",
"model_config",
"fastest_response",
"cooldown_time",
]
default_params = openai_params + litellm_params
@ -947,6 +949,7 @@ def completion(
input_cost_per_token=input_cost_per_token,
output_cost_per_second=output_cost_per_second,
output_cost_per_token=output_cost_per_token,
cooldown_time=cooldown_time,
)
logging.update_environment_variables(
model=model,
@ -3030,6 +3033,7 @@ def embedding(
client = kwargs.pop("client", None)
rpm = kwargs.pop("rpm", None)
tpm = kwargs.pop("tpm", None)
cooldown_time = kwargs.get("cooldown_time", None)
max_parallel_requests = kwargs.pop("max_parallel_requests", None)
model_info = kwargs.get("model_info", None)
metadata = kwargs.get("metadata", None)
@ -3105,6 +3109,7 @@ def embedding(
"region_name",
"allowed_model_region",
"model_config",
"cooldown_time",
]
default_params = openai_params + litellm_params
non_default_params = {
@ -3165,6 +3170,7 @@ def embedding(
"aembedding": aembedding,
"preset_cache_key": None,
"stream_response": {},
"cooldown_time": cooldown_time,
},
)
if azure == True or custom_llm_provider == "azure":

View file

@ -2816,7 +2816,9 @@ class Router:
exception_response = getattr(exception, "response", {})
exception_headers = getattr(exception_response, "headers", None)
_time_to_cooldown = self.cooldown_time
_time_to_cooldown = kwargs.get("litellm_params", {}).get(
"cooldown_time", self.cooldown_time
)
if exception_headers is not None:

View file

@ -1,18 +1,26 @@
#### What this tests ####
# This tests calling router with fallback models
import sys, os, time
import traceback, asyncio
import asyncio
import os
import sys
import time
import traceback
import pytest
sys.path.insert(
0, os.path.abspath("../..")
) # Adds the parent directory to the system path
from unittest.mock import AsyncMock, MagicMock, patch
import httpx
import openai
import litellm
from litellm import Router
from litellm.integrations.custom_logger import CustomLogger
import openai, httpx
@pytest.mark.asyncio
@ -62,3 +70,45 @@ async def test_cooldown_badrequest_error():
assert response is not None
print(response)
@pytest.mark.asyncio
async def test_dynamic_cooldowns():
"""
Assert kwargs for completion/embedding have 'cooldown_time' as a litellm_param
"""
# litellm.set_verbose = True
tmp_mock = MagicMock()
litellm.failure_callback = [tmp_mock]
router = Router(
model_list=[
{
"model_name": "my-fake-model",
"litellm_params": {
"model": "openai/gpt-1",
"api_key": "my-key",
"mock_response": Exception("this is an error"),
},
}
],
cooldown_time=60,
)
try:
_ = router.completion(
model="my-fake-model",
messages=[{"role": "user", "content": "Hey, how's it going?"}],
cooldown_time=0,
num_retries=0,
)
except Exception:
pass
tmp_mock.assert_called_once()
print(tmp_mock.call_count)
assert "cooldown_time" in tmp_mock.call_args[0][0]["litellm_params"]
assert tmp_mock.call_args[0][0]["litellm_params"]["cooldown_time"] == 0

View file

@ -2017,6 +2017,7 @@ def get_litellm_params(
input_cost_per_token=None,
output_cost_per_token=None,
output_cost_per_second=None,
cooldown_time=None,
):
litellm_params = {
"acompletion": acompletion,
@ -2039,6 +2040,7 @@ def get_litellm_params(
"input_cost_per_second": input_cost_per_second,
"output_cost_per_token": output_cost_per_token,
"output_cost_per_second": output_cost_per_second,
"cooldown_time": cooldown_time,
}
return litellm_params