feat return rmng tokens for model for api key

This commit is contained in:
Ishaan Jaff 2024-08-17 12:35:10 -07:00
parent 5985c7e933
commit ee0f772b5c
3 changed files with 73 additions and 6 deletions

View file

@ -148,7 +148,10 @@ from litellm.proxy.common_utils.admin_ui_utils import (
html_form,
show_missing_vars_in_env,
)
from litellm.proxy.common_utils.callback_utils import initialize_callbacks_on_proxy
from litellm.proxy.common_utils.callback_utils import (
get_remaining_tokens_and_requests_from_request_data,
initialize_callbacks_on_proxy,
)
from litellm.proxy.common_utils.debug_utils import init_verbose_loggers
from litellm.proxy.common_utils.debug_utils import router as debugging_endpoints_router
from litellm.proxy.common_utils.encrypt_decrypt_utils import (
@ -503,6 +506,7 @@ def get_custom_headers(
model_region: Optional[str] = None,
response_cost: Optional[Union[float, str]] = None,
fastest_response_batch_completion: Optional[bool] = None,
request_data: Optional[dict] = {},
**kwargs,
) -> dict:
exclude_values = {"", None}
@ -523,6 +527,12 @@ def get_custom_headers(
),
**{k: str(v) for k, v in kwargs.items()},
}
if request_data:
remaining_tokens_header = get_remaining_tokens_and_requests_from_request_data(
request_data
)
headers.update(remaining_tokens_header)
try:
return {
key: value for key, value in headers.items() if value not in exclude_values
@ -3106,6 +3116,7 @@ async def chat_completion(
response_cost=response_cost,
model_region=getattr(user_api_key_dict, "allowed_model_region", ""),
fastest_response_batch_completion=fastest_response_batch_completion,
request_data=data,
**additional_headers,
)
selected_data_generator = select_data_generator(
@ -3140,6 +3151,7 @@ async def chat_completion(
response_cost=response_cost,
model_region=getattr(user_api_key_dict, "allowed_model_region", ""),
fastest_response_batch_completion=fastest_response_batch_completion,
request_data=data,
**additional_headers,
)
)
@ -3323,6 +3335,7 @@ async def completion(
api_base=api_base,
version=version,
response_cost=response_cost,
request_data=data,
)
selected_data_generator = select_data_generator(
response=response,
@ -3344,6 +3357,7 @@ async def completion(
api_base=api_base,
version=version,
response_cost=response_cost,
request_data=data,
)
)
await check_response_size_is_safe(response=response)
@ -3551,6 +3565,7 @@ async def embeddings(
response_cost=response_cost,
model_region=getattr(user_api_key_dict, "allowed_model_region", ""),
call_id=litellm_call_id,
request_data=data,
)
)
await check_response_size_is_safe(response=response)
@ -3678,6 +3693,7 @@ async def image_generation(
response_cost=response_cost,
model_region=getattr(user_api_key_dict, "allowed_model_region", ""),
call_id=litellm_call_id,
request_data=data,
)
)
@ -3799,6 +3815,7 @@ async def audio_speech(
model_region=getattr(user_api_key_dict, "allowed_model_region", ""),
fastest_response_batch_completion=None,
call_id=litellm_call_id,
request_data=data,
)
selected_data_generator = select_data_generator(
@ -3936,6 +3953,7 @@ async def audio_transcriptions(
response_cost=response_cost,
model_region=getattr(user_api_key_dict, "allowed_model_region", ""),
call_id=litellm_call_id,
request_data=data,
)
)
@ -4039,6 +4057,7 @@ async def get_assistants(
api_base=api_base,
version=version,
model_region=getattr(user_api_key_dict, "allowed_model_region", ""),
request_data=data,
)
)
@ -4134,6 +4153,7 @@ async def create_assistant(
api_base=api_base,
version=version,
model_region=getattr(user_api_key_dict, "allowed_model_region", ""),
request_data=data,
)
)
@ -4229,6 +4249,7 @@ async def delete_assistant(
api_base=api_base,
version=version,
model_region=getattr(user_api_key_dict, "allowed_model_region", ""),
request_data=data,
)
)
@ -4324,6 +4345,7 @@ async def create_threads(
api_base=api_base,
version=version,
model_region=getattr(user_api_key_dict, "allowed_model_region", ""),
request_data=data,
)
)
@ -4418,6 +4440,7 @@ async def get_thread(
api_base=api_base,
version=version,
model_region=getattr(user_api_key_dict, "allowed_model_region", ""),
request_data=data,
)
)
@ -4515,6 +4538,7 @@ async def add_messages(
api_base=api_base,
version=version,
model_region=getattr(user_api_key_dict, "allowed_model_region", ""),
request_data=data,
)
)
@ -4608,6 +4632,7 @@ async def get_messages(
api_base=api_base,
version=version,
model_region=getattr(user_api_key_dict, "allowed_model_region", ""),
request_data=data,
)
)
@ -4715,6 +4740,7 @@ async def run_thread(
api_base=api_base,
version=version,
model_region=getattr(user_api_key_dict, "allowed_model_region", ""),
request_data=data,
)
)
@ -4837,6 +4863,7 @@ async def create_batch(
api_base=api_base,
version=version,
model_region=getattr(user_api_key_dict, "allowed_model_region", ""),
request_data=data,
)
)
@ -4932,6 +4959,7 @@ async def retrieve_batch(
api_base=api_base,
version=version,
model_region=getattr(user_api_key_dict, "allowed_model_region", ""),
request_data=data,
)
)
@ -5150,6 +5178,7 @@ async def moderations(
api_base=api_base,
version=version,
model_region=getattr(user_api_key_dict, "allowed_model_region", ""),
request_data=data,
)
)
@ -5319,6 +5348,7 @@ async def anthropic_response(
api_base=api_base,
version=version,
response_cost=response_cost,
request_data=data,
)
)