addressed comments

This commit is contained in:
sumanth 2024-05-14 10:05:19 +05:30
parent 0db58c2fac
commit 4bbd9c866c
4 changed files with 50 additions and 15 deletions

View file

@ -340,13 +340,17 @@ class Router:
litellm.callbacks.append(self.leastbusy_logger) # type: ignore litellm.callbacks.append(self.leastbusy_logger) # type: ignore
elif routing_strategy == "usage-based-routing": elif routing_strategy == "usage-based-routing":
self.lowesttpm_logger = LowestTPMLoggingHandler( self.lowesttpm_logger = LowestTPMLoggingHandler(
router_cache=self.cache, model_list=self.model_list router_cache=self.cache,
model_list=self.model_list,
routing_args=routing_strategy_args
) )
if isinstance(litellm.callbacks, list): if isinstance(litellm.callbacks, list):
litellm.callbacks.append(self.lowesttpm_logger) # type: ignore litellm.callbacks.append(self.lowesttpm_logger) # type: ignore
elif routing_strategy == "usage-based-routing-v2": elif routing_strategy == "usage-based-routing-v2":
self.lowesttpm_logger_v2 = LowestTPMLoggingHandler_v2( self.lowesttpm_logger_v2 = LowestTPMLoggingHandler_v2(
router_cache=self.cache, model_list=self.model_list router_cache=self.cache,
model_list=self.model_list,
routing_args=routing_strategy_args
) )
if isinstance(litellm.callbacks, list): if isinstance(litellm.callbacks, list):
litellm.callbacks.append(self.lowesttpm_logger_v2) # type: ignore litellm.callbacks.append(self.lowesttpm_logger_v2) # type: ignore

View file

@ -1,6 +1,6 @@
#### What this does #### #### What this does ####
# identifies lowest tpm deployment # identifies lowest tpm deployment
from pydantic import BaseModel
import dotenv, os, requests, random import dotenv, os, requests, random
from typing import Optional, Union, List, Dict from typing import Optional, Union, List, Dict
from datetime import datetime from datetime import datetime
@ -11,16 +11,31 @@ from litellm.integrations.custom_logger import CustomLogger
from litellm._logging import verbose_router_logger from litellm._logging import verbose_router_logger
from litellm.utils import print_verbose from litellm.utils import print_verbose
class LiteLLMBase(BaseModel):
"""
Implements default functions, all pydantic objects should have.
"""
def json(self, **kwargs):
try:
return self.model_dump() # noqa
except:
# if using pydantic v1
return self.dict()
class RoutingArgs(LiteLLMBase):
ttl: int = 1 * 60 # 1min (RPM/TPM expire key)
class LowestTPMLoggingHandler(CustomLogger): class LowestTPMLoggingHandler(CustomLogger):
test_flag: bool = False test_flag: bool = False
logged_success: int = 0 logged_success: int = 0
logged_failure: int = 0 logged_failure: int = 0
default_cache_time_seconds: int = 1 * 60 * 60 # 1 hour default_cache_time_seconds: int = 1 * 60 * 60 # 1 hour
def __init__(self, router_cache: DualCache, model_list: list): def __init__(self, router_cache: DualCache, model_list: list, routing_args: dict = {}):
self.router_cache = router_cache self.router_cache = router_cache
self.model_list = model_list self.model_list = model_list
self.routing_args = RoutingArgs(**routing_args)
def log_success_event(self, kwargs, response_obj, start_time, end_time): def log_success_event(self, kwargs, response_obj, start_time, end_time):
try: try:
@ -57,13 +72,13 @@ class LowestTPMLoggingHandler(CustomLogger):
request_count_dict = self.router_cache.get_cache(key=tpm_key) or {} request_count_dict = self.router_cache.get_cache(key=tpm_key) or {}
request_count_dict[id] = request_count_dict.get(id, 0) + total_tokens request_count_dict[id] = request_count_dict.get(id, 0) + total_tokens
self.router_cache.set_cache(key=tpm_key, value=request_count_dict, ttl= 60) self.router_cache.set_cache(key=tpm_key, value=request_count_dict, ttl=self.routing_args.ttl)
## RPM ## RPM
request_count_dict = self.router_cache.get_cache(key=rpm_key) or {} request_count_dict = self.router_cache.get_cache(key=rpm_key) or {}
request_count_dict[id] = request_count_dict.get(id, 0) + 1 request_count_dict[id] = request_count_dict.get(id, 0) + 1
self.router_cache.set_cache(key=rpm_key, value=request_count_dict, ttl= 60) self.router_cache.set_cache(key=rpm_key, value=request_count_dict, ttl=self.routing_args.ttl)
### TESTING ### ### TESTING ###
if self.test_flag: if self.test_flag:
@ -108,13 +123,13 @@ class LowestTPMLoggingHandler(CustomLogger):
request_count_dict = self.router_cache.get_cache(key=tpm_key) or {} request_count_dict = self.router_cache.get_cache(key=tpm_key) or {}
request_count_dict[id] = request_count_dict.get(id, 0) + total_tokens request_count_dict[id] = request_count_dict.get(id, 0) + total_tokens
self.router_cache.set_cache(key=tpm_key, value=request_count_dict, ttl= 60) self.router_cache.set_cache(key=tpm_key, value=request_count_dict, ttl=self.routing_args.ttl)
## RPM ## RPM
request_count_dict = self.router_cache.get_cache(key=rpm_key) or {} request_count_dict = self.router_cache.get_cache(key=rpm_key) or {}
request_count_dict[id] = request_count_dict.get(id, 0) + 1 request_count_dict[id] = request_count_dict.get(id, 0) + 1
self.router_cache.set_cache(key=rpm_key, value=request_count_dict, ttl= 60) self.router_cache.set_cache(key=rpm_key, value=request_count_dict, ttl=self.routing_args.ttl)
### TESTING ### ### TESTING ###
if self.test_flag: if self.test_flag:

View file

@ -1,6 +1,6 @@
#### What this does #### #### What this does ####
# identifies lowest tpm deployment # identifies lowest tpm deployment
from pydantic import BaseModel
import dotenv, os, requests, random import dotenv, os, requests, random
from typing import Optional, Union, List, Dict from typing import Optional, Union, List, Dict
import datetime as datetime_og import datetime as datetime_og
@ -14,6 +14,20 @@ from litellm._logging import verbose_router_logger
from litellm.utils import print_verbose, get_utc_datetime from litellm.utils import print_verbose, get_utc_datetime
from litellm.types.router import RouterErrors from litellm.types.router import RouterErrors
class LiteLLMBase(BaseModel):
"""
Implements default functions, all pydantic objects should have.
"""
def json(self, **kwargs):
try:
return self.model_dump() # noqa
except:
# if using pydantic v1
return self.dict()
class RoutingArgs(LiteLLMBase):
ttl: int = 1 * 60 # 1min (RPM/TPM expire key)
class LowestTPMLoggingHandler_v2(CustomLogger): class LowestTPMLoggingHandler_v2(CustomLogger):
""" """
@ -33,9 +47,10 @@ class LowestTPMLoggingHandler_v2(CustomLogger):
logged_failure: int = 0 logged_failure: int = 0
default_cache_time_seconds: int = 1 * 60 * 60 # 1 hour default_cache_time_seconds: int = 1 * 60 * 60 # 1 hour
def __init__(self, router_cache: DualCache, model_list: list): def __init__(self, router_cache: DualCache, model_list: list, routing_args: dict = {}):
self.router_cache = router_cache self.router_cache = router_cache
self.model_list = model_list self.model_list = model_list
self.routing_args = RoutingArgs(**routing_args)
def pre_call_check(self, deployment: Dict) -> Optional[Dict]: def pre_call_check(self, deployment: Dict) -> Optional[Dict]:
""" """
@ -89,7 +104,7 @@ class LowestTPMLoggingHandler_v2(CustomLogger):
) )
else: else:
# if local result below limit, check redis ## prevent unnecessary redis checks # if local result below limit, check redis ## prevent unnecessary redis checks
result = self.router_cache.increment_cache(key=rpm_key, value=1, ttl = 60) result = self.router_cache.increment_cache(key=rpm_key, value=1, ttl=self.routing_args.ttl)
if result is not None and result > deployment_rpm: if result is not None and result > deployment_rpm:
raise litellm.RateLimitError( raise litellm.RateLimitError(
message="Deployment over defined rpm limit={}. current usage={}".format( message="Deployment over defined rpm limit={}. current usage={}".format(
@ -168,7 +183,7 @@ class LowestTPMLoggingHandler_v2(CustomLogger):
else: else:
# if local result below limit, check redis ## prevent unnecessary redis checks # if local result below limit, check redis ## prevent unnecessary redis checks
result = await self.router_cache.async_increment_cache( result = await self.router_cache.async_increment_cache(
key=rpm_key, value=1, ttl = 60 key=rpm_key, value=1, ttl=self.routing_args.ttl
) )
if result is not None and result > deployment_rpm: if result is not None and result > deployment_rpm:
raise litellm.RateLimitError( raise litellm.RateLimitError(
@ -229,7 +244,7 @@ class LowestTPMLoggingHandler_v2(CustomLogger):
# update cache # update cache
## TPM ## TPM
self.router_cache.increment_cache(key=tpm_key, value=total_tokens, ttl = 60) self.router_cache.increment_cache(key=tpm_key, value=total_tokens, ttl=self.routing_args.ttl)
### TESTING ### ### TESTING ###
if self.test_flag: if self.test_flag:
self.logged_success += 1 self.logged_success += 1
@ -273,7 +288,7 @@ class LowestTPMLoggingHandler_v2(CustomLogger):
## TPM ## TPM
await self.router_cache.async_increment_cache( await self.router_cache.async_increment_cache(
key=tpm_key, value=total_tokens, ttl = 60 key=tpm_key, value=total_tokens, ttl=self.routing_args.ttl
) )
### TESTING ### ### TESTING ###

View file

@ -161,13 +161,14 @@ async def test_completion_caching_on_router():
redis_port=os.environ["REDIS_PORT"], redis_port=os.environ["REDIS_PORT"],
cache_responses=True, cache_responses=True,
timeout=30, timeout=30,
routing_strategy_args={"ttl": 10},
routing_strategy="usage-based-routing", routing_strategy="usage-based-routing",
) )
response1 = await router.completion( response1 = await router.completion(
model="gpt-3.5-turbo", messages=messages, temperature=1 model="gpt-3.5-turbo", messages=messages, temperature=1
) )
print(f"response1: {response1}") print(f"response1: {response1}")
await asyncio.sleep(60) await asyncio.sleep(10)
response2 = await router.completion( response2 = await router.completion(
model="gpt-3.5-turbo", messages=messages, temperature=1 model="gpt-3.5-turbo", messages=messages, temperature=1
) )