Merge pull request #2687 from BerriAI/litellm_jwt_auth_fixes_2

Litellm jwt auth fixes
This commit is contained in:
Krish Dholakia 2024-03-25 13:27:19 -07:00 committed by GitHub
commit f15ba10170
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
5 changed files with 300 additions and 43 deletions

View file

@ -67,17 +67,21 @@ class JWTHandler:
self.http_handler = HTTPHandler()
def update_environment(
self, prisma_client: Optional[PrismaClient], user_api_key_cache: DualCache
self,
prisma_client: Optional[PrismaClient],
user_api_key_cache: DualCache,
litellm_proxy_roles: LiteLLMProxyRoles,
) -> None:
self.prisma_client = prisma_client
self.user_api_key_cache = user_api_key_cache
self.litellm_proxy_roles = litellm_proxy_roles
def is_jwt(self, token: str):
parts = token.split(".")
return len(parts) == 3
def is_admin(self, scopes: list) -> bool:
if LiteLLMProxyRoles.PROXY_ADMIN.value in scopes:
if self.litellm_proxy_roles.proxy_admin in scopes:
return True
return False
@ -90,7 +94,7 @@ class JWTHandler:
def get_team_id(self, token: dict, default_value: Optional[str]) -> Optional[str]:
try:
team_id = token["azp"]
team_id = token["client_id"]
except KeyError:
team_id = default_value
return team_id
@ -130,58 +134,94 @@ class JWTHandler:
def get_scopes(self, token: dict) -> list:
try:
# Assuming the scopes are stored in 'scope' claim and are space-separated
scopes = token["scope"].split()
if isinstance(token["scope"], str):
# Assuming the scopes are stored in 'scope' claim and are space-separated
scopes = token["scope"].split()
elif isinstance(token["scope"], list):
scopes = token["scope"]
else:
raise Exception(
f"Unmapped scope type - {type(token['scope'])}. Supported types - list, str."
)
except KeyError:
scopes = []
return scopes
async def auth_jwt(self, token: str) -> dict:
from jwt.algorithms import RSAAlgorithm
async def get_public_key(self, kid: Optional[str]) -> dict:
keys_url = os.getenv("JWT_PUBLIC_KEY_URL")
if keys_url is None:
raise Exception("Missing JWT Public Key URL from environment.")
response = await self.http_handler.get(keys_url)
cached_keys = await self.user_api_key_cache.async_get_cache(
"litellm_jwt_auth_keys"
)
if cached_keys is None:
response = await self.http_handler.get(keys_url)
keys = response.json()["keys"]
keys = response.json()["keys"]
await self.user_api_key_cache.async_set_cache(
key="litellm_jwt_auth_keys", value=keys, ttl=600 # cache for 10 mins
)
else:
keys = cached_keys
public_key: Optional[dict] = None
if len(keys) == 1:
public_key = keys[0]
elif len(keys) > 1:
for key in keys:
if kid is not None and key["kid"] == kid:
public_key = key
if public_key is None:
raise Exception(
f"No matching public key found. kid={kid}, keys_url={keys_url}, cached_keys={cached_keys}"
)
return public_key
async def auth_jwt(self, token: str) -> dict:
from jwt.algorithms import RSAAlgorithm
header = jwt.get_unverified_header(token)
verbose_proxy_logger.debug("header: %s", header)
if "kid" in header:
kid = header["kid"]
else:
raise Exception(f"Expected 'kid' in header. header={header}.")
kid = header.get("kid", None)
for key in keys:
if key["kid"] == kid:
jwk = {
"kty": key["kty"],
"kid": key["kid"],
"n": key["n"],
"e": key["e"],
}
public_key = RSAAlgorithm.from_jwk(json.dumps(jwk))
public_key = await self.get_public_key(kid=kid)
try:
# decode the token using the public key
payload = jwt.decode(
token,
public_key, # type: ignore
algorithms=["RS256"],
audience="account",
)
return payload
if public_key is not None and isinstance(public_key, dict):
jwk = {}
if "kty" in public_key:
jwk["kty"] = public_key["kty"]
if "kid" in public_key:
jwk["kid"] = public_key["kid"]
if "n" in public_key:
jwk["n"] = public_key["n"]
if "e" in public_key:
jwk["e"] = public_key["e"]
except jwt.ExpiredSignatureError:
# the token is expired, do something to refresh it
raise Exception("Token Expired")
except Exception as e:
raise Exception(f"Validation fails: {str(e)}")
public_key_rsa = RSAAlgorithm.from_jwk(json.dumps(jwk))
try:
# decode the token using the public key
payload = jwt.decode(
token,
public_key_rsa, # type: ignore
algorithms=["RS256"],
options={"verify_aud": False},
)
return payload
except jwt.ExpiredSignatureError:
# the token is expired, do something to refresh it
raise Exception("Token Expired")
except Exception as e:
raise Exception(f"Validation fails: {str(e)}")
raise Exception("Invalid JWT Submitted")