mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-25 18:54:30 +00:00
support default fallback models
This commit is contained in:
parent
31c995a1a4
commit
ebd763287a
5 changed files with 61 additions and 33 deletions
|
@ -2795,36 +2795,45 @@ def completion_with_config(*args, config: Union[dict, str], **kwargs):
|
|||
models_with_config = completion_config["model"].keys()
|
||||
model = args[0] if len(args) > 0 else kwargs["model"]
|
||||
messages = args[1] if len(args) > 1 else kwargs["messages"]
|
||||
if model in models_with_config:
|
||||
## Moderation check
|
||||
if completion_config["model"][model].get("needs_moderation"):
|
||||
input = " ".join(message["content"] for message in messages)
|
||||
response = litellm.moderation(input=input)
|
||||
flagged = response["results"][0]["flagged"]
|
||||
if flagged:
|
||||
raise Exception("This response was flagged as inappropriate")
|
||||
|
||||
## Load Error Handling Logic
|
||||
error_handling = None
|
||||
if completion_config["model"][model].get("error_handling"):
|
||||
error_handling = completion_config["model"][model]["error_handling"]
|
||||
|
||||
try:
|
||||
response = litellm.completion(*args, **kwargs)
|
||||
return response
|
||||
except Exception as e:
|
||||
exception_name = type(e).__name__
|
||||
fallback_model = None
|
||||
if error_handling and exception_name in error_handling:
|
||||
error_handler = error_handling[exception_name]
|
||||
# either switch model or api key
|
||||
fallback_model = error_handler.get("fallback_model", None)
|
||||
if fallback_model:
|
||||
kwargs["model"] = fallback_model
|
||||
return litellm.completion(*args, **kwargs)
|
||||
raise e
|
||||
else:
|
||||
return litellm.completion(*args, **kwargs)
|
||||
## Default fallback models
|
||||
fallback_models = completion_config.get("default_fallback_models")
|
||||
try:
|
||||
if model in models_with_config:
|
||||
## Moderation check
|
||||
if completion_config["model"][model].get("needs_moderation"):
|
||||
input = " ".join(message["content"] for message in messages)
|
||||
response = litellm.moderation(input=input)
|
||||
flagged = response["results"][0]["flagged"]
|
||||
if flagged:
|
||||
raise Exception("This response was flagged as inappropriate")
|
||||
|
||||
## Model-specific Error Handling
|
||||
error_handling = None
|
||||
if completion_config["model"][model].get("error_handling"):
|
||||
error_handling = completion_config["model"][model]["error_handling"]
|
||||
|
||||
try:
|
||||
response = litellm.completion(*args, **kwargs)
|
||||
return response
|
||||
except Exception as e:
|
||||
exception_name = type(e).__name__
|
||||
fallback_model = None
|
||||
if error_handling and exception_name in error_handling:
|
||||
error_handler = error_handling[exception_name]
|
||||
# either switch model or api key
|
||||
fallback_model = error_handler.get("fallback_model", None)
|
||||
if fallback_model:
|
||||
kwargs["model"] = fallback_model
|
||||
return litellm.completion(*args, **kwargs)
|
||||
raise e
|
||||
else:
|
||||
return litellm.completion(*args, **kwargs)
|
||||
except Exception as e:
|
||||
if fallback_models:
|
||||
model = fallback_models.pop(0)
|
||||
return completion_with_fallbacks(model=model, messages=messages, fallbacks=fallback_models)
|
||||
raise e
|
||||
|
||||
|
||||
|
||||
|
@ -2924,8 +2933,7 @@ def completion_with_fallbacks(**kwargs):
|
|||
# delete model from kwargs if it exists
|
||||
if kwargs.get("model"):
|
||||
del kwargs["model"]
|
||||
|
||||
print("making completion call", model)
|
||||
|
||||
response = litellm.completion(**kwargs, model=model)
|
||||
|
||||
if response != None:
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue