fix(router.py): check for fallbacks in completion params for router

This commit is contained in:
Krrish Dholakia 2023-11-25 18:46:45 -08:00
parent 793d3ecf81
commit fa713abfc3
2 changed files with 57 additions and 16 deletions

View file

@ -82,7 +82,7 @@ def test_sync_fallbacks():
router.flush_cache()
except Exception as e:
print(e)
test_sync_fallbacks()
# test_sync_fallbacks()
def test_async_fallbacks():
litellm.set_verbose = False
@ -101,7 +101,7 @@ def test_async_fallbacks():
asyncio.run(test_get_response())
test_async_fallbacks()
# test_async_fallbacks()
def test_sync_context_window_fallbacks():
try:
@ -110,8 +110,46 @@ def test_sync_context_window_fallbacks():
kwargs["messages"] = [{"role": "user", "content": sample_text}]
response = router.completion(**kwargs)
print(f"response: {response}")
router.flush_cache()
router.reset()
except Exception as e:
print(e)
# test_sync_context_window_fallbacks()
# test_sync_context_window_fallbacks()
def test_dynamic_fallbacks_sync():
"""
Allow setting the fallback in the router.completion() call.
"""
try:
router = Router(model_list=model_list, set_verbose=True)
kwargs = {}
kwargs["model"] = "azure/gpt-3.5-turbo"
kwargs["messages"] = [{"role": "user", "content": "Hey, how's it going?"}]
kwargs["fallbacks"] = [{"azure/gpt-3.5-turbo": ["gpt-3.5-turbo"]}]
response = router.completion(**kwargs)
print(f"response: {response}")
router.reset()
except Exception as e:
pytest.fail(f"An exception occurred - {e}")
# test_dynamic_fallbacks_sync()
def test_dynamic_fallbacks_async():
"""
Allow setting the fallback in the router.completion() call.
"""
async def test_get_response():
try:
router = Router(model_list=model_list, set_verbose=True)
kwargs = {}
kwargs["model"] = "azure/gpt-3.5-turbo"
kwargs["messages"] = [{"role": "user", "content": "Hey, how's it going?"}]
kwargs["fallbacks"] = [{"azure/gpt-3.5-turbo": ["gpt-3.5-turbo"]}]
response = await router.acompletion(**kwargs)
print(f"response: {response}")
router.reset()
except Exception as e:
pytest.fail(f"An exception occurred - {e}")
asyncio.run(test_get_response())
# test_dynamic_fallbacks_async()