fix: fix tests

This commit is contained in:
Krrish Dholakia 2024-08-07 15:02:04 -07:00
parent d832327ccf
commit ff373663a3
4 changed files with 24 additions and 14 deletions

View file

@ -730,10 +730,15 @@ LITELLM_EXCEPTION_TYPES = [
class BudgetExceededError(Exception): class BudgetExceededError(Exception):
def __init__(self, current_cost, max_budget): def __init__(
self, current_cost: float, max_budget: float, message: Optional[str] = None
):
self.current_cost = current_cost self.current_cost = current_cost
self.max_budget = max_budget self.max_budget = max_budget
message = f"Budget has been exceeded! Current cost: {current_cost}, Max budget: {max_budget}" message = (
message
or f"Budget has been exceeded! Current cost: {current_cost}, Max budget: {max_budget}"
)
self.message = message self.message = message
super().__init__(message) super().__init__(message)

View file

@ -88,8 +88,10 @@ def common_checks(
and team_object.spend is not None and team_object.spend is not None
and team_object.spend > team_object.max_budget and team_object.spend > team_object.max_budget
): ):
raise Exception( raise litellm.BudgetExceededError(
f"Team={team_object.team_id} over budget. Spend={team_object.spend}, Budget={team_object.max_budget}" current_cost=team_object.spend,
max_budget=team_object.max_budget,
message=f"Team={team_object.team_id} over budget. Spend={team_object.spend}, Budget={team_object.max_budget}",
) )
# 4. If user is in budget # 4. If user is in budget
## 4.1 check personal budget, if personal key ## 4.1 check personal budget, if personal key
@ -100,16 +102,20 @@ def common_checks(
): ):
user_budget = user_object.max_budget user_budget = user_object.max_budget
if user_budget < user_object.spend: if user_budget < user_object.spend:
raise Exception( raise litellm.BudgetExceededError(
f"ExceededBudget: User={user_object.user_id} over budget. Spend={user_object.spend}, Budget={user_budget}" current_cost=user_object.spend,
max_budget=user_budget,
message=f"ExceededBudget: User={user_object.user_id} over budget. Spend={user_object.spend}, Budget={user_budget}",
) )
## 4.2 check team member budget, if team key ## 4.2 check team member budget, if team key
# 5. If end_user ('user' passed to /chat/completions, /embeddings endpoint) is in budget # 5. If end_user ('user' passed to /chat/completions, /embeddings endpoint) is in budget
if end_user_object is not None and end_user_object.litellm_budget_table is not None: if end_user_object is not None and end_user_object.litellm_budget_table is not None:
end_user_budget = end_user_object.litellm_budget_table.max_budget end_user_budget = end_user_object.litellm_budget_table.max_budget
if end_user_budget is not None and end_user_object.spend > end_user_budget: if end_user_budget is not None and end_user_object.spend > end_user_budget:
raise Exception( raise litellm.BudgetExceededError(
f"ExceededBudget: End User={end_user_object.user_id} over budget. Spend={end_user_object.spend}, Budget={end_user_budget}" current_cost=end_user_object.spend,
max_budget=end_user_budget,
message=f"ExceededBudget: End User={end_user_object.user_id} over budget. Spend={end_user_object.spend}, Budget={end_user_budget}",
) )
# 6. [OPTIONAL] If 'enforce_user_param' enabled - did developer pass in 'user' param for openai endpoints # 6. [OPTIONAL] If 'enforce_user_param' enabled - did developer pass in 'user' param for openai endpoints
if ( if (

View file

@ -552,7 +552,6 @@ async def user_api_key_auth(
key=api_key key=api_key
) )
if valid_token is None: if valid_token is None:
user_obj: Optional[LiteLLM_UserTable] = None
## check db ## check db
verbose_proxy_logger.debug("api key: %s", api_key) verbose_proxy_logger.debug("api key: %s", api_key)
if prisma_client is not None: if prisma_client is not None:
@ -584,6 +583,7 @@ async def user_api_key_auth(
user_id_information: Optional[List] = None user_id_information: Optional[List] = None
if valid_token is not None: if valid_token is not None:
user_obj: Optional[LiteLLM_UserTable] = None
# Got Valid Token from Cache, DB # Got Valid Token from Cache, DB
# Run checks for # Run checks for
# 1. If token can call model # 1. If token can call model

View file

@ -504,7 +504,7 @@ def test_call_with_user_over_budget(prisma_client):
asyncio.run(test()) asyncio.run(test())
except Exception as e: except Exception as e:
error_detail = e.message error_detail = e.message
assert "Budget has been exceeded" in error_detail assert "ExceededBudget:" in error_detail
assert isinstance(e, ProxyException) assert isinstance(e, ProxyException)
assert e.type == ProxyErrorTypes.budget_exceeded assert e.type == ProxyErrorTypes.budget_exceeded
print(vars(e)) print(vars(e))
@ -607,7 +607,7 @@ def test_call_with_end_user_over_budget(prisma_client):
# use generated key to auth in # use generated key to auth in
result = await user_api_key_auth(request=request, api_key=bearer_token) result = await user_api_key_auth(request=request, api_key=bearer_token)
print("result from user auth with new key", result) print("result from user auth with new key", result)
pytest.fail(f"This should have failed!. They key crossed it's budget") pytest.fail("This should have failed!. They key crossed it's budget")
asyncio.run(test()) asyncio.run(test())
except Exception as e: except Exception as e:
@ -779,12 +779,12 @@ def test_call_with_user_over_budget_stream(prisma_client):
# use generated key to auth in # use generated key to auth in
result = await user_api_key_auth(request=request, api_key=bearer_token) result = await user_api_key_auth(request=request, api_key=bearer_token)
print("result from user auth with new key", result) print("result from user auth with new key", result)
pytest.fail(f"This should have failed!. They key crossed it's budget") pytest.fail("This should have failed!. They key crossed it's budget")
asyncio.run(test()) asyncio.run(test())
except Exception as e: except Exception as e:
error_detail = e.message error_detail = e.message
assert "Budget has been exceeded" in error_detail assert "ExceededBudget:" in error_detail
assert isinstance(e, ProxyException) assert isinstance(e, ProxyException)
assert e.type == ProxyErrorTypes.budget_exceeded assert e.type == ProxyErrorTypes.budget_exceeded
print(vars(e)) print(vars(e))
@ -2511,7 +2511,6 @@ async def test_update_user_role(prisma_client):
Tests if we update user role, incorrect values are not stored in cache Tests if we update user role, incorrect values are not stored in cache
-> create a user with role == INTERNAL_USER -> create a user with role == INTERNAL_USER
-> access an Admin only route -> expect to fail -> access an Admin only route -> expect to fail
-> update user role to == PROXY_ADMIN -> update user role to == PROXY_ADMIN
-> access an Admin only route -> expect to succeed -> access an Admin only route -> expect to succeed
""" """