From 93cb65dfee86c25aa7a3a88236880c9cb16c01c8 Mon Sep 17 00:00:00 2001 From: Krrish Dholakia Date: Wed, 15 May 2024 22:19:59 -0700 Subject: [PATCH] feat(proxy_server.py): support 'user_id_upsert' flag for jwt_auth --- litellm/proxy/_super_secret_config.yaml | 3 +- litellm/proxy/auth/auth_checks.py | 29 +++++++++---- litellm/proxy/auth/handle_jwt.py | 8 ++++ litellm/proxy/proxy_server.py | 6 +-- litellm/tests/test_jwt.py | 57 ++++++++++++++++++------- 5 files changed, 74 insertions(+), 29 deletions(-) diff --git a/litellm/proxy/_super_secret_config.yaml b/litellm/proxy/_super_secret_config.yaml index eae0cbf4ad..3fa3d1e9a6 100644 --- a/litellm/proxy/_super_secret_config.yaml +++ b/litellm/proxy/_super_secret_config.yaml @@ -38,7 +38,8 @@ general_settings: enable_jwt_auth: True litellm_jwtauth: team_id_default: "1234" - upsert_users: True + user_id_jwt_field: + user_id_upsert: True disable_reset_budget: True proxy_batch_write_at: 60 # 👈 Frequency of batch writing logs to server (in seconds) routing_strategy: simple-shuffle # Literal["simple-shuffle", "least-busy", "usage-based-routing","latency-based-routing"], default="simple-shuffle" diff --git a/litellm/proxy/auth/auth_checks.py b/litellm/proxy/auth/auth_checks.py index b5eb0c4b39..08da25556d 100644 --- a/litellm/proxy/auth/auth_checks.py +++ b/litellm/proxy/auth/auth_checks.py @@ -241,6 +241,7 @@ async def get_user_object( user_id: str, prisma_client: Optional[PrismaClient], user_api_key_cache: DualCache, + user_id_upsert: bool, ) -> Optional[LiteLLM_UserTable]: """ - Check if user id in proxy User Table @@ -254,7 +255,7 @@ async def get_user_object( return None # check if in cache - cached_user_obj = user_api_key_cache.async_get_cache(key=user_id) + cached_user_obj = await user_api_key_cache.async_get_cache(key=user_id) if cached_user_obj is not None: if isinstance(cached_user_obj, dict): return LiteLLM_UserTable(**cached_user_obj) @@ -262,16 +263,27 @@ async def get_user_object( return cached_user_obj # else, check db try: + response = await prisma_client.db.litellm_usertable.find_unique( where={"user_id": user_id} ) if response is None: - raise Exception + if user_id_upsert: + response = await prisma_client.db.litellm_usertable.create( + data={"user_id": user_id} + ) + else: + raise Exception - return LiteLLM_UserTable(**response.dict()) - except Exception as e: # if end-user not in db - raise Exception( + _response = LiteLLM_UserTable(**dict(response)) + + # save the user object to cache + await user_api_key_cache.async_set_cache(key=user_id, value=_response) + + return _response + except Exception as e: # if user not in db + raise ValueError( f"User doesn't exist in db. 'user_id'={user_id}. Create user via `/user/new` call." ) @@ -292,7 +304,7 @@ async def get_team_object( ) # check if in cache - cached_team_obj = user_api_key_cache.async_get_cache(key=team_id) + cached_team_obj = await user_api_key_cache.async_get_cache(key=team_id) if cached_team_obj is not None: if isinstance(cached_team_obj, dict): return LiteLLM_TeamTable(**cached_team_obj) @@ -307,10 +319,11 @@ async def get_team_object( if response is None: raise Exception + _response = LiteLLM_TeamTable(**response.dict()) # save the team object to cache - await user_api_key_cache.async_set_cache(key=response.team_id, value=response) + await user_api_key_cache.async_set_cache(key=response.team_id, value=_response) - return LiteLLM_TeamTable(**response.dict()) + return _response except Exception as e: raise Exception( f"Team doesn't exist in db. Team={team_id}. Create team via `/team/new` call." diff --git a/litellm/proxy/auth/handle_jwt.py b/litellm/proxy/auth/handle_jwt.py index de357030d2..0a186d7dde 100644 --- a/litellm/proxy/auth/handle_jwt.py +++ b/litellm/proxy/auth/handle_jwt.py @@ -89,6 +89,14 @@ class JWTHandler: team_id = default_value return team_id + def is_upsert_user_id(self) -> bool: + """ + Returns: + - True: if 'user_id_upsert' is set + - False: if not + """ + return self.litellm_jwtauth.user_id_upsert + def get_user_id(self, token: dict, default_value: Optional[str]) -> Optional[str]: try: if self.litellm_jwtauth.user_id_jwt_field is not None: diff --git a/litellm/proxy/proxy_server.py b/litellm/proxy/proxy_server.py index e66f2d6dbc..aa63d92e51 100644 --- a/litellm/proxy/proxy_server.py +++ b/litellm/proxy/proxy_server.py @@ -484,11 +484,9 @@ async def user_api_key_auth( user_id=user_id, prisma_client=prisma_client, user_api_key_cache=user_api_key_cache, + user_id_upsert=jwt_handler.is_upsert_user_id(), ) - # save the user object to cache - await user_api_key_cache.async_set_cache( - key=user_id, value=user_object - ) + # [OPTIONAL] track spend against an external user - `LiteLLM_EndUserTable` end_user_object = None end_user_id = jwt_handler.get_end_user_id( diff --git a/litellm/tests/test_jwt.py b/litellm/tests/test_jwt.py index dd89f18e9d..45f4616290 100644 --- a/litellm/tests/test_jwt.py +++ b/litellm/tests/test_jwt.py @@ -1,7 +1,7 @@ #### What this tests #### # Unit tests for JWT-Auth -import sys, os, asyncio, time, random +import sys, os, asyncio, time, random, uuid import traceback from dotenv import load_dotenv @@ -369,13 +369,20 @@ async def test_team_token_output(prisma_client, audience): @pytest.mark.parametrize("audience", [None, "litellm-proxy"]) @pytest.mark.parametrize( - "team_id_set, default_team_id", [(True, None), (False, "1234")] + "team_id_set, default_team_id", + [(True, False), (False, True)], ) @pytest.mark.parametrize("user_id_upsert", [True, False]) @pytest.mark.asyncio async def test_user_token_output( prisma_client, audience, team_id_set, default_team_id, user_id_upsert ): + import uuid + + args = locals() + print(f"received args - {args}") + if default_team_id: + default_team_id = "team_id_12344_{}".format(uuid.uuid4()) """ - If user required, check if it exists - fail initial request (when user doesn't exist) @@ -388,7 +395,12 @@ async def test_user_token_output( from cryptography.hazmat.backends import default_backend from fastapi import Request from starlette.datastructures import URL - from litellm.proxy.proxy_server import user_api_key_auth, new_team, new_user + from litellm.proxy.proxy_server import ( + user_api_key_auth, + new_team, + new_user, + user_info, + ) from litellm.proxy._types import NewTeamRequest, UserAPIKeyAuth, NewUserRequest import litellm import uuid @@ -439,6 +451,7 @@ async def test_user_token_output( jwt_handler.litellm_jwtauth.user_id_jwt_field = "sub" jwt_handler.litellm_jwtauth.team_id_default = default_team_id + jwt_handler.litellm_jwtauth.user_id_upsert = user_id_upsert if team_id_set: jwt_handler.litellm_jwtauth.team_id_jwt_field = "client_id" @@ -522,7 +535,7 @@ async def test_user_token_output( ), user_api_key_dict=result, ) - if default_team_id is not None: + if default_team_id: await new_team( data=NewTeamRequest( team_id=default_team_id, @@ -542,23 +555,35 @@ async def test_user_token_output( team_result: UserAPIKeyAuth = await user_api_key_auth( request=request, api_key=bearer_token ) - pytest.fail(f"User doesn't exist. this should fail") + if user_id_upsert == False: + pytest.fail(f"User doesn't exist. this should fail") except Exception as e: pass ## 4. Create user - try: - bearer_token = "Bearer " + admin_token + if user_id_upsert: + ## check if user already exists + try: + bearer_token = "Bearer " + admin_token - request._url = URL(url="/team/new") - result = await user_api_key_auth(request=request, api_key=bearer_token) - await new_user( - data=NewUserRequest( - user_id=user_id, - ), - ) - except Exception as e: - pytest.fail(f"This should not fail - {str(e)}") + request._url = URL(url="/team/new") + result = await user_api_key_auth(request=request, api_key=bearer_token) + await user_info(user_id=user_id) + except Exception as e: + pytest.fail(f"This should not fail - {str(e)}") + else: + try: + bearer_token = "Bearer " + admin_token + + request._url = URL(url="/team/new") + result = await user_api_key_auth(request=request, api_key=bearer_token) + await new_user( + data=NewUserRequest( + user_id=user_id, + ), + ) + except Exception as e: + pytest.fail(f"This should not fail - {str(e)}") ## 5. 3rd call w/ same team, same user -> call should succeed bearer_token = "Bearer " + token