fix(router.py): set cooldown_time: per model

This commit is contained in:
Krrish Dholakia 2024-06-25 16:51:55 -07:00
parent e813e984f7
commit d98e00d1e0
6 changed files with 72 additions and 11 deletions

View file

@ -650,6 +650,7 @@ def completion(
headers = kwargs.get("headers", None) or extra_headers
num_retries = kwargs.get("num_retries", None) ## deprecated
max_retries = kwargs.get("max_retries", None)
cooldown_time = kwargs.get("cooldown_time", None)
context_window_fallback_dict = kwargs.get("context_window_fallback_dict", None)
organization = kwargs.get("organization", None)
### CUSTOM MODEL COST ###
@ -763,6 +764,7 @@ def completion(
"allowed_model_region",
"model_config",
"fastest_response",
"cooldown_time",
]
default_params = openai_params + litellm_params
@ -947,6 +949,7 @@ def completion(
input_cost_per_token=input_cost_per_token,
output_cost_per_second=output_cost_per_second,
output_cost_per_token=output_cost_per_token,
cooldown_time=cooldown_time,
)
logging.update_environment_variables(
model=model,
@ -3030,6 +3033,7 @@ def embedding(
client = kwargs.pop("client", None)
rpm = kwargs.pop("rpm", None)
tpm = kwargs.pop("tpm", None)
cooldown_time = kwargs.get("cooldown_time", None)
max_parallel_requests = kwargs.pop("max_parallel_requests", None)
model_info = kwargs.get("model_info", None)
metadata = kwargs.get("metadata", None)
@ -3105,6 +3109,7 @@ def embedding(
"region_name",
"allowed_model_region",
"model_config",
"cooldown_time",
]
default_params = openai_params + litellm_params
non_default_params = {
@ -3165,6 +3170,7 @@ def embedding(
"aembedding": aembedding,
"preset_cache_key": None,
"stream_response": {},
"cooldown_time": cooldown_time,
},
)
if azure == True or custom_llm_provider == "azure":