feat(proxy_server.py): support 'user_id_upsert' flag for jwt_auth

This commit is contained in:
Krrish Dholakia 2024-05-15 22:19:59 -07:00
parent ed4315af38
commit 93cb65dfee
5 changed files with 74 additions and 29 deletions

View file

@ -38,7 +38,8 @@ general_settings:
enable_jwt_auth: True enable_jwt_auth: True
litellm_jwtauth: litellm_jwtauth:
team_id_default: "1234" team_id_default: "1234"
upsert_users: True user_id_jwt_field:
user_id_upsert: True
disable_reset_budget: True disable_reset_budget: True
proxy_batch_write_at: 60 # 👈 Frequency of batch writing logs to server (in seconds) proxy_batch_write_at: 60 # 👈 Frequency of batch writing logs to server (in seconds)
routing_strategy: simple-shuffle # Literal["simple-shuffle", "least-busy", "usage-based-routing","latency-based-routing"], default="simple-shuffle" routing_strategy: simple-shuffle # Literal["simple-shuffle", "least-busy", "usage-based-routing","latency-based-routing"], default="simple-shuffle"

View file

@ -241,6 +241,7 @@ async def get_user_object(
user_id: str, user_id: str,
prisma_client: Optional[PrismaClient], prisma_client: Optional[PrismaClient],
user_api_key_cache: DualCache, user_api_key_cache: DualCache,
user_id_upsert: bool,
) -> Optional[LiteLLM_UserTable]: ) -> Optional[LiteLLM_UserTable]:
""" """
- Check if user id in proxy User Table - Check if user id in proxy User Table
@ -254,7 +255,7 @@ async def get_user_object(
return None return None
# check if in cache # check if in cache
cached_user_obj = user_api_key_cache.async_get_cache(key=user_id) cached_user_obj = await user_api_key_cache.async_get_cache(key=user_id)
if cached_user_obj is not None: if cached_user_obj is not None:
if isinstance(cached_user_obj, dict): if isinstance(cached_user_obj, dict):
return LiteLLM_UserTable(**cached_user_obj) return LiteLLM_UserTable(**cached_user_obj)
@ -262,16 +263,27 @@ async def get_user_object(
return cached_user_obj return cached_user_obj
# else, check db # else, check db
try: try:
response = await prisma_client.db.litellm_usertable.find_unique( response = await prisma_client.db.litellm_usertable.find_unique(
where={"user_id": user_id} where={"user_id": user_id}
) )
if response is None: if response is None:
raise Exception if user_id_upsert:
response = await prisma_client.db.litellm_usertable.create(
data={"user_id": user_id}
)
else:
raise Exception
return LiteLLM_UserTable(**response.dict()) _response = LiteLLM_UserTable(**dict(response))
except Exception as e: # if end-user not in db
raise Exception( # save the user object to cache
await user_api_key_cache.async_set_cache(key=user_id, value=_response)
return _response
except Exception as e: # if user not in db
raise ValueError(
f"User doesn't exist in db. 'user_id'={user_id}. Create user via `/user/new` call." f"User doesn't exist in db. 'user_id'={user_id}. Create user via `/user/new` call."
) )
@ -292,7 +304,7 @@ async def get_team_object(
) )
# check if in cache # check if in cache
cached_team_obj = user_api_key_cache.async_get_cache(key=team_id) cached_team_obj = await user_api_key_cache.async_get_cache(key=team_id)
if cached_team_obj is not None: if cached_team_obj is not None:
if isinstance(cached_team_obj, dict): if isinstance(cached_team_obj, dict):
return LiteLLM_TeamTable(**cached_team_obj) return LiteLLM_TeamTable(**cached_team_obj)
@ -307,10 +319,11 @@ async def get_team_object(
if response is None: if response is None:
raise Exception raise Exception
_response = LiteLLM_TeamTable(**response.dict())
# save the team object to cache # save the team object to cache
await user_api_key_cache.async_set_cache(key=response.team_id, value=response) await user_api_key_cache.async_set_cache(key=response.team_id, value=_response)
return LiteLLM_TeamTable(**response.dict()) return _response
except Exception as e: except Exception as e:
raise Exception( raise Exception(
f"Team doesn't exist in db. Team={team_id}. Create team via `/team/new` call." f"Team doesn't exist in db. Team={team_id}. Create team via `/team/new` call."

View file

@ -89,6 +89,14 @@ class JWTHandler:
team_id = default_value team_id = default_value
return team_id return team_id
def is_upsert_user_id(self) -> bool:
"""
Returns:
- True: if 'user_id_upsert' is set
- False: if not
"""
return self.litellm_jwtauth.user_id_upsert
def get_user_id(self, token: dict, default_value: Optional[str]) -> Optional[str]: def get_user_id(self, token: dict, default_value: Optional[str]) -> Optional[str]:
try: try:
if self.litellm_jwtauth.user_id_jwt_field is not None: if self.litellm_jwtauth.user_id_jwt_field is not None:

View file

@ -484,11 +484,9 @@ async def user_api_key_auth(
user_id=user_id, user_id=user_id,
prisma_client=prisma_client, prisma_client=prisma_client,
user_api_key_cache=user_api_key_cache, user_api_key_cache=user_api_key_cache,
user_id_upsert=jwt_handler.is_upsert_user_id(),
) )
# save the user object to cache
await user_api_key_cache.async_set_cache(
key=user_id, value=user_object
)
# [OPTIONAL] track spend against an external user - `LiteLLM_EndUserTable` # [OPTIONAL] track spend against an external user - `LiteLLM_EndUserTable`
end_user_object = None end_user_object = None
end_user_id = jwt_handler.get_end_user_id( end_user_id = jwt_handler.get_end_user_id(

View file

@ -1,7 +1,7 @@
#### What this tests #### #### What this tests ####
# Unit tests for JWT-Auth # Unit tests for JWT-Auth
import sys, os, asyncio, time, random import sys, os, asyncio, time, random, uuid
import traceback import traceback
from dotenv import load_dotenv from dotenv import load_dotenv
@ -369,13 +369,20 @@ async def test_team_token_output(prisma_client, audience):
@pytest.mark.parametrize("audience", [None, "litellm-proxy"]) @pytest.mark.parametrize("audience", [None, "litellm-proxy"])
@pytest.mark.parametrize( @pytest.mark.parametrize(
"team_id_set, default_team_id", [(True, None), (False, "1234")] "team_id_set, default_team_id",
[(True, False), (False, True)],
) )
@pytest.mark.parametrize("user_id_upsert", [True, False]) @pytest.mark.parametrize("user_id_upsert", [True, False])
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_user_token_output( async def test_user_token_output(
prisma_client, audience, team_id_set, default_team_id, user_id_upsert prisma_client, audience, team_id_set, default_team_id, user_id_upsert
): ):
import uuid
args = locals()
print(f"received args - {args}")
if default_team_id:
default_team_id = "team_id_12344_{}".format(uuid.uuid4())
""" """
- If user required, check if it exists - If user required, check if it exists
- fail initial request (when user doesn't exist) - fail initial request (when user doesn't exist)
@ -388,7 +395,12 @@ async def test_user_token_output(
from cryptography.hazmat.backends import default_backend from cryptography.hazmat.backends import default_backend
from fastapi import Request from fastapi import Request
from starlette.datastructures import URL from starlette.datastructures import URL
from litellm.proxy.proxy_server import user_api_key_auth, new_team, new_user from litellm.proxy.proxy_server import (
user_api_key_auth,
new_team,
new_user,
user_info,
)
from litellm.proxy._types import NewTeamRequest, UserAPIKeyAuth, NewUserRequest from litellm.proxy._types import NewTeamRequest, UserAPIKeyAuth, NewUserRequest
import litellm import litellm
import uuid import uuid
@ -439,6 +451,7 @@ async def test_user_token_output(
jwt_handler.litellm_jwtauth.user_id_jwt_field = "sub" jwt_handler.litellm_jwtauth.user_id_jwt_field = "sub"
jwt_handler.litellm_jwtauth.team_id_default = default_team_id jwt_handler.litellm_jwtauth.team_id_default = default_team_id
jwt_handler.litellm_jwtauth.user_id_upsert = user_id_upsert
if team_id_set: if team_id_set:
jwt_handler.litellm_jwtauth.team_id_jwt_field = "client_id" jwt_handler.litellm_jwtauth.team_id_jwt_field = "client_id"
@ -522,7 +535,7 @@ async def test_user_token_output(
), ),
user_api_key_dict=result, user_api_key_dict=result,
) )
if default_team_id is not None: if default_team_id:
await new_team( await new_team(
data=NewTeamRequest( data=NewTeamRequest(
team_id=default_team_id, team_id=default_team_id,
@ -542,23 +555,35 @@ async def test_user_token_output(
team_result: UserAPIKeyAuth = await user_api_key_auth( team_result: UserAPIKeyAuth = await user_api_key_auth(
request=request, api_key=bearer_token request=request, api_key=bearer_token
) )
pytest.fail(f"User doesn't exist. this should fail") if user_id_upsert == False:
pytest.fail(f"User doesn't exist. this should fail")
except Exception as e: except Exception as e:
pass pass
## 4. Create user ## 4. Create user
try: if user_id_upsert:
bearer_token = "Bearer " + admin_token ## check if user already exists
try:
bearer_token = "Bearer " + admin_token
request._url = URL(url="/team/new") request._url = URL(url="/team/new")
result = await user_api_key_auth(request=request, api_key=bearer_token) result = await user_api_key_auth(request=request, api_key=bearer_token)
await new_user( await user_info(user_id=user_id)
data=NewUserRequest( except Exception as e:
user_id=user_id, pytest.fail(f"This should not fail - {str(e)}")
), else:
) try:
except Exception as e: bearer_token = "Bearer " + admin_token
pytest.fail(f"This should not fail - {str(e)}")
request._url = URL(url="/team/new")
result = await user_api_key_auth(request=request, api_key=bearer_token)
await new_user(
data=NewUserRequest(
user_id=user_id,
),
)
except Exception as e:
pytest.fail(f"This should not fail - {str(e)}")
## 5. 3rd call w/ same team, same user -> call should succeed ## 5. 3rd call w/ same team, same user -> call should succeed
bearer_token = "Bearer " + token bearer_token = "Bearer " + token