unit testing for provider budget routing

This commit is contained in:
Ishaan Jaff 2024-11-19 14:34:23 -08:00
parent b3b237a597
commit d14347df95
2 changed files with 87 additions and 4 deletions

View file

@ -72,6 +72,8 @@ class ProviderBudgetLimiting(CustomLogger):
_provider_configs: Dict[str, Optional[ProviderBudgetInfo]] = {}
for deployment in healthy_deployments:
provider = self._get_llm_provider_for_deployment(deployment)
if provider is None:
continue
budget_config = self._get_budget_config_for_provider(provider)
_provider_configs[provider] = budget_config
@ -102,6 +104,8 @@ class ProviderBudgetLimiting(CustomLogger):
# Filter healthy deployments based on budget constraints
for deployment in healthy_deployments:
provider = self._get_llm_provider_for_deployment(deployment)
if provider is None:
continue
budget_config = provider_configs.get(provider)
if not budget_config:
@ -179,17 +183,20 @@ class ProviderBudgetLimiting(CustomLogger):
) -> Optional[ProviderBudgetInfo]:
return self.provider_budget_config.get(provider, None)
def _get_llm_provider_for_deployment(self, deployment: Dict) -> str:
def _get_llm_provider_for_deployment(self, deployment: Dict) -> Optional[str]:
try:
_litellm_params: LiteLLM_Params = LiteLLM_Params(
**deployment["litellm_params"]
**deployment.get("litellm_params", {"model": ""})
)
_, custom_llm_provider, _, _ = litellm.get_llm_provider(
model=_litellm_params.model,
litellm_params=_litellm_params,
)
except Exception as e:
raise e
except Exception:
verbose_router_logger.error(
f"Error getting LLM provider for deployment: {deployment}"
)
return None
return custom_llm_provider
def get_ttl_seconds(self, time_period: str) -> int:

View file

@ -133,3 +133,79 @@ async def test_provider_budgets_e2e_test_expect_to_fail():
await asyncio.sleep(0.5)
# Verify the error is related to budget exceeded
def test_get_ttl_seconds():
"""
Test the get_ttl_seconds helper method"
"""
provider_budget = ProviderBudgetLimiting(
router_cache=DualCache(), provider_budget_config={}
)
assert provider_budget.get_ttl_seconds("1d") == 86400 # 1 day in seconds
assert provider_budget.get_ttl_seconds("7d") == 604800 # 7 days in seconds
assert provider_budget.get_ttl_seconds("30d") == 2592000 # 30 days in seconds
with pytest.raises(ValueError, match="Unsupported time period format"):
provider_budget.get_ttl_seconds("1h")
def test_get_llm_provider_for_deployment():
"""
Test the _get_llm_provider_for_deployment helper method
"""
provider_budget = ProviderBudgetLimiting(
router_cache=DualCache(), provider_budget_config={}
)
# Test OpenAI deployment
openai_deployment = {"litellm_params": {"model": "openai/gpt-4"}}
assert (
provider_budget._get_llm_provider_for_deployment(openai_deployment) == "openai"
)
# Test Azure deployment
azure_deployment = {
"litellm_params": {
"model": "azure/gpt-4",
"api_key": "test",
"api_base": "test",
}
}
assert provider_budget._get_llm_provider_for_deployment(azure_deployment) == "azure"
# should not raise error for unknown deployment
unknown_deployment = {}
assert provider_budget._get_llm_provider_for_deployment(unknown_deployment) is None
def test_get_budget_config_for_provider():
"""
Test the _get_budget_config_for_provider helper method
"""
config = {
"openai": ProviderBudgetInfo(time_period="1d", budget_limit=100),
"anthropic": ProviderBudgetInfo(time_period="7d", budget_limit=500),
}
provider_budget = ProviderBudgetLimiting(
router_cache=DualCache(), provider_budget_config=config
)
# Test existing providers
openai_config = provider_budget._get_budget_config_for_provider("openai")
assert openai_config is not None
assert openai_config.time_period == "1d"
assert openai_config.budget_limit == 100
anthropic_config = provider_budget._get_budget_config_for_provider("anthropic")
assert anthropic_config is not None
assert anthropic_config.time_period == "7d"
assert anthropic_config.budget_limit == 500
# Test non-existent provider
assert provider_budget._get_budget_config_for_provider("unknown") is None