From ed4315af38d349d201ce47b1ca4481e25637393c Mon Sep 17 00:00:00 2001 From: Krrish Dholakia Date: Wed, 15 May 2024 21:33:35 -0700 Subject: [PATCH] feat(handle_jwt.py): add support for 'team_id_default allows admin to set a default team id for spend-tracking + permissions --- litellm/proxy/_super_secret_config.yaml | 3 +++ litellm/proxy/_types.py | 8 +++++++- litellm/proxy/auth/handle_jwt.py | 12 ++++++------ litellm/tests/test_jwt.py | 23 ++++++++++++++++++++--- 4 files changed, 36 insertions(+), 10 deletions(-) diff --git a/litellm/proxy/_super_secret_config.yaml b/litellm/proxy/_super_secret_config.yaml index 98191aa043..eae0cbf4ad 100644 --- a/litellm/proxy/_super_secret_config.yaml +++ b/litellm/proxy/_super_secret_config.yaml @@ -36,6 +36,9 @@ litellm_settings: general_settings: enable_jwt_auth: True + litellm_jwtauth: + team_id_default: "1234" + upsert_users: True disable_reset_budget: True proxy_batch_write_at: 60 # 👈 Frequency of batch writing logs to server (in seconds) routing_strategy: simple-shuffle # Literal["simple-shuffle", "least-busy", "usage-based-routing","latency-based-routing"], default="simple-shuffle" diff --git a/litellm/proxy/_types.py b/litellm/proxy/_types.py index a835f87e27..6af8145d1d 100644 --- a/litellm/proxy/_types.py +++ b/litellm/proxy/_types.py @@ -227,13 +227,19 @@ class LiteLLM_JWTAuth(LiteLLMBase): "global_spend_tracking_routes", "info_routes", ] - team_jwt_scope: str = "litellm_team" team_id_jwt_field: Optional[str] = None team_allowed_routes: List[ Literal["openai_routes", "info_routes", "management_routes"] ] = ["openai_routes", "info_routes"] + team_id_default: Optional[str] = Field( + default=None, + description="If no team_id given, default permissions/spend-tracking to this team.s", + ) org_id_jwt_field: Optional[str] = None user_id_jwt_field: Optional[str] = None + user_id_upsert: bool = Field( + default=False, description="If user doesn't exist, upsert them into the db." + ) end_user_id_jwt_field: Optional[str] = None public_key_ttl: float = 600 diff --git a/litellm/proxy/auth/handle_jwt.py b/litellm/proxy/auth/handle_jwt.py index 6dbe214dee..de357030d2 100644 --- a/litellm/proxy/auth/handle_jwt.py +++ b/litellm/proxy/auth/handle_jwt.py @@ -55,11 +55,6 @@ class JWTHandler: return True return False - def is_team(self, scopes: list) -> bool: - if self.litellm_jwtauth.team_jwt_scope in scopes: - return True - return False - def get_end_user_id( self, token: dict, default_value: Optional[str] ) -> Optional[str]: @@ -84,7 +79,12 @@ class JWTHandler: def get_team_id(self, token: dict, default_value: Optional[str]) -> Optional[str]: try: - team_id = token[self.litellm_jwtauth.team_id_jwt_field] + if self.litellm_jwtauth.team_id_jwt_field is not None: + team_id = token[self.litellm_jwtauth.team_id_jwt_field] + elif self.litellm_jwtauth.team_id_default is not None: + team_id = self.litellm_jwtauth.team_id_default + else: + team_id = None except KeyError: team_id = default_value return team_id diff --git a/litellm/tests/test_jwt.py b/litellm/tests/test_jwt.py index 6a0e5c427f..dd89f18e9d 100644 --- a/litellm/tests/test_jwt.py +++ b/litellm/tests/test_jwt.py @@ -368,9 +368,14 @@ async def test_team_token_output(prisma_client, audience): @pytest.mark.parametrize("audience", [None, "litellm-proxy"]) -@pytest.mark.parametrize("team_id_set", [True, False]) +@pytest.mark.parametrize( + "team_id_set, default_team_id", [(True, None), (False, "1234")] +) +@pytest.mark.parametrize("user_id_upsert", [True, False]) @pytest.mark.asyncio -async def test_user_token_output(prisma_client, audience, team_id_set): +async def test_user_token_output( + prisma_client, audience, team_id_set, default_team_id, user_id_upsert +): """ - If user required, check if it exists - fail initial request (when user doesn't exist) @@ -433,6 +438,8 @@ async def test_user_token_output(prisma_client, audience, team_id_set): jwt_handler.litellm_jwtauth = LiteLLM_JWTAuth() jwt_handler.litellm_jwtauth.user_id_jwt_field = "sub" + jwt_handler.litellm_jwtauth.team_id_default = default_team_id + if team_id_set: jwt_handler.litellm_jwtauth.team_id_jwt_field = "client_id" @@ -515,6 +522,16 @@ async def test_user_token_output(prisma_client, audience, team_id_set): ), user_api_key_dict=result, ) + if default_team_id is not None: + await new_team( + data=NewTeamRequest( + team_id=default_team_id, + tpm_limit=100, + rpm_limit=99, + models=["gpt-3.5-turbo", "gpt-4"], + ), + user_api_key_dict=result, + ) except Exception as e: pytest.fail(f"This should not fail - {str(e)}") @@ -555,7 +572,7 @@ async def test_user_token_output(prisma_client, audience, team_id_set): ## 6. ASSERT USER_API_KEY_AUTH format (used for tpm/rpm limiting in parallel_request_limiter.py AND cost tracking) - if team_id_set: + if team_id_set or default_team_id is not None: assert team_result.team_tpm_limit == 100 assert team_result.team_rpm_limit == 99 assert team_result.team_models == ["gpt-3.5-turbo", "gpt-4"]