forked from phoenix/litellm-mirror
Merge pull request #2687 from BerriAI/litellm_jwt_auth_fixes_2
Litellm jwt auth fixes
This commit is contained in:
commit
f15ba10170
5 changed files with 300 additions and 43 deletions
|
@ -107,4 +107,38 @@ general_settings:
|
||||||
master_key: sk-1234
|
master_key: sk-1234
|
||||||
enable_jwt_auth: True
|
enable_jwt_auth: True
|
||||||
allowed_routes: ["/chat/completions", "/embeddings"]
|
allowed_routes: ["/chat/completions", "/embeddings"]
|
||||||
|
```
|
||||||
|
|
||||||
|
## Advanced - Set Accepted JWT Scope Names
|
||||||
|
|
||||||
|
Change the string in JWT 'scopes', that litellm evaluates to see if a user has admin access.
|
||||||
|
|
||||||
|
```yaml
|
||||||
|
general_settings:
|
||||||
|
master_key: sk-1234
|
||||||
|
enable_jwt_auth: True
|
||||||
|
litellm_proxy_roles:
|
||||||
|
proxy_admin: "litellm-proxy-admin"
|
||||||
|
```
|
||||||
|
|
||||||
|
### Allowed LiteLLM scopes
|
||||||
|
|
||||||
|
```python
|
||||||
|
class LiteLLMProxyRoles(LiteLLMBase):
|
||||||
|
proxy_admin: str = "litellm_proxy_admin"
|
||||||
|
proxy_user: str = "litellm_user" # 👈 Not implemented yet, for JWT-Auth.
|
||||||
|
```
|
||||||
|
|
||||||
|
### JWT Scopes
|
||||||
|
|
||||||
|
Here's what scopes on JWT-Auth tokens look like
|
||||||
|
|
||||||
|
**Can be a list**
|
||||||
|
```
|
||||||
|
scope: ["litellm-proxy-admin",...]
|
||||||
|
```
|
||||||
|
|
||||||
|
**Can be a space-separated string**
|
||||||
|
```
|
||||||
|
scope: "litellm-proxy-admin ..."
|
||||||
```
|
```
|
|
@ -14,11 +14,6 @@ def hash_token(token: str):
|
||||||
return hashed_token
|
return hashed_token
|
||||||
|
|
||||||
|
|
||||||
class LiteLLMProxyRoles(enum.Enum):
|
|
||||||
PROXY_ADMIN = "litellm_proxy_admin"
|
|
||||||
USER = "litellm_user"
|
|
||||||
|
|
||||||
|
|
||||||
class LiteLLMBase(BaseModel):
|
class LiteLLMBase(BaseModel):
|
||||||
"""
|
"""
|
||||||
Implements default functions, all pydantic objects should have.
|
Implements default functions, all pydantic objects should have.
|
||||||
|
@ -42,6 +37,11 @@ class LiteLLMBase(BaseModel):
|
||||||
protected_namespaces = ()
|
protected_namespaces = ()
|
||||||
|
|
||||||
|
|
||||||
|
class LiteLLMProxyRoles(LiteLLMBase):
|
||||||
|
proxy_admin: str = "litellm_proxy_admin"
|
||||||
|
proxy_user: str = "litellm_user"
|
||||||
|
|
||||||
|
|
||||||
class LiteLLMPromptInjectionParams(LiteLLMBase):
|
class LiteLLMPromptInjectionParams(LiteLLMBase):
|
||||||
heuristics_check: bool = False
|
heuristics_check: bool = False
|
||||||
vector_db_check: bool = False
|
vector_db_check: bool = False
|
||||||
|
|
|
@ -67,17 +67,21 @@ class JWTHandler:
|
||||||
self.http_handler = HTTPHandler()
|
self.http_handler = HTTPHandler()
|
||||||
|
|
||||||
def update_environment(
|
def update_environment(
|
||||||
self, prisma_client: Optional[PrismaClient], user_api_key_cache: DualCache
|
self,
|
||||||
|
prisma_client: Optional[PrismaClient],
|
||||||
|
user_api_key_cache: DualCache,
|
||||||
|
litellm_proxy_roles: LiteLLMProxyRoles,
|
||||||
) -> None:
|
) -> None:
|
||||||
self.prisma_client = prisma_client
|
self.prisma_client = prisma_client
|
||||||
self.user_api_key_cache = user_api_key_cache
|
self.user_api_key_cache = user_api_key_cache
|
||||||
|
self.litellm_proxy_roles = litellm_proxy_roles
|
||||||
|
|
||||||
def is_jwt(self, token: str):
|
def is_jwt(self, token: str):
|
||||||
parts = token.split(".")
|
parts = token.split(".")
|
||||||
return len(parts) == 3
|
return len(parts) == 3
|
||||||
|
|
||||||
def is_admin(self, scopes: list) -> bool:
|
def is_admin(self, scopes: list) -> bool:
|
||||||
if LiteLLMProxyRoles.PROXY_ADMIN.value in scopes:
|
if self.litellm_proxy_roles.proxy_admin in scopes:
|
||||||
return True
|
return True
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
@ -90,7 +94,7 @@ class JWTHandler:
|
||||||
|
|
||||||
def get_team_id(self, token: dict, default_value: Optional[str]) -> Optional[str]:
|
def get_team_id(self, token: dict, default_value: Optional[str]) -> Optional[str]:
|
||||||
try:
|
try:
|
||||||
team_id = token["azp"]
|
team_id = token["client_id"]
|
||||||
except KeyError:
|
except KeyError:
|
||||||
team_id = default_value
|
team_id = default_value
|
||||||
return team_id
|
return team_id
|
||||||
|
@ -130,58 +134,94 @@ class JWTHandler:
|
||||||
|
|
||||||
def get_scopes(self, token: dict) -> list:
|
def get_scopes(self, token: dict) -> list:
|
||||||
try:
|
try:
|
||||||
# Assuming the scopes are stored in 'scope' claim and are space-separated
|
if isinstance(token["scope"], str):
|
||||||
scopes = token["scope"].split()
|
# Assuming the scopes are stored in 'scope' claim and are space-separated
|
||||||
|
scopes = token["scope"].split()
|
||||||
|
elif isinstance(token["scope"], list):
|
||||||
|
scopes = token["scope"]
|
||||||
|
else:
|
||||||
|
raise Exception(
|
||||||
|
f"Unmapped scope type - {type(token['scope'])}. Supported types - list, str."
|
||||||
|
)
|
||||||
except KeyError:
|
except KeyError:
|
||||||
scopes = []
|
scopes = []
|
||||||
return scopes
|
return scopes
|
||||||
|
|
||||||
async def auth_jwt(self, token: str) -> dict:
|
async def get_public_key(self, kid: Optional[str]) -> dict:
|
||||||
from jwt.algorithms import RSAAlgorithm
|
|
||||||
|
|
||||||
keys_url = os.getenv("JWT_PUBLIC_KEY_URL")
|
keys_url = os.getenv("JWT_PUBLIC_KEY_URL")
|
||||||
|
|
||||||
if keys_url is None:
|
if keys_url is None:
|
||||||
raise Exception("Missing JWT Public Key URL from environment.")
|
raise Exception("Missing JWT Public Key URL from environment.")
|
||||||
|
|
||||||
response = await self.http_handler.get(keys_url)
|
cached_keys = await self.user_api_key_cache.async_get_cache(
|
||||||
|
"litellm_jwt_auth_keys"
|
||||||
|
)
|
||||||
|
if cached_keys is None:
|
||||||
|
response = await self.http_handler.get(keys_url)
|
||||||
|
|
||||||
keys = response.json()["keys"]
|
keys = response.json()["keys"]
|
||||||
|
|
||||||
|
await self.user_api_key_cache.async_set_cache(
|
||||||
|
key="litellm_jwt_auth_keys", value=keys, ttl=600 # cache for 10 mins
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
keys = cached_keys
|
||||||
|
|
||||||
|
public_key: Optional[dict] = None
|
||||||
|
|
||||||
|
if len(keys) == 1:
|
||||||
|
public_key = keys[0]
|
||||||
|
elif len(keys) > 1:
|
||||||
|
for key in keys:
|
||||||
|
if kid is not None and key["kid"] == kid:
|
||||||
|
public_key = key
|
||||||
|
|
||||||
|
if public_key is None:
|
||||||
|
raise Exception(
|
||||||
|
f"No matching public key found. kid={kid}, keys_url={keys_url}, cached_keys={cached_keys}"
|
||||||
|
)
|
||||||
|
|
||||||
|
return public_key
|
||||||
|
|
||||||
|
async def auth_jwt(self, token: str) -> dict:
|
||||||
|
from jwt.algorithms import RSAAlgorithm
|
||||||
|
|
||||||
header = jwt.get_unverified_header(token)
|
header = jwt.get_unverified_header(token)
|
||||||
|
|
||||||
verbose_proxy_logger.debug("header: %s", header)
|
verbose_proxy_logger.debug("header: %s", header)
|
||||||
|
|
||||||
if "kid" in header:
|
kid = header.get("kid", None)
|
||||||
kid = header["kid"]
|
|
||||||
else:
|
|
||||||
raise Exception(f"Expected 'kid' in header. header={header}.")
|
|
||||||
|
|
||||||
for key in keys:
|
public_key = await self.get_public_key(kid=kid)
|
||||||
if key["kid"] == kid:
|
|
||||||
jwk = {
|
|
||||||
"kty": key["kty"],
|
|
||||||
"kid": key["kid"],
|
|
||||||
"n": key["n"],
|
|
||||||
"e": key["e"],
|
|
||||||
}
|
|
||||||
public_key = RSAAlgorithm.from_jwk(json.dumps(jwk))
|
|
||||||
|
|
||||||
try:
|
if public_key is not None and isinstance(public_key, dict):
|
||||||
# decode the token using the public key
|
jwk = {}
|
||||||
payload = jwt.decode(
|
if "kty" in public_key:
|
||||||
token,
|
jwk["kty"] = public_key["kty"]
|
||||||
public_key, # type: ignore
|
if "kid" in public_key:
|
||||||
algorithms=["RS256"],
|
jwk["kid"] = public_key["kid"]
|
||||||
audience="account",
|
if "n" in public_key:
|
||||||
)
|
jwk["n"] = public_key["n"]
|
||||||
return payload
|
if "e" in public_key:
|
||||||
|
jwk["e"] = public_key["e"]
|
||||||
|
|
||||||
except jwt.ExpiredSignatureError:
|
public_key_rsa = RSAAlgorithm.from_jwk(json.dumps(jwk))
|
||||||
# the token is expired, do something to refresh it
|
|
||||||
raise Exception("Token Expired")
|
try:
|
||||||
except Exception as e:
|
# decode the token using the public key
|
||||||
raise Exception(f"Validation fails: {str(e)}")
|
payload = jwt.decode(
|
||||||
|
token,
|
||||||
|
public_key_rsa, # type: ignore
|
||||||
|
algorithms=["RS256"],
|
||||||
|
options={"verify_aud": False},
|
||||||
|
)
|
||||||
|
return payload
|
||||||
|
|
||||||
|
except jwt.ExpiredSignatureError:
|
||||||
|
# the token is expired, do something to refresh it
|
||||||
|
raise Exception("Token Expired")
|
||||||
|
except Exception as e:
|
||||||
|
raise Exception(f"Validation fails: {str(e)}")
|
||||||
|
|
||||||
raise Exception("Invalid JWT Submitted")
|
raise Exception("Invalid JWT Submitted")
|
||||||
|
|
||||||
|
|
|
@ -2710,7 +2710,11 @@ async def startup_event():
|
||||||
|
|
||||||
## JWT AUTH ##
|
## JWT AUTH ##
|
||||||
jwt_handler.update_environment(
|
jwt_handler.update_environment(
|
||||||
prisma_client=prisma_client, user_api_key_cache=user_api_key_cache
|
prisma_client=prisma_client,
|
||||||
|
user_api_key_cache=user_api_key_cache,
|
||||||
|
litellm_proxy_roles=LiteLLMProxyRoles(
|
||||||
|
**general_settings.get("litellm_proxy_roles", {})
|
||||||
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
if use_background_health_checks:
|
if use_background_health_checks:
|
||||||
|
|
179
litellm/tests/test_jwt.py
Normal file
179
litellm/tests/test_jwt.py
Normal file
|
@ -0,0 +1,179 @@
|
||||||
|
#### What this tests ####
|
||||||
|
# Unit tests for JWT-Auth
|
||||||
|
|
||||||
|
import sys, os, asyncio, time, random
|
||||||
|
import traceback
|
||||||
|
from dotenv import load_dotenv
|
||||||
|
|
||||||
|
load_dotenv()
|
||||||
|
import os
|
||||||
|
|
||||||
|
sys.path.insert(
|
||||||
|
0, os.path.abspath("../..")
|
||||||
|
) # Adds the parent directory to the system path
|
||||||
|
import pytest
|
||||||
|
from litellm.proxy._types import LiteLLMProxyRoles
|
||||||
|
from litellm.proxy.auth.handle_jwt import JWTHandler
|
||||||
|
from litellm.caching import DualCache
|
||||||
|
from datetime import datetime, timedelta
|
||||||
|
|
||||||
|
public_key = {
|
||||||
|
"kty": "RSA",
|
||||||
|
"e": "AQAB",
|
||||||
|
"n": "qIgOQfEVrrErJC0E7gsHXi6rs_V0nyFY5qPFui2-tv0o4CwpwDzgfBtLO7o_wLiguq0lnu54sMT2eLNoRiiPuLvv6bg7Iy1H9yc5_4Jf5oYEOrqN5o9ZBOoYp1q68Pv0oNJYyZdGu5ZJfd7V4y953vB2XfEKgXCsAkhVhlvIUMiDNKWoMDWsyb2xela5tRURZ2mJAXcHfSC_sYdZxIA2YYrIHfoevq_vTlaz0qVSe_uOKjEpgOAS08UUrgda4CQL11nzICiIQzc6qmjIQt2cjzB2D_9zb4BYndzEtfl0kwAT0z_I85S3mkwTqHU-1BvKe_4MG4VG3dAAeffLPXJyXQ",
|
||||||
|
"alg": "RS256",
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def test_load_config_with_custom_role_names():
|
||||||
|
config = {
|
||||||
|
"general_settings": {
|
||||||
|
"litellm_proxy_roles": {"proxy_admin": "litellm-proxy-admin"}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
proxy_roles = LiteLLMProxyRoles(
|
||||||
|
**config.get("general_settings", {}).get("litellm_proxy_roles", {})
|
||||||
|
)
|
||||||
|
|
||||||
|
print(f"proxy_roles: {proxy_roles}")
|
||||||
|
|
||||||
|
assert proxy_roles.proxy_admin == "litellm-proxy-admin"
|
||||||
|
|
||||||
|
|
||||||
|
# test_load_config_with_custom_role_names()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_token_single_public_key():
|
||||||
|
import jwt
|
||||||
|
|
||||||
|
jwt_handler = JWTHandler()
|
||||||
|
|
||||||
|
backend_keys = {
|
||||||
|
"keys": [
|
||||||
|
{
|
||||||
|
"kty": "RSA",
|
||||||
|
"use": "sig",
|
||||||
|
"e": "AQAB",
|
||||||
|
"n": "qIgOQfEVrrErJC0E7gsHXi6rs_V0nyFY5qPFui2-tv0o4CwpwDzgfBtLO7o_wLiguq0lnu54sMT2eLNoRiiPuLvv6bg7Iy1H9yc5_4Jf5oYEOrqN5o9ZBOoYp1q68Pv0oNJYyZdGu5ZJfd7V4y953vB2XfEKgXCsAkhVhlvIUMiDNKWoMDWsyb2xela5tRURZ2mJAXcHfSC_sYdZxIA2YYrIHfoevq_vTlaz0qVSe_uOKjEpgOAS08UUrgda4CQL11nzICiIQzc6qmjIQt2cjzB2D_9zb4BYndzEtfl0kwAT0z_I85S3mkwTqHU-1BvKe_4MG4VG3dAAeffLPXJyXQ",
|
||||||
|
"alg": "RS256",
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
|
||||||
|
# set cache
|
||||||
|
cache = DualCache()
|
||||||
|
|
||||||
|
await cache.async_set_cache(key="litellm_jwt_auth_keys", value=backend_keys["keys"])
|
||||||
|
|
||||||
|
jwt_handler.user_api_key_cache = cache
|
||||||
|
|
||||||
|
public_key = await jwt_handler.get_public_key(kid=None)
|
||||||
|
|
||||||
|
assert public_key is not None
|
||||||
|
assert isinstance(public_key, dict)
|
||||||
|
assert (
|
||||||
|
public_key["n"]
|
||||||
|
== "qIgOQfEVrrErJC0E7gsHXi6rs_V0nyFY5qPFui2-tv0o4CwpwDzgfBtLO7o_wLiguq0lnu54sMT2eLNoRiiPuLvv6bg7Iy1H9yc5_4Jf5oYEOrqN5o9ZBOoYp1q68Pv0oNJYyZdGu5ZJfd7V4y953vB2XfEKgXCsAkhVhlvIUMiDNKWoMDWsyb2xela5tRURZ2mJAXcHfSC_sYdZxIA2YYrIHfoevq_vTlaz0qVSe_uOKjEpgOAS08UUrgda4CQL11nzICiIQzc6qmjIQt2cjzB2D_9zb4BYndzEtfl0kwAT0z_I85S3mkwTqHU-1BvKe_4MG4VG3dAAeffLPXJyXQ"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_valid_invalid_token():
|
||||||
|
"""
|
||||||
|
Tests
|
||||||
|
- valid token
|
||||||
|
- invalid token
|
||||||
|
"""
|
||||||
|
import jwt, json
|
||||||
|
from cryptography.hazmat.primitives import serialization
|
||||||
|
from cryptography.hazmat.primitives.asymmetric import rsa
|
||||||
|
from cryptography.hazmat.backends import default_backend
|
||||||
|
|
||||||
|
# Generate a private / public key pair using RSA algorithm
|
||||||
|
key = rsa.generate_private_key(
|
||||||
|
public_exponent=65537, key_size=2048, backend=default_backend()
|
||||||
|
)
|
||||||
|
# Get private key in PEM format
|
||||||
|
private_key = key.private_bytes(
|
||||||
|
encoding=serialization.Encoding.PEM,
|
||||||
|
format=serialization.PrivateFormat.PKCS8,
|
||||||
|
encryption_algorithm=serialization.NoEncryption(),
|
||||||
|
)
|
||||||
|
|
||||||
|
# Get public key in PEM format
|
||||||
|
public_key = key.public_key().public_bytes(
|
||||||
|
encoding=serialization.Encoding.PEM,
|
||||||
|
format=serialization.PublicFormat.SubjectPublicKeyInfo,
|
||||||
|
)
|
||||||
|
|
||||||
|
public_key_obj = serialization.load_pem_public_key(
|
||||||
|
public_key, backend=default_backend()
|
||||||
|
)
|
||||||
|
|
||||||
|
# Convert RSA public key object to JWK (JSON Web Key)
|
||||||
|
public_jwk = json.loads(jwt.algorithms.RSAAlgorithm.to_jwk(public_key_obj))
|
||||||
|
|
||||||
|
assert isinstance(public_jwk, dict)
|
||||||
|
|
||||||
|
# set cache
|
||||||
|
cache = DualCache()
|
||||||
|
|
||||||
|
await cache.async_set_cache(key="litellm_jwt_auth_keys", value=[public_jwk])
|
||||||
|
|
||||||
|
jwt_handler = JWTHandler()
|
||||||
|
|
||||||
|
jwt_handler.user_api_key_cache = cache
|
||||||
|
|
||||||
|
# VALID TOKEN
|
||||||
|
## GENERATE A TOKEN
|
||||||
|
# Assuming the current time is in UTC
|
||||||
|
expiration_time = int((datetime.utcnow() + timedelta(minutes=10)).timestamp())
|
||||||
|
|
||||||
|
payload = {
|
||||||
|
"sub": "user123",
|
||||||
|
"exp": expiration_time, # set the token to expire in 10 minutes
|
||||||
|
"scope": "litellm-proxy-admin",
|
||||||
|
}
|
||||||
|
|
||||||
|
# Generate the JWT token
|
||||||
|
# But before, you should convert bytes to string
|
||||||
|
private_key_str = private_key.decode("utf-8")
|
||||||
|
token = jwt.encode(payload, private_key_str, algorithm="RS256")
|
||||||
|
|
||||||
|
## VERIFY IT WORKS
|
||||||
|
|
||||||
|
# verify token
|
||||||
|
|
||||||
|
response = await jwt_handler.auth_jwt(token=token)
|
||||||
|
|
||||||
|
assert response is not None
|
||||||
|
assert isinstance(response, dict)
|
||||||
|
|
||||||
|
print(f"response: {response}")
|
||||||
|
|
||||||
|
# INVALID TOKEN
|
||||||
|
## GENERATE A TOKEN
|
||||||
|
# Assuming the current time is in UTC
|
||||||
|
expiration_time = int((datetime.utcnow() + timedelta(minutes=10)).timestamp())
|
||||||
|
|
||||||
|
payload = {
|
||||||
|
"sub": "user123",
|
||||||
|
"exp": expiration_time, # set the token to expire in 10 minutes
|
||||||
|
"scope": "litellm-NO-SCOPE",
|
||||||
|
}
|
||||||
|
|
||||||
|
# Generate the JWT token
|
||||||
|
# But before, you should convert bytes to string
|
||||||
|
private_key_str = private_key.decode("utf-8")
|
||||||
|
token = jwt.encode(payload, private_key_str, algorithm="RS256")
|
||||||
|
|
||||||
|
## VERIFY IT WORKS
|
||||||
|
|
||||||
|
# verify token
|
||||||
|
|
||||||
|
try:
|
||||||
|
response = await jwt_handler.auth_jwt(token=token)
|
||||||
|
except Exception as e:
|
||||||
|
pytest.fail(f"An exception occurred - {str(e)}")
|
Loading…
Add table
Add a link
Reference in a new issue