diff --git a/litellm/proxy/_types.py b/litellm/proxy/_types.py index b1af153e8..a835f87e2 100644 --- a/litellm/proxy/_types.py +++ b/litellm/proxy/_types.py @@ -228,7 +228,7 @@ class LiteLLM_JWTAuth(LiteLLMBase): "info_routes", ] team_jwt_scope: str = "litellm_team" - team_id_jwt_field: str = "client_id" + team_id_jwt_field: Optional[str] = None team_allowed_routes: List[ Literal["openai_routes", "info_routes", "management_routes"] ] = ["openai_routes", "info_routes"] diff --git a/litellm/proxy/auth/auth_checks.py b/litellm/proxy/auth/auth_checks.py index 62e5eba01..b5eb0c4b3 100644 --- a/litellm/proxy/auth/auth_checks.py +++ b/litellm/proxy/auth/auth_checks.py @@ -26,7 +26,7 @@ all_routes = LiteLLMRoutes.openai_routes.value + LiteLLMRoutes.management_routes def common_checks( request_body: dict, - team_object: LiteLLM_TeamTable, + team_object: Optional[LiteLLM_TeamTable], user_object: Optional[LiteLLM_UserTable], end_user_object: Optional[LiteLLM_EndUserTable], global_proxy_spend: Optional[float], @@ -45,13 +45,14 @@ def common_checks( 6. [OPTIONAL] If 'litellm.max_budget' is set (>0), is proxy under budget """ _model = request_body.get("model", None) - if team_object.blocked == True: + if team_object is not None and team_object.blocked == True: raise Exception( f"Team={team_object.team_id} is blocked. Update via `/team/unblock` if your admin." ) # 2. If user can call model if ( _model is not None + and team_object is not None and len(team_object.models) > 0 and _model not in team_object.models ): @@ -65,7 +66,8 @@ def common_checks( ) # 3. If team is in budget if ( - team_object.max_budget is not None + team_object is not None + and team_object.max_budget is not None and team_object.spend is not None and team_object.spend > team_object.max_budget ): @@ -305,6 +307,9 @@ async def get_team_object( if response is None: raise Exception + # save the team object to cache + await user_api_key_cache.async_set_cache(key=response.team_id, value=response) + return LiteLLM_TeamTable(**response.dict()) except Exception as e: raise Exception( diff --git a/litellm/proxy/auth/handle_jwt.py b/litellm/proxy/auth/handle_jwt.py index 18c0d7b2c..6dbe214de 100644 --- a/litellm/proxy/auth/handle_jwt.py +++ b/litellm/proxy/auth/handle_jwt.py @@ -60,7 +60,9 @@ class JWTHandler: return True return False - def get_end_user_id(self, token: dict, default_value: Optional[str]) -> str: + def get_end_user_id( + self, token: dict, default_value: Optional[str] + ) -> Optional[str]: try: if self.litellm_jwtauth.end_user_id_jwt_field is not None: user_id = token[self.litellm_jwtauth.end_user_id_jwt_field] @@ -70,6 +72,16 @@ class JWTHandler: user_id = default_value return user_id + def is_required_team_id(self) -> bool: + """ + Returns: + - True: if 'team_id_jwt_field' is set + - False: if not + """ + if self.litellm_jwtauth.team_id_jwt_field is None: + return False + return True + def get_team_id(self, token: dict, default_value: Optional[str]) -> Optional[str]: try: team_id = token[self.litellm_jwtauth.team_id_jwt_field] @@ -165,7 +177,7 @@ class JWTHandler: decode_options = None if audience is None: decode_options = {"verify_aud": False} - + from jwt.algorithms import RSAAlgorithm header = jwt.get_unverified_header(token) @@ -207,12 +219,14 @@ class JWTHandler: raise Exception(f"Validation fails: {str(e)}") elif public_key is not None and isinstance(public_key, str): try: - cert = x509.load_pem_x509_certificate(public_key.encode(), default_backend()) + cert = x509.load_pem_x509_certificate( + public_key.encode(), default_backend() + ) # Extract public key key = cert.public_key().public_bytes( serialization.Encoding.PEM, - serialization.PublicFormat.SubjectPublicKeyInfo + serialization.PublicFormat.SubjectPublicKeyInfo, ) # decode the token using the public key @@ -221,7 +235,7 @@ class JWTHandler: key, algorithms=algorithms, audience=audience, - options=decode_options + options=decode_options, ) return payload diff --git a/litellm/proxy/proxy_server.py b/litellm/proxy/proxy_server.py index e9324dd96..e66f2d6db 100644 --- a/litellm/proxy/proxy_server.py +++ b/litellm/proxy/proxy_server.py @@ -440,29 +440,32 @@ async def user_api_key_auth( # get team id team_id = jwt_handler.get_team_id(token=valid_token, default_value=None) - if team_id is None: + if team_id is None and jwt_handler.is_required_team_id() == True: raise Exception( f"No team id passed in. Field checked in jwt token - '{jwt_handler.litellm_jwtauth.team_id_jwt_field}'" ) - # check allowed team routes - is_allowed = allowed_routes_check( - user_role="team", - user_route=route, - litellm_proxy_roles=jwt_handler.litellm_jwtauth, - ) - if is_allowed == False: - allowed_routes = jwt_handler.litellm_jwtauth.team_allowed_routes # type: ignore - actual_routes = get_actual_routes(allowed_routes=allowed_routes) - raise Exception( - f"Team not allowed to access this route. Route={route}, Allowed Routes={actual_routes}" - ) - # check if team in db - team_object = await get_team_object( - team_id=team_id, - prisma_client=prisma_client, - user_api_key_cache=user_api_key_cache, - ) + team_object: Optional[LiteLLM_TeamTable] = None + if team_id is not None: + # check allowed team routes + is_allowed = allowed_routes_check( + user_role="team", + user_route=route, + litellm_proxy_roles=jwt_handler.litellm_jwtauth, + ) + if is_allowed == False: + allowed_routes = jwt_handler.litellm_jwtauth.team_allowed_routes # type: ignore + actual_routes = get_actual_routes(allowed_routes=allowed_routes) + raise Exception( + f"Team not allowed to access this route. Route={route}, Allowed Routes={actual_routes}" + ) + + # check if team in db + team_object = await get_team_object( + team_id=team_id, + prisma_client=prisma_client, + user_api_key_cache=user_api_key_cache, + ) # [OPTIONAL] track spend for an org id - `LiteLLM_OrganizationTable` org_id = jwt_handler.get_org_id(token=valid_token, default_value=None) @@ -547,18 +550,18 @@ async def user_api_key_auth( global_proxy_spend=global_proxy_spend, route=route, ) - # save team object in cache - await user_api_key_cache.async_set_cache( - key=team_object.team_id, value=team_object - ) # return UserAPIKeyAuth object return UserAPIKeyAuth( api_key=None, - team_id=team_object.team_id, - team_tpm_limit=team_object.tpm_limit, - team_rpm_limit=team_object.rpm_limit, - team_models=team_object.models, + team_id=team_object.team_id if team_object is not None else None, + team_tpm_limit=( + team_object.tpm_limit if team_object is not None else None + ), + team_rpm_limit=( + team_object.rpm_limit if team_object is not None else None + ), + team_models=team_object.models if team_object is not None else [], user_role="app_owner", user_id=user_id, org_id=org_id, diff --git a/litellm/tests/test_jwt.py b/litellm/tests/test_jwt.py index b3af9913f..6a0e5c427 100644 --- a/litellm/tests/test_jwt.py +++ b/litellm/tests/test_jwt.py @@ -24,6 +24,7 @@ public_key = { "alg": "RS256", } + def test_load_config_with_custom_role_names(): config = { "general_settings": { @@ -77,7 +78,8 @@ async def test_token_single_public_key(): == "qIgOQfEVrrErJC0E7gsHXi6rs_V0nyFY5qPFui2-tv0o4CwpwDzgfBtLO7o_wLiguq0lnu54sMT2eLNoRiiPuLvv6bg7Iy1H9yc5_4Jf5oYEOrqN5o9ZBOoYp1q68Pv0oNJYyZdGu5ZJfd7V4y953vB2XfEKgXCsAkhVhlvIUMiDNKWoMDWsyb2xela5tRURZ2mJAXcHfSC_sYdZxIA2YYrIHfoevq_vTlaz0qVSe_uOKjEpgOAS08UUrgda4CQL11nzICiIQzc6qmjIQt2cjzB2D_9zb4BYndzEtfl0kwAT0z_I85S3mkwTqHU-1BvKe_4MG4VG3dAAeffLPXJyXQ" ) -@pytest.mark.parametrize('audience', [None, "litellm-proxy"]) + +@pytest.mark.parametrize("audience", [None, "litellm-proxy"]) @pytest.mark.asyncio async def test_valid_invalid_token(audience): """ @@ -90,7 +92,7 @@ async def test_valid_invalid_token(audience): from cryptography.hazmat.primitives.asymmetric import rsa from cryptography.hazmat.backends import default_backend - os.environ.pop('JWT_AUDIENCE', None) + os.environ.pop("JWT_AUDIENCE", None) if audience: os.environ["JWT_AUDIENCE"] = audience @@ -138,7 +140,7 @@ async def test_valid_invalid_token(audience): "sub": "user123", "exp": expiration_time, # set the token to expire in 10 minutes "scope": "litellm-proxy-admin", - "aud": audience + "aud": audience, } # Generate the JWT token @@ -166,7 +168,7 @@ async def test_valid_invalid_token(audience): "sub": "user123", "exp": expiration_time, # set the token to expire in 10 minutes "scope": "litellm-NO-SCOPE", - "aud": audience + "aud": audience, } # Generate the JWT token @@ -183,6 +185,7 @@ async def test_valid_invalid_token(audience): except Exception as e: pytest.fail(f"An exception occurred - {str(e)}") + @pytest.fixture def prisma_client(): import litellm @@ -205,7 +208,7 @@ def prisma_client(): return prisma_client -@pytest.mark.parametrize('audience', [None, "litellm-proxy"]) +@pytest.mark.parametrize("audience", [None, "litellm-proxy"]) @pytest.mark.asyncio async def test_team_token_output(prisma_client, audience): import jwt, json @@ -222,7 +225,7 @@ async def test_team_token_output(prisma_client, audience): setattr(litellm.proxy.proxy_server, "prisma_client", prisma_client) await litellm.proxy.proxy_server.prisma_client.connect() - os.environ.pop('JWT_AUDIENCE', None) + os.environ.pop("JWT_AUDIENCE", None) if audience: os.environ["JWT_AUDIENCE"] = audience @@ -261,7 +264,7 @@ async def test_team_token_output(prisma_client, audience): jwt_handler.user_api_key_cache = cache - jwt_handler.litellm_jwtauth = LiteLLM_JWTAuth() + jwt_handler.litellm_jwtauth = LiteLLM_JWTAuth(team_id_jwt_field="client_id") # VALID TOKEN ## GENERATE A TOKEN @@ -274,7 +277,7 @@ async def test_team_token_output(prisma_client, audience): "exp": expiration_time, # set the token to expire in 10 minutes "scope": "litellm_team", "client_id": team_id, - "aud": audience + "aud": audience, } # Generate the JWT token @@ -289,7 +292,7 @@ async def test_team_token_output(prisma_client, audience): "sub": "user123", "exp": expiration_time, # set the token to expire in 10 minutes "scope": "litellm_proxy_admin", - "aud": audience + "aud": audience, } admin_token = jwt.encode(payload, private_key_str, algorithm="RS256") @@ -315,7 +318,13 @@ async def test_team_token_output(prisma_client, audience): ## 1. INITIAL TEAM CALL - should fail # use generated key to auth in - setattr(litellm.proxy.proxy_server, "general_settings", {"enable_jwt_auth": True}) + setattr( + litellm.proxy.proxy_server, + "general_settings", + { + "enable_jwt_auth": True, + }, + ) setattr(litellm.proxy.proxy_server, "jwt_handler", jwt_handler) try: result = await user_api_key_auth(request=request, api_key=bearer_token) @@ -358,9 +367,10 @@ async def test_team_token_output(prisma_client, audience): assert team_result.team_models == ["gpt-3.5-turbo", "gpt-4"] -@pytest.mark.parametrize('audience', [None, "litellm-proxy"]) +@pytest.mark.parametrize("audience", [None, "litellm-proxy"]) +@pytest.mark.parametrize("team_id_set", [True, False]) @pytest.mark.asyncio -async def test_user_token_output(prisma_client, audience): +async def test_user_token_output(prisma_client, audience, team_id_set): """ - If user required, check if it exists - fail initial request (when user doesn't exist) @@ -381,7 +391,7 @@ async def test_user_token_output(prisma_client, audience): setattr(litellm.proxy.proxy_server, "prisma_client", prisma_client) await litellm.proxy.proxy_server.prisma_client.connect() - os.environ.pop('JWT_AUDIENCE', None) + os.environ.pop("JWT_AUDIENCE", None) if audience: os.environ["JWT_AUDIENCE"] = audience @@ -423,6 +433,8 @@ async def test_user_token_output(prisma_client, audience): jwt_handler.litellm_jwtauth = LiteLLM_JWTAuth() jwt_handler.litellm_jwtauth.user_id_jwt_field = "sub" + if team_id_set: + jwt_handler.litellm_jwtauth.team_id_jwt_field = "client_id" # VALID TOKEN ## GENERATE A TOKEN @@ -436,7 +448,7 @@ async def test_user_token_output(prisma_client, audience): "exp": expiration_time, # set the token to expire in 10 minutes "scope": "litellm_team", "client_id": team_id, - "aud": audience + "aud": audience, } # Generate the JWT token @@ -451,7 +463,7 @@ async def test_user_token_output(prisma_client, audience): "sub": user_id, "exp": expiration_time, # set the token to expire in 10 minutes "scope": "litellm_proxy_admin", - "aud": audience + "aud": audience, } admin_token = jwt.encode(payload, private_key_str, algorithm="RS256") @@ -543,7 +555,8 @@ async def test_user_token_output(prisma_client, audience): ## 6. ASSERT USER_API_KEY_AUTH format (used for tpm/rpm limiting in parallel_request_limiter.py AND cost tracking) - assert team_result.team_tpm_limit == 100 - assert team_result.team_rpm_limit == 99 - assert team_result.team_models == ["gpt-3.5-turbo", "gpt-4"] + if team_id_set: + assert team_result.team_tpm_limit == 100 + assert team_result.team_rpm_limit == 99 + assert team_result.team_models == ["gpt-3.5-turbo", "gpt-4"] assert team_result.user_id == user_id