forked from phoenix/litellm-mirror
fix(router.py): support pre_call_rpm_check for lowest_tpm_rpm_v2 routing
have routing strategies expose an ‘update rpm’ function; for checking + updating rpm pre call
This commit is contained in:
parent
2267aeb803
commit
c03b0bbb24
3 changed files with 75 additions and 25 deletions
|
@ -98,11 +98,12 @@ class InMemoryCache(BaseCache):
|
|||
return_val.append(val)
|
||||
return return_val
|
||||
|
||||
async def async_increment(self, key, value: int, **kwargs):
|
||||
async def async_increment(self, key, value: int, **kwargs) -> int:
|
||||
# get the value
|
||||
init_value = await self.async_get_cache(key=key) or 0
|
||||
value = init_value + value
|
||||
await self.async_set_cache(key, value, **kwargs)
|
||||
return value
|
||||
|
||||
def flush_cache(self):
|
||||
self.cache_dict.clear()
|
||||
|
@ -266,11 +267,12 @@ class RedisCache(BaseCache):
|
|||
if len(self.redis_batch_writing_buffer) >= self.redis_flush_size:
|
||||
await self.flush_cache_buffer()
|
||||
|
||||
async def async_increment(self, key, value: int, **kwargs):
|
||||
async def async_increment(self, key, value: int, **kwargs) -> int:
|
||||
_redis_client = self.init_async_client()
|
||||
try:
|
||||
async with _redis_client as redis_client:
|
||||
await redis_client.incr(name=key, amount=value)
|
||||
result = await redis_client.incr(name=key, amount=value)
|
||||
return result
|
||||
except Exception as e:
|
||||
verbose_logger.error(
|
||||
"LiteLLM Redis Caching: async async_increment() - Got exception from REDIS %s, Writing value=%s",
|
||||
|
@ -278,6 +280,7 @@ class RedisCache(BaseCache):
|
|||
value,
|
||||
)
|
||||
traceback.print_exc()
|
||||
raise e
|
||||
|
||||
async def flush_cache_buffer(self):
|
||||
print_verbose(
|
||||
|
@ -1076,21 +1079,29 @@ class DualCache(BaseCache):
|
|||
|
||||
async def async_increment_cache(
|
||||
self, key, value: int, local_only: bool = False, **kwargs
|
||||
):
|
||||
) -> int:
|
||||
"""
|
||||
Key - the key in cache
|
||||
|
||||
Value - int - the value you want to increment by
|
||||
|
||||
Returns - int - the incremented value
|
||||
"""
|
||||
try:
|
||||
result: int = value
|
||||
if self.in_memory_cache is not None:
|
||||
await self.in_memory_cache.async_increment(key, value, **kwargs)
|
||||
result = await self.in_memory_cache.async_increment(
|
||||
key, value, **kwargs
|
||||
)
|
||||
|
||||
if self.redis_cache is not None and local_only == False:
|
||||
await self.redis_cache.async_increment(key, value, **kwargs)
|
||||
result = await self.redis_cache.async_increment(key, value, **kwargs)
|
||||
|
||||
return result
|
||||
except Exception as e:
|
||||
print_verbose(f"LiteLLM Cache: Excepton async add_cache: {str(e)}")
|
||||
traceback.print_exc()
|
||||
raise e
|
||||
|
||||
def flush_cache(self):
|
||||
if self.in_memory_cache is not None:
|
||||
|
|
|
@ -491,23 +491,16 @@ class Router:
|
|||
deployment=deployment, kwargs=kwargs, client_type="rpm_client"
|
||||
)
|
||||
|
||||
if rpm_semaphore is not None and isinstance(
|
||||
rpm_semaphore, asyncio.Semaphore
|
||||
if (
|
||||
rpm_semaphore is not None
|
||||
and isinstance(rpm_semaphore, asyncio.Semaphore)
|
||||
and self.routing_strategy == "usage-based-routing-v2"
|
||||
):
|
||||
async with rpm_semaphore:
|
||||
"""
|
||||
- Check against in-memory tpm/rpm limits before making the call
|
||||
- Check rpm limits before making the call
|
||||
"""
|
||||
dt = get_utc_datetime()
|
||||
current_minute = dt.strftime("%H-%M")
|
||||
id = kwargs["model_info"]["id"]
|
||||
rpm_key = "{}:rpm:{}".format(id, current_minute)
|
||||
curr_rpm = await self.cache.async_get_cache(key=rpm_key)
|
||||
if (
|
||||
curr_rpm is not None and curr_rpm >= data["rpm"]
|
||||
): # >= b/c the initial count is 0
|
||||
raise Exception("Rate Limit error")
|
||||
await self.cache.async_increment_cache(key=rpm_key, value=1)
|
||||
await self.lowesttpm_logger_v2.pre_call_rpm_check(deployment)
|
||||
response = await _response
|
||||
else:
|
||||
response = await _response
|
||||
|
|
|
@ -7,7 +7,8 @@ import datetime as datetime_og
|
|||
from datetime import datetime
|
||||
|
||||
dotenv.load_dotenv() # Loading env variables using dotenv
|
||||
import traceback, asyncio
|
||||
import traceback, asyncio, httpx
|
||||
import litellm
|
||||
from litellm import token_counter
|
||||
from litellm.caching import DualCache
|
||||
from litellm.integrations.custom_logger import CustomLogger
|
||||
|
@ -37,6 +38,55 @@ class LowestTPMLoggingHandler_v2(CustomLogger):
|
|||
self.router_cache = router_cache
|
||||
self.model_list = model_list
|
||||
|
||||
async def pre_call_rpm_check(self, deployment: dict) -> dict:
|
||||
"""
|
||||
Pre-call check + update model rpm
|
||||
- Used inside semaphore
|
||||
- raise rate limit error if deployment over limit
|
||||
|
||||
Why? solves concurrency issue - https://github.com/BerriAI/litellm/issues/2994
|
||||
|
||||
Returns - deployment
|
||||
|
||||
Raises - RateLimitError if deployment over defined RPM limit
|
||||
"""
|
||||
|
||||
# ------------
|
||||
# Setup values
|
||||
# ------------
|
||||
dt = get_utc_datetime()
|
||||
current_minute = dt.strftime("%H-%M")
|
||||
model_group = deployment.get("model_name", "")
|
||||
rpm_key = f"{model_group}:rpm:{current_minute}"
|
||||
result = await self.router_cache.async_increment_cache(key=rpm_key, value=1)
|
||||
|
||||
deployment_rpm = None
|
||||
if deployment_rpm is None:
|
||||
deployment_rpm = deployment.get("rpm")
|
||||
if deployment_rpm is None:
|
||||
deployment_rpm = deployment.get("litellm_params", {}).get("rpm")
|
||||
if deployment_rpm is None:
|
||||
deployment_rpm = deployment.get("model_info", {}).get("rpm")
|
||||
if deployment_rpm is None:
|
||||
deployment_rpm = float("inf")
|
||||
|
||||
if result is not None and result > deployment_rpm:
|
||||
raise litellm.RateLimitError(
|
||||
message="Deployment over defined rpm limit={}. current usage={}".format(
|
||||
deployment_rpm, result
|
||||
),
|
||||
llm_provider="",
|
||||
model=deployment.get("litellm_params", {}).get("model"),
|
||||
response=httpx.Response(
|
||||
status_code=429,
|
||||
content="Deployment over defined rpm limit={}. current usage={}".format(
|
||||
deployment_rpm, result
|
||||
),
|
||||
request=httpx.Request(method="tpm_rpm_limits", url="https://github.com/BerriAI/litellm"), # type: ignore
|
||||
),
|
||||
)
|
||||
return deployment
|
||||
|
||||
def log_success_event(self, kwargs, response_obj, start_time, end_time):
|
||||
try:
|
||||
"""
|
||||
|
@ -91,7 +141,7 @@ class LowestTPMLoggingHandler_v2(CustomLogger):
|
|||
async def async_log_success_event(self, kwargs, response_obj, start_time, end_time):
|
||||
try:
|
||||
"""
|
||||
Update TPM/RPM usage on success
|
||||
Update TPM usage on success
|
||||
"""
|
||||
if kwargs["litellm_params"].get("metadata") is None:
|
||||
pass
|
||||
|
@ -117,8 +167,6 @@ class LowestTPMLoggingHandler_v2(CustomLogger):
|
|||
) # use the same timezone regardless of system clock
|
||||
|
||||
tpm_key = f"{id}:tpm:{current_minute}"
|
||||
rpm_key = f"{id}:rpm:{current_minute}"
|
||||
|
||||
# ------------
|
||||
# Update usage
|
||||
# ------------
|
||||
|
@ -128,8 +176,6 @@ class LowestTPMLoggingHandler_v2(CustomLogger):
|
|||
await self.router_cache.async_increment_cache(
|
||||
key=tpm_key, value=total_tokens
|
||||
)
|
||||
## RPM
|
||||
await self.router_cache.async_increment_cache(key=rpm_key, value=1)
|
||||
|
||||
### TESTING ###
|
||||
if self.test_flag:
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue