mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-26 11:14:04 +00:00
Merge pull request #3412 from sumanth13131/usage-based-routing-ttl-on-cache
usage-based-routing-ttl-on-cache
This commit is contained in:
commit
c0e43a7296
4 changed files with 98 additions and 14 deletions
|
@ -374,13 +374,17 @@ class Router:
|
||||||
litellm.callbacks.append(self.leastbusy_logger) # type: ignore
|
litellm.callbacks.append(self.leastbusy_logger) # type: ignore
|
||||||
elif routing_strategy == "usage-based-routing":
|
elif routing_strategy == "usage-based-routing":
|
||||||
self.lowesttpm_logger = LowestTPMLoggingHandler(
|
self.lowesttpm_logger = LowestTPMLoggingHandler(
|
||||||
router_cache=self.cache, model_list=self.model_list
|
router_cache=self.cache,
|
||||||
|
model_list=self.model_list,
|
||||||
|
routing_args=routing_strategy_args
|
||||||
)
|
)
|
||||||
if isinstance(litellm.callbacks, list):
|
if isinstance(litellm.callbacks, list):
|
||||||
litellm.callbacks.append(self.lowesttpm_logger) # type: ignore
|
litellm.callbacks.append(self.lowesttpm_logger) # type: ignore
|
||||||
elif routing_strategy == "usage-based-routing-v2":
|
elif routing_strategy == "usage-based-routing-v2":
|
||||||
self.lowesttpm_logger_v2 = LowestTPMLoggingHandler_v2(
|
self.lowesttpm_logger_v2 = LowestTPMLoggingHandler_v2(
|
||||||
router_cache=self.cache, model_list=self.model_list
|
router_cache=self.cache,
|
||||||
|
model_list=self.model_list,
|
||||||
|
routing_args=routing_strategy_args
|
||||||
)
|
)
|
||||||
if isinstance(litellm.callbacks, list):
|
if isinstance(litellm.callbacks, list):
|
||||||
litellm.callbacks.append(self.lowesttpm_logger_v2) # type: ignore
|
litellm.callbacks.append(self.lowesttpm_logger_v2) # type: ignore
|
||||||
|
|
|
@ -1,6 +1,6 @@
|
||||||
#### What this does ####
|
#### What this does ####
|
||||||
# identifies lowest tpm deployment
|
# identifies lowest tpm deployment
|
||||||
|
from pydantic import BaseModel
|
||||||
import dotenv, os, requests, random
|
import dotenv, os, requests, random
|
||||||
from typing import Optional, Union, List, Dict
|
from typing import Optional, Union, List, Dict
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
|
@ -11,16 +11,31 @@ from litellm.integrations.custom_logger import CustomLogger
|
||||||
from litellm._logging import verbose_router_logger
|
from litellm._logging import verbose_router_logger
|
||||||
from litellm.utils import print_verbose
|
from litellm.utils import print_verbose
|
||||||
|
|
||||||
|
class LiteLLMBase(BaseModel):
|
||||||
|
"""
|
||||||
|
Implements default functions, all pydantic objects should have.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def json(self, **kwargs):
|
||||||
|
try:
|
||||||
|
return self.model_dump() # noqa
|
||||||
|
except:
|
||||||
|
# if using pydantic v1
|
||||||
|
return self.dict()
|
||||||
|
|
||||||
|
class RoutingArgs(LiteLLMBase):
|
||||||
|
ttl: int = 1 * 60 # 1min (RPM/TPM expire key)
|
||||||
|
|
||||||
class LowestTPMLoggingHandler(CustomLogger):
|
class LowestTPMLoggingHandler(CustomLogger):
|
||||||
test_flag: bool = False
|
test_flag: bool = False
|
||||||
logged_success: int = 0
|
logged_success: int = 0
|
||||||
logged_failure: int = 0
|
logged_failure: int = 0
|
||||||
default_cache_time_seconds: int = 1 * 60 * 60 # 1 hour
|
default_cache_time_seconds: int = 1 * 60 * 60 # 1 hour
|
||||||
|
|
||||||
def __init__(self, router_cache: DualCache, model_list: list):
|
def __init__(self, router_cache: DualCache, model_list: list, routing_args: dict = {}):
|
||||||
self.router_cache = router_cache
|
self.router_cache = router_cache
|
||||||
self.model_list = model_list
|
self.model_list = model_list
|
||||||
|
self.routing_args = RoutingArgs(**routing_args)
|
||||||
|
|
||||||
def log_success_event(self, kwargs, response_obj, start_time, end_time):
|
def log_success_event(self, kwargs, response_obj, start_time, end_time):
|
||||||
try:
|
try:
|
||||||
|
@ -57,13 +72,13 @@ class LowestTPMLoggingHandler(CustomLogger):
|
||||||
request_count_dict = self.router_cache.get_cache(key=tpm_key) or {}
|
request_count_dict = self.router_cache.get_cache(key=tpm_key) or {}
|
||||||
request_count_dict[id] = request_count_dict.get(id, 0) + total_tokens
|
request_count_dict[id] = request_count_dict.get(id, 0) + total_tokens
|
||||||
|
|
||||||
self.router_cache.set_cache(key=tpm_key, value=request_count_dict)
|
self.router_cache.set_cache(key=tpm_key, value=request_count_dict, ttl=self.routing_args.ttl)
|
||||||
|
|
||||||
## RPM
|
## RPM
|
||||||
request_count_dict = self.router_cache.get_cache(key=rpm_key) or {}
|
request_count_dict = self.router_cache.get_cache(key=rpm_key) or {}
|
||||||
request_count_dict[id] = request_count_dict.get(id, 0) + 1
|
request_count_dict[id] = request_count_dict.get(id, 0) + 1
|
||||||
|
|
||||||
self.router_cache.set_cache(key=rpm_key, value=request_count_dict)
|
self.router_cache.set_cache(key=rpm_key, value=request_count_dict, ttl=self.routing_args.ttl)
|
||||||
|
|
||||||
### TESTING ###
|
### TESTING ###
|
||||||
if self.test_flag:
|
if self.test_flag:
|
||||||
|
@ -108,13 +123,13 @@ class LowestTPMLoggingHandler(CustomLogger):
|
||||||
request_count_dict = self.router_cache.get_cache(key=tpm_key) or {}
|
request_count_dict = self.router_cache.get_cache(key=tpm_key) or {}
|
||||||
request_count_dict[id] = request_count_dict.get(id, 0) + total_tokens
|
request_count_dict[id] = request_count_dict.get(id, 0) + total_tokens
|
||||||
|
|
||||||
self.router_cache.set_cache(key=tpm_key, value=request_count_dict)
|
self.router_cache.set_cache(key=tpm_key, value=request_count_dict, ttl=self.routing_args.ttl)
|
||||||
|
|
||||||
## RPM
|
## RPM
|
||||||
request_count_dict = self.router_cache.get_cache(key=rpm_key) or {}
|
request_count_dict = self.router_cache.get_cache(key=rpm_key) or {}
|
||||||
request_count_dict[id] = request_count_dict.get(id, 0) + 1
|
request_count_dict[id] = request_count_dict.get(id, 0) + 1
|
||||||
|
|
||||||
self.router_cache.set_cache(key=rpm_key, value=request_count_dict)
|
self.router_cache.set_cache(key=rpm_key, value=request_count_dict, ttl=self.routing_args.ttl)
|
||||||
|
|
||||||
### TESTING ###
|
### TESTING ###
|
||||||
if self.test_flag:
|
if self.test_flag:
|
||||||
|
|
|
@ -1,6 +1,6 @@
|
||||||
#### What this does ####
|
#### What this does ####
|
||||||
# identifies lowest tpm deployment
|
# identifies lowest tpm deployment
|
||||||
|
from pydantic import BaseModel
|
||||||
import dotenv, os, requests, random
|
import dotenv, os, requests, random
|
||||||
from typing import Optional, Union, List, Dict
|
from typing import Optional, Union, List, Dict
|
||||||
import datetime as datetime_og
|
import datetime as datetime_og
|
||||||
|
@ -14,6 +14,20 @@ from litellm._logging import verbose_router_logger
|
||||||
from litellm.utils import print_verbose, get_utc_datetime
|
from litellm.utils import print_verbose, get_utc_datetime
|
||||||
from litellm.types.router import RouterErrors
|
from litellm.types.router import RouterErrors
|
||||||
|
|
||||||
|
class LiteLLMBase(BaseModel):
|
||||||
|
"""
|
||||||
|
Implements default functions, all pydantic objects should have.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def json(self, **kwargs):
|
||||||
|
try:
|
||||||
|
return self.model_dump() # noqa
|
||||||
|
except:
|
||||||
|
# if using pydantic v1
|
||||||
|
return self.dict()
|
||||||
|
|
||||||
|
class RoutingArgs(LiteLLMBase):
|
||||||
|
ttl: int = 1 * 60 # 1min (RPM/TPM expire key)
|
||||||
|
|
||||||
class LowestTPMLoggingHandler_v2(CustomLogger):
|
class LowestTPMLoggingHandler_v2(CustomLogger):
|
||||||
"""
|
"""
|
||||||
|
@ -33,9 +47,10 @@ class LowestTPMLoggingHandler_v2(CustomLogger):
|
||||||
logged_failure: int = 0
|
logged_failure: int = 0
|
||||||
default_cache_time_seconds: int = 1 * 60 * 60 # 1 hour
|
default_cache_time_seconds: int = 1 * 60 * 60 # 1 hour
|
||||||
|
|
||||||
def __init__(self, router_cache: DualCache, model_list: list):
|
def __init__(self, router_cache: DualCache, model_list: list, routing_args: dict = {}):
|
||||||
self.router_cache = router_cache
|
self.router_cache = router_cache
|
||||||
self.model_list = model_list
|
self.model_list = model_list
|
||||||
|
self.routing_args = RoutingArgs(**routing_args)
|
||||||
|
|
||||||
def pre_call_check(self, deployment: Dict) -> Optional[Dict]:
|
def pre_call_check(self, deployment: Dict) -> Optional[Dict]:
|
||||||
"""
|
"""
|
||||||
|
@ -89,7 +104,7 @@ class LowestTPMLoggingHandler_v2(CustomLogger):
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
# if local result below limit, check redis ## prevent unnecessary redis checks
|
# if local result below limit, check redis ## prevent unnecessary redis checks
|
||||||
result = self.router_cache.increment_cache(key=rpm_key, value=1)
|
result = self.router_cache.increment_cache(key=rpm_key, value=1, ttl=self.routing_args.ttl)
|
||||||
if result is not None and result > deployment_rpm:
|
if result is not None and result > deployment_rpm:
|
||||||
raise litellm.RateLimitError(
|
raise litellm.RateLimitError(
|
||||||
message="Deployment over defined rpm limit={}. current usage={}".format(
|
message="Deployment over defined rpm limit={}. current usage={}".format(
|
||||||
|
@ -168,7 +183,7 @@ class LowestTPMLoggingHandler_v2(CustomLogger):
|
||||||
else:
|
else:
|
||||||
# if local result below limit, check redis ## prevent unnecessary redis checks
|
# if local result below limit, check redis ## prevent unnecessary redis checks
|
||||||
result = await self.router_cache.async_increment_cache(
|
result = await self.router_cache.async_increment_cache(
|
||||||
key=rpm_key, value=1
|
key=rpm_key, value=1, ttl=self.routing_args.ttl
|
||||||
)
|
)
|
||||||
if result is not None and result > deployment_rpm:
|
if result is not None and result > deployment_rpm:
|
||||||
raise litellm.RateLimitError(
|
raise litellm.RateLimitError(
|
||||||
|
@ -229,7 +244,7 @@ class LowestTPMLoggingHandler_v2(CustomLogger):
|
||||||
# update cache
|
# update cache
|
||||||
|
|
||||||
## TPM
|
## TPM
|
||||||
self.router_cache.increment_cache(key=tpm_key, value=total_tokens)
|
self.router_cache.increment_cache(key=tpm_key, value=total_tokens, ttl=self.routing_args.ttl)
|
||||||
### TESTING ###
|
### TESTING ###
|
||||||
if self.test_flag:
|
if self.test_flag:
|
||||||
self.logged_success += 1
|
self.logged_success += 1
|
||||||
|
@ -273,7 +288,7 @@ class LowestTPMLoggingHandler_v2(CustomLogger):
|
||||||
|
|
||||||
## TPM
|
## TPM
|
||||||
await self.router_cache.async_increment_cache(
|
await self.router_cache.async_increment_cache(
|
||||||
key=tpm_key, value=total_tokens
|
key=tpm_key, value=total_tokens, ttl=self.routing_args.ttl
|
||||||
)
|
)
|
||||||
|
|
||||||
### TESTING ###
|
### TESTING ###
|
||||||
|
|
|
@ -134,6 +134,56 @@ async def test_acompletion_caching_on_router():
|
||||||
traceback.print_exc()
|
traceback.print_exc()
|
||||||
pytest.fail(f"Error occurred: {e}")
|
pytest.fail(f"Error occurred: {e}")
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_completion_caching_on_router():
|
||||||
|
# tests completion + caching on router
|
||||||
|
try:
|
||||||
|
litellm.set_verbose = True
|
||||||
|
model_list = [
|
||||||
|
{
|
||||||
|
"model_name": "gpt-3.5-turbo",
|
||||||
|
"litellm_params": {
|
||||||
|
"model": "gpt-3.5-turbo",
|
||||||
|
"api_key": os.getenv("OPENAI_API_KEY"),
|
||||||
|
},
|
||||||
|
"tpm": 1000,
|
||||||
|
"rpm": 1,
|
||||||
|
},
|
||||||
|
]
|
||||||
|
|
||||||
|
messages = [
|
||||||
|
{"role": "user", "content": f"write a one sentence poem {time.time()}?"}
|
||||||
|
]
|
||||||
|
router = Router(
|
||||||
|
model_list=model_list,
|
||||||
|
redis_host=os.environ["REDIS_HOST"],
|
||||||
|
redis_password=os.environ["REDIS_PASSWORD"],
|
||||||
|
redis_port=os.environ["REDIS_PORT"],
|
||||||
|
cache_responses=True,
|
||||||
|
timeout=30,
|
||||||
|
routing_strategy_args={"ttl": 10},
|
||||||
|
routing_strategy="usage-based-routing",
|
||||||
|
)
|
||||||
|
response1 = await router.completion(
|
||||||
|
model="gpt-3.5-turbo", messages=messages, temperature=1
|
||||||
|
)
|
||||||
|
print(f"response1: {response1}")
|
||||||
|
await asyncio.sleep(10)
|
||||||
|
response2 = await router.completion(
|
||||||
|
model="gpt-3.5-turbo", messages=messages, temperature=1
|
||||||
|
)
|
||||||
|
print(f"response2: {response2}")
|
||||||
|
assert len(response1.choices[0].message.content) > 0
|
||||||
|
assert len(response2.choices[0].message.content) > 0
|
||||||
|
|
||||||
|
router.reset()
|
||||||
|
except litellm.Timeout as e:
|
||||||
|
end_time = time.time()
|
||||||
|
print(f"timeout error occurred: {end_time - start_time}")
|
||||||
|
pass
|
||||||
|
except Exception as e:
|
||||||
|
traceback.print_exc()
|
||||||
|
pytest.fail(f"Error occurred: {e}")
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_acompletion_caching_with_ttl_on_router():
|
async def test_acompletion_caching_with_ttl_on_router():
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue