mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-26 11:14:04 +00:00
feat(proxy_server.py): support 'user_id_upsert' flag for jwt_auth
This commit is contained in:
parent
ed4315af38
commit
93cb65dfee
5 changed files with 74 additions and 29 deletions
|
@ -241,6 +241,7 @@ async def get_user_object(
|
|||
user_id: str,
|
||||
prisma_client: Optional[PrismaClient],
|
||||
user_api_key_cache: DualCache,
|
||||
user_id_upsert: bool,
|
||||
) -> Optional[LiteLLM_UserTable]:
|
||||
"""
|
||||
- Check if user id in proxy User Table
|
||||
|
@ -254,7 +255,7 @@ async def get_user_object(
|
|||
return None
|
||||
|
||||
# check if in cache
|
||||
cached_user_obj = user_api_key_cache.async_get_cache(key=user_id)
|
||||
cached_user_obj = await user_api_key_cache.async_get_cache(key=user_id)
|
||||
if cached_user_obj is not None:
|
||||
if isinstance(cached_user_obj, dict):
|
||||
return LiteLLM_UserTable(**cached_user_obj)
|
||||
|
@ -262,16 +263,27 @@ async def get_user_object(
|
|||
return cached_user_obj
|
||||
# else, check db
|
||||
try:
|
||||
|
||||
response = await prisma_client.db.litellm_usertable.find_unique(
|
||||
where={"user_id": user_id}
|
||||
)
|
||||
|
||||
if response is None:
|
||||
raise Exception
|
||||
if user_id_upsert:
|
||||
response = await prisma_client.db.litellm_usertable.create(
|
||||
data={"user_id": user_id}
|
||||
)
|
||||
else:
|
||||
raise Exception
|
||||
|
||||
return LiteLLM_UserTable(**response.dict())
|
||||
except Exception as e: # if end-user not in db
|
||||
raise Exception(
|
||||
_response = LiteLLM_UserTable(**dict(response))
|
||||
|
||||
# save the user object to cache
|
||||
await user_api_key_cache.async_set_cache(key=user_id, value=_response)
|
||||
|
||||
return _response
|
||||
except Exception as e: # if user not in db
|
||||
raise ValueError(
|
||||
f"User doesn't exist in db. 'user_id'={user_id}. Create user via `/user/new` call."
|
||||
)
|
||||
|
||||
|
@ -292,7 +304,7 @@ async def get_team_object(
|
|||
)
|
||||
|
||||
# check if in cache
|
||||
cached_team_obj = user_api_key_cache.async_get_cache(key=team_id)
|
||||
cached_team_obj = await user_api_key_cache.async_get_cache(key=team_id)
|
||||
if cached_team_obj is not None:
|
||||
if isinstance(cached_team_obj, dict):
|
||||
return LiteLLM_TeamTable(**cached_team_obj)
|
||||
|
@ -307,10 +319,11 @@ async def get_team_object(
|
|||
if response is None:
|
||||
raise Exception
|
||||
|
||||
_response = LiteLLM_TeamTable(**response.dict())
|
||||
# save the team object to cache
|
||||
await user_api_key_cache.async_set_cache(key=response.team_id, value=response)
|
||||
await user_api_key_cache.async_set_cache(key=response.team_id, value=_response)
|
||||
|
||||
return LiteLLM_TeamTable(**response.dict())
|
||||
return _response
|
||||
except Exception as e:
|
||||
raise Exception(
|
||||
f"Team doesn't exist in db. Team={team_id}. Create team via `/team/new` call."
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue