litellm/tests/local_testing/test_jwt.py
Krish Dholakia 7cc12bd5c6
LiteLLM Minor Fixes & Improvements (10/18/2024) (#6320)
* fix(converse_transformation.py): handle cross region model name when getting openai param support

Fixes https://github.com/BerriAI/litellm/issues/6291

* LiteLLM Minor Fixes & Improvements (10/17/2024)  (#6293)

* fix(ui_sso.py): fix faulty admin only check

Fixes https://github.com/BerriAI/litellm/issues/6286

* refactor(sso_helper_utils.py): refactor /sso/callback to use helper utils, covered by unit testing

Prevent future regressions

* feat(prompt_factory): support 'ensure_alternating_roles' param

Closes https://github.com/BerriAI/litellm/issues/6257

* fix(proxy/utils.py): add dailytagspend to expected views

* feat(auth_utils.py): support setting regex for clientside auth credentials

Fixes https://github.com/BerriAI/litellm/issues/6203

* build(cookbook): add tutorial for mlflow + langchain + litellm proxy tracing

* feat(argilla.py): add argilla logging integration

Closes https://github.com/BerriAI/litellm/issues/6201

* fix: fix linting errors

* fix: fix ruff error

* test: fix test

* fix: update vertex ai assumption - parts not always guaranteed (#6296)

* docs(configs.md): add argila env var to docs

* docs(user_keys.md): add regex doc for clientside auth params

* docs(argilla.md): add doc on argilla logging

* docs(argilla.md): add sampling rate to argilla calls

* bump: version 1.49.6 → 1.49.7

* add gpt-4o-audio models to model cost map (#6306)

* (code quality) add ruff check PLR0915 for `too-many-statements`  (#6309)

* ruff add PLR0915

* add noqa for PLR0915

* fix noqa

* add # noqa: PLR0915

* # noqa: PLR0915

* # noqa: PLR0915

* # noqa: PLR0915

* add # noqa: PLR0915

* # noqa: PLR0915

* # noqa: PLR0915

* # noqa: PLR0915

* # noqa: PLR0915

* doc fix Turn on / off caching per Key. (#6297)

* (feat) Support `audio`,  `modalities` params (#6304)

* add audio, modalities param

* add test for gpt audio models

* add get_supported_openai_params for GPT audio models

* add supported params for audio

* test_audio_output_from_model

* bump openai to openai==1.52.0

* bump openai on pyproject

* fix audio test

* fix test mock_chat_response

* handle audio for Message

* fix handling audio for OAI compatible API endpoints

* fix linting

* fix mock dbrx test

* (feat) Support audio param in responses streaming (#6312)

* add audio, modalities param

* add test for gpt audio models

* add get_supported_openai_params for GPT audio models

* add supported params for audio

* test_audio_output_from_model

* bump openai to openai==1.52.0

* bump openai on pyproject

* fix audio test

* fix test mock_chat_response

* handle audio for Message

* fix handling audio for OAI compatible API endpoints

* fix linting

* fix mock dbrx test

* add audio to Delta

* handle model_response.choices.delta.audio

* fix linting

* build(model_prices_and_context_window.json): add gpt-4o-audio audio token cost tracking

* refactor(model_prices_and_context_window.json): refactor 'supports_audio' to be 'supports_audio_input' and 'supports_audio_output'

Allows for flag to be used for openai + gemini models (both support audio input)

* feat(cost_calculation.py): support cost calc for audio model

Closes https://github.com/BerriAI/litellm/issues/6302

* feat(utils.py): expose new `supports_audio_input` and `supports_audio_output` functions

Closes https://github.com/BerriAI/litellm/issues/6303

* feat(handle_jwt.py): support single dict list

* fix(cost_calculator.py): fix linting errors

* fix: fix linting error

* fix(cost_calculator): move to using standard openai usage cached tokens value

* test: fix test

---------

Co-authored-by: Ishaan Jaff <ishaanjaffer0324@gmail.com>
2024-10-19 22:23:27 -07:00

1021 lines
32 KiB
Python

#### What this tests ####
# Unit tests for JWT-Auth
import asyncio
import os
import random
import sys
import time
import traceback
import uuid
from dotenv import load_dotenv
load_dotenv()
import os
sys.path.insert(
0, os.path.abspath("../..")
) # Adds the parent directory to the system path
from datetime import datetime, timedelta
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from fastapi import Request
import litellm
from litellm.caching.caching import DualCache
from litellm.proxy._types import LiteLLM_JWTAuth, LiteLLM_UserTable, LiteLLMRoutes
from litellm.proxy.auth.handle_jwt import JWTHandler
from litellm.proxy.management_endpoints.team_endpoints import new_team
from litellm.proxy.proxy_server import chat_completion
public_key = {
"kty": "RSA",
"e": "AQAB",
"n": "qIgOQfEVrrErJC0E7gsHXi6rs_V0nyFY5qPFui2-tv0o4CwpwDzgfBtLO7o_wLiguq0lnu54sMT2eLNoRiiPuLvv6bg7Iy1H9yc5_4Jf5oYEOrqN5o9ZBOoYp1q68Pv0oNJYyZdGu5ZJfd7V4y953vB2XfEKgXCsAkhVhlvIUMiDNKWoMDWsyb2xela5tRURZ2mJAXcHfSC_sYdZxIA2YYrIHfoevq_vTlaz0qVSe_uOKjEpgOAS08UUrgda4CQL11nzICiIQzc6qmjIQt2cjzB2D_9zb4BYndzEtfl0kwAT0z_I85S3mkwTqHU-1BvKe_4MG4VG3dAAeffLPXJyXQ",
"alg": "RS256",
}
def test_load_config_with_custom_role_names():
config = {
"general_settings": {
"litellm_proxy_roles": {"admin_jwt_scope": "litellm-proxy-admin"}
}
}
proxy_roles = LiteLLM_JWTAuth(
**config.get("general_settings", {}).get("litellm_proxy_roles", {})
)
print(f"proxy_roles: {proxy_roles}")
assert proxy_roles.admin_jwt_scope == "litellm-proxy-admin"
# test_load_config_with_custom_role_names()
@pytest.mark.asyncio
async def test_token_single_public_key():
import jwt
jwt_handler = JWTHandler()
backend_keys = {
"keys": [
{
"kty": "RSA",
"use": "sig",
"e": "AQAB",
"n": "qIgOQfEVrrErJC0E7gsHXi6rs_V0nyFY5qPFui2-tv0o4CwpwDzgfBtLO7o_wLiguq0lnu54sMT2eLNoRiiPuLvv6bg7Iy1H9yc5_4Jf5oYEOrqN5o9ZBOoYp1q68Pv0oNJYyZdGu5ZJfd7V4y953vB2XfEKgXCsAkhVhlvIUMiDNKWoMDWsyb2xela5tRURZ2mJAXcHfSC_sYdZxIA2YYrIHfoevq_vTlaz0qVSe_uOKjEpgOAS08UUrgda4CQL11nzICiIQzc6qmjIQt2cjzB2D_9zb4BYndzEtfl0kwAT0z_I85S3mkwTqHU-1BvKe_4MG4VG3dAAeffLPXJyXQ",
"alg": "RS256",
}
]
}
# set cache
cache = DualCache()
await cache.async_set_cache(key="litellm_jwt_auth_keys", value=backend_keys["keys"])
jwt_handler.user_api_key_cache = cache
public_key = await jwt_handler.get_public_key(kid=None)
assert public_key is not None
assert isinstance(public_key, dict)
assert (
public_key["n"]
== "qIgOQfEVrrErJC0E7gsHXi6rs_V0nyFY5qPFui2-tv0o4CwpwDzgfBtLO7o_wLiguq0lnu54sMT2eLNoRiiPuLvv6bg7Iy1H9yc5_4Jf5oYEOrqN5o9ZBOoYp1q68Pv0oNJYyZdGu5ZJfd7V4y953vB2XfEKgXCsAkhVhlvIUMiDNKWoMDWsyb2xela5tRURZ2mJAXcHfSC_sYdZxIA2YYrIHfoevq_vTlaz0qVSe_uOKjEpgOAS08UUrgda4CQL11nzICiIQzc6qmjIQt2cjzB2D_9zb4BYndzEtfl0kwAT0z_I85S3mkwTqHU-1BvKe_4MG4VG3dAAeffLPXJyXQ"
)
@pytest.mark.parametrize("audience", [None, "litellm-proxy"])
@pytest.mark.asyncio
async def test_valid_invalid_token(audience):
"""
Tests
- valid token
- invalid token
"""
import json
import jwt
from cryptography.hazmat.backends import default_backend
from cryptography.hazmat.primitives import serialization
from cryptography.hazmat.primitives.asymmetric import rsa
os.environ.pop("JWT_AUDIENCE", None)
if audience:
os.environ["JWT_AUDIENCE"] = audience
# Generate a private / public key pair using RSA algorithm
key = rsa.generate_private_key(
public_exponent=65537, key_size=2048, backend=default_backend()
)
# Get private key in PEM format
private_key = key.private_bytes(
encoding=serialization.Encoding.PEM,
format=serialization.PrivateFormat.PKCS8,
encryption_algorithm=serialization.NoEncryption(),
)
# Get public key in PEM format
public_key = key.public_key().public_bytes(
encoding=serialization.Encoding.PEM,
format=serialization.PublicFormat.SubjectPublicKeyInfo,
)
public_key_obj = serialization.load_pem_public_key(
public_key, backend=default_backend()
)
# Convert RSA public key object to JWK (JSON Web Key)
public_jwk = json.loads(jwt.algorithms.RSAAlgorithm.to_jwk(public_key_obj))
assert isinstance(public_jwk, dict)
# set cache
cache = DualCache()
await cache.async_set_cache(key="litellm_jwt_auth_keys", value=[public_jwk])
jwt_handler = JWTHandler()
jwt_handler.user_api_key_cache = cache
# VALID TOKEN
## GENERATE A TOKEN
# Assuming the current time is in UTC
expiration_time = int((datetime.utcnow() + timedelta(minutes=10)).timestamp())
payload = {
"sub": "user123",
"exp": expiration_time, # set the token to expire in 10 minutes
"scope": "litellm-proxy-admin",
"aud": audience,
}
# Generate the JWT token
# But before, you should convert bytes to string
private_key_str = private_key.decode("utf-8")
token = jwt.encode(payload, private_key_str, algorithm="RS256")
## VERIFY IT WORKS
# verify token
response = await jwt_handler.auth_jwt(token=token)
assert response is not None
assert isinstance(response, dict)
print(f"response: {response}")
# INVALID TOKEN
## GENERATE A TOKEN
# Assuming the current time is in UTC
expiration_time = int((datetime.utcnow() + timedelta(minutes=10)).timestamp())
payload = {
"sub": "user123",
"exp": expiration_time, # set the token to expire in 10 minutes
"scope": "litellm-NO-SCOPE",
"aud": audience,
}
# Generate the JWT token
# But before, you should convert bytes to string
private_key_str = private_key.decode("utf-8")
token = jwt.encode(payload, private_key_str, algorithm="RS256")
## VERIFY IT WORKS
# verify token
try:
response = await jwt_handler.auth_jwt(token=token)
except Exception as e:
pytest.fail(f"An exception occurred - {str(e)}")
@pytest.fixture
def prisma_client():
import litellm
from litellm.proxy.proxy_cli import append_query_params
from litellm.proxy.utils import PrismaClient, ProxyLogging
proxy_logging_obj = ProxyLogging(user_api_key_cache=DualCache())
### add connection pool + pool timeout args
params = {"connection_limit": 100, "pool_timeout": 60}
database_url = os.getenv("DATABASE_URL")
modified_url = append_query_params(database_url, params)
os.environ["DATABASE_URL"] = modified_url
# Assuming PrismaClient is a class that needs to be instantiated
prisma_client = PrismaClient(
database_url=os.environ["DATABASE_URL"], proxy_logging_obj=proxy_logging_obj
)
return prisma_client
@pytest.fixture
def team_token_tuple():
import json
import uuid
import jwt
from cryptography.hazmat.backends import default_backend
from cryptography.hazmat.primitives import serialization
from cryptography.hazmat.primitives.asymmetric import rsa
from fastapi import Request
from starlette.datastructures import URL
import litellm
from litellm.proxy._types import NewTeamRequest, UserAPIKeyAuth
from litellm.proxy.proxy_server import user_api_key_auth
# Generate a private / public key pair using RSA algorithm
key = rsa.generate_private_key(
public_exponent=65537, key_size=2048, backend=default_backend()
)
# Get private key in PEM format
private_key = key.private_bytes(
encoding=serialization.Encoding.PEM,
format=serialization.PrivateFormat.PKCS8,
encryption_algorithm=serialization.NoEncryption(),
)
# Get public key in PEM format
public_key = key.public_key().public_bytes(
encoding=serialization.Encoding.PEM,
format=serialization.PublicFormat.SubjectPublicKeyInfo,
)
public_key_obj = serialization.load_pem_public_key(
public_key, backend=default_backend()
)
# Convert RSA public key object to JWK (JSON Web Key)
public_jwk = json.loads(jwt.algorithms.RSAAlgorithm.to_jwk(public_key_obj))
# VALID TOKEN
## GENERATE A TOKEN
# Assuming the current time is in UTC
expiration_time = int((datetime.utcnow() + timedelta(minutes=10)).timestamp())
team_id = f"team123_{uuid.uuid4()}"
payload = {
"sub": "user123",
"exp": expiration_time, # set the token to expire in 10 minutes
"scope": "litellm_team",
"client_id": team_id,
"aud": None,
}
# Generate the JWT token
# But before, you should convert bytes to string
private_key_str = private_key.decode("utf-8")
## team token
token = jwt.encode(payload, private_key_str, algorithm="RS256")
return team_id, token, public_jwk
@pytest.mark.parametrize("audience", [None, "litellm-proxy"])
@pytest.mark.asyncio
async def test_team_token_output(prisma_client, audience):
import json
import uuid
import jwt
from cryptography.hazmat.backends import default_backend
from cryptography.hazmat.primitives import serialization
from cryptography.hazmat.primitives.asymmetric import rsa
from fastapi import Request
from starlette.datastructures import URL
import litellm
from litellm.proxy._types import NewTeamRequest, UserAPIKeyAuth
from litellm.proxy.proxy_server import user_api_key_auth
setattr(litellm.proxy.proxy_server, "prisma_client", prisma_client)
await litellm.proxy.proxy_server.prisma_client.connect()
os.environ.pop("JWT_AUDIENCE", None)
if audience:
os.environ["JWT_AUDIENCE"] = audience
# Generate a private / public key pair using RSA algorithm
key = rsa.generate_private_key(
public_exponent=65537, key_size=2048, backend=default_backend()
)
# Get private key in PEM format
private_key = key.private_bytes(
encoding=serialization.Encoding.PEM,
format=serialization.PrivateFormat.PKCS8,
encryption_algorithm=serialization.NoEncryption(),
)
# Get public key in PEM format
public_key = key.public_key().public_bytes(
encoding=serialization.Encoding.PEM,
format=serialization.PublicFormat.SubjectPublicKeyInfo,
)
public_key_obj = serialization.load_pem_public_key(
public_key, backend=default_backend()
)
# Convert RSA public key object to JWK (JSON Web Key)
public_jwk = json.loads(jwt.algorithms.RSAAlgorithm.to_jwk(public_key_obj))
assert isinstance(public_jwk, dict)
# set cache
cache = DualCache()
await cache.async_set_cache(key="litellm_jwt_auth_keys", value=[public_jwk])
jwt_handler = JWTHandler()
jwt_handler.user_api_key_cache = cache
jwt_handler.litellm_jwtauth = LiteLLM_JWTAuth(team_id_jwt_field="client_id")
# VALID TOKEN
## GENERATE A TOKEN
# Assuming the current time is in UTC
expiration_time = int((datetime.utcnow() + timedelta(minutes=10)).timestamp())
team_id = f"team123_{uuid.uuid4()}"
payload = {
"sub": "user123",
"exp": expiration_time, # set the token to expire in 10 minutes
"scope": "litellm_team",
"client_id": team_id,
"aud": audience,
}
# Generate the JWT token
# But before, you should convert bytes to string
private_key_str = private_key.decode("utf-8")
## team token
token = jwt.encode(payload, private_key_str, algorithm="RS256")
## admin token
payload = {
"sub": "user123",
"exp": expiration_time, # set the token to expire in 10 minutes
"scope": "litellm_proxy_admin",
"aud": audience,
}
admin_token = jwt.encode(payload, private_key_str, algorithm="RS256")
## VERIFY IT WORKS
# verify token
response = await jwt_handler.auth_jwt(token=token)
## RUN IT THROUGH USER API KEY AUTH
"""
- 1. Initial call should fail -> team doesn't exist
- 2. Create team via admin token
- 3. 2nd call w/ same team -> call should succeed -> assert UserAPIKeyAuth object correctly formatted
"""
bearer_token = "Bearer " + token
request = Request(scope={"type": "http"})
request._url = URL(url="/chat/completions")
## 1. INITIAL TEAM CALL - should fail
# use generated key to auth in
setattr(
litellm.proxy.proxy_server,
"general_settings",
{
"enable_jwt_auth": True,
},
)
setattr(litellm.proxy.proxy_server, "jwt_handler", jwt_handler)
try:
result = await user_api_key_auth(request=request, api_key=bearer_token)
pytest.fail("Team doesn't exist. This should fail")
except Exception as e:
pass
## 2. CREATE TEAM W/ ADMIN TOKEN - should succeed
try:
bearer_token = "Bearer " + admin_token
request._url = URL(url="/team/new")
result = await user_api_key_auth(request=request, api_key=bearer_token)
await new_team(
data=NewTeamRequest(
team_id=team_id,
tpm_limit=100,
rpm_limit=99,
models=["gpt-3.5-turbo", "gpt-4"],
),
user_api_key_dict=result,
http_request=Request(scope={"type": "http"}),
)
except Exception as e:
pytest.fail(f"This should not fail - {str(e)}")
## 3. 2nd CALL W/ TEAM TOKEN - should succeed
bearer_token = "Bearer " + token
request._url = URL(url="/chat/completions")
try:
team_result: UserAPIKeyAuth = await user_api_key_auth(
request=request, api_key=bearer_token
)
except Exception as e:
pytest.fail(f"Team exists. This should not fail - {e}")
## 4. ASSERT USER_API_KEY_AUTH format (used for tpm/rpm limiting in parallel_request_limiter.py)
assert team_result.team_tpm_limit == 100
assert team_result.team_rpm_limit == 99
assert team_result.team_models == ["gpt-3.5-turbo", "gpt-4"]
@pytest.mark.parametrize("audience", [None, "litellm-proxy"])
@pytest.mark.parametrize(
"team_id_set, default_team_id",
[(True, False), (False, True)],
)
@pytest.mark.parametrize("user_id_upsert", [True, False])
@pytest.mark.asyncio
async def aaaatest_user_token_output(
prisma_client, audience, team_id_set, default_team_id, user_id_upsert
):
import uuid
args = locals()
print(f"received args - {args}")
if default_team_id:
default_team_id = "team_id_12344_{}".format(uuid.uuid4())
"""
- If user required, check if it exists
- fail initial request (when user doesn't exist)
- create user
- retry -> it should pass now
"""
import json
import uuid
import jwt
from cryptography.hazmat.backends import default_backend
from cryptography.hazmat.primitives import serialization
from cryptography.hazmat.primitives.asymmetric import rsa
from fastapi import Request
from starlette.datastructures import URL
import litellm
from litellm.proxy._types import NewTeamRequest, NewUserRequest, UserAPIKeyAuth
from litellm.proxy.management_endpoints.internal_user_endpoints import (
new_user,
user_info,
)
from litellm.proxy.proxy_server import user_api_key_auth
setattr(litellm.proxy.proxy_server, "prisma_client", prisma_client)
await litellm.proxy.proxy_server.prisma_client.connect()
os.environ.pop("JWT_AUDIENCE", None)
if audience:
os.environ["JWT_AUDIENCE"] = audience
# Generate a private / public key pair using RSA algorithm
key = rsa.generate_private_key(
public_exponent=65537, key_size=2048, backend=default_backend()
)
# Get private key in PEM format
private_key = key.private_bytes(
encoding=serialization.Encoding.PEM,
format=serialization.PrivateFormat.PKCS8,
encryption_algorithm=serialization.NoEncryption(),
)
# Get public key in PEM format
public_key = key.public_key().public_bytes(
encoding=serialization.Encoding.PEM,
format=serialization.PublicFormat.SubjectPublicKeyInfo,
)
public_key_obj = serialization.load_pem_public_key(
public_key, backend=default_backend()
)
# Convert RSA public key object to JWK (JSON Web Key)
public_jwk = json.loads(jwt.algorithms.RSAAlgorithm.to_jwk(public_key_obj))
assert isinstance(public_jwk, dict)
# set cache
cache = DualCache()
await cache.async_set_cache(key="litellm_jwt_auth_keys", value=[public_jwk])
jwt_handler = JWTHandler()
jwt_handler.user_api_key_cache = cache
jwt_handler.litellm_jwtauth = LiteLLM_JWTAuth()
jwt_handler.litellm_jwtauth.user_id_jwt_field = "sub"
jwt_handler.litellm_jwtauth.team_id_default = default_team_id
jwt_handler.litellm_jwtauth.user_id_upsert = user_id_upsert
if team_id_set:
jwt_handler.litellm_jwtauth.team_id_jwt_field = "client_id"
# VALID TOKEN
## GENERATE A TOKEN
# Assuming the current time is in UTC
expiration_time = int((datetime.utcnow() + timedelta(minutes=10)).timestamp())
team_id = f"team123_{uuid.uuid4()}"
user_id = f"user123_{uuid.uuid4()}"
payload = {
"sub": user_id,
"exp": expiration_time, # set the token to expire in 10 minutes
"scope": "litellm_team",
"client_id": team_id,
"aud": audience,
}
# Generate the JWT token
# But before, you should convert bytes to string
private_key_str = private_key.decode("utf-8")
## team token
token = jwt.encode(payload, private_key_str, algorithm="RS256")
## admin token
payload = {
"sub": user_id,
"exp": expiration_time, # set the token to expire in 10 minutes
"scope": "litellm_proxy_admin",
"aud": audience,
}
admin_token = jwt.encode(payload, private_key_str, algorithm="RS256")
## VERIFY IT WORKS
# verify token
response = await jwt_handler.auth_jwt(token=token)
## RUN IT THROUGH USER API KEY AUTH
"""
- 1. Initial call should fail -> team doesn't exist
- 2. Create team via admin token
- 3. 2nd call w/ same team -> call should fail -> user doesn't exist
- 4. Create user via admin token
- 5. 3rd call w/ same team, same user -> call should succeed
- 6. assert user api key auth format
"""
bearer_token = "Bearer " + token
request = Request(scope={"type": "http"})
request._url = URL(url="/chat/completions")
## 1. INITIAL TEAM CALL - should fail
# use generated key to auth in
setattr(litellm.proxy.proxy_server, "general_settings", {"enable_jwt_auth": True})
setattr(litellm.proxy.proxy_server, "jwt_handler", jwt_handler)
try:
result = await user_api_key_auth(request=request, api_key=bearer_token)
pytest.fail("Team doesn't exist. This should fail")
except Exception as e:
pass
## 2. CREATE TEAM W/ ADMIN TOKEN - should succeed
try:
bearer_token = "Bearer " + admin_token
request._url = URL(url="/team/new")
result = await user_api_key_auth(request=request, api_key=bearer_token)
await new_team(
data=NewTeamRequest(
team_id=team_id,
tpm_limit=100,
rpm_limit=99,
models=["gpt-3.5-turbo", "gpt-4"],
),
user_api_key_dict=result,
http_request=Request(scope={"type": "http"}),
)
if default_team_id:
await new_team(
data=NewTeamRequest(
team_id=default_team_id,
tpm_limit=100,
rpm_limit=99,
models=["gpt-3.5-turbo", "gpt-4"],
),
user_api_key_dict=result,
http_request=Request(scope={"type": "http"}),
)
except Exception as e:
pytest.fail(f"This should not fail - {str(e)}")
## 3. 2nd CALL W/ TEAM TOKEN - should fail
bearer_token = "Bearer " + token
request._url = URL(url="/chat/completions")
try:
team_result: UserAPIKeyAuth = await user_api_key_auth(
request=request, api_key=bearer_token
)
if user_id_upsert == False:
pytest.fail(f"User doesn't exist. this should fail")
except Exception as e:
pass
## 4. Create user
if user_id_upsert:
## check if user already exists
try:
bearer_token = "Bearer " + admin_token
request._url = URL(url="/team/new")
result = await user_api_key_auth(request=request, api_key=bearer_token)
await user_info(user_id=user_id)
except Exception as e:
pytest.fail(f"This should not fail - {str(e)}")
else:
try:
bearer_token = "Bearer " + admin_token
request._url = URL(url="/team/new")
result = await user_api_key_auth(request=request, api_key=bearer_token)
await new_user(
data=NewUserRequest(
user_id=user_id,
),
)
except Exception as e:
pytest.fail(f"This should not fail - {str(e)}")
## 5. 3rd call w/ same team, same user -> call should succeed
bearer_token = "Bearer " + token
request._url = URL(url="/chat/completions")
try:
team_result: UserAPIKeyAuth = await user_api_key_auth(
request=request, api_key=bearer_token
)
except Exception as e:
pytest.fail(f"Team exists. This should not fail - {e}")
## 6. ASSERT USER_API_KEY_AUTH format (used for tpm/rpm limiting in parallel_request_limiter.py AND cost tracking)
if team_id_set or default_team_id is not None:
assert team_result.team_tpm_limit == 100
assert team_result.team_rpm_limit == 99
assert team_result.team_models == ["gpt-3.5-turbo", "gpt-4"]
assert team_result.user_id == user_id
@pytest.mark.parametrize("audience", [None, "litellm-proxy"])
@pytest.mark.asyncio
async def test_allowed_routes_admin(prisma_client, audience):
"""
Add a check to make sure jwt proxy admin scope can access all allowed admin routes
- iterate through allowed endpoints
- check if admin passes user_api_key_auth for them
"""
import json
import uuid
import jwt
from cryptography.hazmat.backends import default_backend
from cryptography.hazmat.primitives import serialization
from cryptography.hazmat.primitives.asymmetric import rsa
from fastapi import Request
from starlette.datastructures import URL
import litellm
from litellm.proxy._types import NewTeamRequest, UserAPIKeyAuth
from litellm.proxy.proxy_server import user_api_key_auth
setattr(litellm.proxy.proxy_server, "prisma_client", prisma_client)
await litellm.proxy.proxy_server.prisma_client.connect()
os.environ.pop("JWT_AUDIENCE", None)
if audience:
os.environ["JWT_AUDIENCE"] = audience
# Generate a private / public key pair using RSA algorithm
key = rsa.generate_private_key(
public_exponent=65537, key_size=2048, backend=default_backend()
)
# Get private key in PEM format
private_key = key.private_bytes(
encoding=serialization.Encoding.PEM,
format=serialization.PrivateFormat.PKCS8,
encryption_algorithm=serialization.NoEncryption(),
)
# Get public key in PEM format
public_key = key.public_key().public_bytes(
encoding=serialization.Encoding.PEM,
format=serialization.PublicFormat.SubjectPublicKeyInfo,
)
public_key_obj = serialization.load_pem_public_key(
public_key, backend=default_backend()
)
# Convert RSA public key object to JWK (JSON Web Key)
public_jwk = json.loads(jwt.algorithms.RSAAlgorithm.to_jwk(public_key_obj))
assert isinstance(public_jwk, dict)
# set cache
cache = DualCache()
await cache.async_set_cache(key="litellm_jwt_auth_keys", value=[public_jwk])
jwt_handler = JWTHandler()
jwt_handler.user_api_key_cache = cache
jwt_handler.litellm_jwtauth = LiteLLM_JWTAuth(team_id_jwt_field="client_id")
# VALID TOKEN
## GENERATE A TOKEN
# Assuming the current time is in UTC
expiration_time = int((datetime.utcnow() + timedelta(minutes=10)).timestamp())
# Generate the JWT token
# But before, you should convert bytes to string
private_key_str = private_key.decode("utf-8")
## admin token
payload = {
"sub": "user123",
"exp": expiration_time, # set the token to expire in 10 minutes
"scope": "litellm_proxy_admin",
"aud": audience,
}
admin_token = jwt.encode(payload, private_key_str, algorithm="RS256")
# verify token
response = await jwt_handler.auth_jwt(token=admin_token)
## RUN IT THROUGH USER API KEY AUTH
"""
- 1. Initial call should fail -> team doesn't exist
- 2. Create team via admin token
- 3. 2nd call w/ same team -> call should succeed -> assert UserAPIKeyAuth object correctly formatted
"""
bearer_token = "Bearer " + admin_token
pseudo_routes = jwt_handler.litellm_jwtauth.admin_allowed_routes
actual_routes = []
for route in pseudo_routes:
if route in LiteLLMRoutes.__members__:
actual_routes.extend(LiteLLMRoutes[route].value)
for route in actual_routes:
request = Request(scope={"type": "http"})
request._url = URL(url=route)
## 1. INITIAL TEAM CALL - should fail
# use generated key to auth in
setattr(
litellm.proxy.proxy_server,
"general_settings",
{
"enable_jwt_auth": True,
},
)
setattr(litellm.proxy.proxy_server, "jwt_handler", jwt_handler)
try:
result = await user_api_key_auth(request=request, api_key=bearer_token)
except Exception as e:
raise e
import pytest
@pytest.mark.asyncio
async def test_team_cache_update_called():
import litellm
from litellm.proxy.proxy_server import user_api_key_cache
# Use setattr to replace the method on the user_api_key_cache object
cache = DualCache()
setattr(
litellm.proxy.proxy_server,
"user_api_key_cache",
cache,
)
with patch.object(cache, "async_get_cache", new=AsyncMock()) as mock_call_cache:
cache.async_get_cache = mock_call_cache
# Call the function under test
await litellm.proxy.proxy_server.update_cache(
token=None,
user_id=None,
end_user_id=None,
team_id="1234",
response_cost=20,
parent_otel_span=None,
) # type: ignore
await asyncio.sleep(3)
mock_call_cache.assert_awaited_once()
@pytest.fixture
def public_jwt_key():
import json
import jwt
from cryptography.hazmat.backends import default_backend
from cryptography.hazmat.primitives import serialization
from cryptography.hazmat.primitives.asymmetric import rsa
# Generate a private / public key pair using RSA algorithm
key = rsa.generate_private_key(
public_exponent=65537, key_size=2048, backend=default_backend()
)
# Get private key in PEM format
private_key = key.private_bytes(
encoding=serialization.Encoding.PEM,
format=serialization.PrivateFormat.PKCS8,
encryption_algorithm=serialization.NoEncryption(),
)
# Get public key in PEM format
public_key = key.public_key().public_bytes(
encoding=serialization.Encoding.PEM,
format=serialization.PublicFormat.SubjectPublicKeyInfo,
)
public_key_obj = serialization.load_pem_public_key(
public_key, backend=default_backend()
)
# Convert RSA public key object to JWK (JSON Web Key)
public_jwk = json.loads(jwt.algorithms.RSAAlgorithm.to_jwk(public_key_obj))
return {"private_key": private_key, "public_jwk": public_jwk}
def mock_user_object(*args, **kwargs):
print("Args: {}".format(args))
print("kwargs: {}".format(kwargs))
assert kwargs["user_id_upsert"] is True
@pytest.mark.parametrize(
"user_email, should_work", [("ishaan@berri.ai", True), ("krrish@tassle.xyz", False)]
)
@pytest.mark.asyncio
async def test_allow_access_by_email(public_jwt_key, user_email, should_work):
"""
Allow anyone with an `@xyz.com` email make a request to the proxy.
Relevant issue: https://github.com/BerriAI/litellm/issues/5605
"""
import jwt
from starlette.datastructures import URL
from litellm.proxy._types import NewTeamRequest, UserAPIKeyAuth
from litellm.proxy.proxy_server import user_api_key_auth
public_jwk = public_jwt_key["public_jwk"]
private_key = public_jwt_key["private_key"]
# set cache
cache = DualCache()
await cache.async_set_cache(key="litellm_jwt_auth_keys", value=[public_jwk])
jwt_handler = JWTHandler()
jwt_handler.user_api_key_cache = cache
jwt_handler.litellm_jwtauth = LiteLLM_JWTAuth(
user_email_jwt_field="email",
user_allowed_email_domain="berri.ai",
user_id_upsert=True,
)
# VALID TOKEN
## GENERATE A TOKEN
# Assuming the current time is in UTC
expiration_time = int((datetime.utcnow() + timedelta(minutes=10)).timestamp())
team_id = f"team123_{uuid.uuid4()}"
payload = {
"sub": "user123",
"exp": expiration_time, # set the token to expire in 10 minutes
"scope": "litellm_team",
"client_id": team_id,
"aud": "litellm-proxy",
"email": user_email,
}
# Generate the JWT token
# But before, you should convert bytes to string
private_key_str = private_key.decode("utf-8")
## team token
token = jwt.encode(payload, private_key_str, algorithm="RS256")
## VERIFY IT WORKS
# Expect the call to succeed
response = await jwt_handler.auth_jwt(token=token)
assert response is not None # Adjust this based on your actual response check
## RUN IT THROUGH USER API KEY AUTH
bearer_token = "Bearer " + token
request = Request(scope={"type": "http"})
request._url = URL(url="/chat/completions")
## 1. INITIAL TEAM CALL - should fail
# use generated key to auth in
setattr(
litellm.proxy.proxy_server,
"general_settings",
{
"enable_jwt_auth": True,
},
)
setattr(litellm.proxy.proxy_server, "jwt_handler", jwt_handler)
setattr(litellm.proxy.proxy_server, "prisma_client", {})
# AsyncMock(
# return_value=LiteLLM_UserTable(
# spend=0, user_id=user_email, max_budget=None, user_email=user_email
# )
# ),
with patch.object(
litellm.proxy.auth.user_api_key_auth,
"get_user_object",
side_effect=mock_user_object,
) as mock_client:
if should_work:
# Expect the call to succeed
result = await user_api_key_auth(request=request, api_key=bearer_token)
assert result is not None # Adjust this based on your actual response check
else:
# Expect the call to fail
with pytest.raises(
Exception
): # Replace with the actual exception raised on failure
resp = await user_api_key_auth(request=request, api_key=bearer_token)
print(resp)
def test_get_public_key_from_jwk_url():
import litellm
from litellm.proxy.auth.handle_jwt import JWTHandler
jwt_handler = JWTHandler()
jwk_response = [
{
"kty": "RSA",
"alg": "RS256",
"kid": "RaPJB8QVptWHjHcoHkVlUWO4f0D3BtcY6iSDXgGVBgk",
"use": "sig",
"e": "AQAB",
"n": "zgLDu57gLpkzzIkKrTKQVyjK8X40hvu6X_JOeFjmYmI0r3bh7FTOmre5rTEkDOL-1xvQguZAx4hjKmCzBU5Kz84FbsGiqM0ug19df4kwdTS6XOM6YEKUZrbaw4P7xTPsbZj7W2G_kxWNm3Xaxq6UKFdUF7n9snnBKKD6iUA-cE6HfsYmt9OhYZJfy44dbAbuanFmAsWw97SHrPFL3ueh3Ixt19KgpF4iSsXNg3YvoesdFM8psmivgePyyHA8k7pK1Yq7rNQX1Q9nzhvP-F7ocFbP52KYPlaSTu30YwPTVTFKYpDNmHT1fZ7LXZZNLrP_7-NSY76HS2ozSpzjsGVelQ",
}
]
public_key = jwt_handler.parse_keys(
keys=jwk_response,
kid="RaPJB8QVptWHjHcoHkVlUWO4f0D3BtcY6iSDXgGVBgk",
)
assert public_key is not None
assert public_key == jwk_response[0]