forked from phoenix/litellm-mirror
unit testing for provider budget routing
This commit is contained in:
parent
b3b237a597
commit
d14347df95
2 changed files with 87 additions and 4 deletions
|
@ -72,6 +72,8 @@ class ProviderBudgetLimiting(CustomLogger):
|
|||
_provider_configs: Dict[str, Optional[ProviderBudgetInfo]] = {}
|
||||
for deployment in healthy_deployments:
|
||||
provider = self._get_llm_provider_for_deployment(deployment)
|
||||
if provider is None:
|
||||
continue
|
||||
budget_config = self._get_budget_config_for_provider(provider)
|
||||
_provider_configs[provider] = budget_config
|
||||
|
||||
|
@ -102,6 +104,8 @@ class ProviderBudgetLimiting(CustomLogger):
|
|||
# Filter healthy deployments based on budget constraints
|
||||
for deployment in healthy_deployments:
|
||||
provider = self._get_llm_provider_for_deployment(deployment)
|
||||
if provider is None:
|
||||
continue
|
||||
budget_config = provider_configs.get(provider)
|
||||
|
||||
if not budget_config:
|
||||
|
@ -179,17 +183,20 @@ class ProviderBudgetLimiting(CustomLogger):
|
|||
) -> Optional[ProviderBudgetInfo]:
|
||||
return self.provider_budget_config.get(provider, None)
|
||||
|
||||
def _get_llm_provider_for_deployment(self, deployment: Dict) -> str:
|
||||
def _get_llm_provider_for_deployment(self, deployment: Dict) -> Optional[str]:
|
||||
try:
|
||||
_litellm_params: LiteLLM_Params = LiteLLM_Params(
|
||||
**deployment["litellm_params"]
|
||||
**deployment.get("litellm_params", {"model": ""})
|
||||
)
|
||||
_, custom_llm_provider, _, _ = litellm.get_llm_provider(
|
||||
model=_litellm_params.model,
|
||||
litellm_params=_litellm_params,
|
||||
)
|
||||
except Exception as e:
|
||||
raise e
|
||||
except Exception:
|
||||
verbose_router_logger.error(
|
||||
f"Error getting LLM provider for deployment: {deployment}"
|
||||
)
|
||||
return None
|
||||
return custom_llm_provider
|
||||
|
||||
def get_ttl_seconds(self, time_period: str) -> int:
|
||||
|
|
|
@ -133,3 +133,79 @@ async def test_provider_budgets_e2e_test_expect_to_fail():
|
|||
|
||||
await asyncio.sleep(0.5)
|
||||
# Verify the error is related to budget exceeded
|
||||
|
||||
|
||||
def test_get_ttl_seconds():
|
||||
"""
|
||||
Test the get_ttl_seconds helper method"
|
||||
|
||||
"""
|
||||
provider_budget = ProviderBudgetLimiting(
|
||||
router_cache=DualCache(), provider_budget_config={}
|
||||
)
|
||||
|
||||
assert provider_budget.get_ttl_seconds("1d") == 86400 # 1 day in seconds
|
||||
assert provider_budget.get_ttl_seconds("7d") == 604800 # 7 days in seconds
|
||||
assert provider_budget.get_ttl_seconds("30d") == 2592000 # 30 days in seconds
|
||||
|
||||
with pytest.raises(ValueError, match="Unsupported time period format"):
|
||||
provider_budget.get_ttl_seconds("1h")
|
||||
|
||||
|
||||
def test_get_llm_provider_for_deployment():
|
||||
"""
|
||||
Test the _get_llm_provider_for_deployment helper method
|
||||
|
||||
"""
|
||||
provider_budget = ProviderBudgetLimiting(
|
||||
router_cache=DualCache(), provider_budget_config={}
|
||||
)
|
||||
|
||||
# Test OpenAI deployment
|
||||
openai_deployment = {"litellm_params": {"model": "openai/gpt-4"}}
|
||||
assert (
|
||||
provider_budget._get_llm_provider_for_deployment(openai_deployment) == "openai"
|
||||
)
|
||||
|
||||
# Test Azure deployment
|
||||
azure_deployment = {
|
||||
"litellm_params": {
|
||||
"model": "azure/gpt-4",
|
||||
"api_key": "test",
|
||||
"api_base": "test",
|
||||
}
|
||||
}
|
||||
assert provider_budget._get_llm_provider_for_deployment(azure_deployment) == "azure"
|
||||
|
||||
# should not raise error for unknown deployment
|
||||
unknown_deployment = {}
|
||||
assert provider_budget._get_llm_provider_for_deployment(unknown_deployment) is None
|
||||
|
||||
|
||||
def test_get_budget_config_for_provider():
|
||||
"""
|
||||
Test the _get_budget_config_for_provider helper method
|
||||
|
||||
"""
|
||||
config = {
|
||||
"openai": ProviderBudgetInfo(time_period="1d", budget_limit=100),
|
||||
"anthropic": ProviderBudgetInfo(time_period="7d", budget_limit=500),
|
||||
}
|
||||
|
||||
provider_budget = ProviderBudgetLimiting(
|
||||
router_cache=DualCache(), provider_budget_config=config
|
||||
)
|
||||
|
||||
# Test existing providers
|
||||
openai_config = provider_budget._get_budget_config_for_provider("openai")
|
||||
assert openai_config is not None
|
||||
assert openai_config.time_period == "1d"
|
||||
assert openai_config.budget_limit == 100
|
||||
|
||||
anthropic_config = provider_budget._get_budget_config_for_provider("anthropic")
|
||||
assert anthropic_config is not None
|
||||
assert anthropic_config.time_period == "7d"
|
||||
assert anthropic_config.budget_limit == 500
|
||||
|
||||
# Test non-existent provider
|
||||
assert provider_budget._get_budget_config_for_provider("unknown") is None
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue