fix(acompletion): support fallbacks on acompletion (#7184)

* fix(acompletion): support fallbacks on acompletion

allows health checks for wildcard routes to use fallback models

* test: update cohere generate api testing

* add max tokens to health check (#7000)

* fix: fix health check test

* test: update testing

---------

Co-authored-by: Cameron <561860+wallies@users.noreply.github.com>
This commit is contained in:
Krish Dholakia 2024-12-11 19:20:54 -08:00 committed by GitHub
parent 5fe77499d2
commit a9aeb21d0b
8 changed files with 240 additions and 69 deletions

View file

@ -5697,6 +5697,72 @@ def completion_with_fallbacks(**kwargs):
return response
async def async_completion_with_fallbacks(**kwargs):
nested_kwargs = kwargs.pop("kwargs", {})
response = None
rate_limited_models = set()
model_expiration_times = {}
start_time = time.time()
original_model = kwargs["model"]
fallbacks = [kwargs["model"]] + nested_kwargs.get("fallbacks", [])
if "fallbacks" in nested_kwargs:
del nested_kwargs["fallbacks"] # remove fallbacks so it's not recursive
if "acompletion" in kwargs:
del kwargs[
"acompletion"
] # remove acompletion so it doesn't lead to keyword errors
litellm_call_id = str(uuid.uuid4())
# max time to process a request with fallbacks: default 45s
while response is None and time.time() - start_time < 45:
for model in fallbacks:
# loop thru all models
try:
# check if it's dict or new model string
if isinstance(
model, dict
): # completion(model="gpt-4", fallbacks=[{"api_key": "", "api_base": ""}, {"api_key": "", "api_base": ""}])
kwargs["api_key"] = model.get("api_key", None)
kwargs["api_base"] = model.get("api_base", None)
model = model.get("model", original_model)
elif (
model in rate_limited_models
): # check if model is currently cooling down
if (
model_expiration_times.get(model)
and time.time() >= model_expiration_times[model]
):
rate_limited_models.remove(
model
) # check if it's been 60s of cool down and remove model
else:
continue # skip model
# delete model from kwargs if it exists
if kwargs.get("model"):
del kwargs["model"]
print_verbose(f"trying to make completion call with model: {model}")
kwargs["litellm_call_id"] = litellm_call_id
kwargs = {
**kwargs,
**nested_kwargs,
} # combine the openai + litellm params at the same level
response = await litellm.acompletion(**kwargs, model=model)
print_verbose(f"response: {response}")
if response is not None:
return response
except Exception as e:
print_verbose(f"error: {e}")
rate_limited_models.add(model)
model_expiration_times[model] = (
time.time() + 60
) # cool down this selected model
pass
return response
def process_system_message(system_message, max_tokens, model):
system_message_event = {"role": "system", "content": system_message}
system_message_tokens = get_token_count([system_message_event], model)