Merge pull request #3790 from BerriAI/litellm_set_team_member_budgets

[Feat] Set Budgets for Users within a Team
This commit is contained in:
Ishaan Jaff 2024-05-22 19:44:04 -07:00 committed by GitHub
commit a8b64a01dc
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
5 changed files with 270 additions and 9 deletions

View file

@ -473,4 +473,75 @@ curl --location 'http://0.0.0.0:4000/key/generate' \
--header 'Authorization: Bearer <your-master-key>' \
--header 'Content-Type: application/json' \
--data '{"models": ["azure-models"], "user_id": "krrish@berri.ai"}'
```
```
## Advanced
### Set Budgets for Members within a Team
Use this when you want to budget a users spend within a Team
#### Step 1. Create User
Create a user with `user_id=ishaan`
```shell
curl --location 'http://0.0.0.0:4000/user/new' \
--header 'Authorization: Bearer sk-1234' \
--header 'Content-Type: application/json' \
--data '{
"user_id": "ishaan"
}'
```
#### Step 2. Add User to an existing Team - set `max_budget_in_team`
Set `max_budget_in_team` when adding a User to a team. We use the same `user_id` we set in Step 1
```shell
curl -X POST 'http://0.0.0.0:4000/team/member_add' \
-H 'Authorization: Bearer sk-1234' \
-H 'Content-Type: application/json' \
-d '{"team_id": "e8d1460f-846c-45d7-9b43-55f3cc52ac32", "max_budget_in_team": 0.000000000001, "member": {"role": "user", "user_id": "ishaan"}}'
```
#### Step 3. Create a Key for user from Step 1
Set `user_id=ishaan` from step 1
```shell
curl --location 'http://0.0.0.0:4000/key/generate' \
--header 'Authorization: Bearer sk-1234' \
--header 'Content-Type: application/json' \
--data '{
"user_id": "ishaan"
}'
```
Response from `/key/generate`
We use the `key` from this response in Step 4
```shell
{"models":[],"spend":0.0,"max_budget":null,"user_id":"ishaan","team_id":null,"max_parallel_requests":null,"metadata":{},"tpm_limit":null,"rpm_limit":null,"budget_duration":null,"allowed_cache_controls":[],"soft_budget":null,"key_alias":null,"duration":null,"aliases":{},"config":{},"permissions":{},"model_max_budget":{},"key":"sk-RV-l2BJEZ_LYNChSx2EueQ","key_name":null,"expires":null,"token_id":null}%
```
#### Step 4. Make /chat/completions requests for user
Use the key from step 3 for this request. After 2-3 requests expect to see The following error `ExceededBudget: Crossed spend within team`
```shell
curl --location 'http://localhost:4000/chat/completions' \
--header 'Authorization: Bearer sk-RV-l2BJEZ_LYNChSx2EueQ' \
--header 'Content-Type: application/json' \
--data '{
"model": "llama3",
"messages": [
{
"role": "user",
"content": "tes4"
}
]
}'
```

View file

@ -648,6 +648,20 @@ class LiteLLM_BudgetTable(LiteLLMBase):
protected_namespaces = ()
class LiteLLM_TeamMemberTable(LiteLLM_BudgetTable):
"""
Used to track spend of a user_id within a team_id
"""
spend: Optional[float] = None
user_id: Optional[str] = None
team_id: Optional[str] = None
budget_id: Optional[str] = None
class Config:
protected_namespaces = ()
class NewOrganizationRequest(LiteLLM_BudgetTable):
organization_id: Optional[str] = None
organization_alias: str
@ -942,6 +956,7 @@ class LiteLLM_VerificationTokenView(LiteLLM_VerificationToken):
team_blocked: bool = False
soft_budget: Optional[float] = None
team_model_aliases: Optional[Dict] = None
team_member_spend: Optional[float] = None
# End User Params
end_user_id: Optional[str] = None

View file

@ -797,12 +797,13 @@ async def user_api_key_auth(
# Run checks for
# 1. If token can call model
# 2. If user_id for this token is in budget
# 3. If 'user' passed to /chat/completions, /embeddings endpoint is in budget
# 4. If token is expired
# 5. If token spend is under Budget for the token
# 6. If token spend per model is under budget per model
# 7. If token spend is under team budget
# 8. If team spend is under team budget
# 3. If the user spend within their own team is within budget
# 4. If 'user' passed to /chat/completions, /embeddings endpoint is in budget
# 5. If token is expired
# 6. If token spend is under Budget for the token
# 7. If token spend per model is under budget per model
# 8. If token spend is under team budget
# 9. If team spend is under team budget
# Check 1. If token can call model
_model_alias_map = {}
@ -1000,6 +1001,43 @@ async def user_api_key_auth(
raise Exception(
f"ExceededBudget: User {valid_token.user_id} has exceeded their budget. Current spend: {user_current_spend}; Max Budget: {user_max_budget}"
)
# Check 3. Check if user is in their team budget
if valid_token.team_member_spend is not None:
if prisma_client is not None:
_cache_key = f"{valid_token.team_id}_{valid_token.user_id}"
team_member_info = await user_api_key_cache.async_get_cache(
key=_cache_key
)
if team_member_info is None:
team_member_info = (
await prisma_client.db.litellm_teammembership.find_first(
where={
"user_id": valid_token.user_id,
"team_id": valid_token.team_id,
}, # type: ignore
include={"litellm_budget_table": True},
)
)
await user_api_key_cache.async_set_cache(
key=_cache_key,
value=team_member_info,
ttl=UserAPIKeyCacheTTLEnum.user_information_cache.value,
)
if (
team_member_info is not None
and team_member_info.litellm_budget_table is not None
):
team_member_budget = (
team_member_info.litellm_budget_table.max_budget
)
if team_member_budget is not None and team_member_budget > 0:
if valid_token.team_member_spend > team_member_budget:
raise Exception(
f"ExceededBudget: Crossed spend within team. UserID: {valid_token.user_id}, in team {valid_token.team_id} has exceeded their budget. Current spend: {valid_token.team_member_spend}; Max Budget: {team_member_budget}"
)
# Check 3. If token is expired
if valid_token.expires is not None:
@ -1701,6 +1739,19 @@ async def update_database(
response_cost
+ prisma_client.team_list_transactons.get(team_id, 0)
)
try:
# Track spend of the team member within this team
# key is "team_id::<value>::user_id::<value>"
team_member_key = f"team_id::{team_id}::user_id::{user_id}"
prisma_client.team_member_list_transactons[team_member_key] = (
response_cost
+ prisma_client.team_member_list_transactons.get(
team_member_key, 0
)
)
except:
pass
except Exception as e:
verbose_proxy_logger.info(
f"Update Team DB failed to execute - {str(e)}\n{traceback.format_exc()}"
@ -1832,6 +1883,16 @@ async def update_cache(
# Calculate the new cost by adding the existing cost and response_cost
existing_spend_obj.team_spend = existing_team_spend + response_cost
if (
existing_spend_obj is not None
and getattr(existing_spend_obj, "team_member_spend", None) is not None
):
existing_team_member_spend = existing_spend_obj.team_member_spend or 0
# Calculate the new cost by adding the existing cost and response_cost
existing_spend_obj.team_member_spend = (
existing_team_member_spend + response_cost
)
# Update the cost column for the given token
existing_spend_obj.spend = new_spend
user_api_key_cache.set_cache(key=hashed_token, value=existing_spend_obj)

View file

@ -551,6 +551,7 @@ class PrismaClient:
end_user_list_transactons: dict = {}
key_list_transactons: dict = {}
team_list_transactons: dict = {}
team_member_list_transactons: dict = {} # key is ["team_id" + "user_id"]
org_list_transactons: dict = {}
spend_log_transactions: List = []
@ -1096,9 +1097,11 @@ class PrismaClient:
t.models AS team_models,
t.blocked AS team_blocked,
t.team_alias AS team_alias,
tm.spend AS team_member_spend,
m.aliases as team_model_aliases
FROM "LiteLLM_VerificationToken" AS v
LEFT JOIN "LiteLLM_TeamTable" AS t ON v.team_id = t.team_id
LEFT JOIN "LiteLLM_TeamMembership" AS tm ON v.team_id = tm.team_id AND tm.user_id = v.user_id
LEFT JOIN "LiteLLM_ModelTable" m ON t.model_id = m.id
WHERE v.token = '{token}'
"""
@ -2262,6 +2265,56 @@ async def update_spend(
)
raise e
### UPDATE TEAM Membership TABLE with spend ###
if len(prisma_client.team_member_list_transactons.keys()) > 0:
for i in range(n_retry_times + 1):
start_time = time.time()
try:
async with prisma_client.db.tx(
timeout=timedelta(seconds=60)
) as transaction:
async with transaction.batch_() as batcher:
for (
key,
response_cost,
) in prisma_client.team_member_list_transactons.items():
# key is "team_id::<value>::user_id::<value>"
team_id = key.split("::")[1]
user_id = key.split("::")[3]
batcher.litellm_teammembership.update_many( # 'update_many' prevents error from being raised if no row exists
where={"team_id": team_id, "user_id": user_id},
data={"spend": {"increment": response_cost}},
)
prisma_client.team_member_list_transactons = (
{}
) # Clear the remaining transactions after processing all batches in the loop.
break
except httpx.ReadTimeout:
if i >= n_retry_times: # If we've reached the maximum number of retries
raise # Re-raise the last exception
# Optionally, sleep for a bit before retrying
await asyncio.sleep(2**i) # Exponential backoff
except Exception as e:
import traceback
error_msg = (
f"LiteLLM Prisma Client Exception - update team spend: {str(e)}"
)
print_verbose(error_msg)
error_traceback = error_msg + "\n" + traceback.format_exc()
end_time = time.time()
_duration = end_time - start_time
asyncio.create_task(
proxy_logging_obj.failure_handler(
original_exception=e,
duration=_duration,
call_type="update_spend",
traceback_str=error_traceback,
)
)
raise e
### UPDATE ORG TABLE ###
if len(prisma_client.org_list_transactons.keys()) > 0:
for i in range(n_retry_times + 1):

View file

@ -8,7 +8,13 @@ from openai import AsyncOpenAI
async def new_user(
session, i, user_id=None, budget=None, budget_duration=None, models=["azure-models"]
session,
i,
user_id=None,
budget=None,
budget_duration=None,
models=["azure-models"],
team_id=None,
):
url = "http://0.0.0.0:4000/user/new"
headers = {"Authorization": "Bearer sk-1234", "Content-Type": "application/json"}
@ -23,6 +29,9 @@ async def new_user(
if user_id is not None:
data["user_id"] = user_id
if team_id is not None:
data["team_id"] = team_id
async with session.post(url, headers=headers, json=data) as response:
status = response.status
response_text = await response.text()
@ -37,7 +46,9 @@ async def new_user(
return await response.json()
async def add_member(session, i, team_id, user_id=None, user_email=None):
async def add_member(
session, i, team_id, user_id=None, user_email=None, max_budget=None
):
url = "http://0.0.0.0:4000/team/member_add"
headers = {"Authorization": "Bearer sk-1234", "Content-Type": "application/json"}
data = {"team_id": team_id, "member": {"role": "user"}}
@ -46,6 +57,9 @@ async def add_member(session, i, team_id, user_id=None, user_email=None):
elif user_id is not None:
data["member"]["user_id"] = user_id
if max_budget is not None:
data["max_budget_in_team"] = max_budget
async with session.post(url, headers=headers, json=data) as response:
status = response.status
response_text = await response.text()
@ -475,3 +489,50 @@ async def test_team_alias():
key = key_gen["key"]
## Test key
response = await chat_completion(session=session, key=key, model="cheap-model")
@pytest.mark.asyncio
async def test_users_in_team_budget():
"""
- Create Team
- Create User
- Add User to team with budget = 0.0000001
- Make Call 1 -> pass
- Make Call 2 -> fail
"""
get_user = f"krrish_{time.time()}@berri.ai"
async with aiohttp.ClientSession() as session:
team = await new_team(session, 0, user_id=get_user)
print("New team=", team)
key_gen = await new_user(
session,
0,
user_id=get_user,
budget=10,
budget_duration="5s",
team_id=team["team_id"],
models=["fake-openai-endpoint"],
)
key = key_gen["key"]
# Add user to team
await add_member(
session, 0, team_id=team["team_id"], user_id=get_user, max_budget=0.0000001
)
# Call 1
result = await chat_completion(session, key, model="fake-openai-endpoint")
print("Call 1 passed", result)
await asyncio.sleep(2)
# Call 2
try:
await chat_completion(session, key, model="fake-openai-endpoint")
pytest.fail(
"Call 2 should have failed. The user crossed their budget within their team"
)
except Exception as e:
print("got exception, this is expected")
print(e)
assert "Crossed spend within team" in str(e)