Merge pull request #3412 from sumanth13131/usage-based-routing-ttl-on-cache

usage-based-routing-ttl-on-cache
This commit is contained in:
Krish Dholakia 2024-05-21 07:58:41 -07:00 committed by GitHub
commit 2cda5a2bc3
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
4 changed files with 98 additions and 14 deletions

View file

@ -374,13 +374,17 @@ class Router:
litellm.callbacks.append(self.leastbusy_logger) # type: ignore litellm.callbacks.append(self.leastbusy_logger) # type: ignore
elif routing_strategy == "usage-based-routing": elif routing_strategy == "usage-based-routing":
self.lowesttpm_logger = LowestTPMLoggingHandler( self.lowesttpm_logger = LowestTPMLoggingHandler(
router_cache=self.cache, model_list=self.model_list router_cache=self.cache,
model_list=self.model_list,
routing_args=routing_strategy_args
) )
if isinstance(litellm.callbacks, list): if isinstance(litellm.callbacks, list):
litellm.callbacks.append(self.lowesttpm_logger) # type: ignore litellm.callbacks.append(self.lowesttpm_logger) # type: ignore
elif routing_strategy == "usage-based-routing-v2": elif routing_strategy == "usage-based-routing-v2":
self.lowesttpm_logger_v2 = LowestTPMLoggingHandler_v2( self.lowesttpm_logger_v2 = LowestTPMLoggingHandler_v2(
router_cache=self.cache, model_list=self.model_list router_cache=self.cache,
model_list=self.model_list,
routing_args=routing_strategy_args
) )
if isinstance(litellm.callbacks, list): if isinstance(litellm.callbacks, list):
litellm.callbacks.append(self.lowesttpm_logger_v2) # type: ignore litellm.callbacks.append(self.lowesttpm_logger_v2) # type: ignore

View file

@ -1,6 +1,6 @@
#### What this does #### #### What this does ####
# identifies lowest tpm deployment # identifies lowest tpm deployment
from pydantic import BaseModel
import dotenv, os, requests, random import dotenv, os, requests, random
from typing import Optional, Union, List, Dict from typing import Optional, Union, List, Dict
from datetime import datetime from datetime import datetime
@ -11,16 +11,31 @@ from litellm.integrations.custom_logger import CustomLogger
from litellm._logging import verbose_router_logger from litellm._logging import verbose_router_logger
from litellm.utils import print_verbose from litellm.utils import print_verbose
class LiteLLMBase(BaseModel):
"""
Implements default functions, all pydantic objects should have.
"""
def json(self, **kwargs):
try:
return self.model_dump() # noqa
except:
# if using pydantic v1
return self.dict()
class RoutingArgs(LiteLLMBase):
ttl: int = 1 * 60 # 1min (RPM/TPM expire key)
class LowestTPMLoggingHandler(CustomLogger): class LowestTPMLoggingHandler(CustomLogger):
test_flag: bool = False test_flag: bool = False
logged_success: int = 0 logged_success: int = 0
logged_failure: int = 0 logged_failure: int = 0
default_cache_time_seconds: int = 1 * 60 * 60 # 1 hour default_cache_time_seconds: int = 1 * 60 * 60 # 1 hour
def __init__(self, router_cache: DualCache, model_list: list): def __init__(self, router_cache: DualCache, model_list: list, routing_args: dict = {}):
self.router_cache = router_cache self.router_cache = router_cache
self.model_list = model_list self.model_list = model_list
self.routing_args = RoutingArgs(**routing_args)
def log_success_event(self, kwargs, response_obj, start_time, end_time): def log_success_event(self, kwargs, response_obj, start_time, end_time):
try: try:
@ -57,13 +72,13 @@ class LowestTPMLoggingHandler(CustomLogger):
request_count_dict = self.router_cache.get_cache(key=tpm_key) or {} request_count_dict = self.router_cache.get_cache(key=tpm_key) or {}
request_count_dict[id] = request_count_dict.get(id, 0) + total_tokens request_count_dict[id] = request_count_dict.get(id, 0) + total_tokens
self.router_cache.set_cache(key=tpm_key, value=request_count_dict) self.router_cache.set_cache(key=tpm_key, value=request_count_dict, ttl=self.routing_args.ttl)
## RPM ## RPM
request_count_dict = self.router_cache.get_cache(key=rpm_key) or {} request_count_dict = self.router_cache.get_cache(key=rpm_key) or {}
request_count_dict[id] = request_count_dict.get(id, 0) + 1 request_count_dict[id] = request_count_dict.get(id, 0) + 1
self.router_cache.set_cache(key=rpm_key, value=request_count_dict) self.router_cache.set_cache(key=rpm_key, value=request_count_dict, ttl=self.routing_args.ttl)
### TESTING ### ### TESTING ###
if self.test_flag: if self.test_flag:
@ -108,13 +123,13 @@ class LowestTPMLoggingHandler(CustomLogger):
request_count_dict = self.router_cache.get_cache(key=tpm_key) or {} request_count_dict = self.router_cache.get_cache(key=tpm_key) or {}
request_count_dict[id] = request_count_dict.get(id, 0) + total_tokens request_count_dict[id] = request_count_dict.get(id, 0) + total_tokens
self.router_cache.set_cache(key=tpm_key, value=request_count_dict) self.router_cache.set_cache(key=tpm_key, value=request_count_dict, ttl=self.routing_args.ttl)
## RPM ## RPM
request_count_dict = self.router_cache.get_cache(key=rpm_key) or {} request_count_dict = self.router_cache.get_cache(key=rpm_key) or {}
request_count_dict[id] = request_count_dict.get(id, 0) + 1 request_count_dict[id] = request_count_dict.get(id, 0) + 1
self.router_cache.set_cache(key=rpm_key, value=request_count_dict) self.router_cache.set_cache(key=rpm_key, value=request_count_dict, ttl=self.routing_args.ttl)
### TESTING ### ### TESTING ###
if self.test_flag: if self.test_flag:

View file

@ -1,6 +1,6 @@
#### What this does #### #### What this does ####
# identifies lowest tpm deployment # identifies lowest tpm deployment
from pydantic import BaseModel
import dotenv, os, requests, random import dotenv, os, requests, random
from typing import Optional, Union, List, Dict from typing import Optional, Union, List, Dict
import datetime as datetime_og import datetime as datetime_og
@ -14,6 +14,20 @@ from litellm._logging import verbose_router_logger
from litellm.utils import print_verbose, get_utc_datetime from litellm.utils import print_verbose, get_utc_datetime
from litellm.types.router import RouterErrors from litellm.types.router import RouterErrors
class LiteLLMBase(BaseModel):
"""
Implements default functions, all pydantic objects should have.
"""
def json(self, **kwargs):
try:
return self.model_dump() # noqa
except:
# if using pydantic v1
return self.dict()
class RoutingArgs(LiteLLMBase):
ttl: int = 1 * 60 # 1min (RPM/TPM expire key)
class LowestTPMLoggingHandler_v2(CustomLogger): class LowestTPMLoggingHandler_v2(CustomLogger):
""" """
@ -33,9 +47,10 @@ class LowestTPMLoggingHandler_v2(CustomLogger):
logged_failure: int = 0 logged_failure: int = 0
default_cache_time_seconds: int = 1 * 60 * 60 # 1 hour default_cache_time_seconds: int = 1 * 60 * 60 # 1 hour
def __init__(self, router_cache: DualCache, model_list: list): def __init__(self, router_cache: DualCache, model_list: list, routing_args: dict = {}):
self.router_cache = router_cache self.router_cache = router_cache
self.model_list = model_list self.model_list = model_list
self.routing_args = RoutingArgs(**routing_args)
def pre_call_check(self, deployment: Dict) -> Optional[Dict]: def pre_call_check(self, deployment: Dict) -> Optional[Dict]:
""" """
@ -89,7 +104,7 @@ class LowestTPMLoggingHandler_v2(CustomLogger):
) )
else: else:
# if local result below limit, check redis ## prevent unnecessary redis checks # if local result below limit, check redis ## prevent unnecessary redis checks
result = self.router_cache.increment_cache(key=rpm_key, value=1) result = self.router_cache.increment_cache(key=rpm_key, value=1, ttl=self.routing_args.ttl)
if result is not None and result > deployment_rpm: if result is not None and result > deployment_rpm:
raise litellm.RateLimitError( raise litellm.RateLimitError(
message="Deployment over defined rpm limit={}. current usage={}".format( message="Deployment over defined rpm limit={}. current usage={}".format(
@ -168,7 +183,7 @@ class LowestTPMLoggingHandler_v2(CustomLogger):
else: else:
# if local result below limit, check redis ## prevent unnecessary redis checks # if local result below limit, check redis ## prevent unnecessary redis checks
result = await self.router_cache.async_increment_cache( result = await self.router_cache.async_increment_cache(
key=rpm_key, value=1 key=rpm_key, value=1, ttl=self.routing_args.ttl
) )
if result is not None and result > deployment_rpm: if result is not None and result > deployment_rpm:
raise litellm.RateLimitError( raise litellm.RateLimitError(
@ -229,7 +244,7 @@ class LowestTPMLoggingHandler_v2(CustomLogger):
# update cache # update cache
## TPM ## TPM
self.router_cache.increment_cache(key=tpm_key, value=total_tokens) self.router_cache.increment_cache(key=tpm_key, value=total_tokens, ttl=self.routing_args.ttl)
### TESTING ### ### TESTING ###
if self.test_flag: if self.test_flag:
self.logged_success += 1 self.logged_success += 1
@ -273,7 +288,7 @@ class LowestTPMLoggingHandler_v2(CustomLogger):
## TPM ## TPM
await self.router_cache.async_increment_cache( await self.router_cache.async_increment_cache(
key=tpm_key, value=total_tokens key=tpm_key, value=total_tokens, ttl=self.routing_args.ttl
) )
### TESTING ### ### TESTING ###

View file

@ -134,6 +134,56 @@ async def test_acompletion_caching_on_router():
traceback.print_exc() traceback.print_exc()
pytest.fail(f"Error occurred: {e}") pytest.fail(f"Error occurred: {e}")
@pytest.mark.asyncio
async def test_completion_caching_on_router():
# tests completion + caching on router
try:
litellm.set_verbose = True
model_list = [
{
"model_name": "gpt-3.5-turbo",
"litellm_params": {
"model": "gpt-3.5-turbo",
"api_key": os.getenv("OPENAI_API_KEY"),
},
"tpm": 1000,
"rpm": 1,
},
]
messages = [
{"role": "user", "content": f"write a one sentence poem {time.time()}?"}
]
router = Router(
model_list=model_list,
redis_host=os.environ["REDIS_HOST"],
redis_password=os.environ["REDIS_PASSWORD"],
redis_port=os.environ["REDIS_PORT"],
cache_responses=True,
timeout=30,
routing_strategy_args={"ttl": 10},
routing_strategy="usage-based-routing",
)
response1 = await router.completion(
model="gpt-3.5-turbo", messages=messages, temperature=1
)
print(f"response1: {response1}")
await asyncio.sleep(10)
response2 = await router.completion(
model="gpt-3.5-turbo", messages=messages, temperature=1
)
print(f"response2: {response2}")
assert len(response1.choices[0].message.content) > 0
assert len(response2.choices[0].message.content) > 0
router.reset()
except litellm.Timeout as e:
end_time = time.time()
print(f"timeout error occurred: {end_time - start_time}")
pass
except Exception as e:
traceback.print_exc()
pytest.fail(f"Error occurred: {e}")
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_acompletion_caching_with_ttl_on_router(): async def test_acompletion_caching_with_ttl_on_router():