forked from phoenix/litellm-mirror
fix(router.py): set cooldown_time:
per model
This commit is contained in:
parent
e813e984f7
commit
d98e00d1e0
6 changed files with 72 additions and 11 deletions
|
@ -1,11 +1,13 @@
|
||||||
#### What this does ####
|
#### What this does ####
|
||||||
# On success, logs events to Promptlayer
|
# On success, logs events to Promptlayer
|
||||||
import dotenv, os
|
import os
|
||||||
|
|
||||||
from litellm.proxy._types import UserAPIKeyAuth
|
|
||||||
from litellm.caching import DualCache
|
|
||||||
from typing import Literal, Union, Optional
|
|
||||||
import traceback
|
import traceback
|
||||||
|
from typing import Literal, Optional, Union
|
||||||
|
|
||||||
|
import dotenv
|
||||||
|
|
||||||
|
from litellm.caching import DualCache
|
||||||
|
from litellm.proxy._types import UserAPIKeyAuth
|
||||||
|
|
||||||
|
|
||||||
class CustomLogger: # https://docs.litellm.ai/docs/observability/custom_callback#callback-class
|
class CustomLogger: # https://docs.litellm.ai/docs/observability/custom_callback#callback-class
|
||||||
|
|
|
@ -19,8 +19,7 @@ from litellm import (
|
||||||
turn_off_message_logging,
|
turn_off_message_logging,
|
||||||
verbose_logger,
|
verbose_logger,
|
||||||
)
|
)
|
||||||
|
from litellm.caching import DualCache, InMemoryCache, S3Cache
|
||||||
from litellm.caching import InMemoryCache, S3Cache, DualCache
|
|
||||||
from litellm.integrations.custom_logger import CustomLogger
|
from litellm.integrations.custom_logger import CustomLogger
|
||||||
from litellm.litellm_core_utils.redact_messages import (
|
from litellm.litellm_core_utils.redact_messages import (
|
||||||
redact_message_input_output_from_logging,
|
redact_message_input_output_from_logging,
|
||||||
|
|
|
@ -650,6 +650,7 @@ def completion(
|
||||||
headers = kwargs.get("headers", None) or extra_headers
|
headers = kwargs.get("headers", None) or extra_headers
|
||||||
num_retries = kwargs.get("num_retries", None) ## deprecated
|
num_retries = kwargs.get("num_retries", None) ## deprecated
|
||||||
max_retries = kwargs.get("max_retries", None)
|
max_retries = kwargs.get("max_retries", None)
|
||||||
|
cooldown_time = kwargs.get("cooldown_time", None)
|
||||||
context_window_fallback_dict = kwargs.get("context_window_fallback_dict", None)
|
context_window_fallback_dict = kwargs.get("context_window_fallback_dict", None)
|
||||||
organization = kwargs.get("organization", None)
|
organization = kwargs.get("organization", None)
|
||||||
### CUSTOM MODEL COST ###
|
### CUSTOM MODEL COST ###
|
||||||
|
@ -763,6 +764,7 @@ def completion(
|
||||||
"allowed_model_region",
|
"allowed_model_region",
|
||||||
"model_config",
|
"model_config",
|
||||||
"fastest_response",
|
"fastest_response",
|
||||||
|
"cooldown_time",
|
||||||
]
|
]
|
||||||
|
|
||||||
default_params = openai_params + litellm_params
|
default_params = openai_params + litellm_params
|
||||||
|
@ -947,6 +949,7 @@ def completion(
|
||||||
input_cost_per_token=input_cost_per_token,
|
input_cost_per_token=input_cost_per_token,
|
||||||
output_cost_per_second=output_cost_per_second,
|
output_cost_per_second=output_cost_per_second,
|
||||||
output_cost_per_token=output_cost_per_token,
|
output_cost_per_token=output_cost_per_token,
|
||||||
|
cooldown_time=cooldown_time,
|
||||||
)
|
)
|
||||||
logging.update_environment_variables(
|
logging.update_environment_variables(
|
||||||
model=model,
|
model=model,
|
||||||
|
@ -3030,6 +3033,7 @@ def embedding(
|
||||||
client = kwargs.pop("client", None)
|
client = kwargs.pop("client", None)
|
||||||
rpm = kwargs.pop("rpm", None)
|
rpm = kwargs.pop("rpm", None)
|
||||||
tpm = kwargs.pop("tpm", None)
|
tpm = kwargs.pop("tpm", None)
|
||||||
|
cooldown_time = kwargs.get("cooldown_time", None)
|
||||||
max_parallel_requests = kwargs.pop("max_parallel_requests", None)
|
max_parallel_requests = kwargs.pop("max_parallel_requests", None)
|
||||||
model_info = kwargs.get("model_info", None)
|
model_info = kwargs.get("model_info", None)
|
||||||
metadata = kwargs.get("metadata", None)
|
metadata = kwargs.get("metadata", None)
|
||||||
|
@ -3105,6 +3109,7 @@ def embedding(
|
||||||
"region_name",
|
"region_name",
|
||||||
"allowed_model_region",
|
"allowed_model_region",
|
||||||
"model_config",
|
"model_config",
|
||||||
|
"cooldown_time",
|
||||||
]
|
]
|
||||||
default_params = openai_params + litellm_params
|
default_params = openai_params + litellm_params
|
||||||
non_default_params = {
|
non_default_params = {
|
||||||
|
@ -3165,6 +3170,7 @@ def embedding(
|
||||||
"aembedding": aembedding,
|
"aembedding": aembedding,
|
||||||
"preset_cache_key": None,
|
"preset_cache_key": None,
|
||||||
"stream_response": {},
|
"stream_response": {},
|
||||||
|
"cooldown_time": cooldown_time,
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
if azure == True or custom_llm_provider == "azure":
|
if azure == True or custom_llm_provider == "azure":
|
||||||
|
|
|
@ -2816,7 +2816,9 @@ class Router:
|
||||||
|
|
||||||
exception_response = getattr(exception, "response", {})
|
exception_response = getattr(exception, "response", {})
|
||||||
exception_headers = getattr(exception_response, "headers", None)
|
exception_headers = getattr(exception_response, "headers", None)
|
||||||
_time_to_cooldown = self.cooldown_time
|
_time_to_cooldown = kwargs.get("litellm_params", {}).get(
|
||||||
|
"cooldown_time", self.cooldown_time
|
||||||
|
)
|
||||||
|
|
||||||
if exception_headers is not None:
|
if exception_headers is not None:
|
||||||
|
|
||||||
|
|
|
@ -1,18 +1,26 @@
|
||||||
#### What this tests ####
|
#### What this tests ####
|
||||||
# This tests calling router with fallback models
|
# This tests calling router with fallback models
|
||||||
|
|
||||||
import sys, os, time
|
import asyncio
|
||||||
import traceback, asyncio
|
import os
|
||||||
|
import sys
|
||||||
|
import time
|
||||||
|
import traceback
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
sys.path.insert(
|
sys.path.insert(
|
||||||
0, os.path.abspath("../..")
|
0, os.path.abspath("../..")
|
||||||
) # Adds the parent directory to the system path
|
) # Adds the parent directory to the system path
|
||||||
|
|
||||||
|
from unittest.mock import AsyncMock, MagicMock, patch
|
||||||
|
|
||||||
|
import httpx
|
||||||
|
import openai
|
||||||
|
|
||||||
import litellm
|
import litellm
|
||||||
from litellm import Router
|
from litellm import Router
|
||||||
from litellm.integrations.custom_logger import CustomLogger
|
from litellm.integrations.custom_logger import CustomLogger
|
||||||
import openai, httpx
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
|
@ -62,3 +70,45 @@ async def test_cooldown_badrequest_error():
|
||||||
assert response is not None
|
assert response is not None
|
||||||
|
|
||||||
print(response)
|
print(response)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_dynamic_cooldowns():
|
||||||
|
"""
|
||||||
|
Assert kwargs for completion/embedding have 'cooldown_time' as a litellm_param
|
||||||
|
"""
|
||||||
|
# litellm.set_verbose = True
|
||||||
|
tmp_mock = MagicMock()
|
||||||
|
|
||||||
|
litellm.failure_callback = [tmp_mock]
|
||||||
|
|
||||||
|
router = Router(
|
||||||
|
model_list=[
|
||||||
|
{
|
||||||
|
"model_name": "my-fake-model",
|
||||||
|
"litellm_params": {
|
||||||
|
"model": "openai/gpt-1",
|
||||||
|
"api_key": "my-key",
|
||||||
|
"mock_response": Exception("this is an error"),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
],
|
||||||
|
cooldown_time=60,
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
_ = router.completion(
|
||||||
|
model="my-fake-model",
|
||||||
|
messages=[{"role": "user", "content": "Hey, how's it going?"}],
|
||||||
|
cooldown_time=0,
|
||||||
|
num_retries=0,
|
||||||
|
)
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
|
tmp_mock.assert_called_once()
|
||||||
|
|
||||||
|
print(tmp_mock.call_count)
|
||||||
|
|
||||||
|
assert "cooldown_time" in tmp_mock.call_args[0][0]["litellm_params"]
|
||||||
|
assert tmp_mock.call_args[0][0]["litellm_params"]["cooldown_time"] == 0
|
||||||
|
|
|
@ -2017,6 +2017,7 @@ def get_litellm_params(
|
||||||
input_cost_per_token=None,
|
input_cost_per_token=None,
|
||||||
output_cost_per_token=None,
|
output_cost_per_token=None,
|
||||||
output_cost_per_second=None,
|
output_cost_per_second=None,
|
||||||
|
cooldown_time=None,
|
||||||
):
|
):
|
||||||
litellm_params = {
|
litellm_params = {
|
||||||
"acompletion": acompletion,
|
"acompletion": acompletion,
|
||||||
|
@ -2039,6 +2040,7 @@ def get_litellm_params(
|
||||||
"input_cost_per_second": input_cost_per_second,
|
"input_cost_per_second": input_cost_per_second,
|
||||||
"output_cost_per_token": output_cost_per_token,
|
"output_cost_per_token": output_cost_per_token,
|
||||||
"output_cost_per_second": output_cost_per_second,
|
"output_cost_per_second": output_cost_per_second,
|
||||||
|
"cooldown_time": cooldown_time,
|
||||||
}
|
}
|
||||||
|
|
||||||
return litellm_params
|
return litellm_params
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue