mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-26 03:04:13 +00:00
fix(user_api_key_auth.py): respect team budgets over user budget, if key belongs to team
Closes https://github.com/BerriAI/litellm/issues/5097
This commit is contained in:
parent
f579aef740
commit
d832327ccf
3 changed files with 92 additions and 110 deletions
|
@ -55,11 +55,11 @@ def common_checks(
|
||||||
1. If team is blocked
|
1. If team is blocked
|
||||||
2. If team can call model
|
2. If team can call model
|
||||||
3. If team is in budget
|
3. If team is in budget
|
||||||
5. If user passed in (JWT or key.user_id) - is in budget
|
4. If user passed in (JWT or key.user_id) - is in budget
|
||||||
4. If end_user (either via JWT or 'user' passed to /chat/completions, /embeddings endpoint) is in budget
|
5. If end_user (either via JWT or 'user' passed to /chat/completions, /embeddings endpoint) is in budget
|
||||||
5. [OPTIONAL] If 'enforce_end_user' enabled - did developer pass in 'user' param for openai endpoints
|
6. [OPTIONAL] If 'enforce_end_user' enabled - did developer pass in 'user' param for openai endpoints
|
||||||
6. [OPTIONAL] If 'litellm.max_budget' is set (>0), is proxy under budget
|
7. [OPTIONAL] If 'litellm.max_budget' is set (>0), is proxy under budget
|
||||||
7. [OPTIONAL] If guardrails modified - is request allowed to change this
|
8. [OPTIONAL] If guardrails modified - is request allowed to change this
|
||||||
"""
|
"""
|
||||||
_model = request_body.get("model", None)
|
_model = request_body.get("model", None)
|
||||||
if team_object is not None and team_object.blocked is True:
|
if team_object is not None and team_object.blocked is True:
|
||||||
|
@ -91,12 +91,19 @@ def common_checks(
|
||||||
raise Exception(
|
raise Exception(
|
||||||
f"Team={team_object.team_id} over budget. Spend={team_object.spend}, Budget={team_object.max_budget}"
|
f"Team={team_object.team_id} over budget. Spend={team_object.spend}, Budget={team_object.max_budget}"
|
||||||
)
|
)
|
||||||
if user_object is not None and user_object.max_budget is not None:
|
# 4. If user is in budget
|
||||||
|
## 4.1 check personal budget, if personal key
|
||||||
|
if (
|
||||||
|
(team_object is None or team_object.team_id is None)
|
||||||
|
and user_object is not None
|
||||||
|
and user_object.max_budget is not None
|
||||||
|
):
|
||||||
user_budget = user_object.max_budget
|
user_budget = user_object.max_budget
|
||||||
if user_budget > user_object.spend:
|
if user_budget < user_object.spend:
|
||||||
raise Exception(
|
raise Exception(
|
||||||
f"ExceededBudget: User={user_object.user_id} over budget. Spend={user_object.spend}, Budget={user_budget}"
|
f"ExceededBudget: User={user_object.user_id} over budget. Spend={user_object.spend}, Budget={user_budget}"
|
||||||
)
|
)
|
||||||
|
## 4.2 check team member budget, if team key
|
||||||
# 5. If end_user ('user' passed to /chat/completions, /embeddings endpoint) is in budget
|
# 5. If end_user ('user' passed to /chat/completions, /embeddings endpoint) is in budget
|
||||||
if end_user_object is not None and end_user_object.litellm_budget_table is not None:
|
if end_user_object is not None and end_user_object.litellm_budget_table is not None:
|
||||||
end_user_budget = end_user_object.litellm_budget_table.max_budget
|
end_user_budget = end_user_object.litellm_budget_table.max_budget
|
||||||
|
|
|
@ -552,6 +552,7 @@ async def user_api_key_auth(
|
||||||
key=api_key
|
key=api_key
|
||||||
)
|
)
|
||||||
if valid_token is None:
|
if valid_token is None:
|
||||||
|
user_obj: Optional[LiteLLM_UserTable] = None
|
||||||
## check db
|
## check db
|
||||||
verbose_proxy_logger.debug("api key: %s", api_key)
|
verbose_proxy_logger.debug("api key: %s", api_key)
|
||||||
if prisma_client is not None:
|
if prisma_client is not None:
|
||||||
|
@ -650,114 +651,26 @@ async def user_api_key_auth(
|
||||||
valid_token=valid_token,
|
valid_token=valid_token,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Check 2. If user_id for this token is in budget
|
# Check 2. If user_id for this token is in budget - done in common_checks()
|
||||||
if valid_token.user_id is not None:
|
if valid_token.user_id is not None:
|
||||||
user_id_list = [valid_token.user_id]
|
user_obj = await get_user_object(
|
||||||
for id in user_id_list:
|
user_id=valid_token.user_id,
|
||||||
value = user_api_key_cache.get_cache(key=id)
|
prisma_client=prisma_client,
|
||||||
if value is not None:
|
user_api_key_cache=user_api_key_cache,
|
||||||
|
user_id_upsert=False,
|
||||||
|
parent_otel_span=parent_otel_span,
|
||||||
|
proxy_logging_obj=proxy_logging_obj,
|
||||||
|
)
|
||||||
|
# value = user_api_key_cache.get_cache(key=id)
|
||||||
|
if user_obj is not None:
|
||||||
if user_id_information is None:
|
if user_id_information is None:
|
||||||
user_id_information = []
|
user_id_information = []
|
||||||
user_id_information.append(value)
|
user_id_information.append(user_obj.model_dump())
|
||||||
if user_id_information is None or (
|
|
||||||
isinstance(user_id_information, list)
|
|
||||||
and len(user_id_information) < 1
|
|
||||||
):
|
|
||||||
if prisma_client is not None:
|
|
||||||
user_id_information = await prisma_client.get_data(
|
|
||||||
user_id_list=[
|
|
||||||
valid_token.user_id,
|
|
||||||
],
|
|
||||||
table_name="user",
|
|
||||||
query_type="find_all",
|
|
||||||
)
|
|
||||||
if user_id_information is not None:
|
|
||||||
for _id in user_id_information:
|
|
||||||
await user_api_key_cache.async_set_cache(
|
|
||||||
key=_id["user_id"],
|
|
||||||
value=_id,
|
|
||||||
)
|
|
||||||
|
|
||||||
verbose_proxy_logger.debug(
|
verbose_proxy_logger.debug(
|
||||||
f"user_id_information: {user_id_information}"
|
f"user_id_information: {user_id_information}"
|
||||||
)
|
)
|
||||||
|
|
||||||
if user_id_information is not None:
|
|
||||||
if isinstance(user_id_information, list):
|
|
||||||
## Check if user in budget
|
|
||||||
for _user in user_id_information:
|
|
||||||
if _user is None:
|
|
||||||
continue
|
|
||||||
assert isinstance(_user, dict)
|
|
||||||
# check if user is admin #
|
|
||||||
|
|
||||||
# Token exists, not expired now check if its in budget for the user
|
|
||||||
user_max_budget = _user.get("max_budget", None)
|
|
||||||
user_current_spend = _user.get("spend", None)
|
|
||||||
|
|
||||||
verbose_proxy_logger.debug(
|
|
||||||
f"user_id: {_user.get('user_id', None)}; user_max_budget: {user_max_budget}; user_current_spend: {user_current_spend}"
|
|
||||||
)
|
|
||||||
|
|
||||||
if (
|
|
||||||
user_max_budget is not None
|
|
||||||
and user_current_spend is not None
|
|
||||||
):
|
|
||||||
call_info = CallInfo(
|
|
||||||
token=valid_token.token,
|
|
||||||
spend=user_current_spend,
|
|
||||||
max_budget=user_max_budget,
|
|
||||||
user_id=_user.get("user_id", None),
|
|
||||||
user_email=_user.get("user_email", None),
|
|
||||||
key_alias=valid_token.key_alias,
|
|
||||||
)
|
|
||||||
asyncio.create_task(
|
|
||||||
proxy_logging_obj.budget_alerts(
|
|
||||||
type="user_budget",
|
|
||||||
user_info=call_info,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
_user_id = _user.get("user_id", None)
|
|
||||||
if user_current_spend > user_max_budget:
|
|
||||||
raise litellm.BudgetExceededError(
|
|
||||||
current_cost=user_current_spend,
|
|
||||||
max_budget=user_max_budget,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
# Token exists, not expired now check if its in budget for the user
|
|
||||||
user_max_budget = getattr(
|
|
||||||
user_id_information, "max_budget", None
|
|
||||||
)
|
|
||||||
user_current_spend = getattr(user_id_information, "spend", None)
|
|
||||||
|
|
||||||
if (
|
|
||||||
user_max_budget is not None
|
|
||||||
and user_current_spend is not None
|
|
||||||
):
|
|
||||||
call_info = CallInfo(
|
|
||||||
token=valid_token.token,
|
|
||||||
spend=user_current_spend,
|
|
||||||
max_budget=user_max_budget,
|
|
||||||
user_id=getattr(user_id_information, "user_id", None),
|
|
||||||
user_email=getattr(
|
|
||||||
user_id_information, "user_email", None
|
|
||||||
),
|
|
||||||
key_alias=valid_token.key_alias,
|
|
||||||
)
|
|
||||||
asyncio.create_task(
|
|
||||||
proxy_logging_obj.budget_alerts(
|
|
||||||
type="user_budget",
|
|
||||||
user_info=call_info,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
if user_current_spend > user_max_budget:
|
|
||||||
raise litellm.BudgetExceededError(
|
|
||||||
current_cost=user_current_spend,
|
|
||||||
max_budget=user_max_budget,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Check 3. Check if user is in their team budget
|
# Check 3. Check if user is in their team budget
|
||||||
if valid_token.team_member_spend is not None:
|
if valid_token.team_member_spend is not None:
|
||||||
if prisma_client is not None:
|
if prisma_client is not None:
|
||||||
|
@ -983,7 +896,7 @@ async def user_api_key_auth(
|
||||||
_ = common_checks(
|
_ = common_checks(
|
||||||
request_body=request_data,
|
request_body=request_data,
|
||||||
team_object=_team_obj,
|
team_object=_team_obj,
|
||||||
user_object=None,
|
user_object=user_obj,
|
||||||
end_user_object=_end_user_object,
|
end_user_object=_end_user_object,
|
||||||
general_settings=general_settings,
|
general_settings=general_settings,
|
||||||
global_proxy_spend=global_proxy_spend,
|
global_proxy_spend=global_proxy_spend,
|
||||||
|
|
|
@ -116,3 +116,65 @@ def test_returned_user_api_key_auth(user_role):
|
||||||
assert new_obj.user_role == user_role
|
assert new_obj.user_role == user_role
|
||||||
else:
|
else:
|
||||||
assert new_obj.user_role == "internal_user"
|
assert new_obj.user_role == "internal_user"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("key_ownership", ["user_key", "team_key"])
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_user_personal_budgets(key_ownership):
|
||||||
|
"""
|
||||||
|
Set a personal budget on a user
|
||||||
|
|
||||||
|
- have it only apply when key belongs to user -> raises BudgetExceededError
|
||||||
|
- if key belongs to team, have key respect team budget -> allows call to go through
|
||||||
|
"""
|
||||||
|
import asyncio
|
||||||
|
import time
|
||||||
|
|
||||||
|
from fastapi import Request
|
||||||
|
from starlette.datastructures import URL
|
||||||
|
|
||||||
|
from litellm.proxy._types import LiteLLM_UserTable, UserAPIKeyAuth
|
||||||
|
from litellm.proxy.auth.user_api_key_auth import user_api_key_auth
|
||||||
|
from litellm.proxy.proxy_server import hash_token, user_api_key_cache
|
||||||
|
|
||||||
|
_user_id = "1234"
|
||||||
|
user_key = "sk-12345678"
|
||||||
|
|
||||||
|
if key_ownership == "user_key":
|
||||||
|
valid_token = UserAPIKeyAuth(
|
||||||
|
token=hash_token(user_key),
|
||||||
|
last_refreshed_at=time.time(),
|
||||||
|
user_id=_user_id,
|
||||||
|
spend=20,
|
||||||
|
)
|
||||||
|
elif key_ownership == "team_key":
|
||||||
|
valid_token = UserAPIKeyAuth(
|
||||||
|
token=hash_token(user_key),
|
||||||
|
last_refreshed_at=time.time(),
|
||||||
|
user_id=_user_id,
|
||||||
|
team_id="my-special-team",
|
||||||
|
team_max_budget=100,
|
||||||
|
spend=20,
|
||||||
|
)
|
||||||
|
await asyncio.sleep(1)
|
||||||
|
user_obj = LiteLLM_UserTable(
|
||||||
|
user_id=_user_id, spend=11, max_budget=10, user_email=""
|
||||||
|
)
|
||||||
|
user_api_key_cache.set_cache(key=hash_token(user_key), value=valid_token)
|
||||||
|
user_api_key_cache.set_cache(key="{}".format(_user_id), value=user_obj)
|
||||||
|
|
||||||
|
setattr(litellm.proxy.proxy_server, "user_api_key_cache", user_api_key_cache)
|
||||||
|
setattr(litellm.proxy.proxy_server, "master_key", "sk-1234")
|
||||||
|
setattr(litellm.proxy.proxy_server, "prisma_client", "hello-world")
|
||||||
|
|
||||||
|
request = Request(scope={"type": "http"})
|
||||||
|
request._url = URL(url="/chat/completions")
|
||||||
|
|
||||||
|
try:
|
||||||
|
await user_api_key_auth(request=request, api_key="Bearer " + user_key)
|
||||||
|
|
||||||
|
if key_ownership == "user_key":
|
||||||
|
pytest.fail("Expected this call to fail. User is over limit.")
|
||||||
|
except Exception:
|
||||||
|
if key_ownership == "team_key":
|
||||||
|
pytest.fail("Expected this call to work. Key is below team budget.")
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue