Merge pull request #5259 from BerriAI/litellm_return_remaining_tokens_in_header

[Feat] return `x-litellm-key-remaining-requests-{model}`: 1, `x-litellm-key-remaining-tokens-{model}: None` in response headers
This commit is contained in:
Ishaan Jaff 2024-08-17 12:41:16 -07:00 committed by GitHub
commit db8f789318
9 changed files with 518 additions and 11 deletions

View file

@ -148,6 +148,10 @@ from litellm.proxy.common_utils.admin_ui_utils import (
html_form,
show_missing_vars_in_env,
)
from litellm.proxy.common_utils.callback_utils import (
get_remaining_tokens_and_requests_from_request_data,
initialize_callbacks_on_proxy,
)
from litellm.proxy.common_utils.debug_utils import init_verbose_loggers
from litellm.proxy.common_utils.debug_utils import router as debugging_endpoints_router
from litellm.proxy.common_utils.encrypt_decrypt_utils import (
@ -158,7 +162,6 @@ from litellm.proxy.common_utils.http_parsing_utils import (
_read_request_body,
check_file_size_under_limit,
)
from litellm.proxy.common_utils.init_callbacks import initialize_callbacks_on_proxy
from litellm.proxy.common_utils.load_config_utils import get_file_contents_from_s3
from litellm.proxy.common_utils.openai_endpoint_utils import (
remove_sensitive_info_from_deployment,
@ -503,6 +506,7 @@ def get_custom_headers(
model_region: Optional[str] = None,
response_cost: Optional[Union[float, str]] = None,
fastest_response_batch_completion: Optional[bool] = None,
request_data: Optional[dict] = {},
**kwargs,
) -> dict:
exclude_values = {"", None}
@ -523,6 +527,12 @@ def get_custom_headers(
),
**{k: str(v) for k, v in kwargs.items()},
}
if request_data:
remaining_tokens_header = get_remaining_tokens_and_requests_from_request_data(
request_data
)
headers.update(remaining_tokens_header)
try:
return {
key: value for key, value in headers.items() if value not in exclude_values
@ -3107,6 +3117,7 @@ async def chat_completion(
response_cost=response_cost,
model_region=getattr(user_api_key_dict, "allowed_model_region", ""),
fastest_response_batch_completion=fastest_response_batch_completion,
request_data=data,
**additional_headers,
)
selected_data_generator = select_data_generator(
@ -3141,6 +3152,7 @@ async def chat_completion(
response_cost=response_cost,
model_region=getattr(user_api_key_dict, "allowed_model_region", ""),
fastest_response_batch_completion=fastest_response_batch_completion,
request_data=data,
**additional_headers,
)
)
@ -3322,6 +3334,7 @@ async def completion(
api_base=api_base,
version=version,
response_cost=response_cost,
request_data=data,
)
selected_data_generator = select_data_generator(
response=response,
@ -3343,6 +3356,7 @@ async def completion(
api_base=api_base,
version=version,
response_cost=response_cost,
request_data=data,
)
)
await check_response_size_is_safe(response=response)
@ -3550,6 +3564,7 @@ async def embeddings(
response_cost=response_cost,
model_region=getattr(user_api_key_dict, "allowed_model_region", ""),
call_id=litellm_call_id,
request_data=data,
)
)
await check_response_size_is_safe(response=response)
@ -3676,6 +3691,7 @@ async def image_generation(
response_cost=response_cost,
model_region=getattr(user_api_key_dict, "allowed_model_region", ""),
call_id=litellm_call_id,
request_data=data,
)
)
@ -3797,6 +3813,7 @@ async def audio_speech(
model_region=getattr(user_api_key_dict, "allowed_model_region", ""),
fastest_response_batch_completion=None,
call_id=litellm_call_id,
request_data=data,
)
selected_data_generator = select_data_generator(
@ -3934,6 +3951,7 @@ async def audio_transcriptions(
response_cost=response_cost,
model_region=getattr(user_api_key_dict, "allowed_model_region", ""),
call_id=litellm_call_id,
request_data=data,
)
)
@ -4037,6 +4055,7 @@ async def get_assistants(
api_base=api_base,
version=version,
model_region=getattr(user_api_key_dict, "allowed_model_region", ""),
request_data=data,
)
)
@ -4132,6 +4151,7 @@ async def create_assistant(
api_base=api_base,
version=version,
model_region=getattr(user_api_key_dict, "allowed_model_region", ""),
request_data=data,
)
)
@ -4227,6 +4247,7 @@ async def delete_assistant(
api_base=api_base,
version=version,
model_region=getattr(user_api_key_dict, "allowed_model_region", ""),
request_data=data,
)
)
@ -4322,6 +4343,7 @@ async def create_threads(
api_base=api_base,
version=version,
model_region=getattr(user_api_key_dict, "allowed_model_region", ""),
request_data=data,
)
)
@ -4416,6 +4438,7 @@ async def get_thread(
api_base=api_base,
version=version,
model_region=getattr(user_api_key_dict, "allowed_model_region", ""),
request_data=data,
)
)
@ -4513,6 +4536,7 @@ async def add_messages(
api_base=api_base,
version=version,
model_region=getattr(user_api_key_dict, "allowed_model_region", ""),
request_data=data,
)
)
@ -4606,6 +4630,7 @@ async def get_messages(
api_base=api_base,
version=version,
model_region=getattr(user_api_key_dict, "allowed_model_region", ""),
request_data=data,
)
)
@ -4713,6 +4738,7 @@ async def run_thread(
api_base=api_base,
version=version,
model_region=getattr(user_api_key_dict, "allowed_model_region", ""),
request_data=data,
)
)
@ -4835,6 +4861,7 @@ async def create_batch(
api_base=api_base,
version=version,
model_region=getattr(user_api_key_dict, "allowed_model_region", ""),
request_data=data,
)
)
@ -4930,6 +4957,7 @@ async def retrieve_batch(
api_base=api_base,
version=version,
model_region=getattr(user_api_key_dict, "allowed_model_region", ""),
request_data=data,
)
)
@ -5148,6 +5176,7 @@ async def moderations(
api_base=api_base,
version=version,
model_region=getattr(user_api_key_dict, "allowed_model_region", ""),
request_data=data,
)
)
@ -5317,6 +5346,7 @@ async def anthropic_response(
api_base=api_base,
version=version,
response_cost=response_cost,
request_data=data,
)
)