litellm/tests/proxy_unit_tests/test_jwt.py
Krish Dholakia 27e18358ab
fix(pattern_match_deployments.py): default to user input if unable to… (#6632)
* fix(pattern_match_deployments.py): default to user input if unable to map based on wildcards

* test: fix test

* test: reset test name

* test: update conftest to reload proxy server module between tests

* ci(config.yml): move langfuse out of local_testing

reduce ci/cd time

* ci(config.yml): cleanup langfuse ci/cd tests

* fix: update test to not use global proxy_server app module

* ci: move caching to a separate test pipeline

speed up ci pipeline

* test: update conftest to check if proxy_server attr exists before reloading

* build(conftest.py): don't block on inability to reload proxy_server

* ci(config.yml): update caching unit test filter to work on 'cache' keyword as well

* fix(encrypt_decrypt_utils.py): use function to get salt key

* test: mark flaky test

* test: handle anthropic overloaded errors

* refactor: create separate ci/cd pipeline for proxy unit tests

make ci/cd faster

* ci(config.yml): add litellm_proxy_unit_testing to build_and_test jobs

* ci(config.yml): generate prisma binaries for proxy unit tests

* test: readd vertex_key.json

* ci(config.yml): remove `-s` from proxy_unit_test cmd

speed up test

* ci: remove any 'debug' logging flag

speed up ci pipeline

* test: fix test

* test(test_braintrust.py): rerun

* test: add delay for braintrust test
2024-11-08 00:55:57 +05:30

1028 lines
32 KiB
Python

#### What this tests ####
# Unit tests for JWT-Auth
import asyncio
import os
import random
import sys
import time
import traceback
import uuid
from dotenv import load_dotenv
load_dotenv()
import os
sys.path.insert(
0, os.path.abspath("../..")
) # Adds the parent directory to the system path
from datetime import datetime, timedelta
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from fastapi import Request
import litellm
from litellm.caching.caching import DualCache
from litellm.proxy._types import LiteLLM_JWTAuth, LiteLLM_UserTable, LiteLLMRoutes
from litellm.proxy.auth.handle_jwt import JWTHandler
from litellm.proxy.management_endpoints.team_endpoints import new_team
from litellm.proxy.proxy_server import chat_completion
public_key = {
"kty": "RSA",
"e": "AQAB",
"n": "qIgOQfEVrrErJC0E7gsHXi6rs_V0nyFY5qPFui2-tv0o4CwpwDzgfBtLO7o_wLiguq0lnu54sMT2eLNoRiiPuLvv6bg7Iy1H9yc5_4Jf5oYEOrqN5o9ZBOoYp1q68Pv0oNJYyZdGu5ZJfd7V4y953vB2XfEKgXCsAkhVhlvIUMiDNKWoMDWsyb2xela5tRURZ2mJAXcHfSC_sYdZxIA2YYrIHfoevq_vTlaz0qVSe_uOKjEpgOAS08UUrgda4CQL11nzICiIQzc6qmjIQt2cjzB2D_9zb4BYndzEtfl0kwAT0z_I85S3mkwTqHU-1BvKe_4MG4VG3dAAeffLPXJyXQ",
"alg": "RS256",
}
def test_load_config_with_custom_role_names():
config = {
"general_settings": {
"litellm_proxy_roles": {"admin_jwt_scope": "litellm-proxy-admin"}
}
}
proxy_roles = LiteLLM_JWTAuth(
**config.get("general_settings", {}).get("litellm_proxy_roles", {})
)
print(f"proxy_roles: {proxy_roles}")
assert proxy_roles.admin_jwt_scope == "litellm-proxy-admin"
# test_load_config_with_custom_role_names()
@pytest.mark.asyncio
async def test_token_single_public_key():
import jwt
jwt_handler = JWTHandler()
backend_keys = {
"keys": [
{
"kty": "RSA",
"use": "sig",
"e": "AQAB",
"n": "qIgOQfEVrrErJC0E7gsHXi6rs_V0nyFY5qPFui2-tv0o4CwpwDzgfBtLO7o_wLiguq0lnu54sMT2eLNoRiiPuLvv6bg7Iy1H9yc5_4Jf5oYEOrqN5o9ZBOoYp1q68Pv0oNJYyZdGu5ZJfd7V4y953vB2XfEKgXCsAkhVhlvIUMiDNKWoMDWsyb2xela5tRURZ2mJAXcHfSC_sYdZxIA2YYrIHfoevq_vTlaz0qVSe_uOKjEpgOAS08UUrgda4CQL11nzICiIQzc6qmjIQt2cjzB2D_9zb4BYndzEtfl0kwAT0z_I85S3mkwTqHU-1BvKe_4MG4VG3dAAeffLPXJyXQ",
"alg": "RS256",
}
]
}
# set cache
cache = DualCache()
await cache.async_set_cache(key="litellm_jwt_auth_keys", value=backend_keys["keys"])
jwt_handler.user_api_key_cache = cache
public_key = await jwt_handler.get_public_key(kid=None)
assert public_key is not None
assert isinstance(public_key, dict)
assert (
public_key["n"]
== "qIgOQfEVrrErJC0E7gsHXi6rs_V0nyFY5qPFui2-tv0o4CwpwDzgfBtLO7o_wLiguq0lnu54sMT2eLNoRiiPuLvv6bg7Iy1H9yc5_4Jf5oYEOrqN5o9ZBOoYp1q68Pv0oNJYyZdGu5ZJfd7V4y953vB2XfEKgXCsAkhVhlvIUMiDNKWoMDWsyb2xela5tRURZ2mJAXcHfSC_sYdZxIA2YYrIHfoevq_vTlaz0qVSe_uOKjEpgOAS08UUrgda4CQL11nzICiIQzc6qmjIQt2cjzB2D_9zb4BYndzEtfl0kwAT0z_I85S3mkwTqHU-1BvKe_4MG4VG3dAAeffLPXJyXQ"
)
@pytest.mark.parametrize("audience", [None, "litellm-proxy"])
@pytest.mark.asyncio
async def test_valid_invalid_token(audience):
"""
Tests
- valid token
- invalid token
"""
import json
import jwt
from cryptography.hazmat.backends import default_backend
from cryptography.hazmat.primitives import serialization
from cryptography.hazmat.primitives.asymmetric import rsa
os.environ.pop("JWT_AUDIENCE", None)
if audience:
os.environ["JWT_AUDIENCE"] = audience
# Generate a private / public key pair using RSA algorithm
key = rsa.generate_private_key(
public_exponent=65537, key_size=2048, backend=default_backend()
)
# Get private key in PEM format
private_key = key.private_bytes(
encoding=serialization.Encoding.PEM,
format=serialization.PrivateFormat.PKCS8,
encryption_algorithm=serialization.NoEncryption(),
)
# Get public key in PEM format
public_key = key.public_key().public_bytes(
encoding=serialization.Encoding.PEM,
format=serialization.PublicFormat.SubjectPublicKeyInfo,
)
public_key_obj = serialization.load_pem_public_key(
public_key, backend=default_backend()
)
# Convert RSA public key object to JWK (JSON Web Key)
public_jwk = json.loads(jwt.algorithms.RSAAlgorithm.to_jwk(public_key_obj))
assert isinstance(public_jwk, dict)
# set cache
cache = DualCache()
await cache.async_set_cache(key="litellm_jwt_auth_keys", value=[public_jwk])
jwt_handler = JWTHandler()
jwt_handler.user_api_key_cache = cache
# VALID TOKEN
## GENERATE A TOKEN
# Assuming the current time is in UTC
expiration_time = int((datetime.now() + timedelta(minutes=10)).timestamp())
payload = {
"sub": "user123",
"exp": expiration_time, # set the token to expire in 10 minutes
"scope": "litellm-proxy-admin",
"aud": audience,
}
# Generate the JWT token
# But before, you should convert bytes to string
private_key_str = private_key.decode("utf-8")
token = jwt.encode(payload, private_key_str, algorithm="RS256")
## VERIFY IT WORKS
# verify token
response = await jwt_handler.auth_jwt(token=token)
assert response is not None
assert isinstance(response, dict)
print(f"response: {response}")
# INVALID TOKEN
## GENERATE A TOKEN
# Assuming the current time is in UTC
expiration_time = int((datetime.now() + timedelta(minutes=10)).timestamp())
payload = {
"sub": "user123",
"exp": expiration_time, # set the token to expire in 10 minutes
"scope": "litellm-NO-SCOPE",
"aud": audience,
}
# Generate the JWT token
# But before, you should convert bytes to string
private_key_str = private_key.decode("utf-8")
token = jwt.encode(payload, private_key_str, algorithm="RS256")
## VERIFY IT WORKS
# verify token
try:
response = await jwt_handler.auth_jwt(token=token)
except Exception as e:
pytest.fail(f"An exception occurred - {str(e)}")
@pytest.fixture
def prisma_client():
import litellm
from litellm.proxy.proxy_cli import append_query_params
from litellm.proxy.utils import PrismaClient, ProxyLogging
proxy_logging_obj = ProxyLogging(user_api_key_cache=DualCache())
### add connection pool + pool timeout args
params = {"connection_limit": 100, "pool_timeout": 60}
database_url = os.getenv("DATABASE_URL")
modified_url = append_query_params(database_url, params)
os.environ["DATABASE_URL"] = modified_url
# Assuming PrismaClient is a class that needs to be instantiated
prisma_client = PrismaClient(
database_url=os.environ["DATABASE_URL"], proxy_logging_obj=proxy_logging_obj
)
return prisma_client
@pytest.fixture
def team_token_tuple():
import json
import uuid
import jwt
from cryptography.hazmat.backends import default_backend
from cryptography.hazmat.primitives import serialization
from cryptography.hazmat.primitives.asymmetric import rsa
from fastapi import Request
from starlette.datastructures import URL
import litellm
from litellm.proxy._types import NewTeamRequest, UserAPIKeyAuth
from litellm.proxy.proxy_server import user_api_key_auth
# Generate a private / public key pair using RSA algorithm
key = rsa.generate_private_key(
public_exponent=65537, key_size=2048, backend=default_backend()
)
# Get private key in PEM format
private_key = key.private_bytes(
encoding=serialization.Encoding.PEM,
format=serialization.PrivateFormat.PKCS8,
encryption_algorithm=serialization.NoEncryption(),
)
# Get public key in PEM format
public_key = key.public_key().public_bytes(
encoding=serialization.Encoding.PEM,
format=serialization.PublicFormat.SubjectPublicKeyInfo,
)
public_key_obj = serialization.load_pem_public_key(
public_key, backend=default_backend()
)
# Convert RSA public key object to JWK (JSON Web Key)
public_jwk = json.loads(jwt.algorithms.RSAAlgorithm.to_jwk(public_key_obj))
# VALID TOKEN
## GENERATE A TOKEN
# Assuming the current time is in UTC
expiration_time = int((datetime.now() + timedelta(minutes=10)).timestamp())
team_id = f"team123_{uuid.uuid4()}"
payload = {
"sub": "user123",
"exp": expiration_time, # set the token to expire in 10 minutes
"scope": "litellm_team",
"client_id": team_id,
"aud": None,
}
# Generate the JWT token
# But before, you should convert bytes to string
private_key_str = private_key.decode("utf-8")
## team token
token = jwt.encode(payload, private_key_str, algorithm="RS256")
return team_id, token, public_jwk
@pytest.mark.parametrize("audience", [None, "litellm-proxy"])
@pytest.mark.asyncio
async def test_team_token_output(prisma_client, audience):
import json
import uuid
import jwt
from cryptography.hazmat.backends import default_backend
from cryptography.hazmat.primitives import serialization
from cryptography.hazmat.primitives.asymmetric import rsa
from fastapi import Request
from starlette.datastructures import URL
import litellm
from litellm.proxy._types import NewTeamRequest, UserAPIKeyAuth
from litellm.proxy.proxy_server import user_api_key_auth
setattr(litellm.proxy.proxy_server, "prisma_client", prisma_client)
await litellm.proxy.proxy_server.prisma_client.connect()
os.environ.pop("JWT_AUDIENCE", None)
if audience:
os.environ["JWT_AUDIENCE"] = audience
# Generate a private / public key pair using RSA algorithm
key = rsa.generate_private_key(
public_exponent=65537, key_size=2048, backend=default_backend()
)
# Get private key in PEM format
private_key = key.private_bytes(
encoding=serialization.Encoding.PEM,
format=serialization.PrivateFormat.PKCS8,
encryption_algorithm=serialization.NoEncryption(),
)
# Get public key in PEM format
public_key = key.public_key().public_bytes(
encoding=serialization.Encoding.PEM,
format=serialization.PublicFormat.SubjectPublicKeyInfo,
)
public_key_obj = serialization.load_pem_public_key(
public_key, backend=default_backend()
)
# Convert RSA public key object to JWK (JSON Web Key)
public_jwk = json.loads(jwt.algorithms.RSAAlgorithm.to_jwk(public_key_obj))
assert isinstance(public_jwk, dict)
# set cache
cache = DualCache()
await cache.async_set_cache(key="litellm_jwt_auth_keys", value=[public_jwk])
jwt_handler = JWTHandler()
jwt_handler.user_api_key_cache = cache
jwt_handler.litellm_jwtauth = LiteLLM_JWTAuth(team_id_jwt_field="client_id")
# VALID TOKEN
## GENERATE A TOKEN
# Assuming the current time is in UTC
expiration_time = int((datetime.now() + timedelta(minutes=10)).timestamp())
team_id = f"team123_{uuid.uuid4()}"
payload = {
"sub": "user123",
"exp": expiration_time, # set the token to expire in 10 minutes
"scope": "litellm_team",
"client_id": team_id,
"aud": audience,
}
# Generate the JWT token
# But before, you should convert bytes to string
private_key_str = private_key.decode("utf-8")
## team token
token = jwt.encode(payload, private_key_str, algorithm="RS256")
## admin token
payload = {
"sub": "user123",
"exp": expiration_time, # set the token to expire in 10 minutes
"scope": "litellm_proxy_admin",
"aud": audience,
}
admin_token = jwt.encode(payload, private_key_str, algorithm="RS256")
## VERIFY IT WORKS
# verify token
response = await jwt_handler.auth_jwt(token=token)
## RUN IT THROUGH USER API KEY AUTH
"""
- 1. Initial call should fail -> team doesn't exist
- 2. Create team via admin token
- 3. 2nd call w/ same team -> call should succeed -> assert UserAPIKeyAuth object correctly formatted
"""
bearer_token = "Bearer " + token
request = Request(scope={"type": "http"})
request._url = URL(url="/chat/completions")
## 1. INITIAL TEAM CALL - should fail
# use generated key to auth in
setattr(
litellm.proxy.proxy_server,
"general_settings",
{
"enable_jwt_auth": True,
},
)
setattr(litellm.proxy.proxy_server, "jwt_handler", jwt_handler)
try:
result = await user_api_key_auth(request=request, api_key=bearer_token)
pytest.fail("Team doesn't exist. This should fail")
except Exception as e:
pass
## 2. CREATE TEAM W/ ADMIN TOKEN - should succeed
try:
bearer_token = "Bearer " + admin_token
request._url = URL(url="/team/new")
result = await user_api_key_auth(request=request, api_key=bearer_token)
await new_team(
data=NewTeamRequest(
team_id=team_id,
tpm_limit=100,
rpm_limit=99,
models=["gpt-3.5-turbo", "gpt-4"],
),
user_api_key_dict=result,
http_request=Request(scope={"type": "http"}),
)
except Exception as e:
pytest.fail(f"This should not fail - {str(e)}")
## 3. 2nd CALL W/ TEAM TOKEN - should succeed
bearer_token = "Bearer " + token
request._url = URL(url="/chat/completions")
try:
team_result: UserAPIKeyAuth = await user_api_key_auth(
request=request, api_key=bearer_token
)
except Exception as e:
pytest.fail(f"Team exists. This should not fail - {e}")
## 4. ASSERT USER_API_KEY_AUTH format (used for tpm/rpm limiting in parallel_request_limiter.py)
assert team_result.team_tpm_limit == 100
assert team_result.team_rpm_limit == 99
assert team_result.team_models == ["gpt-3.5-turbo", "gpt-4"]
@pytest.mark.parametrize("audience", [None, "litellm-proxy"])
@pytest.mark.parametrize(
"team_id_set, default_team_id",
[(True, False), (False, True)],
)
@pytest.mark.parametrize("user_id_upsert", [True, False])
@pytest.mark.asyncio
async def aaaatest_user_token_output(
prisma_client, audience, team_id_set, default_team_id, user_id_upsert
):
import uuid
args = locals()
print(f"received args - {args}")
if default_team_id:
default_team_id = "team_id_12344_{}".format(uuid.uuid4())
"""
- If user required, check if it exists
- fail initial request (when user doesn't exist)
- create user
- retry -> it should pass now
"""
import json
import uuid
import jwt
from cryptography.hazmat.backends import default_backend
from cryptography.hazmat.primitives import serialization
from cryptography.hazmat.primitives.asymmetric import rsa
from fastapi import Request
from starlette.datastructures import URL
import litellm
from litellm.proxy._types import NewTeamRequest, NewUserRequest, UserAPIKeyAuth
from litellm.proxy.management_endpoints.internal_user_endpoints import (
new_user,
user_info,
)
from litellm.proxy.proxy_server import user_api_key_auth
setattr(litellm.proxy.proxy_server, "prisma_client", prisma_client)
await litellm.proxy.proxy_server.prisma_client.connect()
os.environ.pop("JWT_AUDIENCE", None)
if audience:
os.environ["JWT_AUDIENCE"] = audience
# Generate a private / public key pair using RSA algorithm
key = rsa.generate_private_key(
public_exponent=65537, key_size=2048, backend=default_backend()
)
# Get private key in PEM format
private_key = key.private_bytes(
encoding=serialization.Encoding.PEM,
format=serialization.PrivateFormat.PKCS8,
encryption_algorithm=serialization.NoEncryption(),
)
# Get public key in PEM format
public_key = key.public_key().public_bytes(
encoding=serialization.Encoding.PEM,
format=serialization.PublicFormat.SubjectPublicKeyInfo,
)
public_key_obj = serialization.load_pem_public_key(
public_key, backend=default_backend()
)
# Convert RSA public key object to JWK (JSON Web Key)
public_jwk = json.loads(jwt.algorithms.RSAAlgorithm.to_jwk(public_key_obj))
assert isinstance(public_jwk, dict)
# set cache
cache = DualCache()
await cache.async_set_cache(key="litellm_jwt_auth_keys", value=[public_jwk])
jwt_handler = JWTHandler()
jwt_handler.user_api_key_cache = cache
jwt_handler.litellm_jwtauth = LiteLLM_JWTAuth()
jwt_handler.litellm_jwtauth.user_id_jwt_field = "sub"
jwt_handler.litellm_jwtauth.team_id_default = default_team_id
jwt_handler.litellm_jwtauth.user_id_upsert = user_id_upsert
if team_id_set:
jwt_handler.litellm_jwtauth.team_id_jwt_field = "client_id"
# VALID TOKEN
## GENERATE A TOKEN
# Assuming the current time is in UTC
expiration_time = int((datetime.now() + timedelta(minutes=10)).timestamp())
team_id = f"team123_{uuid.uuid4()}"
user_id = f"user123_{uuid.uuid4()}"
payload = {
"sub": user_id,
"exp": expiration_time, # set the token to expire in 10 minutes
"scope": "litellm_team",
"client_id": team_id,
"aud": audience,
}
# Generate the JWT token
# But before, you should convert bytes to string
private_key_str = private_key.decode("utf-8")
## team token
token = jwt.encode(payload, private_key_str, algorithm="RS256")
## admin token
payload = {
"sub": user_id,
"exp": expiration_time, # set the token to expire in 10 minutes
"scope": "litellm_proxy_admin",
"aud": audience,
}
admin_token = jwt.encode(payload, private_key_str, algorithm="RS256")
## VERIFY IT WORKS
# verify token
response = await jwt_handler.auth_jwt(token=token)
## RUN IT THROUGH USER API KEY AUTH
"""
- 1. Initial call should fail -> team doesn't exist
- 2. Create team via admin token
- 3. 2nd call w/ same team -> call should fail -> user doesn't exist
- 4. Create user via admin token
- 5. 3rd call w/ same team, same user -> call should succeed
- 6. assert user api key auth format
"""
bearer_token = "Bearer " + token
request = Request(scope={"type": "http"})
request._url = URL(url="/chat/completions")
## 1. INITIAL TEAM CALL - should fail
# use generated key to auth in
setattr(litellm.proxy.proxy_server, "general_settings", {"enable_jwt_auth": True})
setattr(litellm.proxy.proxy_server, "jwt_handler", jwt_handler)
try:
result = await user_api_key_auth(request=request, api_key=bearer_token)
pytest.fail("Team doesn't exist. This should fail")
except Exception as e:
pass
## 2. CREATE TEAM W/ ADMIN TOKEN - should succeed
try:
bearer_token = "Bearer " + admin_token
request._url = URL(url="/team/new")
result = await user_api_key_auth(request=request, api_key=bearer_token)
await new_team(
data=NewTeamRequest(
team_id=team_id,
tpm_limit=100,
rpm_limit=99,
models=["gpt-3.5-turbo", "gpt-4"],
),
user_api_key_dict=result,
http_request=Request(scope={"type": "http"}),
)
if default_team_id:
await new_team(
data=NewTeamRequest(
team_id=default_team_id,
tpm_limit=100,
rpm_limit=99,
models=["gpt-3.5-turbo", "gpt-4"],
),
user_api_key_dict=result,
http_request=Request(scope={"type": "http"}),
)
except Exception as e:
pytest.fail(f"This should not fail - {str(e)}")
## 3. 2nd CALL W/ TEAM TOKEN - should fail
bearer_token = "Bearer " + token
request._url = URL(url="/chat/completions")
try:
team_result: UserAPIKeyAuth = await user_api_key_auth(
request=request, api_key=bearer_token
)
if user_id_upsert == False:
pytest.fail(f"User doesn't exist. this should fail")
except Exception as e:
pass
## 4. Create user
if user_id_upsert:
## check if user already exists
try:
bearer_token = "Bearer " + admin_token
request._url = URL(url="/team/new")
result = await user_api_key_auth(request=request, api_key=bearer_token)
await user_info(user_id=user_id)
except Exception as e:
pytest.fail(f"This should not fail - {str(e)}")
else:
try:
bearer_token = "Bearer " + admin_token
request._url = URL(url="/team/new")
result = await user_api_key_auth(request=request, api_key=bearer_token)
await new_user(
data=NewUserRequest(
user_id=user_id,
),
)
except Exception as e:
pytest.fail(f"This should not fail - {str(e)}")
## 5. 3rd call w/ same team, same user -> call should succeed
bearer_token = "Bearer " + token
request._url = URL(url="/chat/completions")
try:
team_result: UserAPIKeyAuth = await user_api_key_auth(
request=request, api_key=bearer_token
)
except Exception as e:
pytest.fail(f"Team exists. This should not fail - {e}")
## 6. ASSERT USER_API_KEY_AUTH format (used for tpm/rpm limiting in parallel_request_limiter.py AND cost tracking)
if team_id_set or default_team_id is not None:
assert team_result.team_tpm_limit == 100
assert team_result.team_rpm_limit == 99
assert team_result.team_models == ["gpt-3.5-turbo", "gpt-4"]
assert team_result.user_id == user_id
@pytest.mark.parametrize("admin_allowed_routes", [None, ["ui_routes"]])
@pytest.mark.parametrize("audience", [None, "litellm-proxy"])
@pytest.mark.asyncio
async def test_allowed_routes_admin(prisma_client, audience, admin_allowed_routes):
"""
Add a check to make sure jwt proxy admin scope can access all allowed admin routes
- iterate through allowed endpoints
- check if admin passes user_api_key_auth for them
"""
import json
import uuid
import jwt
from cryptography.hazmat.backends import default_backend
from cryptography.hazmat.primitives import serialization
from cryptography.hazmat.primitives.asymmetric import rsa
from fastapi import Request
from starlette.datastructures import URL
import litellm
from litellm.proxy._types import NewTeamRequest, UserAPIKeyAuth
from litellm.proxy.proxy_server import user_api_key_auth
setattr(litellm.proxy.proxy_server, "prisma_client", prisma_client)
await litellm.proxy.proxy_server.prisma_client.connect()
os.environ.pop("JWT_AUDIENCE", None)
if audience:
os.environ["JWT_AUDIENCE"] = audience
# Generate a private / public key pair using RSA algorithm
key = rsa.generate_private_key(
public_exponent=65537, key_size=2048, backend=default_backend()
)
# Get private key in PEM format
private_key = key.private_bytes(
encoding=serialization.Encoding.PEM,
format=serialization.PrivateFormat.PKCS8,
encryption_algorithm=serialization.NoEncryption(),
)
# Get public key in PEM format
public_key = key.public_key().public_bytes(
encoding=serialization.Encoding.PEM,
format=serialization.PublicFormat.SubjectPublicKeyInfo,
)
public_key_obj = serialization.load_pem_public_key(
public_key, backend=default_backend()
)
# Convert RSA public key object to JWK (JSON Web Key)
public_jwk = json.loads(jwt.algorithms.RSAAlgorithm.to_jwk(public_key_obj))
assert isinstance(public_jwk, dict)
# set cache
cache = DualCache()
await cache.async_set_cache(key="litellm_jwt_auth_keys", value=[public_jwk])
jwt_handler = JWTHandler()
jwt_handler.user_api_key_cache = cache
if admin_allowed_routes:
jwt_handler.litellm_jwtauth = LiteLLM_JWTAuth(
team_id_jwt_field="client_id", admin_allowed_routes=admin_allowed_routes
)
else:
jwt_handler.litellm_jwtauth = LiteLLM_JWTAuth(team_id_jwt_field="client_id")
# VALID TOKEN
## GENERATE A TOKEN
# Assuming the current time is in UTC
expiration_time = int((datetime.now() + timedelta(minutes=10)).timestamp())
# Generate the JWT token
# But before, you should convert bytes to string
private_key_str = private_key.decode("utf-8")
## admin token
payload = {
"sub": "user123",
"exp": expiration_time, # set the token to expire in 10 minutes
"scope": "litellm_proxy_admin",
"aud": audience,
}
admin_token = jwt.encode(payload, private_key_str, algorithm="RS256")
# verify token
print(f"admin_token: {admin_token}")
response = await jwt_handler.auth_jwt(token=admin_token)
## RUN IT THROUGH USER API KEY AUTH
"""
- 1. Initial call should fail -> team doesn't exist
- 2. Create team via admin token
- 3. 2nd call w/ same team -> call should succeed -> assert UserAPIKeyAuth object correctly formatted
"""
bearer_token = "Bearer " + admin_token
pseudo_routes = jwt_handler.litellm_jwtauth.admin_allowed_routes
actual_routes = []
for route in pseudo_routes:
if route in LiteLLMRoutes.__members__:
actual_routes.extend(LiteLLMRoutes[route].value)
for route in actual_routes:
request = Request(scope={"type": "http"})
request._url = URL(url=route)
## 1. INITIAL TEAM CALL - should fail
# use generated key to auth in
setattr(
litellm.proxy.proxy_server,
"general_settings",
{
"enable_jwt_auth": True,
},
)
setattr(litellm.proxy.proxy_server, "jwt_handler", jwt_handler)
try:
result = await user_api_key_auth(request=request, api_key=bearer_token)
except Exception as e:
raise e
import pytest
@pytest.mark.asyncio
async def test_team_cache_update_called():
import litellm
from litellm.proxy.proxy_server import user_api_key_cache
# Use setattr to replace the method on the user_api_key_cache object
cache = DualCache()
setattr(
litellm.proxy.proxy_server,
"user_api_key_cache",
cache,
)
with patch.object(cache, "async_get_cache", new=AsyncMock()) as mock_call_cache:
cache.async_get_cache = mock_call_cache
# Call the function under test
await litellm.proxy.proxy_server.update_cache(
token=None,
user_id=None,
end_user_id=None,
team_id="1234",
response_cost=20,
parent_otel_span=None,
) # type: ignore
await asyncio.sleep(3)
mock_call_cache.assert_awaited_once()
@pytest.fixture
def public_jwt_key():
import json
import jwt
from cryptography.hazmat.backends import default_backend
from cryptography.hazmat.primitives import serialization
from cryptography.hazmat.primitives.asymmetric import rsa
# Generate a private / public key pair using RSA algorithm
key = rsa.generate_private_key(
public_exponent=65537, key_size=2048, backend=default_backend()
)
# Get private key in PEM format
private_key = key.private_bytes(
encoding=serialization.Encoding.PEM,
format=serialization.PrivateFormat.PKCS8,
encryption_algorithm=serialization.NoEncryption(),
)
# Get public key in PEM format
public_key = key.public_key().public_bytes(
encoding=serialization.Encoding.PEM,
format=serialization.PublicFormat.SubjectPublicKeyInfo,
)
public_key_obj = serialization.load_pem_public_key(
public_key, backend=default_backend()
)
# Convert RSA public key object to JWK (JSON Web Key)
public_jwk = json.loads(jwt.algorithms.RSAAlgorithm.to_jwk(public_key_obj))
return {"private_key": private_key, "public_jwk": public_jwk}
def mock_user_object(*args, **kwargs):
print("Args: {}".format(args))
print("kwargs: {}".format(kwargs))
assert kwargs["user_id_upsert"] is True
@pytest.mark.parametrize(
"user_email, should_work", [("ishaan@berri.ai", True), ("krrish@tassle.xyz", False)]
)
@pytest.mark.asyncio
async def test_allow_access_by_email(public_jwt_key, user_email, should_work):
"""
Allow anyone with an `@xyz.com` email make a request to the proxy.
Relevant issue: https://github.com/BerriAI/litellm/issues/5605
"""
import jwt
from starlette.datastructures import URL
from litellm.proxy._types import NewTeamRequest, UserAPIKeyAuth
from litellm.proxy.proxy_server import user_api_key_auth
public_jwk = public_jwt_key["public_jwk"]
private_key = public_jwt_key["private_key"]
# set cache
cache = DualCache()
await cache.async_set_cache(key="litellm_jwt_auth_keys", value=[public_jwk])
jwt_handler = JWTHandler()
jwt_handler.user_api_key_cache = cache
jwt_handler.litellm_jwtauth = LiteLLM_JWTAuth(
user_email_jwt_field="email",
user_allowed_email_domain="berri.ai",
user_id_upsert=True,
)
# VALID TOKEN
## GENERATE A TOKEN
# Assuming the current time is in UTC
expiration_time = int((datetime.now() + timedelta(minutes=10)).timestamp())
team_id = f"team123_{uuid.uuid4()}"
payload = {
"sub": "user123",
"exp": expiration_time, # set the token to expire in 10 minutes
"scope": "litellm_team",
"client_id": team_id,
"aud": "litellm-proxy",
"email": user_email,
}
# Generate the JWT token
# But before, you should convert bytes to string
private_key_str = private_key.decode("utf-8")
## team token
token = jwt.encode(payload, private_key_str, algorithm="RS256")
## VERIFY IT WORKS
# Expect the call to succeed
response = await jwt_handler.auth_jwt(token=token)
assert response is not None # Adjust this based on your actual response check
## RUN IT THROUGH USER API KEY AUTH
bearer_token = "Bearer " + token
request = Request(scope={"type": "http"})
request._url = URL(url="/chat/completions")
## 1. INITIAL TEAM CALL - should fail
# use generated key to auth in
setattr(
litellm.proxy.proxy_server,
"general_settings",
{
"enable_jwt_auth": True,
},
)
setattr(litellm.proxy.proxy_server, "jwt_handler", jwt_handler)
setattr(litellm.proxy.proxy_server, "prisma_client", {})
# AsyncMock(
# return_value=LiteLLM_UserTable(
# spend=0, user_id=user_email, max_budget=None, user_email=user_email
# )
# ),
with patch.object(
litellm.proxy.auth.user_api_key_auth,
"get_user_object",
side_effect=mock_user_object,
) as mock_client:
if should_work:
# Expect the call to succeed
result = await user_api_key_auth(request=request, api_key=bearer_token)
assert result is not None # Adjust this based on your actual response check
else:
# Expect the call to fail
with pytest.raises(
Exception
): # Replace with the actual exception raised on failure
resp = await user_api_key_auth(request=request, api_key=bearer_token)
print(resp)
def test_get_public_key_from_jwk_url():
import litellm
from litellm.proxy.auth.handle_jwt import JWTHandler
jwt_handler = JWTHandler()
jwk_response = [
{
"kty": "RSA",
"alg": "RS256",
"kid": "RaPJB8QVptWHjHcoHkVlUWO4f0D3BtcY6iSDXgGVBgk",
"use": "sig",
"e": "AQAB",
"n": "zgLDu57gLpkzzIkKrTKQVyjK8X40hvu6X_JOeFjmYmI0r3bh7FTOmre5rTEkDOL-1xvQguZAx4hjKmCzBU5Kz84FbsGiqM0ug19df4kwdTS6XOM6YEKUZrbaw4P7xTPsbZj7W2G_kxWNm3Xaxq6UKFdUF7n9snnBKKD6iUA-cE6HfsYmt9OhYZJfy44dbAbuanFmAsWw97SHrPFL3ueh3Ixt19KgpF4iSsXNg3YvoesdFM8psmivgePyyHA8k7pK1Yq7rNQX1Q9nzhvP-F7ocFbP52KYPlaSTu30YwPTVTFKYpDNmHT1fZ7LXZZNLrP_7-NSY76HS2ozSpzjsGVelQ",
}
]
public_key = jwt_handler.parse_keys(
keys=jwk_response,
kid="RaPJB8QVptWHjHcoHkVlUWO4f0D3BtcY6iSDXgGVBgk",
)
assert public_key is not None
assert public_key == jwk_response[0]