feat(proxy_server.py): make team_id optional for jwt token auth (only enforced, if set)

Allows users to use jwt auth for internal chat apps
This commit is contained in:
Krrish Dholakia 2024-05-15 21:05:14 -07:00
parent d9ad7c6218
commit f48cd87cf3
5 changed files with 89 additions and 54 deletions

View file

@ -228,7 +228,7 @@ class LiteLLM_JWTAuth(LiteLLMBase):
"info_routes", "info_routes",
] ]
team_jwt_scope: str = "litellm_team" team_jwt_scope: str = "litellm_team"
team_id_jwt_field: str = "client_id" team_id_jwt_field: Optional[str] = None
team_allowed_routes: List[ team_allowed_routes: List[
Literal["openai_routes", "info_routes", "management_routes"] Literal["openai_routes", "info_routes", "management_routes"]
] = ["openai_routes", "info_routes"] ] = ["openai_routes", "info_routes"]

View file

@ -26,7 +26,7 @@ all_routes = LiteLLMRoutes.openai_routes.value + LiteLLMRoutes.management_routes
def common_checks( def common_checks(
request_body: dict, request_body: dict,
team_object: LiteLLM_TeamTable, team_object: Optional[LiteLLM_TeamTable],
user_object: Optional[LiteLLM_UserTable], user_object: Optional[LiteLLM_UserTable],
end_user_object: Optional[LiteLLM_EndUserTable], end_user_object: Optional[LiteLLM_EndUserTable],
global_proxy_spend: Optional[float], global_proxy_spend: Optional[float],
@ -45,13 +45,14 @@ def common_checks(
6. [OPTIONAL] If 'litellm.max_budget' is set (>0), is proxy under budget 6. [OPTIONAL] If 'litellm.max_budget' is set (>0), is proxy under budget
""" """
_model = request_body.get("model", None) _model = request_body.get("model", None)
if team_object.blocked == True: if team_object is not None and team_object.blocked == True:
raise Exception( raise Exception(
f"Team={team_object.team_id} is blocked. Update via `/team/unblock` if your admin." f"Team={team_object.team_id} is blocked. Update via `/team/unblock` if your admin."
) )
# 2. If user can call model # 2. If user can call model
if ( if (
_model is not None _model is not None
and team_object is not None
and len(team_object.models) > 0 and len(team_object.models) > 0
and _model not in team_object.models and _model not in team_object.models
): ):
@ -65,7 +66,8 @@ def common_checks(
) )
# 3. If team is in budget # 3. If team is in budget
if ( if (
team_object.max_budget is not None team_object is not None
and team_object.max_budget is not None
and team_object.spend is not None and team_object.spend is not None
and team_object.spend > team_object.max_budget and team_object.spend > team_object.max_budget
): ):
@ -305,6 +307,9 @@ async def get_team_object(
if response is None: if response is None:
raise Exception raise Exception
# save the team object to cache
await user_api_key_cache.async_set_cache(key=response.team_id, value=response)
return LiteLLM_TeamTable(**response.dict()) return LiteLLM_TeamTable(**response.dict())
except Exception as e: except Exception as e:
raise Exception( raise Exception(

View file

@ -60,7 +60,9 @@ class JWTHandler:
return True return True
return False return False
def get_end_user_id(self, token: dict, default_value: Optional[str]) -> str: def get_end_user_id(
self, token: dict, default_value: Optional[str]
) -> Optional[str]:
try: try:
if self.litellm_jwtauth.end_user_id_jwt_field is not None: if self.litellm_jwtauth.end_user_id_jwt_field is not None:
user_id = token[self.litellm_jwtauth.end_user_id_jwt_field] user_id = token[self.litellm_jwtauth.end_user_id_jwt_field]
@ -70,6 +72,16 @@ class JWTHandler:
user_id = default_value user_id = default_value
return user_id return user_id
def is_required_team_id(self) -> bool:
"""
Returns:
- True: if 'team_id_jwt_field' is set
- False: if not
"""
if self.litellm_jwtauth.team_id_jwt_field is None:
return False
return True
def get_team_id(self, token: dict, default_value: Optional[str]) -> Optional[str]: def get_team_id(self, token: dict, default_value: Optional[str]) -> Optional[str]:
try: try:
team_id = token[self.litellm_jwtauth.team_id_jwt_field] team_id = token[self.litellm_jwtauth.team_id_jwt_field]
@ -207,12 +219,14 @@ class JWTHandler:
raise Exception(f"Validation fails: {str(e)}") raise Exception(f"Validation fails: {str(e)}")
elif public_key is not None and isinstance(public_key, str): elif public_key is not None and isinstance(public_key, str):
try: try:
cert = x509.load_pem_x509_certificate(public_key.encode(), default_backend()) cert = x509.load_pem_x509_certificate(
public_key.encode(), default_backend()
)
# Extract public key # Extract public key
key = cert.public_key().public_bytes( key = cert.public_key().public_bytes(
serialization.Encoding.PEM, serialization.Encoding.PEM,
serialization.PublicFormat.SubjectPublicKeyInfo serialization.PublicFormat.SubjectPublicKeyInfo,
) )
# decode the token using the public key # decode the token using the public key
@ -221,7 +235,7 @@ class JWTHandler:
key, key,
algorithms=algorithms, algorithms=algorithms,
audience=audience, audience=audience,
options=decode_options options=decode_options,
) )
return payload return payload

View file

@ -440,29 +440,32 @@ async def user_api_key_auth(
# get team id # get team id
team_id = jwt_handler.get_team_id(token=valid_token, default_value=None) team_id = jwt_handler.get_team_id(token=valid_token, default_value=None)
if team_id is None: if team_id is None and jwt_handler.is_required_team_id() == True:
raise Exception( raise Exception(
f"No team id passed in. Field checked in jwt token - '{jwt_handler.litellm_jwtauth.team_id_jwt_field}'" f"No team id passed in. Field checked in jwt token - '{jwt_handler.litellm_jwtauth.team_id_jwt_field}'"
) )
# check allowed team routes
is_allowed = allowed_routes_check(
user_role="team",
user_route=route,
litellm_proxy_roles=jwt_handler.litellm_jwtauth,
)
if is_allowed == False:
allowed_routes = jwt_handler.litellm_jwtauth.team_allowed_routes # type: ignore
actual_routes = get_actual_routes(allowed_routes=allowed_routes)
raise Exception(
f"Team not allowed to access this route. Route={route}, Allowed Routes={actual_routes}"
)
# check if team in db team_object: Optional[LiteLLM_TeamTable] = None
team_object = await get_team_object( if team_id is not None:
team_id=team_id, # check allowed team routes
prisma_client=prisma_client, is_allowed = allowed_routes_check(
user_api_key_cache=user_api_key_cache, user_role="team",
) user_route=route,
litellm_proxy_roles=jwt_handler.litellm_jwtauth,
)
if is_allowed == False:
allowed_routes = jwt_handler.litellm_jwtauth.team_allowed_routes # type: ignore
actual_routes = get_actual_routes(allowed_routes=allowed_routes)
raise Exception(
f"Team not allowed to access this route. Route={route}, Allowed Routes={actual_routes}"
)
# check if team in db
team_object = await get_team_object(
team_id=team_id,
prisma_client=prisma_client,
user_api_key_cache=user_api_key_cache,
)
# [OPTIONAL] track spend for an org id - `LiteLLM_OrganizationTable` # [OPTIONAL] track spend for an org id - `LiteLLM_OrganizationTable`
org_id = jwt_handler.get_org_id(token=valid_token, default_value=None) org_id = jwt_handler.get_org_id(token=valid_token, default_value=None)
@ -547,18 +550,18 @@ async def user_api_key_auth(
global_proxy_spend=global_proxy_spend, global_proxy_spend=global_proxy_spend,
route=route, route=route,
) )
# save team object in cache
await user_api_key_cache.async_set_cache(
key=team_object.team_id, value=team_object
)
# return UserAPIKeyAuth object # return UserAPIKeyAuth object
return UserAPIKeyAuth( return UserAPIKeyAuth(
api_key=None, api_key=None,
team_id=team_object.team_id, team_id=team_object.team_id if team_object is not None else None,
team_tpm_limit=team_object.tpm_limit, team_tpm_limit=(
team_rpm_limit=team_object.rpm_limit, team_object.tpm_limit if team_object is not None else None
team_models=team_object.models, ),
team_rpm_limit=(
team_object.rpm_limit if team_object is not None else None
),
team_models=team_object.models if team_object is not None else [],
user_role="app_owner", user_role="app_owner",
user_id=user_id, user_id=user_id,
org_id=org_id, org_id=org_id,

View file

@ -24,6 +24,7 @@ public_key = {
"alg": "RS256", "alg": "RS256",
} }
def test_load_config_with_custom_role_names(): def test_load_config_with_custom_role_names():
config = { config = {
"general_settings": { "general_settings": {
@ -77,7 +78,8 @@ async def test_token_single_public_key():
== "qIgOQfEVrrErJC0E7gsHXi6rs_V0nyFY5qPFui2-tv0o4CwpwDzgfBtLO7o_wLiguq0lnu54sMT2eLNoRiiPuLvv6bg7Iy1H9yc5_4Jf5oYEOrqN5o9ZBOoYp1q68Pv0oNJYyZdGu5ZJfd7V4y953vB2XfEKgXCsAkhVhlvIUMiDNKWoMDWsyb2xela5tRURZ2mJAXcHfSC_sYdZxIA2YYrIHfoevq_vTlaz0qVSe_uOKjEpgOAS08UUrgda4CQL11nzICiIQzc6qmjIQt2cjzB2D_9zb4BYndzEtfl0kwAT0z_I85S3mkwTqHU-1BvKe_4MG4VG3dAAeffLPXJyXQ" == "qIgOQfEVrrErJC0E7gsHXi6rs_V0nyFY5qPFui2-tv0o4CwpwDzgfBtLO7o_wLiguq0lnu54sMT2eLNoRiiPuLvv6bg7Iy1H9yc5_4Jf5oYEOrqN5o9ZBOoYp1q68Pv0oNJYyZdGu5ZJfd7V4y953vB2XfEKgXCsAkhVhlvIUMiDNKWoMDWsyb2xela5tRURZ2mJAXcHfSC_sYdZxIA2YYrIHfoevq_vTlaz0qVSe_uOKjEpgOAS08UUrgda4CQL11nzICiIQzc6qmjIQt2cjzB2D_9zb4BYndzEtfl0kwAT0z_I85S3mkwTqHU-1BvKe_4MG4VG3dAAeffLPXJyXQ"
) )
@pytest.mark.parametrize('audience', [None, "litellm-proxy"])
@pytest.mark.parametrize("audience", [None, "litellm-proxy"])
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_valid_invalid_token(audience): async def test_valid_invalid_token(audience):
""" """
@ -90,7 +92,7 @@ async def test_valid_invalid_token(audience):
from cryptography.hazmat.primitives.asymmetric import rsa from cryptography.hazmat.primitives.asymmetric import rsa
from cryptography.hazmat.backends import default_backend from cryptography.hazmat.backends import default_backend
os.environ.pop('JWT_AUDIENCE', None) os.environ.pop("JWT_AUDIENCE", None)
if audience: if audience:
os.environ["JWT_AUDIENCE"] = audience os.environ["JWT_AUDIENCE"] = audience
@ -138,7 +140,7 @@ async def test_valid_invalid_token(audience):
"sub": "user123", "sub": "user123",
"exp": expiration_time, # set the token to expire in 10 minutes "exp": expiration_time, # set the token to expire in 10 minutes
"scope": "litellm-proxy-admin", "scope": "litellm-proxy-admin",
"aud": audience "aud": audience,
} }
# Generate the JWT token # Generate the JWT token
@ -166,7 +168,7 @@ async def test_valid_invalid_token(audience):
"sub": "user123", "sub": "user123",
"exp": expiration_time, # set the token to expire in 10 minutes "exp": expiration_time, # set the token to expire in 10 minutes
"scope": "litellm-NO-SCOPE", "scope": "litellm-NO-SCOPE",
"aud": audience "aud": audience,
} }
# Generate the JWT token # Generate the JWT token
@ -183,6 +185,7 @@ async def test_valid_invalid_token(audience):
except Exception as e: except Exception as e:
pytest.fail(f"An exception occurred - {str(e)}") pytest.fail(f"An exception occurred - {str(e)}")
@pytest.fixture @pytest.fixture
def prisma_client(): def prisma_client():
import litellm import litellm
@ -205,7 +208,7 @@ def prisma_client():
return prisma_client return prisma_client
@pytest.mark.parametrize('audience', [None, "litellm-proxy"]) @pytest.mark.parametrize("audience", [None, "litellm-proxy"])
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_team_token_output(prisma_client, audience): async def test_team_token_output(prisma_client, audience):
import jwt, json import jwt, json
@ -222,7 +225,7 @@ async def test_team_token_output(prisma_client, audience):
setattr(litellm.proxy.proxy_server, "prisma_client", prisma_client) setattr(litellm.proxy.proxy_server, "prisma_client", prisma_client)
await litellm.proxy.proxy_server.prisma_client.connect() await litellm.proxy.proxy_server.prisma_client.connect()
os.environ.pop('JWT_AUDIENCE', None) os.environ.pop("JWT_AUDIENCE", None)
if audience: if audience:
os.environ["JWT_AUDIENCE"] = audience os.environ["JWT_AUDIENCE"] = audience
@ -261,7 +264,7 @@ async def test_team_token_output(prisma_client, audience):
jwt_handler.user_api_key_cache = cache jwt_handler.user_api_key_cache = cache
jwt_handler.litellm_jwtauth = LiteLLM_JWTAuth() jwt_handler.litellm_jwtauth = LiteLLM_JWTAuth(team_id_jwt_field="client_id")
# VALID TOKEN # VALID TOKEN
## GENERATE A TOKEN ## GENERATE A TOKEN
@ -274,7 +277,7 @@ async def test_team_token_output(prisma_client, audience):
"exp": expiration_time, # set the token to expire in 10 minutes "exp": expiration_time, # set the token to expire in 10 minutes
"scope": "litellm_team", "scope": "litellm_team",
"client_id": team_id, "client_id": team_id,
"aud": audience "aud": audience,
} }
# Generate the JWT token # Generate the JWT token
@ -289,7 +292,7 @@ async def test_team_token_output(prisma_client, audience):
"sub": "user123", "sub": "user123",
"exp": expiration_time, # set the token to expire in 10 minutes "exp": expiration_time, # set the token to expire in 10 minutes
"scope": "litellm_proxy_admin", "scope": "litellm_proxy_admin",
"aud": audience "aud": audience,
} }
admin_token = jwt.encode(payload, private_key_str, algorithm="RS256") admin_token = jwt.encode(payload, private_key_str, algorithm="RS256")
@ -315,7 +318,13 @@ async def test_team_token_output(prisma_client, audience):
## 1. INITIAL TEAM CALL - should fail ## 1. INITIAL TEAM CALL - should fail
# use generated key to auth in # use generated key to auth in
setattr(litellm.proxy.proxy_server, "general_settings", {"enable_jwt_auth": True}) setattr(
litellm.proxy.proxy_server,
"general_settings",
{
"enable_jwt_auth": True,
},
)
setattr(litellm.proxy.proxy_server, "jwt_handler", jwt_handler) setattr(litellm.proxy.proxy_server, "jwt_handler", jwt_handler)
try: try:
result = await user_api_key_auth(request=request, api_key=bearer_token) result = await user_api_key_auth(request=request, api_key=bearer_token)
@ -358,9 +367,10 @@ async def test_team_token_output(prisma_client, audience):
assert team_result.team_models == ["gpt-3.5-turbo", "gpt-4"] assert team_result.team_models == ["gpt-3.5-turbo", "gpt-4"]
@pytest.mark.parametrize('audience', [None, "litellm-proxy"]) @pytest.mark.parametrize("audience", [None, "litellm-proxy"])
@pytest.mark.parametrize("team_id_set", [True, False])
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_user_token_output(prisma_client, audience): async def test_user_token_output(prisma_client, audience, team_id_set):
""" """
- If user required, check if it exists - If user required, check if it exists
- fail initial request (when user doesn't exist) - fail initial request (when user doesn't exist)
@ -381,7 +391,7 @@ async def test_user_token_output(prisma_client, audience):
setattr(litellm.proxy.proxy_server, "prisma_client", prisma_client) setattr(litellm.proxy.proxy_server, "prisma_client", prisma_client)
await litellm.proxy.proxy_server.prisma_client.connect() await litellm.proxy.proxy_server.prisma_client.connect()
os.environ.pop('JWT_AUDIENCE', None) os.environ.pop("JWT_AUDIENCE", None)
if audience: if audience:
os.environ["JWT_AUDIENCE"] = audience os.environ["JWT_AUDIENCE"] = audience
@ -423,6 +433,8 @@ async def test_user_token_output(prisma_client, audience):
jwt_handler.litellm_jwtauth = LiteLLM_JWTAuth() jwt_handler.litellm_jwtauth = LiteLLM_JWTAuth()
jwt_handler.litellm_jwtauth.user_id_jwt_field = "sub" jwt_handler.litellm_jwtauth.user_id_jwt_field = "sub"
if team_id_set:
jwt_handler.litellm_jwtauth.team_id_jwt_field = "client_id"
# VALID TOKEN # VALID TOKEN
## GENERATE A TOKEN ## GENERATE A TOKEN
@ -436,7 +448,7 @@ async def test_user_token_output(prisma_client, audience):
"exp": expiration_time, # set the token to expire in 10 minutes "exp": expiration_time, # set the token to expire in 10 minutes
"scope": "litellm_team", "scope": "litellm_team",
"client_id": team_id, "client_id": team_id,
"aud": audience "aud": audience,
} }
# Generate the JWT token # Generate the JWT token
@ -451,7 +463,7 @@ async def test_user_token_output(prisma_client, audience):
"sub": user_id, "sub": user_id,
"exp": expiration_time, # set the token to expire in 10 minutes "exp": expiration_time, # set the token to expire in 10 minutes
"scope": "litellm_proxy_admin", "scope": "litellm_proxy_admin",
"aud": audience "aud": audience,
} }
admin_token = jwt.encode(payload, private_key_str, algorithm="RS256") admin_token = jwt.encode(payload, private_key_str, algorithm="RS256")
@ -543,7 +555,8 @@ async def test_user_token_output(prisma_client, audience):
## 6. ASSERT USER_API_KEY_AUTH format (used for tpm/rpm limiting in parallel_request_limiter.py AND cost tracking) ## 6. ASSERT USER_API_KEY_AUTH format (used for tpm/rpm limiting in parallel_request_limiter.py AND cost tracking)
assert team_result.team_tpm_limit == 100 if team_id_set:
assert team_result.team_rpm_limit == 99 assert team_result.team_tpm_limit == 100
assert team_result.team_models == ["gpt-3.5-turbo", "gpt-4"] assert team_result.team_rpm_limit == 99
assert team_result.team_models == ["gpt-3.5-turbo", "gpt-4"]
assert team_result.user_id == user_id assert team_result.user_id == user_id