fix(parallel_request_limiter.py): fix user+team tpm/rpm limit check

Closes https://github.com/BerriAI/litellm/issues/3788
This commit is contained in:
Krrish Dholakia 2024-05-27 08:48:23 -07:00
parent fa064c91fb
commit 4408b717f0
7 changed files with 157 additions and 532 deletions

View file

@ -397,6 +397,7 @@ def _get_pydantic_json_dict(pydantic_obj: BaseModel) -> dict:
def get_custom_headers(
*,
user_api_key_dict: UserAPIKeyAuth,
model_id: Optional[str] = None,
cache_key: Optional[str] = None,
api_base: Optional[str] = None,
@ -410,6 +411,8 @@ def get_custom_headers(
"x-litellm-model-api-base": api_base,
"x-litellm-version": version,
"x-litellm-model-region": model_region,
"x-litellm-key-tpm-limit": str(user_api_key_dict.tpm_limit),
"x-litellm-key-rpm-limit": str(user_api_key_dict.rpm_limit),
}
try:
return {
@ -4059,6 +4062,7 @@ async def chat_completion(
"stream" in data and data["stream"] == True
): # use generate_responses to stream responses
custom_headers = get_custom_headers(
user_api_key_dict=user_api_key_dict,
model_id=model_id,
cache_key=cache_key,
api_base=api_base,
@ -4078,6 +4082,7 @@ async def chat_completion(
fastapi_response.headers.update(
get_custom_headers(
user_api_key_dict=user_api_key_dict,
model_id=model_id,
cache_key=cache_key,
api_base=api_base,
@ -4298,6 +4303,7 @@ async def completion(
"stream" in data and data["stream"] == True
): # use generate_responses to stream responses
custom_headers = get_custom_headers(
user_api_key_dict=user_api_key_dict,
model_id=model_id,
cache_key=cache_key,
api_base=api_base,
@ -4316,6 +4322,7 @@ async def completion(
)
fastapi_response.headers.update(
get_custom_headers(
user_api_key_dict=user_api_key_dict,
model_id=model_id,
cache_key=cache_key,
api_base=api_base,
@ -4565,6 +4572,7 @@ async def embeddings(
fastapi_response.headers.update(
get_custom_headers(
user_api_key_dict=user_api_key_dict,
model_id=model_id,
cache_key=cache_key,
api_base=api_base,
@ -4748,6 +4756,7 @@ async def image_generation(
fastapi_response.headers.update(
get_custom_headers(
user_api_key_dict=user_api_key_dict,
model_id=model_id,
cache_key=cache_key,
api_base=api_base,
@ -4949,6 +4958,7 @@ async def audio_transcriptions(
fastapi_response.headers.update(
get_custom_headers(
user_api_key_dict=user_api_key_dict,
model_id=model_id,
cache_key=cache_key,
api_base=api_base,
@ -5132,6 +5142,7 @@ async def moderations(
fastapi_response.headers.update(
get_custom_headers(
user_api_key_dict=user_api_key_dict,
model_id=model_id,
cache_key=cache_key,
api_base=api_base,