diff --git a/litellm/proxy/_new_secret_config.yaml b/litellm/proxy/_new_secret_config.yaml index 7ff209094..fbb714211 100644 --- a/litellm/proxy/_new_secret_config.yaml +++ b/litellm/proxy/_new_secret_config.yaml @@ -12,3 +12,10 @@ model_list: vertex_ai_project: "adroit-crow-413218" vertex_ai_location: "us-east5" +litellm_settings: + key_generation_settings: + team_key_generation: + allowed_team_member_roles: ["admin"] + required_params: ["tags"] # require team admins to set tags for cost-tracking when generating a team key + personal_key_generation: # maps to 'Default Team' on UI + allowed_user_roles: ["proxy_admin"] \ No newline at end of file diff --git a/litellm/proxy/management_endpoints/key_management_endpoints.py b/litellm/proxy/management_endpoints/key_management_endpoints.py index 511e5a940..f7a383183 100644 --- a/litellm/proxy/management_endpoints/key_management_endpoints.py +++ b/litellm/proxy/management_endpoints/key_management_endpoints.py @@ -29,6 +29,7 @@ from litellm.proxy.auth.auth_checks import ( _cache_key_object, _delete_cache_key_object, get_key_object, + get_team_object, ) from litellm.proxy.auth.user_api_key_auth import user_api_key_auth from litellm.proxy.hooks.key_management_event_hooks import KeyManagementEventHooks @@ -46,7 +47,19 @@ def _is_team_key(data: GenerateKeyRequest): return data.team_id is not None +def _get_user_in_team( + team_table: LiteLLM_TeamTableCachedObj, user_id: Optional[str] +) -> Optional[Member]: + if user_id is None: + return None + for member in team_table.members_with_roles: + if member.user_id is not None and member.user_id == user_id: + return member + return None + + def _team_key_generation_team_member_check( + team_table: LiteLLM_TeamTableCachedObj, user_api_key_dict: UserAPIKeyAuth, team_key_generation: Optional[TeamUIKeyGenerationConfig], ): @@ -56,17 +69,19 @@ def _team_key_generation_team_member_check( ): return True - if user_api_key_dict.team_member is None: + user_in_team = _get_user_in_team( + team_table=team_table, user_id=user_api_key_dict.user_id + ) + if user_in_team is None: raise HTTPException( status_code=400, - detail=f"User not assigned to team. Got team_member={user_api_key_dict.team_member}", + detail=f"User={user_api_key_dict.user_id} not assigned to team={team_table.team_id}", ) - team_member_role = user_api_key_dict.team_member.role - if team_member_role not in team_key_generation["allowed_team_member_roles"]: + if user_in_team.role not in team_key_generation["allowed_team_member_roles"]: raise HTTPException( status_code=400, - detail=f"Team member role {team_member_role} not in allowed_team_member_roles={litellm.key_generation_settings['team_key_generation']['allowed_team_member_roles']}", # type: ignore + detail=f"Team member role {user_in_team.role} not in allowed_team_member_roles={team_key_generation['allowed_team_member_roles']}", ) return True @@ -88,7 +103,9 @@ def _key_generation_required_param_check( def _team_key_generation_check( - user_api_key_dict: UserAPIKeyAuth, data: GenerateKeyRequest + team_table: LiteLLM_TeamTableCachedObj, + user_api_key_dict: UserAPIKeyAuth, + data: GenerateKeyRequest, ): if ( litellm.key_generation_settings is None @@ -99,7 +116,8 @@ def _team_key_generation_check( _team_key_generation = litellm.key_generation_settings["team_key_generation"] # type: ignore _team_key_generation_team_member_check( - user_api_key_dict, + team_table=team_table, + user_api_key_dict=user_api_key_dict, team_key_generation=_team_key_generation, ) _key_generation_required_param_check( @@ -155,7 +173,9 @@ def _personal_key_generation_check( def key_generation_check( - user_api_key_dict: UserAPIKeyAuth, data: GenerateKeyRequest + team_table: Optional[LiteLLM_TeamTableCachedObj], + user_api_key_dict: UserAPIKeyAuth, + data: GenerateKeyRequest, ) -> bool: """ Check if admin has restricted key creation to certain roles for teams or individuals @@ -170,8 +190,15 @@ def key_generation_check( is_team_key = _is_team_key(data=data) if is_team_key: + if team_table is None: + raise HTTPException( + status_code=400, + detail=f"Unable to find team object in database. Team ID: {data.team_id}", + ) return _team_key_generation_check( - user_api_key_dict=user_api_key_dict, data=data + team_table=team_table, + user_api_key_dict=user_api_key_dict, + data=data, ) else: return _personal_key_generation_check( @@ -254,6 +281,7 @@ async def generate_key_fn( # noqa: PLR0915 litellm_proxy_admin_name, prisma_client, proxy_logging_obj, + user_api_key_cache, user_custom_key_generate, ) @@ -271,7 +299,20 @@ async def generate_key_fn( # noqa: PLR0915 status_code=status.HTTP_403_FORBIDDEN, detail=message ) elif litellm.key_generation_settings is not None: - key_generation_check(user_api_key_dict=user_api_key_dict, data=data) + if data.team_id is None: + team_table: Optional[LiteLLM_TeamTableCachedObj] = None + else: + team_table = await get_team_object( + team_id=data.team_id, + prisma_client=prisma_client, + user_api_key_cache=user_api_key_cache, + parent_otel_span=user_api_key_dict.parent_otel_span, + ) + key_generation_check( + team_table=team_table, + user_api_key_dict=user_api_key_dict, + data=data, + ) # check if user set default key/generate params on config.yaml if litellm.default_key_generate_params is not None: for elem in data: diff --git a/tests/proxy_admin_ui_tests/test_key_management.py b/tests/proxy_admin_ui_tests/test_key_management.py index 0b392a268..d0b1ab294 100644 --- a/tests/proxy_admin_ui_tests/test_key_management.py +++ b/tests/proxy_admin_ui_tests/test_key_management.py @@ -556,13 +556,22 @@ def test_team_key_generation_team_member_check(): _team_key_generation_check, ) from fastapi import HTTPException + from litellm.proxy._types import LiteLLM_TeamTableCachedObj litellm.key_generation_settings = { "team_key_generation": {"allowed_team_member_roles": ["admin"]} } + team_table = LiteLLM_TeamTableCachedObj( + team_id="test_team_id", + team_alias="test_team_alias", + members_with_roles=[Member(role="admin", user_id="test_user_id")], + ) + assert _team_key_generation_check( + team_table=team_table, user_api_key_dict=UserAPIKeyAuth( + user_id="test_user_id", user_role=LitellmUserRoles.INTERNAL_USER, api_key="sk-1234", team_member=Member(role="admin", user_id="test_user_id"), @@ -570,8 +579,15 @@ def test_team_key_generation_team_member_check(): data=GenerateKeyRequest(), ) + team_table = LiteLLM_TeamTableCachedObj( + team_id="test_team_id", + team_alias="test_team_alias", + members_with_roles=[Member(role="user", user_id="test_user_id")], + ) + with pytest.raises(HTTPException): _team_key_generation_check( + team_table=team_table, user_api_key_dict=UserAPIKeyAuth( user_role=LitellmUserRoles.INTERNAL_USER, api_key="sk-1234", @@ -607,6 +623,7 @@ def test_key_generation_required_params_check( StandardKeyGenerationConfig, PersonalUIKeyGenerationConfig, ) + from litellm.proxy._types import LiteLLM_TeamTableCachedObj from fastapi import HTTPException user_api_key_dict = UserAPIKeyAuth( @@ -614,7 +631,13 @@ def test_key_generation_required_params_check( api_key="sk-1234", user_id="test_user_id", team_id="test_team_id", - team_member=Member(role="admin", user_id="test_user_id"), + team_member=None, + ) + + team_table = LiteLLM_TeamTableCachedObj( + team_id="test_team_id", + team_alias="test_team_alias", + members_with_roles=[Member(role="admin", user_id="test_user_id")], ) if key_type == "team_key": @@ -632,13 +655,13 @@ def test_key_generation_required_params_check( if expected_result: if key_type == "team_key": - assert _team_key_generation_check(user_api_key_dict, input_data) + assert _team_key_generation_check(team_table, user_api_key_dict, input_data) elif key_type == "personal_key": assert _personal_key_generation_check(user_api_key_dict, input_data) else: if key_type == "team_key": with pytest.raises(HTTPException): - _team_key_generation_check(user_api_key_dict, input_data) + _team_key_generation_check(team_table, user_api_key_dict, input_data) elif key_type == "personal_key": with pytest.raises(HTTPException): _personal_key_generation_check(user_api_key_dict, input_data)