feat(proxy_server.py): make team_id optional for jwt token auth (only enforced, if set)

Allows users to use jwt auth for internal chat apps
This commit is contained in:
Krrish Dholakia 2024-05-15 21:05:14 -07:00
parent d9ad7c6218
commit f48cd87cf3
5 changed files with 89 additions and 54 deletions

View file

@ -60,7 +60,9 @@ class JWTHandler:
return True
return False
def get_end_user_id(self, token: dict, default_value: Optional[str]) -> str:
def get_end_user_id(
self, token: dict, default_value: Optional[str]
) -> Optional[str]:
try:
if self.litellm_jwtauth.end_user_id_jwt_field is not None:
user_id = token[self.litellm_jwtauth.end_user_id_jwt_field]
@ -70,6 +72,16 @@ class JWTHandler:
user_id = default_value
return user_id
def is_required_team_id(self) -> bool:
"""
Returns:
- True: if 'team_id_jwt_field' is set
- False: if not
"""
if self.litellm_jwtauth.team_id_jwt_field is None:
return False
return True
def get_team_id(self, token: dict, default_value: Optional[str]) -> Optional[str]:
try:
team_id = token[self.litellm_jwtauth.team_id_jwt_field]
@ -165,7 +177,7 @@ class JWTHandler:
decode_options = None
if audience is None:
decode_options = {"verify_aud": False}
from jwt.algorithms import RSAAlgorithm
header = jwt.get_unverified_header(token)
@ -207,12 +219,14 @@ class JWTHandler:
raise Exception(f"Validation fails: {str(e)}")
elif public_key is not None and isinstance(public_key, str):
try:
cert = x509.load_pem_x509_certificate(public_key.encode(), default_backend())
cert = x509.load_pem_x509_certificate(
public_key.encode(), default_backend()
)
# Extract public key
key = cert.public_key().public_bytes(
serialization.Encoding.PEM,
serialization.PublicFormat.SubjectPublicKeyInfo
serialization.PublicFormat.SubjectPublicKeyInfo,
)
# decode the token using the public key
@ -221,7 +235,7 @@ class JWTHandler:
key,
algorithms=algorithms,
audience=audience,
options=decode_options
options=decode_options,
)
return payload