support default fallback models

This commit is contained in:
Krrish Dholakia 2023-09-21 14:28:07 -07:00
parent 31c995a1a4
commit ebd763287a
5 changed files with 61 additions and 33 deletions

View file

@ -14,6 +14,7 @@ from litellm import completion_with_config
config = { config = {
"function": "completion", "function": "completion",
"default_fallback_models": ["gpt-3.5-turbo", "claude-instant-1", "j2-ultra"],
"model": { "model": {
"claude-instant-1": { "claude-instant-1": {
"needs_moderation": True "needs_moderation": True
@ -26,12 +27,20 @@ config = {
} }
} }
def test_config(): def test_config_context_window_exceeded():
try: try:
sample_text = "how does a court case get to the Supreme Court?" * 1000 sample_text = "how does a court case get to the Supreme Court?" * 1000
messages = [{"content": sample_text, "role": "user"}] messages = [{"content": sample_text, "role": "user"}]
response = completion_with_config(model="gpt-3.5-turbo", messages=messages, config=config) response = completion_with_config(model="gpt-3.5-turbo", messages=messages, config=config)
print(response) print(response)
except Exception as e:
print(f"Exception: {e}")
pytest.fail(f"An exception occurred: {e}")
# test_config_context_window_exceeded()
def test_config_context_moderation():
try:
messages=[{"role": "user", "content": "I want to kill them."}] messages=[{"role": "user", "content": "I want to kill them."}]
response = completion_with_config(model="claude-instant-1", messages=messages, config=config) response = completion_with_config(model="claude-instant-1", messages=messages, config=config)
print(response) print(response)
@ -39,4 +48,15 @@ def test_config():
print(f"Exception: {e}") print(f"Exception: {e}")
pytest.fail(f"An exception occurred: {e}") pytest.fail(f"An exception occurred: {e}")
# test_config() # test_config_context_moderation()
def test_config_context_default_fallback():
try:
messages=[{"role": "user", "content": "Hey, how's it going?"}]
response = completion_with_config(model="claude-instant-1", messages=messages, config=config, api_key="bad-key")
print(response)
except Exception as e:
print(f"Exception: {e}")
pytest.fail(f"An exception occurred: {e}")
test_config_context_default_fallback()

View file

@ -2795,36 +2795,45 @@ def completion_with_config(*args, config: Union[dict, str], **kwargs):
models_with_config = completion_config["model"].keys() models_with_config = completion_config["model"].keys()
model = args[0] if len(args) > 0 else kwargs["model"] model = args[0] if len(args) > 0 else kwargs["model"]
messages = args[1] if len(args) > 1 else kwargs["messages"] messages = args[1] if len(args) > 1 else kwargs["messages"]
if model in models_with_config:
## Moderation check
if completion_config["model"][model].get("needs_moderation"):
input = " ".join(message["content"] for message in messages)
response = litellm.moderation(input=input)
flagged = response["results"][0]["flagged"]
if flagged:
raise Exception("This response was flagged as inappropriate")
## Load Error Handling Logic
error_handling = None
if completion_config["model"][model].get("error_handling"):
error_handling = completion_config["model"][model]["error_handling"]
try: ## Default fallback models
response = litellm.completion(*args, **kwargs) fallback_models = completion_config.get("default_fallback_models")
return response try:
except Exception as e: if model in models_with_config:
exception_name = type(e).__name__ ## Moderation check
fallback_model = None if completion_config["model"][model].get("needs_moderation"):
if error_handling and exception_name in error_handling: input = " ".join(message["content"] for message in messages)
error_handler = error_handling[exception_name] response = litellm.moderation(input=input)
# either switch model or api key flagged = response["results"][0]["flagged"]
fallback_model = error_handler.get("fallback_model", None) if flagged:
if fallback_model: raise Exception("This response was flagged as inappropriate")
kwargs["model"] = fallback_model
return litellm.completion(*args, **kwargs) ## Model-specific Error Handling
raise e error_handling = None
else: if completion_config["model"][model].get("error_handling"):
return litellm.completion(*args, **kwargs) error_handling = completion_config["model"][model]["error_handling"]
try:
response = litellm.completion(*args, **kwargs)
return response
except Exception as e:
exception_name = type(e).__name__
fallback_model = None
if error_handling and exception_name in error_handling:
error_handler = error_handling[exception_name]
# either switch model or api key
fallback_model = error_handler.get("fallback_model", None)
if fallback_model:
kwargs["model"] = fallback_model
return litellm.completion(*args, **kwargs)
raise e
else:
return litellm.completion(*args, **kwargs)
except Exception as e:
if fallback_models:
model = fallback_models.pop(0)
return completion_with_fallbacks(model=model, messages=messages, fallbacks=fallback_models)
raise e
@ -2924,8 +2933,7 @@ def completion_with_fallbacks(**kwargs):
# delete model from kwargs if it exists # delete model from kwargs if it exists
if kwargs.get("model"): if kwargs.get("model"):
del kwargs["model"] del kwargs["model"]
print("making completion call", model)
response = litellm.completion(**kwargs, model=model) response = litellm.completion(**kwargs, model=model)
if response != None: if response != None: