Merge pull request #2687 from BerriAI/litellm_jwt_auth_fixes_2

Litellm jwt auth fixes
This commit is contained in:
Krish Dholakia 2024-03-25 13:27:19 -07:00 committed by GitHub
commit f15ba10170
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
5 changed files with 300 additions and 43 deletions

View file

@ -107,4 +107,38 @@ general_settings:
master_key: sk-1234 master_key: sk-1234
enable_jwt_auth: True enable_jwt_auth: True
allowed_routes: ["/chat/completions", "/embeddings"] allowed_routes: ["/chat/completions", "/embeddings"]
```
## Advanced - Set Accepted JWT Scope Names
Change the string in JWT 'scopes', that litellm evaluates to see if a user has admin access.
```yaml
general_settings:
master_key: sk-1234
enable_jwt_auth: True
litellm_proxy_roles:
proxy_admin: "litellm-proxy-admin"
```
### Allowed LiteLLM scopes
```python
class LiteLLMProxyRoles(LiteLLMBase):
proxy_admin: str = "litellm_proxy_admin"
proxy_user: str = "litellm_user" # 👈 Not implemented yet, for JWT-Auth.
```
### JWT Scopes
Here's what scopes on JWT-Auth tokens look like
**Can be a list**
```
scope: ["litellm-proxy-admin",...]
```
**Can be a space-separated string**
```
scope: "litellm-proxy-admin ..."
``` ```

View file

@ -14,11 +14,6 @@ def hash_token(token: str):
return hashed_token return hashed_token
class LiteLLMProxyRoles(enum.Enum):
PROXY_ADMIN = "litellm_proxy_admin"
USER = "litellm_user"
class LiteLLMBase(BaseModel): class LiteLLMBase(BaseModel):
""" """
Implements default functions, all pydantic objects should have. Implements default functions, all pydantic objects should have.
@ -42,6 +37,11 @@ class LiteLLMBase(BaseModel):
protected_namespaces = () protected_namespaces = ()
class LiteLLMProxyRoles(LiteLLMBase):
proxy_admin: str = "litellm_proxy_admin"
proxy_user: str = "litellm_user"
class LiteLLMPromptInjectionParams(LiteLLMBase): class LiteLLMPromptInjectionParams(LiteLLMBase):
heuristics_check: bool = False heuristics_check: bool = False
vector_db_check: bool = False vector_db_check: bool = False

View file

@ -67,17 +67,21 @@ class JWTHandler:
self.http_handler = HTTPHandler() self.http_handler = HTTPHandler()
def update_environment( def update_environment(
self, prisma_client: Optional[PrismaClient], user_api_key_cache: DualCache self,
prisma_client: Optional[PrismaClient],
user_api_key_cache: DualCache,
litellm_proxy_roles: LiteLLMProxyRoles,
) -> None: ) -> None:
self.prisma_client = prisma_client self.prisma_client = prisma_client
self.user_api_key_cache = user_api_key_cache self.user_api_key_cache = user_api_key_cache
self.litellm_proxy_roles = litellm_proxy_roles
def is_jwt(self, token: str): def is_jwt(self, token: str):
parts = token.split(".") parts = token.split(".")
return len(parts) == 3 return len(parts) == 3
def is_admin(self, scopes: list) -> bool: def is_admin(self, scopes: list) -> bool:
if LiteLLMProxyRoles.PROXY_ADMIN.value in scopes: if self.litellm_proxy_roles.proxy_admin in scopes:
return True return True
return False return False
@ -90,7 +94,7 @@ class JWTHandler:
def get_team_id(self, token: dict, default_value: Optional[str]) -> Optional[str]: def get_team_id(self, token: dict, default_value: Optional[str]) -> Optional[str]:
try: try:
team_id = token["azp"] team_id = token["client_id"]
except KeyError: except KeyError:
team_id = default_value team_id = default_value
return team_id return team_id
@ -130,58 +134,94 @@ class JWTHandler:
def get_scopes(self, token: dict) -> list: def get_scopes(self, token: dict) -> list:
try: try:
# Assuming the scopes are stored in 'scope' claim and are space-separated if isinstance(token["scope"], str):
scopes = token["scope"].split() # Assuming the scopes are stored in 'scope' claim and are space-separated
scopes = token["scope"].split()
elif isinstance(token["scope"], list):
scopes = token["scope"]
else:
raise Exception(
f"Unmapped scope type - {type(token['scope'])}. Supported types - list, str."
)
except KeyError: except KeyError:
scopes = [] scopes = []
return scopes return scopes
async def auth_jwt(self, token: str) -> dict: async def get_public_key(self, kid: Optional[str]) -> dict:
from jwt.algorithms import RSAAlgorithm
keys_url = os.getenv("JWT_PUBLIC_KEY_URL") keys_url = os.getenv("JWT_PUBLIC_KEY_URL")
if keys_url is None: if keys_url is None:
raise Exception("Missing JWT Public Key URL from environment.") raise Exception("Missing JWT Public Key URL from environment.")
response = await self.http_handler.get(keys_url) cached_keys = await self.user_api_key_cache.async_get_cache(
"litellm_jwt_auth_keys"
)
if cached_keys is None:
response = await self.http_handler.get(keys_url)
keys = response.json()["keys"] keys = response.json()["keys"]
await self.user_api_key_cache.async_set_cache(
key="litellm_jwt_auth_keys", value=keys, ttl=600 # cache for 10 mins
)
else:
keys = cached_keys
public_key: Optional[dict] = None
if len(keys) == 1:
public_key = keys[0]
elif len(keys) > 1:
for key in keys:
if kid is not None and key["kid"] == kid:
public_key = key
if public_key is None:
raise Exception(
f"No matching public key found. kid={kid}, keys_url={keys_url}, cached_keys={cached_keys}"
)
return public_key
async def auth_jwt(self, token: str) -> dict:
from jwt.algorithms import RSAAlgorithm
header = jwt.get_unverified_header(token) header = jwt.get_unverified_header(token)
verbose_proxy_logger.debug("header: %s", header) verbose_proxy_logger.debug("header: %s", header)
if "kid" in header: kid = header.get("kid", None)
kid = header["kid"]
else:
raise Exception(f"Expected 'kid' in header. header={header}.")
for key in keys: public_key = await self.get_public_key(kid=kid)
if key["kid"] == kid:
jwk = {
"kty": key["kty"],
"kid": key["kid"],
"n": key["n"],
"e": key["e"],
}
public_key = RSAAlgorithm.from_jwk(json.dumps(jwk))
try: if public_key is not None and isinstance(public_key, dict):
# decode the token using the public key jwk = {}
payload = jwt.decode( if "kty" in public_key:
token, jwk["kty"] = public_key["kty"]
public_key, # type: ignore if "kid" in public_key:
algorithms=["RS256"], jwk["kid"] = public_key["kid"]
audience="account", if "n" in public_key:
) jwk["n"] = public_key["n"]
return payload if "e" in public_key:
jwk["e"] = public_key["e"]
except jwt.ExpiredSignatureError: public_key_rsa = RSAAlgorithm.from_jwk(json.dumps(jwk))
# the token is expired, do something to refresh it
raise Exception("Token Expired") try:
except Exception as e: # decode the token using the public key
raise Exception(f"Validation fails: {str(e)}") payload = jwt.decode(
token,
public_key_rsa, # type: ignore
algorithms=["RS256"],
options={"verify_aud": False},
)
return payload
except jwt.ExpiredSignatureError:
# the token is expired, do something to refresh it
raise Exception("Token Expired")
except Exception as e:
raise Exception(f"Validation fails: {str(e)}")
raise Exception("Invalid JWT Submitted") raise Exception("Invalid JWT Submitted")

View file

@ -2710,7 +2710,11 @@ async def startup_event():
## JWT AUTH ## ## JWT AUTH ##
jwt_handler.update_environment( jwt_handler.update_environment(
prisma_client=prisma_client, user_api_key_cache=user_api_key_cache prisma_client=prisma_client,
user_api_key_cache=user_api_key_cache,
litellm_proxy_roles=LiteLLMProxyRoles(
**general_settings.get("litellm_proxy_roles", {})
),
) )
if use_background_health_checks: if use_background_health_checks:

179
litellm/tests/test_jwt.py Normal file
View file

@ -0,0 +1,179 @@
#### What this tests ####
# Unit tests for JWT-Auth
import sys, os, asyncio, time, random
import traceback
from dotenv import load_dotenv
load_dotenv()
import os
sys.path.insert(
0, os.path.abspath("../..")
) # Adds the parent directory to the system path
import pytest
from litellm.proxy._types import LiteLLMProxyRoles
from litellm.proxy.auth.handle_jwt import JWTHandler
from litellm.caching import DualCache
from datetime import datetime, timedelta
public_key = {
"kty": "RSA",
"e": "AQAB",
"n": "qIgOQfEVrrErJC0E7gsHXi6rs_V0nyFY5qPFui2-tv0o4CwpwDzgfBtLO7o_wLiguq0lnu54sMT2eLNoRiiPuLvv6bg7Iy1H9yc5_4Jf5oYEOrqN5o9ZBOoYp1q68Pv0oNJYyZdGu5ZJfd7V4y953vB2XfEKgXCsAkhVhlvIUMiDNKWoMDWsyb2xela5tRURZ2mJAXcHfSC_sYdZxIA2YYrIHfoevq_vTlaz0qVSe_uOKjEpgOAS08UUrgda4CQL11nzICiIQzc6qmjIQt2cjzB2D_9zb4BYndzEtfl0kwAT0z_I85S3mkwTqHU-1BvKe_4MG4VG3dAAeffLPXJyXQ",
"alg": "RS256",
}
def test_load_config_with_custom_role_names():
config = {
"general_settings": {
"litellm_proxy_roles": {"proxy_admin": "litellm-proxy-admin"}
}
}
proxy_roles = LiteLLMProxyRoles(
**config.get("general_settings", {}).get("litellm_proxy_roles", {})
)
print(f"proxy_roles: {proxy_roles}")
assert proxy_roles.proxy_admin == "litellm-proxy-admin"
# test_load_config_with_custom_role_names()
@pytest.mark.asyncio
async def test_token_single_public_key():
import jwt
jwt_handler = JWTHandler()
backend_keys = {
"keys": [
{
"kty": "RSA",
"use": "sig",
"e": "AQAB",
"n": "qIgOQfEVrrErJC0E7gsHXi6rs_V0nyFY5qPFui2-tv0o4CwpwDzgfBtLO7o_wLiguq0lnu54sMT2eLNoRiiPuLvv6bg7Iy1H9yc5_4Jf5oYEOrqN5o9ZBOoYp1q68Pv0oNJYyZdGu5ZJfd7V4y953vB2XfEKgXCsAkhVhlvIUMiDNKWoMDWsyb2xela5tRURZ2mJAXcHfSC_sYdZxIA2YYrIHfoevq_vTlaz0qVSe_uOKjEpgOAS08UUrgda4CQL11nzICiIQzc6qmjIQt2cjzB2D_9zb4BYndzEtfl0kwAT0z_I85S3mkwTqHU-1BvKe_4MG4VG3dAAeffLPXJyXQ",
"alg": "RS256",
}
]
}
# set cache
cache = DualCache()
await cache.async_set_cache(key="litellm_jwt_auth_keys", value=backend_keys["keys"])
jwt_handler.user_api_key_cache = cache
public_key = await jwt_handler.get_public_key(kid=None)
assert public_key is not None
assert isinstance(public_key, dict)
assert (
public_key["n"]
== "qIgOQfEVrrErJC0E7gsHXi6rs_V0nyFY5qPFui2-tv0o4CwpwDzgfBtLO7o_wLiguq0lnu54sMT2eLNoRiiPuLvv6bg7Iy1H9yc5_4Jf5oYEOrqN5o9ZBOoYp1q68Pv0oNJYyZdGu5ZJfd7V4y953vB2XfEKgXCsAkhVhlvIUMiDNKWoMDWsyb2xela5tRURZ2mJAXcHfSC_sYdZxIA2YYrIHfoevq_vTlaz0qVSe_uOKjEpgOAS08UUrgda4CQL11nzICiIQzc6qmjIQt2cjzB2D_9zb4BYndzEtfl0kwAT0z_I85S3mkwTqHU-1BvKe_4MG4VG3dAAeffLPXJyXQ"
)
@pytest.mark.asyncio
async def test_valid_invalid_token():
"""
Tests
- valid token
- invalid token
"""
import jwt, json
from cryptography.hazmat.primitives import serialization
from cryptography.hazmat.primitives.asymmetric import rsa
from cryptography.hazmat.backends import default_backend
# Generate a private / public key pair using RSA algorithm
key = rsa.generate_private_key(
public_exponent=65537, key_size=2048, backend=default_backend()
)
# Get private key in PEM format
private_key = key.private_bytes(
encoding=serialization.Encoding.PEM,
format=serialization.PrivateFormat.PKCS8,
encryption_algorithm=serialization.NoEncryption(),
)
# Get public key in PEM format
public_key = key.public_key().public_bytes(
encoding=serialization.Encoding.PEM,
format=serialization.PublicFormat.SubjectPublicKeyInfo,
)
public_key_obj = serialization.load_pem_public_key(
public_key, backend=default_backend()
)
# Convert RSA public key object to JWK (JSON Web Key)
public_jwk = json.loads(jwt.algorithms.RSAAlgorithm.to_jwk(public_key_obj))
assert isinstance(public_jwk, dict)
# set cache
cache = DualCache()
await cache.async_set_cache(key="litellm_jwt_auth_keys", value=[public_jwk])
jwt_handler = JWTHandler()
jwt_handler.user_api_key_cache = cache
# VALID TOKEN
## GENERATE A TOKEN
# Assuming the current time is in UTC
expiration_time = int((datetime.utcnow() + timedelta(minutes=10)).timestamp())
payload = {
"sub": "user123",
"exp": expiration_time, # set the token to expire in 10 minutes
"scope": "litellm-proxy-admin",
}
# Generate the JWT token
# But before, you should convert bytes to string
private_key_str = private_key.decode("utf-8")
token = jwt.encode(payload, private_key_str, algorithm="RS256")
## VERIFY IT WORKS
# verify token
response = await jwt_handler.auth_jwt(token=token)
assert response is not None
assert isinstance(response, dict)
print(f"response: {response}")
# INVALID TOKEN
## GENERATE A TOKEN
# Assuming the current time is in UTC
expiration_time = int((datetime.utcnow() + timedelta(minutes=10)).timestamp())
payload = {
"sub": "user123",
"exp": expiration_time, # set the token to expire in 10 minutes
"scope": "litellm-NO-SCOPE",
}
# Generate the JWT token
# But before, you should convert bytes to string
private_key_str = private_key.decode("utf-8")
token = jwt.encode(payload, private_key_str, algorithm="RS256")
## VERIFY IT WORKS
# verify token
try:
response = await jwt_handler.auth_jwt(token=token)
except Exception as e:
pytest.fail(f"An exception occurred - {str(e)}")