fix use one async async_batch_set_cache (#5956)

This commit is contained in:
Ishaan Jaff 2024-09-28 09:59:38 -07:00 committed by GitHub
parent 1f51159ed2
commit 088d906276
No known key found for this signature in database
GPG key ID: B5690EEEBB952194

View file

@ -1,7 +1,7 @@
import sys
import traceback
from datetime import datetime, timedelta
from typing import TYPE_CHECKING, Any, Literal, Optional, Union
from typing import TYPE_CHECKING, Any, List, Literal, Optional, Tuple, Union
from fastapi import HTTPException
@ -53,6 +53,7 @@ class _PROXY_MaxParallelRequestsHandler(CustomLogger):
rpm_limit: int,
request_count_api_key: str,
rate_limit_type: Literal["user", "customer", "team"],
values_to_update_in_cache: List[Tuple[Any, Any]],
):
current = await self.internal_usage_cache.async_get_cache(
key=request_count_api_key,
@ -69,11 +70,7 @@ class _PROXY_MaxParallelRequestsHandler(CustomLogger):
"current_tpm": 0,
"current_rpm": 0,
}
await self.internal_usage_cache.async_set_cache(
key=request_count_api_key,
value=new_val,
litellm_parent_otel_span=user_api_key_dict.parent_otel_span,
)
values_to_update_in_cache.append((request_count_api_key, new_val))
elif (
int(current["current_requests"]) < max_parallel_requests
and current["current_tpm"] < tpm_limit
@ -85,11 +82,7 @@ class _PROXY_MaxParallelRequestsHandler(CustomLogger):
"current_tpm": current["current_tpm"],
"current_rpm": current["current_rpm"],
}
await self.internal_usage_cache.async_set_cache(
key=request_count_api_key,
value=new_val,
litellm_parent_otel_span=user_api_key_dict.parent_otel_span,
)
values_to_update_in_cache.append((request_count_api_key, new_val))
else:
raise HTTPException(
status_code=429,
@ -148,6 +141,10 @@ class _PROXY_MaxParallelRequestsHandler(CustomLogger):
if rpm_limit is None:
rpm_limit = sys.maxsize
values_to_update_in_cache: List[Tuple[Any, Any]] = (
[]
) # values that need to get updated in cache, will run a batch_set_cache after this function
# ------------
# Setup values
# ------------
@ -208,11 +205,7 @@ class _PROXY_MaxParallelRequestsHandler(CustomLogger):
"current_tpm": 0,
"current_rpm": 0,
}
await self.internal_usage_cache.async_set_cache(
key=request_count_api_key,
value=new_val,
litellm_parent_otel_span=user_api_key_dict.parent_otel_span,
)
values_to_update_in_cache.append((request_count_api_key, new_val))
elif (
int(current["current_requests"]) < max_parallel_requests
and current["current_tpm"] < tpm_limit
@ -224,11 +217,7 @@ class _PROXY_MaxParallelRequestsHandler(CustomLogger):
"current_tpm": current["current_tpm"],
"current_rpm": current["current_rpm"],
}
await self.internal_usage_cache.async_set_cache(
key=request_count_api_key,
value=new_val,
litellm_parent_otel_span=user_api_key_dict.parent_otel_span,
)
values_to_update_in_cache.append((request_count_api_key, new_val))
else:
return self.raise_rate_limit_error(
additional_details=f"Hit limit for api_key: {api_key}. tpm_limit: {tpm_limit}, current_tpm {current['current_tpm']} , rpm_limit: {rpm_limit} current rpm {current['current_rpm']} "
@ -268,11 +257,7 @@ class _PROXY_MaxParallelRequestsHandler(CustomLogger):
"current_tpm": 0,
"current_rpm": 0,
}
await self.internal_usage_cache.async_set_cache(
key=request_count_api_key,
value=new_val,
litellm_parent_otel_span=user_api_key_dict.parent_otel_span,
)
values_to_update_in_cache.append((request_count_api_key, new_val))
elif tpm_limit_for_model is not None or rpm_limit_for_model is not None:
# Increase count for this token
new_val = {
@ -295,11 +280,7 @@ class _PROXY_MaxParallelRequestsHandler(CustomLogger):
additional_details=f"Hit RPM limit for model: {_model} on api_key: {api_key}. rpm_limit: {rpm_limit_for_model}, current_rpm {current['current_rpm']} "
)
else:
await self.internal_usage_cache.async_set_cache(
key=request_count_api_key,
value=new_val,
litellm_parent_otel_span=user_api_key_dict.parent_otel_span,
)
values_to_update_in_cache.append((request_count_api_key, new_val))
_remaining_tokens = None
_remaining_requests = None
@ -356,6 +337,7 @@ class _PROXY_MaxParallelRequestsHandler(CustomLogger):
tpm_limit=user_tpm_limit,
rpm_limit=user_rpm_limit,
rate_limit_type="user",
values_to_update_in_cache=values_to_update_in_cache,
)
# TEAM RATE LIMITS
@ -384,6 +366,7 @@ class _PROXY_MaxParallelRequestsHandler(CustomLogger):
tpm_limit=team_tpm_limit,
rpm_limit=team_rpm_limit,
rate_limit_type="team",
values_to_update_in_cache=values_to_update_in_cache,
)
# End-User Rate Limits
@ -417,8 +400,15 @@ class _PROXY_MaxParallelRequestsHandler(CustomLogger):
tpm_limit=end_user_tpm_limit,
rpm_limit=end_user_rpm_limit,
rate_limit_type="customer",
values_to_update_in_cache=values_to_update_in_cache,
)
await self.internal_usage_cache.async_batch_set_cache(
cache_list=values_to_update_in_cache,
ttl=60,
litellm_parent_otel_span=user_api_key_dict.parent_otel_span,
)
return
async def async_log_success_event(self, kwargs, response_obj, start_time, end_time):