forked from phoenix/litellm-mirror
fix(router.py): check for fallbacks in completion params for router
This commit is contained in:
parent
793d3ecf81
commit
fa713abfc3
2 changed files with 57 additions and 16 deletions
|
@ -272,6 +272,8 @@ class Router:
|
||||||
If it fails after num_retries, fall back to another model group
|
If it fails after num_retries, fall back to another model group
|
||||||
"""
|
"""
|
||||||
model_group = kwargs.get("model")
|
model_group = kwargs.get("model")
|
||||||
|
fallbacks = kwargs.pop("fallbacks", self.fallbacks)
|
||||||
|
context_window_fallbacks = kwargs.pop("context_window_fallbacks", self.context_window_fallbacks)
|
||||||
try:
|
try:
|
||||||
response = await self.async_function_with_retries(*args, **kwargs)
|
response = await self.async_function_with_retries(*args, **kwargs)
|
||||||
self.print_verbose(f'Async Response: {response}')
|
self.print_verbose(f'Async Response: {response}')
|
||||||
|
@ -281,9 +283,9 @@ class Router:
|
||||||
original_exception = e
|
original_exception = e
|
||||||
try:
|
try:
|
||||||
self.print_verbose(f"Trying to fallback b/w models")
|
self.print_verbose(f"Trying to fallback b/w models")
|
||||||
if isinstance(e, litellm.ContextWindowExceededError) and self.context_window_fallbacks is not None:
|
if isinstance(e, litellm.ContextWindowExceededError) and context_window_fallbacks is not None:
|
||||||
fallback_model_group = None
|
fallback_model_group = None
|
||||||
for item in self.context_window_fallbacks: # [{"gpt-3.5-turbo": ["gpt-4"]}]
|
for item in context_window_fallbacks: # [{"gpt-3.5-turbo": ["gpt-4"]}]
|
||||||
if list(item.keys())[0] == model_group:
|
if list(item.keys())[0] == model_group:
|
||||||
fallback_model_group = item[model_group]
|
fallback_model_group = item[model_group]
|
||||||
break
|
break
|
||||||
|
@ -301,9 +303,9 @@ class Router:
|
||||||
return response
|
return response
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
pass
|
pass
|
||||||
elif self.fallbacks is not None:
|
elif fallbacks is not None:
|
||||||
self.print_verbose(f"inside model fallbacks: {self.fallbacks}")
|
self.print_verbose(f"inside model fallbacks: {fallbacks}")
|
||||||
for item in self.fallbacks:
|
for item in fallbacks:
|
||||||
if list(item.keys())[0] == model_group:
|
if list(item.keys())[0] == model_group:
|
||||||
fallback_model_group = item[model_group]
|
fallback_model_group = item[model_group]
|
||||||
break
|
break
|
||||||
|
@ -365,7 +367,8 @@ class Router:
|
||||||
If it fails after num_retries, fall back to another model group
|
If it fails after num_retries, fall back to another model group
|
||||||
"""
|
"""
|
||||||
model_group = kwargs.get("model")
|
model_group = kwargs.get("model")
|
||||||
|
fallbacks = kwargs.pop("fallbacks", self.fallbacks)
|
||||||
|
context_window_fallbacks = kwargs.pop("context_window_fallbacks", self.context_window_fallbacks)
|
||||||
try:
|
try:
|
||||||
response = self.function_with_retries(*args, **kwargs)
|
response = self.function_with_retries(*args, **kwargs)
|
||||||
return response
|
return response
|
||||||
|
@ -374,11 +377,11 @@ class Router:
|
||||||
self.print_verbose(f"An exception occurs {original_exception}")
|
self.print_verbose(f"An exception occurs {original_exception}")
|
||||||
try:
|
try:
|
||||||
self.print_verbose(f"Trying to fallback b/w models. Initial model group: {model_group}")
|
self.print_verbose(f"Trying to fallback b/w models. Initial model group: {model_group}")
|
||||||
if isinstance(e, litellm.ContextWindowExceededError) and self.context_window_fallbacks is not None:
|
if isinstance(e, litellm.ContextWindowExceededError) and context_window_fallbacks is not None:
|
||||||
self.print_verbose(f"inside context window fallbacks: {self.context_window_fallbacks}")
|
self.print_verbose(f"inside context window fallbacks: {context_window_fallbacks}")
|
||||||
fallback_model_group = None
|
fallback_model_group = None
|
||||||
|
|
||||||
for item in self.context_window_fallbacks: # [{"gpt-3.5-turbo": ["gpt-4"]}]
|
for item in context_window_fallbacks: # [{"gpt-3.5-turbo": ["gpt-4"]}]
|
||||||
if list(item.keys())[0] == model_group:
|
if list(item.keys())[0] == model_group:
|
||||||
fallback_model_group = item[model_group]
|
fallback_model_group = item[model_group]
|
||||||
break
|
break
|
||||||
|
@ -396,10 +399,10 @@ class Router:
|
||||||
return response
|
return response
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
pass
|
pass
|
||||||
elif self.fallbacks is not None:
|
elif fallbacks is not None:
|
||||||
self.print_verbose(f"inside model fallbacks: {self.fallbacks}")
|
self.print_verbose(f"inside model fallbacks: {fallbacks}")
|
||||||
fallback_model_group = None
|
fallback_model_group = None
|
||||||
for item in self.fallbacks:
|
for item in fallbacks:
|
||||||
if list(item.keys())[0] == model_group:
|
if list(item.keys())[0] == model_group:
|
||||||
fallback_model_group = item[model_group]
|
fallback_model_group = item[model_group]
|
||||||
break
|
break
|
||||||
|
|
|
@ -82,7 +82,7 @@ def test_sync_fallbacks():
|
||||||
router.flush_cache()
|
router.flush_cache()
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(e)
|
print(e)
|
||||||
test_sync_fallbacks()
|
# test_sync_fallbacks()
|
||||||
|
|
||||||
def test_async_fallbacks():
|
def test_async_fallbacks():
|
||||||
litellm.set_verbose = False
|
litellm.set_verbose = False
|
||||||
|
@ -101,7 +101,7 @@ def test_async_fallbacks():
|
||||||
|
|
||||||
asyncio.run(test_get_response())
|
asyncio.run(test_get_response())
|
||||||
|
|
||||||
test_async_fallbacks()
|
# test_async_fallbacks()
|
||||||
|
|
||||||
def test_sync_context_window_fallbacks():
|
def test_sync_context_window_fallbacks():
|
||||||
try:
|
try:
|
||||||
|
@ -110,8 +110,46 @@ def test_sync_context_window_fallbacks():
|
||||||
kwargs["messages"] = [{"role": "user", "content": sample_text}]
|
kwargs["messages"] = [{"role": "user", "content": sample_text}]
|
||||||
response = router.completion(**kwargs)
|
response = router.completion(**kwargs)
|
||||||
print(f"response: {response}")
|
print(f"response: {response}")
|
||||||
router.flush_cache()
|
router.reset()
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(e)
|
print(e)
|
||||||
|
|
||||||
# test_sync_context_window_fallbacks()
|
# test_sync_context_window_fallbacks()
|
||||||
|
|
||||||
|
def test_dynamic_fallbacks_sync():
|
||||||
|
"""
|
||||||
|
Allow setting the fallback in the router.completion() call.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
router = Router(model_list=model_list, set_verbose=True)
|
||||||
|
kwargs = {}
|
||||||
|
kwargs["model"] = "azure/gpt-3.5-turbo"
|
||||||
|
kwargs["messages"] = [{"role": "user", "content": "Hey, how's it going?"}]
|
||||||
|
kwargs["fallbacks"] = [{"azure/gpt-3.5-turbo": ["gpt-3.5-turbo"]}]
|
||||||
|
response = router.completion(**kwargs)
|
||||||
|
print(f"response: {response}")
|
||||||
|
router.reset()
|
||||||
|
except Exception as e:
|
||||||
|
pytest.fail(f"An exception occurred - {e}")
|
||||||
|
|
||||||
|
# test_dynamic_fallbacks_sync()
|
||||||
|
|
||||||
|
def test_dynamic_fallbacks_async():
|
||||||
|
"""
|
||||||
|
Allow setting the fallback in the router.completion() call.
|
||||||
|
"""
|
||||||
|
async def test_get_response():
|
||||||
|
try:
|
||||||
|
router = Router(model_list=model_list, set_verbose=True)
|
||||||
|
kwargs = {}
|
||||||
|
kwargs["model"] = "azure/gpt-3.5-turbo"
|
||||||
|
kwargs["messages"] = [{"role": "user", "content": "Hey, how's it going?"}]
|
||||||
|
kwargs["fallbacks"] = [{"azure/gpt-3.5-turbo": ["gpt-3.5-turbo"]}]
|
||||||
|
response = await router.acompletion(**kwargs)
|
||||||
|
print(f"response: {response}")
|
||||||
|
router.reset()
|
||||||
|
except Exception as e:
|
||||||
|
pytest.fail(f"An exception occurred - {e}")
|
||||||
|
asyncio.run(test_get_response())
|
||||||
|
|
||||||
|
# test_dynamic_fallbacks_async()
|
Loading…
Add table
Add a link
Reference in a new issue