forked from phoenix/litellm-mirror
fix(router.py): set cooldown_time:
per model
This commit is contained in:
parent
e813e984f7
commit
d98e00d1e0
6 changed files with 72 additions and 11 deletions
|
@ -1,11 +1,13 @@
|
|||
#### What this does ####
|
||||
# On success, logs events to Promptlayer
|
||||
import dotenv, os
|
||||
|
||||
from litellm.proxy._types import UserAPIKeyAuth
|
||||
from litellm.caching import DualCache
|
||||
from typing import Literal, Union, Optional
|
||||
import os
|
||||
import traceback
|
||||
from typing import Literal, Optional, Union
|
||||
|
||||
import dotenv
|
||||
|
||||
from litellm.caching import DualCache
|
||||
from litellm.proxy._types import UserAPIKeyAuth
|
||||
|
||||
|
||||
class CustomLogger: # https://docs.litellm.ai/docs/observability/custom_callback#callback-class
|
||||
|
|
|
@ -19,8 +19,7 @@ from litellm import (
|
|||
turn_off_message_logging,
|
||||
verbose_logger,
|
||||
)
|
||||
|
||||
from litellm.caching import InMemoryCache, S3Cache, DualCache
|
||||
from litellm.caching import DualCache, InMemoryCache, S3Cache
|
||||
from litellm.integrations.custom_logger import CustomLogger
|
||||
from litellm.litellm_core_utils.redact_messages import (
|
||||
redact_message_input_output_from_logging,
|
||||
|
|
|
@ -650,6 +650,7 @@ def completion(
|
|||
headers = kwargs.get("headers", None) or extra_headers
|
||||
num_retries = kwargs.get("num_retries", None) ## deprecated
|
||||
max_retries = kwargs.get("max_retries", None)
|
||||
cooldown_time = kwargs.get("cooldown_time", None)
|
||||
context_window_fallback_dict = kwargs.get("context_window_fallback_dict", None)
|
||||
organization = kwargs.get("organization", None)
|
||||
### CUSTOM MODEL COST ###
|
||||
|
@ -763,6 +764,7 @@ def completion(
|
|||
"allowed_model_region",
|
||||
"model_config",
|
||||
"fastest_response",
|
||||
"cooldown_time",
|
||||
]
|
||||
|
||||
default_params = openai_params + litellm_params
|
||||
|
@ -947,6 +949,7 @@ def completion(
|
|||
input_cost_per_token=input_cost_per_token,
|
||||
output_cost_per_second=output_cost_per_second,
|
||||
output_cost_per_token=output_cost_per_token,
|
||||
cooldown_time=cooldown_time,
|
||||
)
|
||||
logging.update_environment_variables(
|
||||
model=model,
|
||||
|
@ -3030,6 +3033,7 @@ def embedding(
|
|||
client = kwargs.pop("client", None)
|
||||
rpm = kwargs.pop("rpm", None)
|
||||
tpm = kwargs.pop("tpm", None)
|
||||
cooldown_time = kwargs.get("cooldown_time", None)
|
||||
max_parallel_requests = kwargs.pop("max_parallel_requests", None)
|
||||
model_info = kwargs.get("model_info", None)
|
||||
metadata = kwargs.get("metadata", None)
|
||||
|
@ -3105,6 +3109,7 @@ def embedding(
|
|||
"region_name",
|
||||
"allowed_model_region",
|
||||
"model_config",
|
||||
"cooldown_time",
|
||||
]
|
||||
default_params = openai_params + litellm_params
|
||||
non_default_params = {
|
||||
|
@ -3165,6 +3170,7 @@ def embedding(
|
|||
"aembedding": aembedding,
|
||||
"preset_cache_key": None,
|
||||
"stream_response": {},
|
||||
"cooldown_time": cooldown_time,
|
||||
},
|
||||
)
|
||||
if azure == True or custom_llm_provider == "azure":
|
||||
|
|
|
@ -2816,7 +2816,9 @@ class Router:
|
|||
|
||||
exception_response = getattr(exception, "response", {})
|
||||
exception_headers = getattr(exception_response, "headers", None)
|
||||
_time_to_cooldown = self.cooldown_time
|
||||
_time_to_cooldown = kwargs.get("litellm_params", {}).get(
|
||||
"cooldown_time", self.cooldown_time
|
||||
)
|
||||
|
||||
if exception_headers is not None:
|
||||
|
||||
|
|
|
@ -1,18 +1,26 @@
|
|||
#### What this tests ####
|
||||
# This tests calling router with fallback models
|
||||
|
||||
import sys, os, time
|
||||
import traceback, asyncio
|
||||
import asyncio
|
||||
import os
|
||||
import sys
|
||||
import time
|
||||
import traceback
|
||||
|
||||
import pytest
|
||||
|
||||
sys.path.insert(
|
||||
0, os.path.abspath("../..")
|
||||
) # Adds the parent directory to the system path
|
||||
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import httpx
|
||||
import openai
|
||||
|
||||
import litellm
|
||||
from litellm import Router
|
||||
from litellm.integrations.custom_logger import CustomLogger
|
||||
import openai, httpx
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
|
@ -62,3 +70,45 @@ async def test_cooldown_badrequest_error():
|
|||
assert response is not None
|
||||
|
||||
print(response)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_dynamic_cooldowns():
|
||||
"""
|
||||
Assert kwargs for completion/embedding have 'cooldown_time' as a litellm_param
|
||||
"""
|
||||
# litellm.set_verbose = True
|
||||
tmp_mock = MagicMock()
|
||||
|
||||
litellm.failure_callback = [tmp_mock]
|
||||
|
||||
router = Router(
|
||||
model_list=[
|
||||
{
|
||||
"model_name": "my-fake-model",
|
||||
"litellm_params": {
|
||||
"model": "openai/gpt-1",
|
||||
"api_key": "my-key",
|
||||
"mock_response": Exception("this is an error"),
|
||||
},
|
||||
}
|
||||
],
|
||||
cooldown_time=60,
|
||||
)
|
||||
|
||||
try:
|
||||
_ = router.completion(
|
||||
model="my-fake-model",
|
||||
messages=[{"role": "user", "content": "Hey, how's it going?"}],
|
||||
cooldown_time=0,
|
||||
num_retries=0,
|
||||
)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
tmp_mock.assert_called_once()
|
||||
|
||||
print(tmp_mock.call_count)
|
||||
|
||||
assert "cooldown_time" in tmp_mock.call_args[0][0]["litellm_params"]
|
||||
assert tmp_mock.call_args[0][0]["litellm_params"]["cooldown_time"] == 0
|
||||
|
|
|
@ -2017,6 +2017,7 @@ def get_litellm_params(
|
|||
input_cost_per_token=None,
|
||||
output_cost_per_token=None,
|
||||
output_cost_per_second=None,
|
||||
cooldown_time=None,
|
||||
):
|
||||
litellm_params = {
|
||||
"acompletion": acompletion,
|
||||
|
@ -2039,6 +2040,7 @@ def get_litellm_params(
|
|||
"input_cost_per_second": input_cost_per_second,
|
||||
"output_cost_per_token": output_cost_per_token,
|
||||
"output_cost_per_second": output_cost_per_second,
|
||||
"cooldown_time": cooldown_time,
|
||||
}
|
||||
|
||||
return litellm_params
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue