forked from phoenix/litellm-mirror
fix(key_management_endpoints.py): fix user-membership check when creating team key
This commit is contained in:
parent
3d8c0bad58
commit
177acd1c93
3 changed files with 84 additions and 13 deletions
|
@ -12,3 +12,10 @@ model_list:
|
||||||
vertex_ai_project: "adroit-crow-413218"
|
vertex_ai_project: "adroit-crow-413218"
|
||||||
vertex_ai_location: "us-east5"
|
vertex_ai_location: "us-east5"
|
||||||
|
|
||||||
|
litellm_settings:
|
||||||
|
key_generation_settings:
|
||||||
|
team_key_generation:
|
||||||
|
allowed_team_member_roles: ["admin"]
|
||||||
|
required_params: ["tags"] # require team admins to set tags for cost-tracking when generating a team key
|
||||||
|
personal_key_generation: # maps to 'Default Team' on UI
|
||||||
|
allowed_user_roles: ["proxy_admin"]
|
|
@ -29,6 +29,7 @@ from litellm.proxy.auth.auth_checks import (
|
||||||
_cache_key_object,
|
_cache_key_object,
|
||||||
_delete_cache_key_object,
|
_delete_cache_key_object,
|
||||||
get_key_object,
|
get_key_object,
|
||||||
|
get_team_object,
|
||||||
)
|
)
|
||||||
from litellm.proxy.auth.user_api_key_auth import user_api_key_auth
|
from litellm.proxy.auth.user_api_key_auth import user_api_key_auth
|
||||||
from litellm.proxy.hooks.key_management_event_hooks import KeyManagementEventHooks
|
from litellm.proxy.hooks.key_management_event_hooks import KeyManagementEventHooks
|
||||||
|
@ -46,7 +47,19 @@ def _is_team_key(data: GenerateKeyRequest):
|
||||||
return data.team_id is not None
|
return data.team_id is not None
|
||||||
|
|
||||||
|
|
||||||
|
def _get_user_in_team(
|
||||||
|
team_table: LiteLLM_TeamTableCachedObj, user_id: Optional[str]
|
||||||
|
) -> Optional[Member]:
|
||||||
|
if user_id is None:
|
||||||
|
return None
|
||||||
|
for member in team_table.members_with_roles:
|
||||||
|
if member.user_id is not None and member.user_id == user_id:
|
||||||
|
return member
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
def _team_key_generation_team_member_check(
|
def _team_key_generation_team_member_check(
|
||||||
|
team_table: LiteLLM_TeamTableCachedObj,
|
||||||
user_api_key_dict: UserAPIKeyAuth,
|
user_api_key_dict: UserAPIKeyAuth,
|
||||||
team_key_generation: Optional[TeamUIKeyGenerationConfig],
|
team_key_generation: Optional[TeamUIKeyGenerationConfig],
|
||||||
):
|
):
|
||||||
|
@ -56,17 +69,19 @@ def _team_key_generation_team_member_check(
|
||||||
):
|
):
|
||||||
return True
|
return True
|
||||||
|
|
||||||
if user_api_key_dict.team_member is None:
|
user_in_team = _get_user_in_team(
|
||||||
|
team_table=team_table, user_id=user_api_key_dict.user_id
|
||||||
|
)
|
||||||
|
if user_in_team is None:
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status_code=400,
|
status_code=400,
|
||||||
detail=f"User not assigned to team. Got team_member={user_api_key_dict.team_member}",
|
detail=f"User={user_api_key_dict.user_id} not assigned to team={team_table.team_id}",
|
||||||
)
|
)
|
||||||
|
|
||||||
team_member_role = user_api_key_dict.team_member.role
|
if user_in_team.role not in team_key_generation["allowed_team_member_roles"]:
|
||||||
if team_member_role not in team_key_generation["allowed_team_member_roles"]:
|
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status_code=400,
|
status_code=400,
|
||||||
detail=f"Team member role {team_member_role} not in allowed_team_member_roles={litellm.key_generation_settings['team_key_generation']['allowed_team_member_roles']}", # type: ignore
|
detail=f"Team member role {user_in_team.role} not in allowed_team_member_roles={team_key_generation['allowed_team_member_roles']}",
|
||||||
)
|
)
|
||||||
return True
|
return True
|
||||||
|
|
||||||
|
@ -88,7 +103,9 @@ def _key_generation_required_param_check(
|
||||||
|
|
||||||
|
|
||||||
def _team_key_generation_check(
|
def _team_key_generation_check(
|
||||||
user_api_key_dict: UserAPIKeyAuth, data: GenerateKeyRequest
|
team_table: LiteLLM_TeamTableCachedObj,
|
||||||
|
user_api_key_dict: UserAPIKeyAuth,
|
||||||
|
data: GenerateKeyRequest,
|
||||||
):
|
):
|
||||||
if (
|
if (
|
||||||
litellm.key_generation_settings is None
|
litellm.key_generation_settings is None
|
||||||
|
@ -99,7 +116,8 @@ def _team_key_generation_check(
|
||||||
_team_key_generation = litellm.key_generation_settings["team_key_generation"] # type: ignore
|
_team_key_generation = litellm.key_generation_settings["team_key_generation"] # type: ignore
|
||||||
|
|
||||||
_team_key_generation_team_member_check(
|
_team_key_generation_team_member_check(
|
||||||
user_api_key_dict,
|
team_table=team_table,
|
||||||
|
user_api_key_dict=user_api_key_dict,
|
||||||
team_key_generation=_team_key_generation,
|
team_key_generation=_team_key_generation,
|
||||||
)
|
)
|
||||||
_key_generation_required_param_check(
|
_key_generation_required_param_check(
|
||||||
|
@ -155,7 +173,9 @@ def _personal_key_generation_check(
|
||||||
|
|
||||||
|
|
||||||
def key_generation_check(
|
def key_generation_check(
|
||||||
user_api_key_dict: UserAPIKeyAuth, data: GenerateKeyRequest
|
team_table: Optional[LiteLLM_TeamTableCachedObj],
|
||||||
|
user_api_key_dict: UserAPIKeyAuth,
|
||||||
|
data: GenerateKeyRequest,
|
||||||
) -> bool:
|
) -> bool:
|
||||||
"""
|
"""
|
||||||
Check if admin has restricted key creation to certain roles for teams or individuals
|
Check if admin has restricted key creation to certain roles for teams or individuals
|
||||||
|
@ -170,8 +190,15 @@ def key_generation_check(
|
||||||
is_team_key = _is_team_key(data=data)
|
is_team_key = _is_team_key(data=data)
|
||||||
|
|
||||||
if is_team_key:
|
if is_team_key:
|
||||||
|
if team_table is None:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=400,
|
||||||
|
detail=f"Unable to find team object in database. Team ID: {data.team_id}",
|
||||||
|
)
|
||||||
return _team_key_generation_check(
|
return _team_key_generation_check(
|
||||||
user_api_key_dict=user_api_key_dict, data=data
|
team_table=team_table,
|
||||||
|
user_api_key_dict=user_api_key_dict,
|
||||||
|
data=data,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
return _personal_key_generation_check(
|
return _personal_key_generation_check(
|
||||||
|
@ -254,6 +281,7 @@ async def generate_key_fn( # noqa: PLR0915
|
||||||
litellm_proxy_admin_name,
|
litellm_proxy_admin_name,
|
||||||
prisma_client,
|
prisma_client,
|
||||||
proxy_logging_obj,
|
proxy_logging_obj,
|
||||||
|
user_api_key_cache,
|
||||||
user_custom_key_generate,
|
user_custom_key_generate,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -271,7 +299,20 @@ async def generate_key_fn( # noqa: PLR0915
|
||||||
status_code=status.HTTP_403_FORBIDDEN, detail=message
|
status_code=status.HTTP_403_FORBIDDEN, detail=message
|
||||||
)
|
)
|
||||||
elif litellm.key_generation_settings is not None:
|
elif litellm.key_generation_settings is not None:
|
||||||
key_generation_check(user_api_key_dict=user_api_key_dict, data=data)
|
if data.team_id is None:
|
||||||
|
team_table: Optional[LiteLLM_TeamTableCachedObj] = None
|
||||||
|
else:
|
||||||
|
team_table = await get_team_object(
|
||||||
|
team_id=data.team_id,
|
||||||
|
prisma_client=prisma_client,
|
||||||
|
user_api_key_cache=user_api_key_cache,
|
||||||
|
parent_otel_span=user_api_key_dict.parent_otel_span,
|
||||||
|
)
|
||||||
|
key_generation_check(
|
||||||
|
team_table=team_table,
|
||||||
|
user_api_key_dict=user_api_key_dict,
|
||||||
|
data=data,
|
||||||
|
)
|
||||||
# check if user set default key/generate params on config.yaml
|
# check if user set default key/generate params on config.yaml
|
||||||
if litellm.default_key_generate_params is not None:
|
if litellm.default_key_generate_params is not None:
|
||||||
for elem in data:
|
for elem in data:
|
||||||
|
|
|
@ -556,13 +556,22 @@ def test_team_key_generation_team_member_check():
|
||||||
_team_key_generation_check,
|
_team_key_generation_check,
|
||||||
)
|
)
|
||||||
from fastapi import HTTPException
|
from fastapi import HTTPException
|
||||||
|
from litellm.proxy._types import LiteLLM_TeamTableCachedObj
|
||||||
|
|
||||||
litellm.key_generation_settings = {
|
litellm.key_generation_settings = {
|
||||||
"team_key_generation": {"allowed_team_member_roles": ["admin"]}
|
"team_key_generation": {"allowed_team_member_roles": ["admin"]}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
team_table = LiteLLM_TeamTableCachedObj(
|
||||||
|
team_id="test_team_id",
|
||||||
|
team_alias="test_team_alias",
|
||||||
|
members_with_roles=[Member(role="admin", user_id="test_user_id")],
|
||||||
|
)
|
||||||
|
|
||||||
assert _team_key_generation_check(
|
assert _team_key_generation_check(
|
||||||
|
team_table=team_table,
|
||||||
user_api_key_dict=UserAPIKeyAuth(
|
user_api_key_dict=UserAPIKeyAuth(
|
||||||
|
user_id="test_user_id",
|
||||||
user_role=LitellmUserRoles.INTERNAL_USER,
|
user_role=LitellmUserRoles.INTERNAL_USER,
|
||||||
api_key="sk-1234",
|
api_key="sk-1234",
|
||||||
team_member=Member(role="admin", user_id="test_user_id"),
|
team_member=Member(role="admin", user_id="test_user_id"),
|
||||||
|
@ -570,8 +579,15 @@ def test_team_key_generation_team_member_check():
|
||||||
data=GenerateKeyRequest(),
|
data=GenerateKeyRequest(),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
team_table = LiteLLM_TeamTableCachedObj(
|
||||||
|
team_id="test_team_id",
|
||||||
|
team_alias="test_team_alias",
|
||||||
|
members_with_roles=[Member(role="user", user_id="test_user_id")],
|
||||||
|
)
|
||||||
|
|
||||||
with pytest.raises(HTTPException):
|
with pytest.raises(HTTPException):
|
||||||
_team_key_generation_check(
|
_team_key_generation_check(
|
||||||
|
team_table=team_table,
|
||||||
user_api_key_dict=UserAPIKeyAuth(
|
user_api_key_dict=UserAPIKeyAuth(
|
||||||
user_role=LitellmUserRoles.INTERNAL_USER,
|
user_role=LitellmUserRoles.INTERNAL_USER,
|
||||||
api_key="sk-1234",
|
api_key="sk-1234",
|
||||||
|
@ -607,6 +623,7 @@ def test_key_generation_required_params_check(
|
||||||
StandardKeyGenerationConfig,
|
StandardKeyGenerationConfig,
|
||||||
PersonalUIKeyGenerationConfig,
|
PersonalUIKeyGenerationConfig,
|
||||||
)
|
)
|
||||||
|
from litellm.proxy._types import LiteLLM_TeamTableCachedObj
|
||||||
from fastapi import HTTPException
|
from fastapi import HTTPException
|
||||||
|
|
||||||
user_api_key_dict = UserAPIKeyAuth(
|
user_api_key_dict = UserAPIKeyAuth(
|
||||||
|
@ -614,7 +631,13 @@ def test_key_generation_required_params_check(
|
||||||
api_key="sk-1234",
|
api_key="sk-1234",
|
||||||
user_id="test_user_id",
|
user_id="test_user_id",
|
||||||
team_id="test_team_id",
|
team_id="test_team_id",
|
||||||
team_member=Member(role="admin", user_id="test_user_id"),
|
team_member=None,
|
||||||
|
)
|
||||||
|
|
||||||
|
team_table = LiteLLM_TeamTableCachedObj(
|
||||||
|
team_id="test_team_id",
|
||||||
|
team_alias="test_team_alias",
|
||||||
|
members_with_roles=[Member(role="admin", user_id="test_user_id")],
|
||||||
)
|
)
|
||||||
|
|
||||||
if key_type == "team_key":
|
if key_type == "team_key":
|
||||||
|
@ -632,13 +655,13 @@ def test_key_generation_required_params_check(
|
||||||
|
|
||||||
if expected_result:
|
if expected_result:
|
||||||
if key_type == "team_key":
|
if key_type == "team_key":
|
||||||
assert _team_key_generation_check(user_api_key_dict, input_data)
|
assert _team_key_generation_check(team_table, user_api_key_dict, input_data)
|
||||||
elif key_type == "personal_key":
|
elif key_type == "personal_key":
|
||||||
assert _personal_key_generation_check(user_api_key_dict, input_data)
|
assert _personal_key_generation_check(user_api_key_dict, input_data)
|
||||||
else:
|
else:
|
||||||
if key_type == "team_key":
|
if key_type == "team_key":
|
||||||
with pytest.raises(HTTPException):
|
with pytest.raises(HTTPException):
|
||||||
_team_key_generation_check(user_api_key_dict, input_data)
|
_team_key_generation_check(team_table, user_api_key_dict, input_data)
|
||||||
elif key_type == "personal_key":
|
elif key_type == "personal_key":
|
||||||
with pytest.raises(HTTPException):
|
with pytest.raises(HTTPException):
|
||||||
_personal_key_generation_check(user_api_key_dict, input_data)
|
_personal_key_generation_check(user_api_key_dict, input_data)
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue