diff --git a/litellm/tests/test_router_fallbacks.py b/litellm/tests/test_router_fallbacks.py index fa3c175d7..65a6d204d 100644 --- a/litellm/tests/test_router_fallbacks.py +++ b/litellm/tests/test_router_fallbacks.py @@ -701,86 +701,98 @@ async def test_async_fallbacks_max_retries_per_request(): def test_usage_based_routing_fallbacks(): - import os - import litellm - from litellm import Router - from dotenv import load_dotenv + try: + # [Prod Test] + # IT tests Usage Based Routing with fallbacks + # The Request should fail azure/gpt-4-fast. Then fallback -> "azure/gpt-4-basic" -> "openai-gpt-4" + # It should work with "openai-gpt-4" + import os + import litellm + from litellm import Router + from dotenv import load_dotenv - load_dotenv() + load_dotenv() - # Constants for TPM and RPM allocation - AZURE_FAST_TPM = 3 - AZURE_BASIC_TPM = 4 - OPENAI_TPM = 2000 - ANTHROPIC_TPM = 100000 + # Constants for TPM and RPM allocation + AZURE_FAST_TPM = 3 + AZURE_BASIC_TPM = 4 + OPENAI_TPM = 2000 + ANTHROPIC_TPM = 100000 - def get_azure_params(deployment_name: str): - params = { - "model": f"azure/{deployment_name}", - "api_key": os.environ["AZURE_API_KEY"], - "api_version": os.environ["AZURE_API_VERSION"], - "api_base": os.environ["AZURE_API_BASE"], - } - return params + def get_azure_params(deployment_name: str): + params = { + "model": f"azure/{deployment_name}", + "api_key": os.environ["AZURE_API_KEY"], + "api_version": os.environ["AZURE_API_VERSION"], + "api_base": os.environ["AZURE_API_BASE"], + } + return params - def get_openai_params(model: str): - params = { - "model": model, - "api_key": os.environ["OPENAI_API_KEY"], - } - return params + def get_openai_params(model: str): + params = { + "model": model, + "api_key": os.environ["OPENAI_API_KEY"], + } + return params - def get_anthropic_params(model: str): - params = { - "model": model, - "api_key": os.environ["ANTHROPIC_API_KEY"], - } - return params + def get_anthropic_params(model: str): + params = { + "model": model, + "api_key": os.environ["ANTHROPIC_API_KEY"], + } + return params - model_list = [ - { - "model_name": "azure/gpt-4-fast", - "litellm_params": get_azure_params("chatgpt-v-2"), - "tpm": AZURE_FAST_TPM, - }, - { - "model_name": "azure/gpt-4-basic", - "litellm_params": get_azure_params("chatgpt-v-2"), - "tpm": AZURE_BASIC_TPM, - }, - { - "model_name": "openai-gpt-4", - "litellm_params": get_openai_params("gpt-3.5-turbo"), - "tpm": OPENAI_TPM, - }, - { - "model_name": "anthropic-claude-instant-1.2", - "litellm_params": get_anthropic_params("claude-instant-1.2"), - "tpm": ANTHROPIC_TPM, - }, - ] - # litellm.set_verbose=True - fallbacks_list = [ - {"azure/gpt-4-fast": ["azure/gpt-4-basic"]}, - {"azure/gpt-4-basic": ["openai-gpt-4"]}, - {"openai-gpt-4": ["anthropic-claude-instant-1.2"]}, - ] + model_list = [ + { + "model_name": "azure/gpt-4-fast", + "litellm_params": get_azure_params("chatgpt-v-2"), + "tpm": AZURE_FAST_TPM, + }, + { + "model_name": "azure/gpt-4-basic", + "litellm_params": get_azure_params("chatgpt-v-2"), + "tpm": AZURE_BASIC_TPM, + }, + { + "model_name": "openai-gpt-4", + "litellm_params": get_openai_params("gpt-3.5-turbo"), + "tpm": OPENAI_TPM, + }, + { + "model_name": "anthropic-claude-instant-1.2", + "litellm_params": get_anthropic_params("claude-instant-1.2"), + "tpm": ANTHROPIC_TPM, + }, + ] + # litellm.set_verbose=True + fallbacks_list = [ + {"azure/gpt-4-fast": ["azure/gpt-4-basic"]}, + {"azure/gpt-4-basic": ["openai-gpt-4"]}, + {"openai-gpt-4": ["anthropic-claude-instant-1.2"]}, + ] - router = Router( - model_list=model_list, - fallbacks=fallbacks_list, - set_verbose=True, - routing_strategy="usage-based-routing", - redis_host=os.environ["REDIS_HOST"], - redis_port=os.environ["REDIS_PORT"], - ) + router = Router( + model_list=model_list, + fallbacks=fallbacks_list, + set_verbose=True, + routing_strategy="usage-based-routing", + redis_host=os.environ["REDIS_HOST"], + redis_port=os.environ["REDIS_PORT"], + ) - messages = [ - {"content": "Tell me a joke.", "role": "user"}, - ] + messages = [ + {"content": "Tell me a joke.", "role": "user"}, + ] - response = router.completion( - model="azure/gpt-4-fast", messages=messages, n=10, timeout=5 - ) + response = router.completion( + model="azure/gpt-4-fast", messages=messages, timeout=5 + ) + print("response: ", response) + print("response._hidden_params: ", response._hidden_params) - print("response: ", response) + # in this test, we expect azure/gpt-4 fast to fail, then azure-gpt-4 basic to fail and then openai-gpt-4 to pass + # the token count of this message is > AZURE_FAST_TPM, > AZURE_BASIC_TPM + assert response._hidden_params["custom_llm_provider"] == "openai" + + except Exception as e: + pytest.fail(f"An exception occurred {e}")