addressed comments

This commit is contained in:
sumanth 2024-05-14 10:05:19 +05:30
parent 0db58c2fac
commit 4bbd9c866c
4 changed files with 50 additions and 15 deletions

View file

@ -340,13 +340,17 @@ class Router:
litellm.callbacks.append(self.leastbusy_logger) # type: ignore
elif routing_strategy == "usage-based-routing":
self.lowesttpm_logger = LowestTPMLoggingHandler(
router_cache=self.cache, model_list=self.model_list
router_cache=self.cache,
model_list=self.model_list,
routing_args=routing_strategy_args
)
if isinstance(litellm.callbacks, list):
litellm.callbacks.append(self.lowesttpm_logger) # type: ignore
elif routing_strategy == "usage-based-routing-v2":
self.lowesttpm_logger_v2 = LowestTPMLoggingHandler_v2(
router_cache=self.cache, model_list=self.model_list
router_cache=self.cache,
model_list=self.model_list,
routing_args=routing_strategy_args
)
if isinstance(litellm.callbacks, list):
litellm.callbacks.append(self.lowesttpm_logger_v2) # type: ignore

View file

@ -1,6 +1,6 @@
#### What this does ####
# identifies lowest tpm deployment
from pydantic import BaseModel
import dotenv, os, requests, random
from typing import Optional, Union, List, Dict
from datetime import datetime
@ -11,6 +11,20 @@ from litellm.integrations.custom_logger import CustomLogger
from litellm._logging import verbose_router_logger
from litellm.utils import print_verbose
class LiteLLMBase(BaseModel):
"""
Implements default functions, all pydantic objects should have.
"""
def json(self, **kwargs):
try:
return self.model_dump() # noqa
except:
# if using pydantic v1
return self.dict()
class RoutingArgs(LiteLLMBase):
ttl: int = 1 * 60 # 1min (RPM/TPM expire key)
class LowestTPMLoggingHandler(CustomLogger):
test_flag: bool = False
@ -18,9 +32,10 @@ class LowestTPMLoggingHandler(CustomLogger):
logged_failure: int = 0
default_cache_time_seconds: int = 1 * 60 * 60 # 1 hour
def __init__(self, router_cache: DualCache, model_list: list):
def __init__(self, router_cache: DualCache, model_list: list, routing_args: dict = {}):
self.router_cache = router_cache
self.model_list = model_list
self.routing_args = RoutingArgs(**routing_args)
def log_success_event(self, kwargs, response_obj, start_time, end_time):
try:
@ -57,13 +72,13 @@ class LowestTPMLoggingHandler(CustomLogger):
request_count_dict = self.router_cache.get_cache(key=tpm_key) or {}
request_count_dict[id] = request_count_dict.get(id, 0) + total_tokens
self.router_cache.set_cache(key=tpm_key, value=request_count_dict, ttl= 60)
self.router_cache.set_cache(key=tpm_key, value=request_count_dict, ttl=self.routing_args.ttl)
## RPM
request_count_dict = self.router_cache.get_cache(key=rpm_key) or {}
request_count_dict[id] = request_count_dict.get(id, 0) + 1
self.router_cache.set_cache(key=rpm_key, value=request_count_dict, ttl= 60)
self.router_cache.set_cache(key=rpm_key, value=request_count_dict, ttl=self.routing_args.ttl)
### TESTING ###
if self.test_flag:
@ -108,13 +123,13 @@ class LowestTPMLoggingHandler(CustomLogger):
request_count_dict = self.router_cache.get_cache(key=tpm_key) or {}
request_count_dict[id] = request_count_dict.get(id, 0) + total_tokens
self.router_cache.set_cache(key=tpm_key, value=request_count_dict, ttl= 60)
self.router_cache.set_cache(key=tpm_key, value=request_count_dict, ttl=self.routing_args.ttl)
## RPM
request_count_dict = self.router_cache.get_cache(key=rpm_key) or {}
request_count_dict[id] = request_count_dict.get(id, 0) + 1
self.router_cache.set_cache(key=rpm_key, value=request_count_dict, ttl= 60)
self.router_cache.set_cache(key=rpm_key, value=request_count_dict, ttl=self.routing_args.ttl)
### TESTING ###
if self.test_flag:

View file

@ -1,6 +1,6 @@
#### What this does ####
# identifies lowest tpm deployment
from pydantic import BaseModel
import dotenv, os, requests, random
from typing import Optional, Union, List, Dict
import datetime as datetime_og
@ -14,6 +14,20 @@ from litellm._logging import verbose_router_logger
from litellm.utils import print_verbose, get_utc_datetime
from litellm.types.router import RouterErrors
class LiteLLMBase(BaseModel):
"""
Implements default functions, all pydantic objects should have.
"""
def json(self, **kwargs):
try:
return self.model_dump() # noqa
except:
# if using pydantic v1
return self.dict()
class RoutingArgs(LiteLLMBase):
ttl: int = 1 * 60 # 1min (RPM/TPM expire key)
class LowestTPMLoggingHandler_v2(CustomLogger):
"""
@ -33,9 +47,10 @@ class LowestTPMLoggingHandler_v2(CustomLogger):
logged_failure: int = 0
default_cache_time_seconds: int = 1 * 60 * 60 # 1 hour
def __init__(self, router_cache: DualCache, model_list: list):
def __init__(self, router_cache: DualCache, model_list: list, routing_args: dict = {}):
self.router_cache = router_cache
self.model_list = model_list
self.routing_args = RoutingArgs(**routing_args)
def pre_call_check(self, deployment: Dict) -> Optional[Dict]:
"""
@ -89,7 +104,7 @@ class LowestTPMLoggingHandler_v2(CustomLogger):
)
else:
# if local result below limit, check redis ## prevent unnecessary redis checks
result = self.router_cache.increment_cache(key=rpm_key, value=1, ttl = 60)
result = self.router_cache.increment_cache(key=rpm_key, value=1, ttl=self.routing_args.ttl)
if result is not None and result > deployment_rpm:
raise litellm.RateLimitError(
message="Deployment over defined rpm limit={}. current usage={}".format(
@ -168,7 +183,7 @@ class LowestTPMLoggingHandler_v2(CustomLogger):
else:
# if local result below limit, check redis ## prevent unnecessary redis checks
result = await self.router_cache.async_increment_cache(
key=rpm_key, value=1, ttl = 60
key=rpm_key, value=1, ttl=self.routing_args.ttl
)
if result is not None and result > deployment_rpm:
raise litellm.RateLimitError(
@ -229,7 +244,7 @@ class LowestTPMLoggingHandler_v2(CustomLogger):
# update cache
## TPM
self.router_cache.increment_cache(key=tpm_key, value=total_tokens, ttl = 60)
self.router_cache.increment_cache(key=tpm_key, value=total_tokens, ttl=self.routing_args.ttl)
### TESTING ###
if self.test_flag:
self.logged_success += 1
@ -273,7 +288,7 @@ class LowestTPMLoggingHandler_v2(CustomLogger):
## TPM
await self.router_cache.async_increment_cache(
key=tpm_key, value=total_tokens, ttl = 60
key=tpm_key, value=total_tokens, ttl=self.routing_args.ttl
)
### TESTING ###

View file

@ -161,13 +161,14 @@ async def test_completion_caching_on_router():
redis_port=os.environ["REDIS_PORT"],
cache_responses=True,
timeout=30,
routing_strategy_args={"ttl": 10},
routing_strategy="usage-based-routing",
)
response1 = await router.completion(
model="gpt-3.5-turbo", messages=messages, temperature=1
)
print(f"response1: {response1}")
await asyncio.sleep(60)
await asyncio.sleep(10)
response2 = await router.completion(
model="gpt-3.5-turbo", messages=messages, temperature=1
)