diff --git a/litellm/router.py b/litellm/router.py index d678e5912f..3d86bccfd6 100644 --- a/litellm/router.py +++ b/litellm/router.py @@ -374,13 +374,17 @@ class Router: litellm.callbacks.append(self.leastbusy_logger) # type: ignore elif routing_strategy == "usage-based-routing": self.lowesttpm_logger = LowestTPMLoggingHandler( - router_cache=self.cache, model_list=self.model_list + router_cache=self.cache, + model_list=self.model_list, + routing_args=routing_strategy_args ) if isinstance(litellm.callbacks, list): litellm.callbacks.append(self.lowesttpm_logger) # type: ignore elif routing_strategy == "usage-based-routing-v2": self.lowesttpm_logger_v2 = LowestTPMLoggingHandler_v2( - router_cache=self.cache, model_list=self.model_list + router_cache=self.cache, + model_list=self.model_list, + routing_args=routing_strategy_args ) if isinstance(litellm.callbacks, list): litellm.callbacks.append(self.lowesttpm_logger_v2) # type: ignore diff --git a/litellm/router_strategy/lowest_tpm_rpm.py b/litellm/router_strategy/lowest_tpm_rpm.py index 625db70482..15460051b8 100644 --- a/litellm/router_strategy/lowest_tpm_rpm.py +++ b/litellm/router_strategy/lowest_tpm_rpm.py @@ -1,6 +1,6 @@ #### What this does #### # identifies lowest tpm deployment - +from pydantic import BaseModel import dotenv, os, requests, random from typing import Optional, Union, List, Dict from datetime import datetime @@ -11,16 +11,31 @@ from litellm.integrations.custom_logger import CustomLogger from litellm._logging import verbose_router_logger from litellm.utils import print_verbose +class LiteLLMBase(BaseModel): + """ + Implements default functions, all pydantic objects should have. + """ + def json(self, **kwargs): + try: + return self.model_dump() # noqa + except: + # if using pydantic v1 + return self.dict() + +class RoutingArgs(LiteLLMBase): + ttl: int = 1 * 60 # 1min (RPM/TPM expire key) + class LowestTPMLoggingHandler(CustomLogger): test_flag: bool = False logged_success: int = 0 logged_failure: int = 0 default_cache_time_seconds: int = 1 * 60 * 60 # 1 hour - def __init__(self, router_cache: DualCache, model_list: list): + def __init__(self, router_cache: DualCache, model_list: list, routing_args: dict = {}): self.router_cache = router_cache self.model_list = model_list + self.routing_args = RoutingArgs(**routing_args) def log_success_event(self, kwargs, response_obj, start_time, end_time): try: @@ -57,13 +72,13 @@ class LowestTPMLoggingHandler(CustomLogger): request_count_dict = self.router_cache.get_cache(key=tpm_key) or {} request_count_dict[id] = request_count_dict.get(id, 0) + total_tokens - self.router_cache.set_cache(key=tpm_key, value=request_count_dict) + self.router_cache.set_cache(key=tpm_key, value=request_count_dict, ttl=self.routing_args.ttl) ## RPM request_count_dict = self.router_cache.get_cache(key=rpm_key) or {} request_count_dict[id] = request_count_dict.get(id, 0) + 1 - self.router_cache.set_cache(key=rpm_key, value=request_count_dict) + self.router_cache.set_cache(key=rpm_key, value=request_count_dict, ttl=self.routing_args.ttl) ### TESTING ### if self.test_flag: @@ -108,13 +123,13 @@ class LowestTPMLoggingHandler(CustomLogger): request_count_dict = self.router_cache.get_cache(key=tpm_key) or {} request_count_dict[id] = request_count_dict.get(id, 0) + total_tokens - self.router_cache.set_cache(key=tpm_key, value=request_count_dict) + self.router_cache.set_cache(key=tpm_key, value=request_count_dict, ttl=self.routing_args.ttl) ## RPM request_count_dict = self.router_cache.get_cache(key=rpm_key) or {} request_count_dict[id] = request_count_dict.get(id, 0) + 1 - self.router_cache.set_cache(key=rpm_key, value=request_count_dict) + self.router_cache.set_cache(key=rpm_key, value=request_count_dict, ttl=self.routing_args.ttl) ### TESTING ### if self.test_flag: diff --git a/litellm/router_strategy/lowest_tpm_rpm_v2.py b/litellm/router_strategy/lowest_tpm_rpm_v2.py index 23e55f4a3c..40e75031ad 100644 --- a/litellm/router_strategy/lowest_tpm_rpm_v2.py +++ b/litellm/router_strategy/lowest_tpm_rpm_v2.py @@ -1,6 +1,6 @@ #### What this does #### # identifies lowest tpm deployment - +from pydantic import BaseModel import dotenv, os, requests, random from typing import Optional, Union, List, Dict import datetime as datetime_og @@ -14,6 +14,20 @@ from litellm._logging import verbose_router_logger from litellm.utils import print_verbose, get_utc_datetime from litellm.types.router import RouterErrors +class LiteLLMBase(BaseModel): + """ + Implements default functions, all pydantic objects should have. + """ + + def json(self, **kwargs): + try: + return self.model_dump() # noqa + except: + # if using pydantic v1 + return self.dict() + +class RoutingArgs(LiteLLMBase): + ttl: int = 1 * 60 # 1min (RPM/TPM expire key) class LowestTPMLoggingHandler_v2(CustomLogger): """ @@ -33,9 +47,10 @@ class LowestTPMLoggingHandler_v2(CustomLogger): logged_failure: int = 0 default_cache_time_seconds: int = 1 * 60 * 60 # 1 hour - def __init__(self, router_cache: DualCache, model_list: list): + def __init__(self, router_cache: DualCache, model_list: list, routing_args: dict = {}): self.router_cache = router_cache self.model_list = model_list + self.routing_args = RoutingArgs(**routing_args) def pre_call_check(self, deployment: Dict) -> Optional[Dict]: """ @@ -89,7 +104,7 @@ class LowestTPMLoggingHandler_v2(CustomLogger): ) else: # if local result below limit, check redis ## prevent unnecessary redis checks - result = self.router_cache.increment_cache(key=rpm_key, value=1) + result = self.router_cache.increment_cache(key=rpm_key, value=1, ttl=self.routing_args.ttl) if result is not None and result > deployment_rpm: raise litellm.RateLimitError( message="Deployment over defined rpm limit={}. current usage={}".format( @@ -168,7 +183,7 @@ class LowestTPMLoggingHandler_v2(CustomLogger): else: # if local result below limit, check redis ## prevent unnecessary redis checks result = await self.router_cache.async_increment_cache( - key=rpm_key, value=1 + key=rpm_key, value=1, ttl=self.routing_args.ttl ) if result is not None and result > deployment_rpm: raise litellm.RateLimitError( @@ -229,7 +244,7 @@ class LowestTPMLoggingHandler_v2(CustomLogger): # update cache ## TPM - self.router_cache.increment_cache(key=tpm_key, value=total_tokens) + self.router_cache.increment_cache(key=tpm_key, value=total_tokens, ttl=self.routing_args.ttl) ### TESTING ### if self.test_flag: self.logged_success += 1 @@ -273,7 +288,7 @@ class LowestTPMLoggingHandler_v2(CustomLogger): ## TPM await self.router_cache.async_increment_cache( - key=tpm_key, value=total_tokens + key=tpm_key, value=total_tokens, ttl=self.routing_args.ttl ) ### TESTING ### diff --git a/litellm/tests/test_router_caching.py b/litellm/tests/test_router_caching.py index ebace161c9..a7ea322b52 100644 --- a/litellm/tests/test_router_caching.py +++ b/litellm/tests/test_router_caching.py @@ -134,6 +134,56 @@ async def test_acompletion_caching_on_router(): traceback.print_exc() pytest.fail(f"Error occurred: {e}") +@pytest.mark.asyncio +async def test_completion_caching_on_router(): + # tests completion + caching on router + try: + litellm.set_verbose = True + model_list = [ + { + "model_name": "gpt-3.5-turbo", + "litellm_params": { + "model": "gpt-3.5-turbo", + "api_key": os.getenv("OPENAI_API_KEY"), + }, + "tpm": 1000, + "rpm": 1, + }, + ] + + messages = [ + {"role": "user", "content": f"write a one sentence poem {time.time()}?"} + ] + router = Router( + model_list=model_list, + redis_host=os.environ["REDIS_HOST"], + redis_password=os.environ["REDIS_PASSWORD"], + redis_port=os.environ["REDIS_PORT"], + cache_responses=True, + timeout=30, + routing_strategy_args={"ttl": 10}, + routing_strategy="usage-based-routing", + ) + response1 = await router.completion( + model="gpt-3.5-turbo", messages=messages, temperature=1 + ) + print(f"response1: {response1}") + await asyncio.sleep(10) + response2 = await router.completion( + model="gpt-3.5-turbo", messages=messages, temperature=1 + ) + print(f"response2: {response2}") + assert len(response1.choices[0].message.content) > 0 + assert len(response2.choices[0].message.content) > 0 + + router.reset() + except litellm.Timeout as e: + end_time = time.time() + print(f"timeout error occurred: {end_time - start_time}") + pass + except Exception as e: + traceback.print_exc() + pytest.fail(f"Error occurred: {e}") @pytest.mark.asyncio async def test_acompletion_caching_with_ttl_on_router():