forked from phoenix/litellm-mirror
fix(management_helpers/utils.py): use user_default max_budget, budget duration on new user upsert during team member add
Fixes https://github.com/BerriAI/litellm/issues/5106
This commit is contained in:
parent
6af9d9d2b3
commit
1d39c0fb7d
3 changed files with 109 additions and 8 deletions
|
@ -420,7 +420,6 @@ async def update_team(
|
||||||
@management_endpoint_wrapper
|
@management_endpoint_wrapper
|
||||||
async def team_member_add(
|
async def team_member_add(
|
||||||
data: TeamMemberAddRequest,
|
data: TeamMemberAddRequest,
|
||||||
http_request: Request,
|
|
||||||
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
|
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
|
@ -442,8 +441,11 @@ async def team_member_add(
|
||||||
from litellm.proxy.proxy_server import (
|
from litellm.proxy.proxy_server import (
|
||||||
_duration_in_seconds,
|
_duration_in_seconds,
|
||||||
create_audit_log_for_update,
|
create_audit_log_for_update,
|
||||||
|
get_team_object,
|
||||||
litellm_proxy_admin_name,
|
litellm_proxy_admin_name,
|
||||||
prisma_client,
|
prisma_client,
|
||||||
|
proxy_logging_obj,
|
||||||
|
user_api_key_cache,
|
||||||
)
|
)
|
||||||
|
|
||||||
if prisma_client is None:
|
if prisma_client is None:
|
||||||
|
@ -457,8 +459,13 @@ async def team_member_add(
|
||||||
status_code=400, detail={"error": "No member/members passed in"}
|
status_code=400, detail={"error": "No member/members passed in"}
|
||||||
)
|
)
|
||||||
|
|
||||||
existing_team_row = await prisma_client.db.litellm_teamtable.find_unique(
|
existing_team_row = await get_team_object(
|
||||||
where={"team_id": data.team_id}
|
team_id=data.team_id,
|
||||||
|
prisma_client=prisma_client,
|
||||||
|
user_api_key_cache=user_api_key_cache,
|
||||||
|
parent_otel_span=None,
|
||||||
|
proxy_logging_obj=proxy_logging_obj,
|
||||||
|
check_cache_only=False,
|
||||||
)
|
)
|
||||||
if existing_team_row is None:
|
if existing_team_row is None:
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
|
|
|
@ -7,6 +7,7 @@ from typing import Optional
|
||||||
|
|
||||||
from fastapi import Request
|
from fastapi import Request
|
||||||
|
|
||||||
|
import litellm
|
||||||
from litellm._logging import verbose_logger
|
from litellm._logging import verbose_logger
|
||||||
from litellm.proxy._types import ( # key request types; user request types; team request types; customer request types
|
from litellm.proxy._types import ( # key request types; user request types; team request types; customer request types
|
||||||
DeleteCustomerRequest,
|
DeleteCustomerRequest,
|
||||||
|
@ -16,6 +17,7 @@ from litellm.proxy._types import ( # key request types; user request types; tea
|
||||||
LiteLLM_TeamTable,
|
LiteLLM_TeamTable,
|
||||||
ManagementEndpointLoggingPayload,
|
ManagementEndpointLoggingPayload,
|
||||||
Member,
|
Member,
|
||||||
|
SSOUserDefinedValues,
|
||||||
UpdateCustomerRequest,
|
UpdateCustomerRequest,
|
||||||
UpdateKeyRequest,
|
UpdateKeyRequest,
|
||||||
UpdateTeamRequest,
|
UpdateTeamRequest,
|
||||||
|
@ -26,6 +28,25 @@ from litellm.proxy.common_utils.http_parsing_utils import _read_request_body
|
||||||
from litellm.proxy.utils import PrismaClient
|
from litellm.proxy.utils import PrismaClient
|
||||||
|
|
||||||
|
|
||||||
|
def get_new_internal_user_defaults(
|
||||||
|
user_id: str, user_email: Optional[str] = None
|
||||||
|
) -> SSOUserDefinedValues:
|
||||||
|
user_info = litellm.default_user_params or {}
|
||||||
|
|
||||||
|
returned_dict: SSOUserDefinedValues = {
|
||||||
|
"models": user_info.get("models", None),
|
||||||
|
"max_budget": user_info.get("max_budget", litellm.max_internal_user_budget),
|
||||||
|
"budget_duration": user_info.get(
|
||||||
|
"budget_duration", litellm.internal_user_budget_duration
|
||||||
|
),
|
||||||
|
"user_email": user_email or user_info.get("user_email", None),
|
||||||
|
"user_id": user_id,
|
||||||
|
"user_role": "internal_user",
|
||||||
|
}
|
||||||
|
|
||||||
|
return returned_dict
|
||||||
|
|
||||||
|
|
||||||
async def add_new_member(
|
async def add_new_member(
|
||||||
new_member: Member,
|
new_member: Member,
|
||||||
max_budget_in_team: Optional[float],
|
max_budget_in_team: Optional[float],
|
||||||
|
@ -42,15 +63,18 @@ async def add_new_member(
|
||||||
"""
|
"""
|
||||||
## ADD TEAM ID, to USER TABLE IF NEW ##
|
## ADD TEAM ID, to USER TABLE IF NEW ##
|
||||||
if new_member.user_id is not None:
|
if new_member.user_id is not None:
|
||||||
|
new_user_defaults = get_new_internal_user_defaults(user_id=new_member.user_id)
|
||||||
await prisma_client.db.litellm_usertable.upsert(
|
await prisma_client.db.litellm_usertable.upsert(
|
||||||
where={"user_id": new_member.user_id},
|
where={"user_id": new_member.user_id},
|
||||||
data={
|
data={
|
||||||
"update": {"teams": {"push": [team_id]}},
|
"update": {"teams": {"push": [team_id]}},
|
||||||
"create": {"user_id": new_member.user_id, "teams": [team_id]},
|
"create": {"teams": [team_id], **new_user_defaults},
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
elif new_member.user_email is not None:
|
elif new_member.user_email is not None:
|
||||||
user_data = {"user_id": str(uuid.uuid4()), "user_email": new_member.user_email}
|
new_user_defaults = get_new_internal_user_defaults(
|
||||||
|
user_id=str(uuid.uuid4()), user_email=new_member.user_email
|
||||||
|
)
|
||||||
## user email is not unique acc. to prisma schema -> future improvement
|
## user email is not unique acc. to prisma schema -> future improvement
|
||||||
### for now: check if it exists in db, if not - insert it
|
### for now: check if it exists in db, if not - insert it
|
||||||
existing_user_row = await prisma_client.get_data(
|
existing_user_row = await prisma_client.get_data(
|
||||||
|
@ -62,7 +86,7 @@ async def add_new_member(
|
||||||
isinstance(existing_user_row, list) and len(existing_user_row) == 0
|
isinstance(existing_user_row, list) and len(existing_user_row) == 0
|
||||||
):
|
):
|
||||||
|
|
||||||
await prisma_client.insert_data(data=user_data, table_name="user")
|
await prisma_client.insert_data(data=new_user_defaults, table_name="user") # type: ignore
|
||||||
|
|
||||||
# Check if trying to set a budget for team member
|
# Check if trying to set a budget for team member
|
||||||
if max_budget_in_team is not None and new_member.user_id is not None:
|
if max_budget_in_team is not None and new_member.user_id is not None:
|
||||||
|
|
|
@ -804,10 +804,16 @@ async def test_get_team_redis(client_no_auth):
|
||||||
|
|
||||||
import random
|
import random
|
||||||
import uuid
|
import uuid
|
||||||
from unittest.mock import AsyncMock, MagicMock, patch
|
from unittest.mock import AsyncMock, MagicMock, PropertyMock, patch
|
||||||
|
|
||||||
from litellm.proxy._types import LitellmUserRoles, NewUserRequest, UserAPIKeyAuth
|
from litellm.proxy._types import (
|
||||||
|
LitellmUserRoles,
|
||||||
|
NewUserRequest,
|
||||||
|
TeamMemberAddRequest,
|
||||||
|
UserAPIKeyAuth,
|
||||||
|
)
|
||||||
from litellm.proxy.management_endpoints.internal_user_endpoints import new_user
|
from litellm.proxy.management_endpoints.internal_user_endpoints import new_user
|
||||||
|
from litellm.proxy.management_endpoints.team_endpoints import team_member_add
|
||||||
from litellm.tests.test_key_generate_prisma import prisma_client
|
from litellm.tests.test_key_generate_prisma import prisma_client
|
||||||
|
|
||||||
|
|
||||||
|
@ -852,3 +858,67 @@ async def test_create_user_default_budget(prisma_client, user_role):
|
||||||
else:
|
else:
|
||||||
assert mock_client.call_args.kwargs["data"]["max_budget"] is None
|
assert mock_client.call_args.kwargs["data"]["max_budget"] is None
|
||||||
assert mock_client.call_args.kwargs["data"]["budget_duration"] is None
|
assert mock_client.call_args.kwargs["data"]["budget_duration"] is None
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("new_member_method", ["user_id", "user_email"])
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_create_team_member_add(prisma_client, new_member_method):
|
||||||
|
import time
|
||||||
|
|
||||||
|
from litellm.proxy._types import LiteLLM_TeamTableCachedObj
|
||||||
|
from litellm.proxy.proxy_server import hash_token, user_api_key_cache
|
||||||
|
|
||||||
|
setattr(litellm.proxy.proxy_server, "prisma_client", prisma_client)
|
||||||
|
setattr(litellm.proxy.proxy_server, "master_key", "sk-1234")
|
||||||
|
setattr(litellm, "max_internal_user_budget", 10)
|
||||||
|
setattr(litellm, "internal_user_budget_duration", "5m")
|
||||||
|
await litellm.proxy.proxy_server.prisma_client.connect()
|
||||||
|
user = f"ishaan {uuid.uuid4().hex}"
|
||||||
|
_team_id = "litellm-test-client-id-new"
|
||||||
|
team_obj = LiteLLM_TeamTableCachedObj(
|
||||||
|
team_id=_team_id,
|
||||||
|
blocked=False,
|
||||||
|
last_refreshed_at=time.time(),
|
||||||
|
metadata={"guardrails": {"modify_guardrails": False}},
|
||||||
|
)
|
||||||
|
# user_api_key_cache.set_cache(key=hash_token(user_key), value=valid_token)
|
||||||
|
user_api_key_cache.set_cache(key="team_id:{}".format(_team_id), value=team_obj)
|
||||||
|
|
||||||
|
setattr(litellm.proxy.proxy_server, "user_api_key_cache", user_api_key_cache)
|
||||||
|
if new_member_method == "user_id":
|
||||||
|
data = {
|
||||||
|
"team_id": _team_id,
|
||||||
|
"member": [{"role": "user", "user_id": user}],
|
||||||
|
}
|
||||||
|
elif new_member_method == "user_email":
|
||||||
|
data = {
|
||||||
|
"team_id": _team_id,
|
||||||
|
"member": [{"role": "user", "user_email": user}],
|
||||||
|
}
|
||||||
|
team_member_add_request = TeamMemberAddRequest(**data)
|
||||||
|
|
||||||
|
with patch(
|
||||||
|
"litellm.proxy.proxy_server.prisma_client.db.litellm_usertable",
|
||||||
|
new_callable=AsyncMock,
|
||||||
|
) as mock_litellm_usertable:
|
||||||
|
mock_client = AsyncMock()
|
||||||
|
mock_litellm_usertable.upsert = mock_client
|
||||||
|
mock_litellm_usertable.find_many = AsyncMock(return_value=None)
|
||||||
|
|
||||||
|
await team_member_add(
|
||||||
|
data=team_member_add_request, user_api_key_dict=UserAPIKeyAuth()
|
||||||
|
)
|
||||||
|
|
||||||
|
mock_client.assert_called()
|
||||||
|
|
||||||
|
print(f"mock_client.call_args: {mock_client.call_args}")
|
||||||
|
print("mock_client.call_args.kwargs: {}".format(mock_client.call_args.kwargs))
|
||||||
|
|
||||||
|
assert (
|
||||||
|
mock_client.call_args.kwargs["data"]["create"]["max_budget"]
|
||||||
|
== litellm.max_internal_user_budget
|
||||||
|
)
|
||||||
|
assert (
|
||||||
|
mock_client.call_args.kwargs["data"]["create"]["budget_duration"]
|
||||||
|
== litellm.internal_user_budget_duration
|
||||||
|
)
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue