From d14347df95f3334f0544cca941ae4f0b6df180d7 Mon Sep 17 00:00:00 2001 From: Ishaan Jaff Date: Tue, 19 Nov 2024 14:34:23 -0800 Subject: [PATCH] unit testing for provider budget routing --- litellm/router_strategy/provider_budgets.py | 15 ++-- tests/local_testing/test_provider_budgets.py | 76 ++++++++++++++++++++ 2 files changed, 87 insertions(+), 4 deletions(-) diff --git a/litellm/router_strategy/provider_budgets.py b/litellm/router_strategy/provider_budgets.py index c5c6b36fa..de8eda19c 100644 --- a/litellm/router_strategy/provider_budgets.py +++ b/litellm/router_strategy/provider_budgets.py @@ -72,6 +72,8 @@ class ProviderBudgetLimiting(CustomLogger): _provider_configs: Dict[str, Optional[ProviderBudgetInfo]] = {} for deployment in healthy_deployments: provider = self._get_llm_provider_for_deployment(deployment) + if provider is None: + continue budget_config = self._get_budget_config_for_provider(provider) _provider_configs[provider] = budget_config @@ -102,6 +104,8 @@ class ProviderBudgetLimiting(CustomLogger): # Filter healthy deployments based on budget constraints for deployment in healthy_deployments: provider = self._get_llm_provider_for_deployment(deployment) + if provider is None: + continue budget_config = provider_configs.get(provider) if not budget_config: @@ -179,17 +183,20 @@ class ProviderBudgetLimiting(CustomLogger): ) -> Optional[ProviderBudgetInfo]: return self.provider_budget_config.get(provider, None) - def _get_llm_provider_for_deployment(self, deployment: Dict) -> str: + def _get_llm_provider_for_deployment(self, deployment: Dict) -> Optional[str]: try: _litellm_params: LiteLLM_Params = LiteLLM_Params( - **deployment["litellm_params"] + **deployment.get("litellm_params", {"model": ""}) ) _, custom_llm_provider, _, _ = litellm.get_llm_provider( model=_litellm_params.model, litellm_params=_litellm_params, ) - except Exception as e: - raise e + except Exception: + verbose_router_logger.error( + f"Error getting LLM provider for deployment: {deployment}" + ) + return None return custom_llm_provider def get_ttl_seconds(self, time_period: str) -> int: diff --git a/tests/local_testing/test_provider_budgets.py b/tests/local_testing/test_provider_budgets.py index 17371da54..2622a79e0 100644 --- a/tests/local_testing/test_provider_budgets.py +++ b/tests/local_testing/test_provider_budgets.py @@ -133,3 +133,79 @@ async def test_provider_budgets_e2e_test_expect_to_fail(): await asyncio.sleep(0.5) # Verify the error is related to budget exceeded + + +def test_get_ttl_seconds(): + """ + Test the get_ttl_seconds helper method" + + """ + provider_budget = ProviderBudgetLimiting( + router_cache=DualCache(), provider_budget_config={} + ) + + assert provider_budget.get_ttl_seconds("1d") == 86400 # 1 day in seconds + assert provider_budget.get_ttl_seconds("7d") == 604800 # 7 days in seconds + assert provider_budget.get_ttl_seconds("30d") == 2592000 # 30 days in seconds + + with pytest.raises(ValueError, match="Unsupported time period format"): + provider_budget.get_ttl_seconds("1h") + + +def test_get_llm_provider_for_deployment(): + """ + Test the _get_llm_provider_for_deployment helper method + + """ + provider_budget = ProviderBudgetLimiting( + router_cache=DualCache(), provider_budget_config={} + ) + + # Test OpenAI deployment + openai_deployment = {"litellm_params": {"model": "openai/gpt-4"}} + assert ( + provider_budget._get_llm_provider_for_deployment(openai_deployment) == "openai" + ) + + # Test Azure deployment + azure_deployment = { + "litellm_params": { + "model": "azure/gpt-4", + "api_key": "test", + "api_base": "test", + } + } + assert provider_budget._get_llm_provider_for_deployment(azure_deployment) == "azure" + + # should not raise error for unknown deployment + unknown_deployment = {} + assert provider_budget._get_llm_provider_for_deployment(unknown_deployment) is None + + +def test_get_budget_config_for_provider(): + """ + Test the _get_budget_config_for_provider helper method + + """ + config = { + "openai": ProviderBudgetInfo(time_period="1d", budget_limit=100), + "anthropic": ProviderBudgetInfo(time_period="7d", budget_limit=500), + } + + provider_budget = ProviderBudgetLimiting( + router_cache=DualCache(), provider_budget_config=config + ) + + # Test existing providers + openai_config = provider_budget._get_budget_config_for_provider("openai") + assert openai_config is not None + assert openai_config.time_period == "1d" + assert openai_config.budget_limit == 100 + + anthropic_config = provider_budget._get_budget_config_for_provider("anthropic") + assert anthropic_config is not None + assert anthropic_config.time_period == "7d" + assert anthropic_config.budget_limit == 500 + + # Test non-existent provider + assert provider_budget._get_budget_config_for_provider("unknown") is None