mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-25 10:44:24 +00:00
(core sdk fix) - fix fallbacks stuck in infinite loop (#7751)
* test_acompletion_fallbacks_basic * use common run_async_function * fix completion_with_fallbacks * fix completion with fallbacks * fix fallback utils * test_acompletion_fallbacks_basic * test_completion_fallbacks_sync * huggingface/mistralai/Mistral-7B-Instruct-v0.3
This commit is contained in:
parent
970e9c7507
commit
f1335362cf
6 changed files with 222 additions and 156 deletions
128
litellm/utils.py
128
litellm/utils.py
|
@ -5243,134 +5243,6 @@ def read_config_args(config_path) -> dict:
|
|||
########## experimental completion variants ############################
|
||||
|
||||
|
||||
def completion_with_fallbacks(**kwargs):
|
||||
nested_kwargs = kwargs.pop("kwargs", {})
|
||||
response = None
|
||||
rate_limited_models = set()
|
||||
model_expiration_times = {}
|
||||
start_time = time.time()
|
||||
original_model = kwargs["model"]
|
||||
fallbacks = [kwargs["model"]] + nested_kwargs.get("fallbacks", [])
|
||||
if "fallbacks" in nested_kwargs:
|
||||
del nested_kwargs["fallbacks"] # remove fallbacks so it's not recursive
|
||||
litellm_call_id = str(uuid.uuid4())
|
||||
|
||||
# max time to process a request with fallbacks: default 45s
|
||||
while response is None and time.time() - start_time < 45:
|
||||
for model in fallbacks:
|
||||
# loop thru all models
|
||||
try:
|
||||
# check if it's dict or new model string
|
||||
if isinstance(
|
||||
model, dict
|
||||
): # completion(model="gpt-4", fallbacks=[{"api_key": "", "api_base": ""}, {"api_key": "", "api_base": ""}])
|
||||
kwargs["api_key"] = model.get("api_key", None)
|
||||
kwargs["api_base"] = model.get("api_base", None)
|
||||
model = model.get("model", original_model)
|
||||
elif (
|
||||
model in rate_limited_models
|
||||
): # check if model is currently cooling down
|
||||
if (
|
||||
model_expiration_times.get(model)
|
||||
and time.time() >= model_expiration_times[model]
|
||||
):
|
||||
rate_limited_models.remove(
|
||||
model
|
||||
) # check if it's been 60s of cool down and remove model
|
||||
else:
|
||||
continue # skip model
|
||||
|
||||
# delete model from kwargs if it exists
|
||||
if kwargs.get("model"):
|
||||
del kwargs["model"]
|
||||
|
||||
print_verbose(f"trying to make completion call with model: {model}")
|
||||
kwargs["litellm_call_id"] = litellm_call_id
|
||||
kwargs = {
|
||||
**kwargs,
|
||||
**nested_kwargs,
|
||||
} # combine the openai + litellm params at the same level
|
||||
response = litellm.completion(**kwargs, model=model)
|
||||
print_verbose(f"response: {response}")
|
||||
if response is not None:
|
||||
return response
|
||||
|
||||
except Exception as e:
|
||||
print_verbose(e)
|
||||
rate_limited_models.add(model)
|
||||
model_expiration_times[model] = (
|
||||
time.time() + 60
|
||||
) # cool down this selected model
|
||||
pass
|
||||
return response
|
||||
|
||||
|
||||
async def async_completion_with_fallbacks(**kwargs):
|
||||
nested_kwargs = kwargs.pop("kwargs", {})
|
||||
response = None
|
||||
rate_limited_models = set()
|
||||
model_expiration_times = {}
|
||||
start_time = time.time()
|
||||
original_model = kwargs["model"]
|
||||
fallbacks = [kwargs["model"]] + nested_kwargs.get("fallbacks", [])
|
||||
if "fallbacks" in nested_kwargs:
|
||||
del nested_kwargs["fallbacks"] # remove fallbacks so it's not recursive
|
||||
if "acompletion" in kwargs:
|
||||
del kwargs[
|
||||
"acompletion"
|
||||
] # remove acompletion so it doesn't lead to keyword errors
|
||||
litellm_call_id = str(uuid.uuid4())
|
||||
|
||||
# max time to process a request with fallbacks: default 45s
|
||||
while response is None and time.time() - start_time < 45:
|
||||
for model in fallbacks:
|
||||
# loop thru all models
|
||||
try:
|
||||
# check if it's dict or new model string
|
||||
if isinstance(
|
||||
model, dict
|
||||
): # completion(model="gpt-4", fallbacks=[{"api_key": "", "api_base": ""}, {"api_key": "", "api_base": ""}])
|
||||
kwargs["api_key"] = model.get("api_key", None)
|
||||
kwargs["api_base"] = model.get("api_base", None)
|
||||
model = model.get("model", original_model)
|
||||
elif (
|
||||
model in rate_limited_models
|
||||
): # check if model is currently cooling down
|
||||
if (
|
||||
model_expiration_times.get(model)
|
||||
and time.time() >= model_expiration_times[model]
|
||||
):
|
||||
rate_limited_models.remove(
|
||||
model
|
||||
) # check if it's been 60s of cool down and remove model
|
||||
else:
|
||||
continue # skip model
|
||||
|
||||
# delete model from kwargs if it exists
|
||||
if kwargs.get("model"):
|
||||
del kwargs["model"]
|
||||
|
||||
print_verbose(f"trying to make completion call with model: {model}")
|
||||
kwargs["litellm_call_id"] = litellm_call_id
|
||||
kwargs = {
|
||||
**kwargs,
|
||||
**nested_kwargs,
|
||||
} # combine the openai + litellm params at the same level
|
||||
response = await litellm.acompletion(**kwargs, model=model)
|
||||
print_verbose(f"response: {response}")
|
||||
if response is not None:
|
||||
return response
|
||||
|
||||
except Exception as e:
|
||||
print_verbose(f"error: {e}")
|
||||
rate_limited_models.add(model)
|
||||
model_expiration_times[model] = (
|
||||
time.time() + 60
|
||||
) # cool down this selected model
|
||||
pass
|
||||
return response
|
||||
|
||||
|
||||
def process_system_message(system_message, max_tokens, model):
|
||||
system_message_event = {"role": "system", "content": system_message}
|
||||
system_message_tokens = get_token_count([system_message_event], model)
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue