fix(router.py): check for fallbacks in completion params for router

This commit is contained in:
Krrish Dholakia 2023-11-25 18:46:45 -08:00
parent 5c22108868
commit 67fe8824b3
2 changed files with 57 additions and 16 deletions

View file

@ -272,6 +272,8 @@ class Router:
If it fails after num_retries, fall back to another model group If it fails after num_retries, fall back to another model group
""" """
model_group = kwargs.get("model") model_group = kwargs.get("model")
fallbacks = kwargs.pop("fallbacks", self.fallbacks)
context_window_fallbacks = kwargs.pop("context_window_fallbacks", self.context_window_fallbacks)
try: try:
response = await self.async_function_with_retries(*args, **kwargs) response = await self.async_function_with_retries(*args, **kwargs)
self.print_verbose(f'Async Response: {response}') self.print_verbose(f'Async Response: {response}')
@ -281,9 +283,9 @@ class Router:
original_exception = e original_exception = e
try: try:
self.print_verbose(f"Trying to fallback b/w models") self.print_verbose(f"Trying to fallback b/w models")
if isinstance(e, litellm.ContextWindowExceededError) and self.context_window_fallbacks is not None: if isinstance(e, litellm.ContextWindowExceededError) and context_window_fallbacks is not None:
fallback_model_group = None fallback_model_group = None
for item in self.context_window_fallbacks: # [{"gpt-3.5-turbo": ["gpt-4"]}] for item in context_window_fallbacks: # [{"gpt-3.5-turbo": ["gpt-4"]}]
if list(item.keys())[0] == model_group: if list(item.keys())[0] == model_group:
fallback_model_group = item[model_group] fallback_model_group = item[model_group]
break break
@ -301,9 +303,9 @@ class Router:
return response return response
except Exception as e: except Exception as e:
pass pass
elif self.fallbacks is not None: elif fallbacks is not None:
self.print_verbose(f"inside model fallbacks: {self.fallbacks}") self.print_verbose(f"inside model fallbacks: {fallbacks}")
for item in self.fallbacks: for item in fallbacks:
if list(item.keys())[0] == model_group: if list(item.keys())[0] == model_group:
fallback_model_group = item[model_group] fallback_model_group = item[model_group]
break break
@ -365,7 +367,8 @@ class Router:
If it fails after num_retries, fall back to another model group If it fails after num_retries, fall back to another model group
""" """
model_group = kwargs.get("model") model_group = kwargs.get("model")
fallbacks = kwargs.pop("fallbacks", self.fallbacks)
context_window_fallbacks = kwargs.pop("context_window_fallbacks", self.context_window_fallbacks)
try: try:
response = self.function_with_retries(*args, **kwargs) response = self.function_with_retries(*args, **kwargs)
return response return response
@ -374,11 +377,11 @@ class Router:
self.print_verbose(f"An exception occurs {original_exception}") self.print_verbose(f"An exception occurs {original_exception}")
try: try:
self.print_verbose(f"Trying to fallback b/w models. Initial model group: {model_group}") self.print_verbose(f"Trying to fallback b/w models. Initial model group: {model_group}")
if isinstance(e, litellm.ContextWindowExceededError) and self.context_window_fallbacks is not None: if isinstance(e, litellm.ContextWindowExceededError) and context_window_fallbacks is not None:
self.print_verbose(f"inside context window fallbacks: {self.context_window_fallbacks}") self.print_verbose(f"inside context window fallbacks: {context_window_fallbacks}")
fallback_model_group = None fallback_model_group = None
for item in self.context_window_fallbacks: # [{"gpt-3.5-turbo": ["gpt-4"]}] for item in context_window_fallbacks: # [{"gpt-3.5-turbo": ["gpt-4"]}]
if list(item.keys())[0] == model_group: if list(item.keys())[0] == model_group:
fallback_model_group = item[model_group] fallback_model_group = item[model_group]
break break
@ -396,10 +399,10 @@ class Router:
return response return response
except Exception as e: except Exception as e:
pass pass
elif self.fallbacks is not None: elif fallbacks is not None:
self.print_verbose(f"inside model fallbacks: {self.fallbacks}") self.print_verbose(f"inside model fallbacks: {fallbacks}")
fallback_model_group = None fallback_model_group = None
for item in self.fallbacks: for item in fallbacks:
if list(item.keys())[0] == model_group: if list(item.keys())[0] == model_group:
fallback_model_group = item[model_group] fallback_model_group = item[model_group]
break break

View file

@ -82,7 +82,7 @@ def test_sync_fallbacks():
router.flush_cache() router.flush_cache()
except Exception as e: except Exception as e:
print(e) print(e)
test_sync_fallbacks() # test_sync_fallbacks()
def test_async_fallbacks(): def test_async_fallbacks():
litellm.set_verbose = False litellm.set_verbose = False
@ -101,7 +101,7 @@ def test_async_fallbacks():
asyncio.run(test_get_response()) asyncio.run(test_get_response())
test_async_fallbacks() # test_async_fallbacks()
def test_sync_context_window_fallbacks(): def test_sync_context_window_fallbacks():
try: try:
@ -110,8 +110,46 @@ def test_sync_context_window_fallbacks():
kwargs["messages"] = [{"role": "user", "content": sample_text}] kwargs["messages"] = [{"role": "user", "content": sample_text}]
response = router.completion(**kwargs) response = router.completion(**kwargs)
print(f"response: {response}") print(f"response: {response}")
router.flush_cache() router.reset()
except Exception as e: except Exception as e:
print(e) print(e)
# test_sync_context_window_fallbacks() # test_sync_context_window_fallbacks()
def test_dynamic_fallbacks_sync():
"""
Allow setting the fallback in the router.completion() call.
"""
try:
router = Router(model_list=model_list, set_verbose=True)
kwargs = {}
kwargs["model"] = "azure/gpt-3.5-turbo"
kwargs["messages"] = [{"role": "user", "content": "Hey, how's it going?"}]
kwargs["fallbacks"] = [{"azure/gpt-3.5-turbo": ["gpt-3.5-turbo"]}]
response = router.completion(**kwargs)
print(f"response: {response}")
router.reset()
except Exception as e:
pytest.fail(f"An exception occurred - {e}")
# test_dynamic_fallbacks_sync()
def test_dynamic_fallbacks_async():
"""
Allow setting the fallback in the router.completion() call.
"""
async def test_get_response():
try:
router = Router(model_list=model_list, set_verbose=True)
kwargs = {}
kwargs["model"] = "azure/gpt-3.5-turbo"
kwargs["messages"] = [{"role": "user", "content": "Hey, how's it going?"}]
kwargs["fallbacks"] = [{"azure/gpt-3.5-turbo": ["gpt-3.5-turbo"]}]
response = await router.acompletion(**kwargs)
print(f"response: {response}")
router.reset()
except Exception as e:
pytest.fail(f"An exception occurred - {e}")
asyncio.run(test_get_response())
# test_dynamic_fallbacks_async()