forked from phoenix/litellm-mirror
Merge pull request #3790 from BerriAI/litellm_set_team_member_budgets
[Feat] Set Budgets for Users within a Team
This commit is contained in:
commit
a8b64a01dc
5 changed files with 270 additions and 9 deletions
|
@ -473,4 +473,75 @@ curl --location 'http://0.0.0.0:4000/key/generate' \
|
|||
--header 'Authorization: Bearer <your-master-key>' \
|
||||
--header 'Content-Type: application/json' \
|
||||
--data '{"models": ["azure-models"], "user_id": "krrish@berri.ai"}'
|
||||
```
|
||||
```
|
||||
|
||||
|
||||
## Advanced
|
||||
### Set Budgets for Members within a Team
|
||||
|
||||
Use this when you want to budget a users spend within a Team
|
||||
|
||||
|
||||
#### Step 1. Create User
|
||||
|
||||
Create a user with `user_id=ishaan`
|
||||
|
||||
```shell
|
||||
curl --location 'http://0.0.0.0:4000/user/new' \
|
||||
--header 'Authorization: Bearer sk-1234' \
|
||||
--header 'Content-Type: application/json' \
|
||||
--data '{
|
||||
"user_id": "ishaan"
|
||||
}'
|
||||
```
|
||||
|
||||
#### Step 2. Add User to an existing Team - set `max_budget_in_team`
|
||||
|
||||
Set `max_budget_in_team` when adding a User to a team. We use the same `user_id` we set in Step 1
|
||||
|
||||
```shell
|
||||
curl -X POST 'http://0.0.0.0:4000/team/member_add' \
|
||||
-H 'Authorization: Bearer sk-1234' \
|
||||
-H 'Content-Type: application/json' \
|
||||
-d '{"team_id": "e8d1460f-846c-45d7-9b43-55f3cc52ac32", "max_budget_in_team": 0.000000000001, "member": {"role": "user", "user_id": "ishaan"}}'
|
||||
```
|
||||
|
||||
#### Step 3. Create a Key for user from Step 1
|
||||
|
||||
Set `user_id=ishaan` from step 1
|
||||
|
||||
```shell
|
||||
curl --location 'http://0.0.0.0:4000/key/generate' \
|
||||
--header 'Authorization: Bearer sk-1234' \
|
||||
--header 'Content-Type: application/json' \
|
||||
--data '{
|
||||
"user_id": "ishaan"
|
||||
}'
|
||||
```
|
||||
Response from `/key/generate`
|
||||
|
||||
We use the `key` from this response in Step 4
|
||||
```shell
|
||||
{"models":[],"spend":0.0,"max_budget":null,"user_id":"ishaan","team_id":null,"max_parallel_requests":null,"metadata":{},"tpm_limit":null,"rpm_limit":null,"budget_duration":null,"allowed_cache_controls":[],"soft_budget":null,"key_alias":null,"duration":null,"aliases":{},"config":{},"permissions":{},"model_max_budget":{},"key":"sk-RV-l2BJEZ_LYNChSx2EueQ","key_name":null,"expires":null,"token_id":null}%
|
||||
```
|
||||
|
||||
#### Step 4. Make /chat/completions requests for user
|
||||
|
||||
Use the key from step 3 for this request. After 2-3 requests expect to see The following error `ExceededBudget: Crossed spend within team`
|
||||
|
||||
|
||||
```shell
|
||||
curl --location 'http://localhost:4000/chat/completions' \
|
||||
--header 'Authorization: Bearer sk-RV-l2BJEZ_LYNChSx2EueQ' \
|
||||
--header 'Content-Type: application/json' \
|
||||
--data '{
|
||||
"model": "llama3",
|
||||
"messages": [
|
||||
{
|
||||
"role": "user",
|
||||
"content": "tes4"
|
||||
}
|
||||
]
|
||||
}'
|
||||
```
|
||||
|
||||
|
|
|
@ -648,6 +648,20 @@ class LiteLLM_BudgetTable(LiteLLMBase):
|
|||
protected_namespaces = ()
|
||||
|
||||
|
||||
class LiteLLM_TeamMemberTable(LiteLLM_BudgetTable):
|
||||
"""
|
||||
Used to track spend of a user_id within a team_id
|
||||
"""
|
||||
|
||||
spend: Optional[float] = None
|
||||
user_id: Optional[str] = None
|
||||
team_id: Optional[str] = None
|
||||
budget_id: Optional[str] = None
|
||||
|
||||
class Config:
|
||||
protected_namespaces = ()
|
||||
|
||||
|
||||
class NewOrganizationRequest(LiteLLM_BudgetTable):
|
||||
organization_id: Optional[str] = None
|
||||
organization_alias: str
|
||||
|
@ -942,6 +956,7 @@ class LiteLLM_VerificationTokenView(LiteLLM_VerificationToken):
|
|||
team_blocked: bool = False
|
||||
soft_budget: Optional[float] = None
|
||||
team_model_aliases: Optional[Dict] = None
|
||||
team_member_spend: Optional[float] = None
|
||||
|
||||
# End User Params
|
||||
end_user_id: Optional[str] = None
|
||||
|
|
|
@ -797,12 +797,13 @@ async def user_api_key_auth(
|
|||
# Run checks for
|
||||
# 1. If token can call model
|
||||
# 2. If user_id for this token is in budget
|
||||
# 3. If 'user' passed to /chat/completions, /embeddings endpoint is in budget
|
||||
# 4. If token is expired
|
||||
# 5. If token spend is under Budget for the token
|
||||
# 6. If token spend per model is under budget per model
|
||||
# 7. If token spend is under team budget
|
||||
# 8. If team spend is under team budget
|
||||
# 3. If the user spend within their own team is within budget
|
||||
# 4. If 'user' passed to /chat/completions, /embeddings endpoint is in budget
|
||||
# 5. If token is expired
|
||||
# 6. If token spend is under Budget for the token
|
||||
# 7. If token spend per model is under budget per model
|
||||
# 8. If token spend is under team budget
|
||||
# 9. If team spend is under team budget
|
||||
|
||||
# Check 1. If token can call model
|
||||
_model_alias_map = {}
|
||||
|
@ -1000,6 +1001,43 @@ async def user_api_key_auth(
|
|||
raise Exception(
|
||||
f"ExceededBudget: User {valid_token.user_id} has exceeded their budget. Current spend: {user_current_spend}; Max Budget: {user_max_budget}"
|
||||
)
|
||||
# Check 3. Check if user is in their team budget
|
||||
if valid_token.team_member_spend is not None:
|
||||
if prisma_client is not None:
|
||||
|
||||
_cache_key = f"{valid_token.team_id}_{valid_token.user_id}"
|
||||
|
||||
team_member_info = await user_api_key_cache.async_get_cache(
|
||||
key=_cache_key
|
||||
)
|
||||
if team_member_info is None:
|
||||
team_member_info = (
|
||||
await prisma_client.db.litellm_teammembership.find_first(
|
||||
where={
|
||||
"user_id": valid_token.user_id,
|
||||
"team_id": valid_token.team_id,
|
||||
}, # type: ignore
|
||||
include={"litellm_budget_table": True},
|
||||
)
|
||||
)
|
||||
await user_api_key_cache.async_set_cache(
|
||||
key=_cache_key,
|
||||
value=team_member_info,
|
||||
ttl=UserAPIKeyCacheTTLEnum.user_information_cache.value,
|
||||
)
|
||||
|
||||
if (
|
||||
team_member_info is not None
|
||||
and team_member_info.litellm_budget_table is not None
|
||||
):
|
||||
team_member_budget = (
|
||||
team_member_info.litellm_budget_table.max_budget
|
||||
)
|
||||
if team_member_budget is not None and team_member_budget > 0:
|
||||
if valid_token.team_member_spend > team_member_budget:
|
||||
raise Exception(
|
||||
f"ExceededBudget: Crossed spend within team. UserID: {valid_token.user_id}, in team {valid_token.team_id} has exceeded their budget. Current spend: {valid_token.team_member_spend}; Max Budget: {team_member_budget}"
|
||||
)
|
||||
|
||||
# Check 3. If token is expired
|
||||
if valid_token.expires is not None:
|
||||
|
@ -1701,6 +1739,19 @@ async def update_database(
|
|||
response_cost
|
||||
+ prisma_client.team_list_transactons.get(team_id, 0)
|
||||
)
|
||||
|
||||
try:
|
||||
# Track spend of the team member within this team
|
||||
# key is "team_id::<value>::user_id::<value>"
|
||||
team_member_key = f"team_id::{team_id}::user_id::{user_id}"
|
||||
prisma_client.team_member_list_transactons[team_member_key] = (
|
||||
response_cost
|
||||
+ prisma_client.team_member_list_transactons.get(
|
||||
team_member_key, 0
|
||||
)
|
||||
)
|
||||
except:
|
||||
pass
|
||||
except Exception as e:
|
||||
verbose_proxy_logger.info(
|
||||
f"Update Team DB failed to execute - {str(e)}\n{traceback.format_exc()}"
|
||||
|
@ -1832,6 +1883,16 @@ async def update_cache(
|
|||
# Calculate the new cost by adding the existing cost and response_cost
|
||||
existing_spend_obj.team_spend = existing_team_spend + response_cost
|
||||
|
||||
if (
|
||||
existing_spend_obj is not None
|
||||
and getattr(existing_spend_obj, "team_member_spend", None) is not None
|
||||
):
|
||||
existing_team_member_spend = existing_spend_obj.team_member_spend or 0
|
||||
# Calculate the new cost by adding the existing cost and response_cost
|
||||
existing_spend_obj.team_member_spend = (
|
||||
existing_team_member_spend + response_cost
|
||||
)
|
||||
|
||||
# Update the cost column for the given token
|
||||
existing_spend_obj.spend = new_spend
|
||||
user_api_key_cache.set_cache(key=hashed_token, value=existing_spend_obj)
|
||||
|
|
|
@ -551,6 +551,7 @@ class PrismaClient:
|
|||
end_user_list_transactons: dict = {}
|
||||
key_list_transactons: dict = {}
|
||||
team_list_transactons: dict = {}
|
||||
team_member_list_transactons: dict = {} # key is ["team_id" + "user_id"]
|
||||
org_list_transactons: dict = {}
|
||||
spend_log_transactions: List = []
|
||||
|
||||
|
@ -1096,9 +1097,11 @@ class PrismaClient:
|
|||
t.models AS team_models,
|
||||
t.blocked AS team_blocked,
|
||||
t.team_alias AS team_alias,
|
||||
tm.spend AS team_member_spend,
|
||||
m.aliases as team_model_aliases
|
||||
FROM "LiteLLM_VerificationToken" AS v
|
||||
LEFT JOIN "LiteLLM_TeamTable" AS t ON v.team_id = t.team_id
|
||||
LEFT JOIN "LiteLLM_TeamMembership" AS tm ON v.team_id = tm.team_id AND tm.user_id = v.user_id
|
||||
LEFT JOIN "LiteLLM_ModelTable" m ON t.model_id = m.id
|
||||
WHERE v.token = '{token}'
|
||||
"""
|
||||
|
@ -2262,6 +2265,56 @@ async def update_spend(
|
|||
)
|
||||
raise e
|
||||
|
||||
### UPDATE TEAM Membership TABLE with spend ###
|
||||
if len(prisma_client.team_member_list_transactons.keys()) > 0:
|
||||
for i in range(n_retry_times + 1):
|
||||
start_time = time.time()
|
||||
try:
|
||||
async with prisma_client.db.tx(
|
||||
timeout=timedelta(seconds=60)
|
||||
) as transaction:
|
||||
async with transaction.batch_() as batcher:
|
||||
for (
|
||||
key,
|
||||
response_cost,
|
||||
) in prisma_client.team_member_list_transactons.items():
|
||||
# key is "team_id::<value>::user_id::<value>"
|
||||
team_id = key.split("::")[1]
|
||||
user_id = key.split("::")[3]
|
||||
|
||||
batcher.litellm_teammembership.update_many( # 'update_many' prevents error from being raised if no row exists
|
||||
where={"team_id": team_id, "user_id": user_id},
|
||||
data={"spend": {"increment": response_cost}},
|
||||
)
|
||||
prisma_client.team_member_list_transactons = (
|
||||
{}
|
||||
) # Clear the remaining transactions after processing all batches in the loop.
|
||||
break
|
||||
except httpx.ReadTimeout:
|
||||
if i >= n_retry_times: # If we've reached the maximum number of retries
|
||||
raise # Re-raise the last exception
|
||||
# Optionally, sleep for a bit before retrying
|
||||
await asyncio.sleep(2**i) # Exponential backoff
|
||||
except Exception as e:
|
||||
import traceback
|
||||
|
||||
error_msg = (
|
||||
f"LiteLLM Prisma Client Exception - update team spend: {str(e)}"
|
||||
)
|
||||
print_verbose(error_msg)
|
||||
error_traceback = error_msg + "\n" + traceback.format_exc()
|
||||
end_time = time.time()
|
||||
_duration = end_time - start_time
|
||||
asyncio.create_task(
|
||||
proxy_logging_obj.failure_handler(
|
||||
original_exception=e,
|
||||
duration=_duration,
|
||||
call_type="update_spend",
|
||||
traceback_str=error_traceback,
|
||||
)
|
||||
)
|
||||
raise e
|
||||
|
||||
### UPDATE ORG TABLE ###
|
||||
if len(prisma_client.org_list_transactons.keys()) > 0:
|
||||
for i in range(n_retry_times + 1):
|
||||
|
|
|
@ -8,7 +8,13 @@ from openai import AsyncOpenAI
|
|||
|
||||
|
||||
async def new_user(
|
||||
session, i, user_id=None, budget=None, budget_duration=None, models=["azure-models"]
|
||||
session,
|
||||
i,
|
||||
user_id=None,
|
||||
budget=None,
|
||||
budget_duration=None,
|
||||
models=["azure-models"],
|
||||
team_id=None,
|
||||
):
|
||||
url = "http://0.0.0.0:4000/user/new"
|
||||
headers = {"Authorization": "Bearer sk-1234", "Content-Type": "application/json"}
|
||||
|
@ -23,6 +29,9 @@ async def new_user(
|
|||
if user_id is not None:
|
||||
data["user_id"] = user_id
|
||||
|
||||
if team_id is not None:
|
||||
data["team_id"] = team_id
|
||||
|
||||
async with session.post(url, headers=headers, json=data) as response:
|
||||
status = response.status
|
||||
response_text = await response.text()
|
||||
|
@ -37,7 +46,9 @@ async def new_user(
|
|||
return await response.json()
|
||||
|
||||
|
||||
async def add_member(session, i, team_id, user_id=None, user_email=None):
|
||||
async def add_member(
|
||||
session, i, team_id, user_id=None, user_email=None, max_budget=None
|
||||
):
|
||||
url = "http://0.0.0.0:4000/team/member_add"
|
||||
headers = {"Authorization": "Bearer sk-1234", "Content-Type": "application/json"}
|
||||
data = {"team_id": team_id, "member": {"role": "user"}}
|
||||
|
@ -46,6 +57,9 @@ async def add_member(session, i, team_id, user_id=None, user_email=None):
|
|||
elif user_id is not None:
|
||||
data["member"]["user_id"] = user_id
|
||||
|
||||
if max_budget is not None:
|
||||
data["max_budget_in_team"] = max_budget
|
||||
|
||||
async with session.post(url, headers=headers, json=data) as response:
|
||||
status = response.status
|
||||
response_text = await response.text()
|
||||
|
@ -475,3 +489,50 @@ async def test_team_alias():
|
|||
key = key_gen["key"]
|
||||
## Test key
|
||||
response = await chat_completion(session=session, key=key, model="cheap-model")
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_users_in_team_budget():
|
||||
"""
|
||||
- Create Team
|
||||
- Create User
|
||||
- Add User to team with budget = 0.0000001
|
||||
- Make Call 1 -> pass
|
||||
- Make Call 2 -> fail
|
||||
"""
|
||||
get_user = f"krrish_{time.time()}@berri.ai"
|
||||
async with aiohttp.ClientSession() as session:
|
||||
team = await new_team(session, 0, user_id=get_user)
|
||||
print("New team=", team)
|
||||
key_gen = await new_user(
|
||||
session,
|
||||
0,
|
||||
user_id=get_user,
|
||||
budget=10,
|
||||
budget_duration="5s",
|
||||
team_id=team["team_id"],
|
||||
models=["fake-openai-endpoint"],
|
||||
)
|
||||
key = key_gen["key"]
|
||||
|
||||
# Add user to team
|
||||
await add_member(
|
||||
session, 0, team_id=team["team_id"], user_id=get_user, max_budget=0.0000001
|
||||
)
|
||||
|
||||
# Call 1
|
||||
result = await chat_completion(session, key, model="fake-openai-endpoint")
|
||||
print("Call 1 passed", result)
|
||||
|
||||
await asyncio.sleep(2)
|
||||
|
||||
# Call 2
|
||||
try:
|
||||
await chat_completion(session, key, model="fake-openai-endpoint")
|
||||
pytest.fail(
|
||||
"Call 2 should have failed. The user crossed their budget within their team"
|
||||
)
|
||||
except Exception as e:
|
||||
print("got exception, this is expected")
|
||||
print(e)
|
||||
assert "Crossed spend within team" in str(e)
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue