mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-26 19:24:27 +00:00
feat(proxy_server.py): support 'user_id_upsert' flag for jwt_auth
This commit is contained in:
parent
ed4315af38
commit
93cb65dfee
5 changed files with 74 additions and 29 deletions
|
@ -38,7 +38,8 @@ general_settings:
|
||||||
enable_jwt_auth: True
|
enable_jwt_auth: True
|
||||||
litellm_jwtauth:
|
litellm_jwtauth:
|
||||||
team_id_default: "1234"
|
team_id_default: "1234"
|
||||||
upsert_users: True
|
user_id_jwt_field:
|
||||||
|
user_id_upsert: True
|
||||||
disable_reset_budget: True
|
disable_reset_budget: True
|
||||||
proxy_batch_write_at: 60 # 👈 Frequency of batch writing logs to server (in seconds)
|
proxy_batch_write_at: 60 # 👈 Frequency of batch writing logs to server (in seconds)
|
||||||
routing_strategy: simple-shuffle # Literal["simple-shuffle", "least-busy", "usage-based-routing","latency-based-routing"], default="simple-shuffle"
|
routing_strategy: simple-shuffle # Literal["simple-shuffle", "least-busy", "usage-based-routing","latency-based-routing"], default="simple-shuffle"
|
||||||
|
|
|
@ -241,6 +241,7 @@ async def get_user_object(
|
||||||
user_id: str,
|
user_id: str,
|
||||||
prisma_client: Optional[PrismaClient],
|
prisma_client: Optional[PrismaClient],
|
||||||
user_api_key_cache: DualCache,
|
user_api_key_cache: DualCache,
|
||||||
|
user_id_upsert: bool,
|
||||||
) -> Optional[LiteLLM_UserTable]:
|
) -> Optional[LiteLLM_UserTable]:
|
||||||
"""
|
"""
|
||||||
- Check if user id in proxy User Table
|
- Check if user id in proxy User Table
|
||||||
|
@ -254,7 +255,7 @@ async def get_user_object(
|
||||||
return None
|
return None
|
||||||
|
|
||||||
# check if in cache
|
# check if in cache
|
||||||
cached_user_obj = user_api_key_cache.async_get_cache(key=user_id)
|
cached_user_obj = await user_api_key_cache.async_get_cache(key=user_id)
|
||||||
if cached_user_obj is not None:
|
if cached_user_obj is not None:
|
||||||
if isinstance(cached_user_obj, dict):
|
if isinstance(cached_user_obj, dict):
|
||||||
return LiteLLM_UserTable(**cached_user_obj)
|
return LiteLLM_UserTable(**cached_user_obj)
|
||||||
|
@ -262,16 +263,27 @@ async def get_user_object(
|
||||||
return cached_user_obj
|
return cached_user_obj
|
||||||
# else, check db
|
# else, check db
|
||||||
try:
|
try:
|
||||||
|
|
||||||
response = await prisma_client.db.litellm_usertable.find_unique(
|
response = await prisma_client.db.litellm_usertable.find_unique(
|
||||||
where={"user_id": user_id}
|
where={"user_id": user_id}
|
||||||
)
|
)
|
||||||
|
|
||||||
if response is None:
|
if response is None:
|
||||||
raise Exception
|
if user_id_upsert:
|
||||||
|
response = await prisma_client.db.litellm_usertable.create(
|
||||||
|
data={"user_id": user_id}
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
raise Exception
|
||||||
|
|
||||||
return LiteLLM_UserTable(**response.dict())
|
_response = LiteLLM_UserTable(**dict(response))
|
||||||
except Exception as e: # if end-user not in db
|
|
||||||
raise Exception(
|
# save the user object to cache
|
||||||
|
await user_api_key_cache.async_set_cache(key=user_id, value=_response)
|
||||||
|
|
||||||
|
return _response
|
||||||
|
except Exception as e: # if user not in db
|
||||||
|
raise ValueError(
|
||||||
f"User doesn't exist in db. 'user_id'={user_id}. Create user via `/user/new` call."
|
f"User doesn't exist in db. 'user_id'={user_id}. Create user via `/user/new` call."
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -292,7 +304,7 @@ async def get_team_object(
|
||||||
)
|
)
|
||||||
|
|
||||||
# check if in cache
|
# check if in cache
|
||||||
cached_team_obj = user_api_key_cache.async_get_cache(key=team_id)
|
cached_team_obj = await user_api_key_cache.async_get_cache(key=team_id)
|
||||||
if cached_team_obj is not None:
|
if cached_team_obj is not None:
|
||||||
if isinstance(cached_team_obj, dict):
|
if isinstance(cached_team_obj, dict):
|
||||||
return LiteLLM_TeamTable(**cached_team_obj)
|
return LiteLLM_TeamTable(**cached_team_obj)
|
||||||
|
@ -307,10 +319,11 @@ async def get_team_object(
|
||||||
if response is None:
|
if response is None:
|
||||||
raise Exception
|
raise Exception
|
||||||
|
|
||||||
|
_response = LiteLLM_TeamTable(**response.dict())
|
||||||
# save the team object to cache
|
# save the team object to cache
|
||||||
await user_api_key_cache.async_set_cache(key=response.team_id, value=response)
|
await user_api_key_cache.async_set_cache(key=response.team_id, value=_response)
|
||||||
|
|
||||||
return LiteLLM_TeamTable(**response.dict())
|
return _response
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise Exception(
|
raise Exception(
|
||||||
f"Team doesn't exist in db. Team={team_id}. Create team via `/team/new` call."
|
f"Team doesn't exist in db. Team={team_id}. Create team via `/team/new` call."
|
||||||
|
|
|
@ -89,6 +89,14 @@ class JWTHandler:
|
||||||
team_id = default_value
|
team_id = default_value
|
||||||
return team_id
|
return team_id
|
||||||
|
|
||||||
|
def is_upsert_user_id(self) -> bool:
|
||||||
|
"""
|
||||||
|
Returns:
|
||||||
|
- True: if 'user_id_upsert' is set
|
||||||
|
- False: if not
|
||||||
|
"""
|
||||||
|
return self.litellm_jwtauth.user_id_upsert
|
||||||
|
|
||||||
def get_user_id(self, token: dict, default_value: Optional[str]) -> Optional[str]:
|
def get_user_id(self, token: dict, default_value: Optional[str]) -> Optional[str]:
|
||||||
try:
|
try:
|
||||||
if self.litellm_jwtauth.user_id_jwt_field is not None:
|
if self.litellm_jwtauth.user_id_jwt_field is not None:
|
||||||
|
|
|
@ -484,11 +484,9 @@ async def user_api_key_auth(
|
||||||
user_id=user_id,
|
user_id=user_id,
|
||||||
prisma_client=prisma_client,
|
prisma_client=prisma_client,
|
||||||
user_api_key_cache=user_api_key_cache,
|
user_api_key_cache=user_api_key_cache,
|
||||||
|
user_id_upsert=jwt_handler.is_upsert_user_id(),
|
||||||
)
|
)
|
||||||
# save the user object to cache
|
|
||||||
await user_api_key_cache.async_set_cache(
|
|
||||||
key=user_id, value=user_object
|
|
||||||
)
|
|
||||||
# [OPTIONAL] track spend against an external user - `LiteLLM_EndUserTable`
|
# [OPTIONAL] track spend against an external user - `LiteLLM_EndUserTable`
|
||||||
end_user_object = None
|
end_user_object = None
|
||||||
end_user_id = jwt_handler.get_end_user_id(
|
end_user_id = jwt_handler.get_end_user_id(
|
||||||
|
|
|
@ -1,7 +1,7 @@
|
||||||
#### What this tests ####
|
#### What this tests ####
|
||||||
# Unit tests for JWT-Auth
|
# Unit tests for JWT-Auth
|
||||||
|
|
||||||
import sys, os, asyncio, time, random
|
import sys, os, asyncio, time, random, uuid
|
||||||
import traceback
|
import traceback
|
||||||
from dotenv import load_dotenv
|
from dotenv import load_dotenv
|
||||||
|
|
||||||
|
@ -369,13 +369,20 @@ async def test_team_token_output(prisma_client, audience):
|
||||||
|
|
||||||
@pytest.mark.parametrize("audience", [None, "litellm-proxy"])
|
@pytest.mark.parametrize("audience", [None, "litellm-proxy"])
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
"team_id_set, default_team_id", [(True, None), (False, "1234")]
|
"team_id_set, default_team_id",
|
||||||
|
[(True, False), (False, True)],
|
||||||
)
|
)
|
||||||
@pytest.mark.parametrize("user_id_upsert", [True, False])
|
@pytest.mark.parametrize("user_id_upsert", [True, False])
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_user_token_output(
|
async def test_user_token_output(
|
||||||
prisma_client, audience, team_id_set, default_team_id, user_id_upsert
|
prisma_client, audience, team_id_set, default_team_id, user_id_upsert
|
||||||
):
|
):
|
||||||
|
import uuid
|
||||||
|
|
||||||
|
args = locals()
|
||||||
|
print(f"received args - {args}")
|
||||||
|
if default_team_id:
|
||||||
|
default_team_id = "team_id_12344_{}".format(uuid.uuid4())
|
||||||
"""
|
"""
|
||||||
- If user required, check if it exists
|
- If user required, check if it exists
|
||||||
- fail initial request (when user doesn't exist)
|
- fail initial request (when user doesn't exist)
|
||||||
|
@ -388,7 +395,12 @@ async def test_user_token_output(
|
||||||
from cryptography.hazmat.backends import default_backend
|
from cryptography.hazmat.backends import default_backend
|
||||||
from fastapi import Request
|
from fastapi import Request
|
||||||
from starlette.datastructures import URL
|
from starlette.datastructures import URL
|
||||||
from litellm.proxy.proxy_server import user_api_key_auth, new_team, new_user
|
from litellm.proxy.proxy_server import (
|
||||||
|
user_api_key_auth,
|
||||||
|
new_team,
|
||||||
|
new_user,
|
||||||
|
user_info,
|
||||||
|
)
|
||||||
from litellm.proxy._types import NewTeamRequest, UserAPIKeyAuth, NewUserRequest
|
from litellm.proxy._types import NewTeamRequest, UserAPIKeyAuth, NewUserRequest
|
||||||
import litellm
|
import litellm
|
||||||
import uuid
|
import uuid
|
||||||
|
@ -439,6 +451,7 @@ async def test_user_token_output(
|
||||||
|
|
||||||
jwt_handler.litellm_jwtauth.user_id_jwt_field = "sub"
|
jwt_handler.litellm_jwtauth.user_id_jwt_field = "sub"
|
||||||
jwt_handler.litellm_jwtauth.team_id_default = default_team_id
|
jwt_handler.litellm_jwtauth.team_id_default = default_team_id
|
||||||
|
jwt_handler.litellm_jwtauth.user_id_upsert = user_id_upsert
|
||||||
|
|
||||||
if team_id_set:
|
if team_id_set:
|
||||||
jwt_handler.litellm_jwtauth.team_id_jwt_field = "client_id"
|
jwt_handler.litellm_jwtauth.team_id_jwt_field = "client_id"
|
||||||
|
@ -522,7 +535,7 @@ async def test_user_token_output(
|
||||||
),
|
),
|
||||||
user_api_key_dict=result,
|
user_api_key_dict=result,
|
||||||
)
|
)
|
||||||
if default_team_id is not None:
|
if default_team_id:
|
||||||
await new_team(
|
await new_team(
|
||||||
data=NewTeamRequest(
|
data=NewTeamRequest(
|
||||||
team_id=default_team_id,
|
team_id=default_team_id,
|
||||||
|
@ -542,23 +555,35 @@ async def test_user_token_output(
|
||||||
team_result: UserAPIKeyAuth = await user_api_key_auth(
|
team_result: UserAPIKeyAuth = await user_api_key_auth(
|
||||||
request=request, api_key=bearer_token
|
request=request, api_key=bearer_token
|
||||||
)
|
)
|
||||||
pytest.fail(f"User doesn't exist. this should fail")
|
if user_id_upsert == False:
|
||||||
|
pytest.fail(f"User doesn't exist. this should fail")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
## 4. Create user
|
## 4. Create user
|
||||||
try:
|
if user_id_upsert:
|
||||||
bearer_token = "Bearer " + admin_token
|
## check if user already exists
|
||||||
|
try:
|
||||||
|
bearer_token = "Bearer " + admin_token
|
||||||
|
|
||||||
request._url = URL(url="/team/new")
|
request._url = URL(url="/team/new")
|
||||||
result = await user_api_key_auth(request=request, api_key=bearer_token)
|
result = await user_api_key_auth(request=request, api_key=bearer_token)
|
||||||
await new_user(
|
await user_info(user_id=user_id)
|
||||||
data=NewUserRequest(
|
except Exception as e:
|
||||||
user_id=user_id,
|
pytest.fail(f"This should not fail - {str(e)}")
|
||||||
),
|
else:
|
||||||
)
|
try:
|
||||||
except Exception as e:
|
bearer_token = "Bearer " + admin_token
|
||||||
pytest.fail(f"This should not fail - {str(e)}")
|
|
||||||
|
request._url = URL(url="/team/new")
|
||||||
|
result = await user_api_key_auth(request=request, api_key=bearer_token)
|
||||||
|
await new_user(
|
||||||
|
data=NewUserRequest(
|
||||||
|
user_id=user_id,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
pytest.fail(f"This should not fail - {str(e)}")
|
||||||
|
|
||||||
## 5. 3rd call w/ same team, same user -> call should succeed
|
## 5. 3rd call w/ same team, same user -> call should succeed
|
||||||
bearer_token = "Bearer " + token
|
bearer_token = "Bearer " + token
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue