diff --git a/litellm/proxy/auth/handle_jwt.py b/litellm/proxy/auth/handle_jwt.py index 606ff68281..9c846e8f66 100644 --- a/litellm/proxy/auth/handle_jwt.py +++ b/litellm/proxy/auth/handle_jwt.py @@ -156,6 +156,11 @@ class JWTHandler: return public_key async def auth_jwt(self, token: str) -> dict: + # Supported algos: https://pyjwt.readthedocs.io/en/stable/algorithms.html + # "Warning: Make sure not to mix symmetric and asymmetric algorithms that interpret + # the key in different ways (e.g. HS* and RS*)." + algorithms = ["RS256", "RS384", "RS512", "PS256", "PS384", "PS512"], + audience = os.getenv("JWT_AUDIENCE") decode_options = None if audience is None: @@ -189,7 +194,7 @@ class JWTHandler: payload = jwt.decode( token, public_key_rsa, # type: ignore - algorithms=["RS256"], + algorithms=algorithms, options=decode_options, audience=audience, ) @@ -214,7 +219,7 @@ class JWTHandler: payload = jwt.decode( token, key, - algorithms=["RS256"], + algorithms=algorithms, audience=audience, options=decode_options )