forked from phoenix/litellm-mirror
Merge pull request #3790 from BerriAI/litellm_set_team_member_budgets
[Feat] Set Budgets for Users within a Team
This commit is contained in:
commit
a8b64a01dc
5 changed files with 270 additions and 9 deletions
|
@ -473,4 +473,75 @@ curl --location 'http://0.0.0.0:4000/key/generate' \
|
||||||
--header 'Authorization: Bearer <your-master-key>' \
|
--header 'Authorization: Bearer <your-master-key>' \
|
||||||
--header 'Content-Type: application/json' \
|
--header 'Content-Type: application/json' \
|
||||||
--data '{"models": ["azure-models"], "user_id": "krrish@berri.ai"}'
|
--data '{"models": ["azure-models"], "user_id": "krrish@berri.ai"}'
|
||||||
```
|
```
|
||||||
|
|
||||||
|
|
||||||
|
## Advanced
|
||||||
|
### Set Budgets for Members within a Team
|
||||||
|
|
||||||
|
Use this when you want to budget a users spend within a Team
|
||||||
|
|
||||||
|
|
||||||
|
#### Step 1. Create User
|
||||||
|
|
||||||
|
Create a user with `user_id=ishaan`
|
||||||
|
|
||||||
|
```shell
|
||||||
|
curl --location 'http://0.0.0.0:4000/user/new' \
|
||||||
|
--header 'Authorization: Bearer sk-1234' \
|
||||||
|
--header 'Content-Type: application/json' \
|
||||||
|
--data '{
|
||||||
|
"user_id": "ishaan"
|
||||||
|
}'
|
||||||
|
```
|
||||||
|
|
||||||
|
#### Step 2. Add User to an existing Team - set `max_budget_in_team`
|
||||||
|
|
||||||
|
Set `max_budget_in_team` when adding a User to a team. We use the same `user_id` we set in Step 1
|
||||||
|
|
||||||
|
```shell
|
||||||
|
curl -X POST 'http://0.0.0.0:4000/team/member_add' \
|
||||||
|
-H 'Authorization: Bearer sk-1234' \
|
||||||
|
-H 'Content-Type: application/json' \
|
||||||
|
-d '{"team_id": "e8d1460f-846c-45d7-9b43-55f3cc52ac32", "max_budget_in_team": 0.000000000001, "member": {"role": "user", "user_id": "ishaan"}}'
|
||||||
|
```
|
||||||
|
|
||||||
|
#### Step 3. Create a Key for user from Step 1
|
||||||
|
|
||||||
|
Set `user_id=ishaan` from step 1
|
||||||
|
|
||||||
|
```shell
|
||||||
|
curl --location 'http://0.0.0.0:4000/key/generate' \
|
||||||
|
--header 'Authorization: Bearer sk-1234' \
|
||||||
|
--header 'Content-Type: application/json' \
|
||||||
|
--data '{
|
||||||
|
"user_id": "ishaan"
|
||||||
|
}'
|
||||||
|
```
|
||||||
|
Response from `/key/generate`
|
||||||
|
|
||||||
|
We use the `key` from this response in Step 4
|
||||||
|
```shell
|
||||||
|
{"models":[],"spend":0.0,"max_budget":null,"user_id":"ishaan","team_id":null,"max_parallel_requests":null,"metadata":{},"tpm_limit":null,"rpm_limit":null,"budget_duration":null,"allowed_cache_controls":[],"soft_budget":null,"key_alias":null,"duration":null,"aliases":{},"config":{},"permissions":{},"model_max_budget":{},"key":"sk-RV-l2BJEZ_LYNChSx2EueQ","key_name":null,"expires":null,"token_id":null}%
|
||||||
|
```
|
||||||
|
|
||||||
|
#### Step 4. Make /chat/completions requests for user
|
||||||
|
|
||||||
|
Use the key from step 3 for this request. After 2-3 requests expect to see The following error `ExceededBudget: Crossed spend within team`
|
||||||
|
|
||||||
|
|
||||||
|
```shell
|
||||||
|
curl --location 'http://localhost:4000/chat/completions' \
|
||||||
|
--header 'Authorization: Bearer sk-RV-l2BJEZ_LYNChSx2EueQ' \
|
||||||
|
--header 'Content-Type: application/json' \
|
||||||
|
--data '{
|
||||||
|
"model": "llama3",
|
||||||
|
"messages": [
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": "tes4"
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}'
|
||||||
|
```
|
||||||
|
|
||||||
|
|
|
@ -648,6 +648,20 @@ class LiteLLM_BudgetTable(LiteLLMBase):
|
||||||
protected_namespaces = ()
|
protected_namespaces = ()
|
||||||
|
|
||||||
|
|
||||||
|
class LiteLLM_TeamMemberTable(LiteLLM_BudgetTable):
|
||||||
|
"""
|
||||||
|
Used to track spend of a user_id within a team_id
|
||||||
|
"""
|
||||||
|
|
||||||
|
spend: Optional[float] = None
|
||||||
|
user_id: Optional[str] = None
|
||||||
|
team_id: Optional[str] = None
|
||||||
|
budget_id: Optional[str] = None
|
||||||
|
|
||||||
|
class Config:
|
||||||
|
protected_namespaces = ()
|
||||||
|
|
||||||
|
|
||||||
class NewOrganizationRequest(LiteLLM_BudgetTable):
|
class NewOrganizationRequest(LiteLLM_BudgetTable):
|
||||||
organization_id: Optional[str] = None
|
organization_id: Optional[str] = None
|
||||||
organization_alias: str
|
organization_alias: str
|
||||||
|
@ -942,6 +956,7 @@ class LiteLLM_VerificationTokenView(LiteLLM_VerificationToken):
|
||||||
team_blocked: bool = False
|
team_blocked: bool = False
|
||||||
soft_budget: Optional[float] = None
|
soft_budget: Optional[float] = None
|
||||||
team_model_aliases: Optional[Dict] = None
|
team_model_aliases: Optional[Dict] = None
|
||||||
|
team_member_spend: Optional[float] = None
|
||||||
|
|
||||||
# End User Params
|
# End User Params
|
||||||
end_user_id: Optional[str] = None
|
end_user_id: Optional[str] = None
|
||||||
|
|
|
@ -797,12 +797,13 @@ async def user_api_key_auth(
|
||||||
# Run checks for
|
# Run checks for
|
||||||
# 1. If token can call model
|
# 1. If token can call model
|
||||||
# 2. If user_id for this token is in budget
|
# 2. If user_id for this token is in budget
|
||||||
# 3. If 'user' passed to /chat/completions, /embeddings endpoint is in budget
|
# 3. If the user spend within their own team is within budget
|
||||||
# 4. If token is expired
|
# 4. If 'user' passed to /chat/completions, /embeddings endpoint is in budget
|
||||||
# 5. If token spend is under Budget for the token
|
# 5. If token is expired
|
||||||
# 6. If token spend per model is under budget per model
|
# 6. If token spend is under Budget for the token
|
||||||
# 7. If token spend is under team budget
|
# 7. If token spend per model is under budget per model
|
||||||
# 8. If team spend is under team budget
|
# 8. If token spend is under team budget
|
||||||
|
# 9. If team spend is under team budget
|
||||||
|
|
||||||
# Check 1. If token can call model
|
# Check 1. If token can call model
|
||||||
_model_alias_map = {}
|
_model_alias_map = {}
|
||||||
|
@ -1000,6 +1001,43 @@ async def user_api_key_auth(
|
||||||
raise Exception(
|
raise Exception(
|
||||||
f"ExceededBudget: User {valid_token.user_id} has exceeded their budget. Current spend: {user_current_spend}; Max Budget: {user_max_budget}"
|
f"ExceededBudget: User {valid_token.user_id} has exceeded their budget. Current spend: {user_current_spend}; Max Budget: {user_max_budget}"
|
||||||
)
|
)
|
||||||
|
# Check 3. Check if user is in their team budget
|
||||||
|
if valid_token.team_member_spend is not None:
|
||||||
|
if prisma_client is not None:
|
||||||
|
|
||||||
|
_cache_key = f"{valid_token.team_id}_{valid_token.user_id}"
|
||||||
|
|
||||||
|
team_member_info = await user_api_key_cache.async_get_cache(
|
||||||
|
key=_cache_key
|
||||||
|
)
|
||||||
|
if team_member_info is None:
|
||||||
|
team_member_info = (
|
||||||
|
await prisma_client.db.litellm_teammembership.find_first(
|
||||||
|
where={
|
||||||
|
"user_id": valid_token.user_id,
|
||||||
|
"team_id": valid_token.team_id,
|
||||||
|
}, # type: ignore
|
||||||
|
include={"litellm_budget_table": True},
|
||||||
|
)
|
||||||
|
)
|
||||||
|
await user_api_key_cache.async_set_cache(
|
||||||
|
key=_cache_key,
|
||||||
|
value=team_member_info,
|
||||||
|
ttl=UserAPIKeyCacheTTLEnum.user_information_cache.value,
|
||||||
|
)
|
||||||
|
|
||||||
|
if (
|
||||||
|
team_member_info is not None
|
||||||
|
and team_member_info.litellm_budget_table is not None
|
||||||
|
):
|
||||||
|
team_member_budget = (
|
||||||
|
team_member_info.litellm_budget_table.max_budget
|
||||||
|
)
|
||||||
|
if team_member_budget is not None and team_member_budget > 0:
|
||||||
|
if valid_token.team_member_spend > team_member_budget:
|
||||||
|
raise Exception(
|
||||||
|
f"ExceededBudget: Crossed spend within team. UserID: {valid_token.user_id}, in team {valid_token.team_id} has exceeded their budget. Current spend: {valid_token.team_member_spend}; Max Budget: {team_member_budget}"
|
||||||
|
)
|
||||||
|
|
||||||
# Check 3. If token is expired
|
# Check 3. If token is expired
|
||||||
if valid_token.expires is not None:
|
if valid_token.expires is not None:
|
||||||
|
@ -1701,6 +1739,19 @@ async def update_database(
|
||||||
response_cost
|
response_cost
|
||||||
+ prisma_client.team_list_transactons.get(team_id, 0)
|
+ prisma_client.team_list_transactons.get(team_id, 0)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Track spend of the team member within this team
|
||||||
|
# key is "team_id::<value>::user_id::<value>"
|
||||||
|
team_member_key = f"team_id::{team_id}::user_id::{user_id}"
|
||||||
|
prisma_client.team_member_list_transactons[team_member_key] = (
|
||||||
|
response_cost
|
||||||
|
+ prisma_client.team_member_list_transactons.get(
|
||||||
|
team_member_key, 0
|
||||||
|
)
|
||||||
|
)
|
||||||
|
except:
|
||||||
|
pass
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
verbose_proxy_logger.info(
|
verbose_proxy_logger.info(
|
||||||
f"Update Team DB failed to execute - {str(e)}\n{traceback.format_exc()}"
|
f"Update Team DB failed to execute - {str(e)}\n{traceback.format_exc()}"
|
||||||
|
@ -1832,6 +1883,16 @@ async def update_cache(
|
||||||
# Calculate the new cost by adding the existing cost and response_cost
|
# Calculate the new cost by adding the existing cost and response_cost
|
||||||
existing_spend_obj.team_spend = existing_team_spend + response_cost
|
existing_spend_obj.team_spend = existing_team_spend + response_cost
|
||||||
|
|
||||||
|
if (
|
||||||
|
existing_spend_obj is not None
|
||||||
|
and getattr(existing_spend_obj, "team_member_spend", None) is not None
|
||||||
|
):
|
||||||
|
existing_team_member_spend = existing_spend_obj.team_member_spend or 0
|
||||||
|
# Calculate the new cost by adding the existing cost and response_cost
|
||||||
|
existing_spend_obj.team_member_spend = (
|
||||||
|
existing_team_member_spend + response_cost
|
||||||
|
)
|
||||||
|
|
||||||
# Update the cost column for the given token
|
# Update the cost column for the given token
|
||||||
existing_spend_obj.spend = new_spend
|
existing_spend_obj.spend = new_spend
|
||||||
user_api_key_cache.set_cache(key=hashed_token, value=existing_spend_obj)
|
user_api_key_cache.set_cache(key=hashed_token, value=existing_spend_obj)
|
||||||
|
|
|
@ -551,6 +551,7 @@ class PrismaClient:
|
||||||
end_user_list_transactons: dict = {}
|
end_user_list_transactons: dict = {}
|
||||||
key_list_transactons: dict = {}
|
key_list_transactons: dict = {}
|
||||||
team_list_transactons: dict = {}
|
team_list_transactons: dict = {}
|
||||||
|
team_member_list_transactons: dict = {} # key is ["team_id" + "user_id"]
|
||||||
org_list_transactons: dict = {}
|
org_list_transactons: dict = {}
|
||||||
spend_log_transactions: List = []
|
spend_log_transactions: List = []
|
||||||
|
|
||||||
|
@ -1096,9 +1097,11 @@ class PrismaClient:
|
||||||
t.models AS team_models,
|
t.models AS team_models,
|
||||||
t.blocked AS team_blocked,
|
t.blocked AS team_blocked,
|
||||||
t.team_alias AS team_alias,
|
t.team_alias AS team_alias,
|
||||||
|
tm.spend AS team_member_spend,
|
||||||
m.aliases as team_model_aliases
|
m.aliases as team_model_aliases
|
||||||
FROM "LiteLLM_VerificationToken" AS v
|
FROM "LiteLLM_VerificationToken" AS v
|
||||||
LEFT JOIN "LiteLLM_TeamTable" AS t ON v.team_id = t.team_id
|
LEFT JOIN "LiteLLM_TeamTable" AS t ON v.team_id = t.team_id
|
||||||
|
LEFT JOIN "LiteLLM_TeamMembership" AS tm ON v.team_id = tm.team_id AND tm.user_id = v.user_id
|
||||||
LEFT JOIN "LiteLLM_ModelTable" m ON t.model_id = m.id
|
LEFT JOIN "LiteLLM_ModelTable" m ON t.model_id = m.id
|
||||||
WHERE v.token = '{token}'
|
WHERE v.token = '{token}'
|
||||||
"""
|
"""
|
||||||
|
@ -2262,6 +2265,56 @@ async def update_spend(
|
||||||
)
|
)
|
||||||
raise e
|
raise e
|
||||||
|
|
||||||
|
### UPDATE TEAM Membership TABLE with spend ###
|
||||||
|
if len(prisma_client.team_member_list_transactons.keys()) > 0:
|
||||||
|
for i in range(n_retry_times + 1):
|
||||||
|
start_time = time.time()
|
||||||
|
try:
|
||||||
|
async with prisma_client.db.tx(
|
||||||
|
timeout=timedelta(seconds=60)
|
||||||
|
) as transaction:
|
||||||
|
async with transaction.batch_() as batcher:
|
||||||
|
for (
|
||||||
|
key,
|
||||||
|
response_cost,
|
||||||
|
) in prisma_client.team_member_list_transactons.items():
|
||||||
|
# key is "team_id::<value>::user_id::<value>"
|
||||||
|
team_id = key.split("::")[1]
|
||||||
|
user_id = key.split("::")[3]
|
||||||
|
|
||||||
|
batcher.litellm_teammembership.update_many( # 'update_many' prevents error from being raised if no row exists
|
||||||
|
where={"team_id": team_id, "user_id": user_id},
|
||||||
|
data={"spend": {"increment": response_cost}},
|
||||||
|
)
|
||||||
|
prisma_client.team_member_list_transactons = (
|
||||||
|
{}
|
||||||
|
) # Clear the remaining transactions after processing all batches in the loop.
|
||||||
|
break
|
||||||
|
except httpx.ReadTimeout:
|
||||||
|
if i >= n_retry_times: # If we've reached the maximum number of retries
|
||||||
|
raise # Re-raise the last exception
|
||||||
|
# Optionally, sleep for a bit before retrying
|
||||||
|
await asyncio.sleep(2**i) # Exponential backoff
|
||||||
|
except Exception as e:
|
||||||
|
import traceback
|
||||||
|
|
||||||
|
error_msg = (
|
||||||
|
f"LiteLLM Prisma Client Exception - update team spend: {str(e)}"
|
||||||
|
)
|
||||||
|
print_verbose(error_msg)
|
||||||
|
error_traceback = error_msg + "\n" + traceback.format_exc()
|
||||||
|
end_time = time.time()
|
||||||
|
_duration = end_time - start_time
|
||||||
|
asyncio.create_task(
|
||||||
|
proxy_logging_obj.failure_handler(
|
||||||
|
original_exception=e,
|
||||||
|
duration=_duration,
|
||||||
|
call_type="update_spend",
|
||||||
|
traceback_str=error_traceback,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
raise e
|
||||||
|
|
||||||
### UPDATE ORG TABLE ###
|
### UPDATE ORG TABLE ###
|
||||||
if len(prisma_client.org_list_transactons.keys()) > 0:
|
if len(prisma_client.org_list_transactons.keys()) > 0:
|
||||||
for i in range(n_retry_times + 1):
|
for i in range(n_retry_times + 1):
|
||||||
|
|
|
@ -8,7 +8,13 @@ from openai import AsyncOpenAI
|
||||||
|
|
||||||
|
|
||||||
async def new_user(
|
async def new_user(
|
||||||
session, i, user_id=None, budget=None, budget_duration=None, models=["azure-models"]
|
session,
|
||||||
|
i,
|
||||||
|
user_id=None,
|
||||||
|
budget=None,
|
||||||
|
budget_duration=None,
|
||||||
|
models=["azure-models"],
|
||||||
|
team_id=None,
|
||||||
):
|
):
|
||||||
url = "http://0.0.0.0:4000/user/new"
|
url = "http://0.0.0.0:4000/user/new"
|
||||||
headers = {"Authorization": "Bearer sk-1234", "Content-Type": "application/json"}
|
headers = {"Authorization": "Bearer sk-1234", "Content-Type": "application/json"}
|
||||||
|
@ -23,6 +29,9 @@ async def new_user(
|
||||||
if user_id is not None:
|
if user_id is not None:
|
||||||
data["user_id"] = user_id
|
data["user_id"] = user_id
|
||||||
|
|
||||||
|
if team_id is not None:
|
||||||
|
data["team_id"] = team_id
|
||||||
|
|
||||||
async with session.post(url, headers=headers, json=data) as response:
|
async with session.post(url, headers=headers, json=data) as response:
|
||||||
status = response.status
|
status = response.status
|
||||||
response_text = await response.text()
|
response_text = await response.text()
|
||||||
|
@ -37,7 +46,9 @@ async def new_user(
|
||||||
return await response.json()
|
return await response.json()
|
||||||
|
|
||||||
|
|
||||||
async def add_member(session, i, team_id, user_id=None, user_email=None):
|
async def add_member(
|
||||||
|
session, i, team_id, user_id=None, user_email=None, max_budget=None
|
||||||
|
):
|
||||||
url = "http://0.0.0.0:4000/team/member_add"
|
url = "http://0.0.0.0:4000/team/member_add"
|
||||||
headers = {"Authorization": "Bearer sk-1234", "Content-Type": "application/json"}
|
headers = {"Authorization": "Bearer sk-1234", "Content-Type": "application/json"}
|
||||||
data = {"team_id": team_id, "member": {"role": "user"}}
|
data = {"team_id": team_id, "member": {"role": "user"}}
|
||||||
|
@ -46,6 +57,9 @@ async def add_member(session, i, team_id, user_id=None, user_email=None):
|
||||||
elif user_id is not None:
|
elif user_id is not None:
|
||||||
data["member"]["user_id"] = user_id
|
data["member"]["user_id"] = user_id
|
||||||
|
|
||||||
|
if max_budget is not None:
|
||||||
|
data["max_budget_in_team"] = max_budget
|
||||||
|
|
||||||
async with session.post(url, headers=headers, json=data) as response:
|
async with session.post(url, headers=headers, json=data) as response:
|
||||||
status = response.status
|
status = response.status
|
||||||
response_text = await response.text()
|
response_text = await response.text()
|
||||||
|
@ -475,3 +489,50 @@ async def test_team_alias():
|
||||||
key = key_gen["key"]
|
key = key_gen["key"]
|
||||||
## Test key
|
## Test key
|
||||||
response = await chat_completion(session=session, key=key, model="cheap-model")
|
response = await chat_completion(session=session, key=key, model="cheap-model")
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_users_in_team_budget():
|
||||||
|
"""
|
||||||
|
- Create Team
|
||||||
|
- Create User
|
||||||
|
- Add User to team with budget = 0.0000001
|
||||||
|
- Make Call 1 -> pass
|
||||||
|
- Make Call 2 -> fail
|
||||||
|
"""
|
||||||
|
get_user = f"krrish_{time.time()}@berri.ai"
|
||||||
|
async with aiohttp.ClientSession() as session:
|
||||||
|
team = await new_team(session, 0, user_id=get_user)
|
||||||
|
print("New team=", team)
|
||||||
|
key_gen = await new_user(
|
||||||
|
session,
|
||||||
|
0,
|
||||||
|
user_id=get_user,
|
||||||
|
budget=10,
|
||||||
|
budget_duration="5s",
|
||||||
|
team_id=team["team_id"],
|
||||||
|
models=["fake-openai-endpoint"],
|
||||||
|
)
|
||||||
|
key = key_gen["key"]
|
||||||
|
|
||||||
|
# Add user to team
|
||||||
|
await add_member(
|
||||||
|
session, 0, team_id=team["team_id"], user_id=get_user, max_budget=0.0000001
|
||||||
|
)
|
||||||
|
|
||||||
|
# Call 1
|
||||||
|
result = await chat_completion(session, key, model="fake-openai-endpoint")
|
||||||
|
print("Call 1 passed", result)
|
||||||
|
|
||||||
|
await asyncio.sleep(2)
|
||||||
|
|
||||||
|
# Call 2
|
||||||
|
try:
|
||||||
|
await chat_completion(session, key, model="fake-openai-endpoint")
|
||||||
|
pytest.fail(
|
||||||
|
"Call 2 should have failed. The user crossed their budget within their team"
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
print("got exception, this is expected")
|
||||||
|
print(e)
|
||||||
|
assert "Crossed spend within team" in str(e)
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue