fix(router.py): add support for context window fallbacks on router

This commit is contained in:
Krrish Dholakia 2023-11-23 16:41:45 -08:00
parent a1bb880872
commit c273d6f0d6
3 changed files with 65 additions and 104 deletions

View file

@ -280,10 +280,15 @@ class Router:
try:
self.print_verbose(f"Trying to fallback b/w models")
if isinstance(e, litellm.ContextWindowExceededError):
for item in self.context_window_fallback_model_group: # [{"gpt-3.5-turbo": ["gpt-4"]}]
fallback_model_group = None
for item in self.context_window_fallbacks: # [{"gpt-3.5-turbo": ["gpt-4"]}]
if list(item.keys())[0] == model_group:
fallback_model_group = item[model_group]
break
if fallback_model_group is None:
raise original_exception
for mg in fallback_model_group:
"""
Iterate through the model groups and try calling that deployment
@ -360,6 +365,7 @@ class Router:
If it fails after num_retries, fall back to another model group
"""
model_group = kwargs.get("model")
try:
response = self.function_with_retries(*args, **kwargs)
self.print_verbose(f'Response: {response}')
@ -368,36 +374,47 @@ class Router:
original_exception = e
self.print_verbose(f"An exception occurs{original_exception}")
try:
self.print_verbose(f"Trying to fallback b/w models")
fallback_model_group = []
self.print_verbose(f"Trying to fallback b/w models. Initial model group: {model_group}")
self.print_verbose(f"Type of exception: {type(e)}; error_message: {str(e)}")
if isinstance(e, litellm.ContextWindowExceededError):
for item in self.context_window_fallback_model_group: # [{"gpt-3.5-turbo": ["gpt-4"]}]
self.print_verbose(f"inside context window fallbacks: {self.context_window_fallbacks}")
fallback_model_group = None
for item in self.context_window_fallbacks: # [{"gpt-3.5-turbo": ["gpt-4"]}]
if list(item.keys())[0] == model_group:
fallback_model_group = item[model_group]
break
if fallback_model_group is None:
raise original_exception
for mg in fallback_model_group:
"""
Iterate through the model groups and try calling that deployment
"""
try:
kwargs["model"] = mg
response = self.function_with_retries(*args, **kwargs)
response = self.function_with_fallbacks(*args, **kwargs)
return response
except Exception as e:
pass
else:
self.print_verbose(f"inside model fallbacks: {self.fallbacks}")
fallback_model_group = None
for item in self.fallbacks:
if list(item.keys())[0] == model_group:
fallback_model_group = item[model_group]
break
if fallback_model_group is None:
raise original_exception
for mg in fallback_model_group:
"""
Iterate through the model groups and try calling that deployment
"""
try:
kwargs["model"] = mg
response = self.function_with_retries(*args, **kwargs)
response = self.function_with_fallbacks(*args, **kwargs)
return response
except Exception as e:
pass