forked from phoenix/litellm-mirror
fix(proxy_server.py): fix admin allowed routes
This commit is contained in:
parent
595a2a5b1b
commit
a33b9221da
5 changed files with 144 additions and 26 deletions
|
@ -23,3 +23,5 @@ general_settings:
|
||||||
alerting: ["slack"]
|
alerting: ["slack"]
|
||||||
alerting_args:
|
alerting_args:
|
||||||
report_check_interval: 10
|
report_check_interval: 10
|
||||||
|
enable_jwt_auth: True
|
||||||
|
|
||||||
|
|
|
@ -123,18 +123,8 @@ def _allowed_routes_check(user_route: str, allowed_routes: list) -> bool:
|
||||||
"""
|
"""
|
||||||
for allowed_route in allowed_routes:
|
for allowed_route in allowed_routes:
|
||||||
if (
|
if (
|
||||||
allowed_route == LiteLLMRoutes.openai_routes.name
|
allowed_route in LiteLLMRoutes.__members__
|
||||||
and user_route in LiteLLMRoutes.openai_routes.value
|
and user_route in LiteLLMRoutes[allowed_route].value
|
||||||
):
|
|
||||||
return True
|
|
||||||
elif (
|
|
||||||
allowed_route == LiteLLMRoutes.info_routes.name
|
|
||||||
and user_route in LiteLLMRoutes.info_routes.value
|
|
||||||
):
|
|
||||||
return True
|
|
||||||
elif (
|
|
||||||
allowed_route == LiteLLMRoutes.management_routes.name
|
|
||||||
and user_route in LiteLLMRoutes.management_routes.value
|
|
||||||
):
|
):
|
||||||
return True
|
return True
|
||||||
elif allowed_route == user_route:
|
elif allowed_route == user_route:
|
||||||
|
@ -152,17 +142,11 @@ def allowed_routes_check(
|
||||||
"""
|
"""
|
||||||
|
|
||||||
if user_role == "proxy_admin":
|
if user_role == "proxy_admin":
|
||||||
if litellm_proxy_roles.admin_allowed_routes is None:
|
is_allowed = _allowed_routes_check(
|
||||||
is_allowed = _allowed_routes_check(
|
user_route=user_route,
|
||||||
user_route=user_route, allowed_routes=["management_routes"]
|
allowed_routes=litellm_proxy_roles.admin_allowed_routes,
|
||||||
)
|
)
|
||||||
return is_allowed
|
return is_allowed
|
||||||
elif litellm_proxy_roles.admin_allowed_routes is not None:
|
|
||||||
is_allowed = _allowed_routes_check(
|
|
||||||
user_route=user_route,
|
|
||||||
allowed_routes=litellm_proxy_roles.admin_allowed_routes,
|
|
||||||
)
|
|
||||||
return is_allowed
|
|
||||||
|
|
||||||
elif user_role == "team":
|
elif user_role == "team":
|
||||||
if litellm_proxy_roles.team_allowed_routes is None:
|
if litellm_proxy_roles.team_allowed_routes is None:
|
||||||
|
|
|
@ -167,10 +167,17 @@ class JWTHandler:
|
||||||
for key in keys:
|
for key in keys:
|
||||||
if kid is not None and key == kid:
|
if kid is not None and key == kid:
|
||||||
public_key = keys[key]
|
public_key = keys[key]
|
||||||
|
elif (
|
||||||
|
kid is not None
|
||||||
|
and isinstance(key, dict)
|
||||||
|
and key.get("kid", None) is not None
|
||||||
|
and key["kid"] == kid
|
||||||
|
):
|
||||||
|
public_key = key
|
||||||
|
|
||||||
if public_key is None:
|
if public_key is None:
|
||||||
raise Exception(
|
raise Exception(
|
||||||
f"No matching public key found. kid={kid}, keys_url={keys_url}, cached_keys={cached_keys}"
|
f"No matching public key found. kid={kid}, keys_url={keys_url}, cached_keys={cached_keys}, len(keys)={len(keys)}"
|
||||||
)
|
)
|
||||||
|
|
||||||
return public_key
|
return public_key
|
||||||
|
|
|
@ -497,7 +497,7 @@ async def user_api_key_auth(
|
||||||
litellm_proxy_roles=jwt_handler.litellm_jwtauth,
|
litellm_proxy_roles=jwt_handler.litellm_jwtauth,
|
||||||
)
|
)
|
||||||
if is_allowed:
|
if is_allowed:
|
||||||
return UserAPIKeyAuth()
|
return UserAPIKeyAuth(user_role="proxy_admin")
|
||||||
else:
|
else:
|
||||||
allowed_routes = (
|
allowed_routes = (
|
||||||
jwt_handler.litellm_jwtauth.admin_allowed_routes
|
jwt_handler.litellm_jwtauth.admin_allowed_routes
|
||||||
|
|
|
@ -12,7 +12,7 @@ sys.path.insert(
|
||||||
0, os.path.abspath("../..")
|
0, os.path.abspath("../..")
|
||||||
) # Adds the parent directory to the system path
|
) # Adds the parent directory to the system path
|
||||||
import pytest
|
import pytest
|
||||||
from litellm.proxy._types import LiteLLM_JWTAuth
|
from litellm.proxy._types import LiteLLM_JWTAuth, LiteLLMRoutes
|
||||||
from litellm.proxy.auth.handle_jwt import JWTHandler
|
from litellm.proxy.auth.handle_jwt import JWTHandler
|
||||||
from litellm.caching import DualCache
|
from litellm.caching import DualCache
|
||||||
from datetime import datetime, timedelta
|
from datetime import datetime, timedelta
|
||||||
|
@ -602,3 +602,128 @@ async def test_user_token_output(
|
||||||
assert team_result.team_rpm_limit == 99
|
assert team_result.team_rpm_limit == 99
|
||||||
assert team_result.team_models == ["gpt-3.5-turbo", "gpt-4"]
|
assert team_result.team_models == ["gpt-3.5-turbo", "gpt-4"]
|
||||||
assert team_result.user_id == user_id
|
assert team_result.user_id == user_id
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("audience", [None, "litellm-proxy"])
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_allowed_routes_admin(prisma_client, audience):
|
||||||
|
"""
|
||||||
|
Add a check to make sure jwt proxy admin scope can access all allowed admin routes
|
||||||
|
|
||||||
|
- iterate through allowed endpoints
|
||||||
|
- check if admin passes user_api_key_auth for them
|
||||||
|
"""
|
||||||
|
import jwt, json
|
||||||
|
from cryptography.hazmat.primitives import serialization
|
||||||
|
from cryptography.hazmat.primitives.asymmetric import rsa
|
||||||
|
from cryptography.hazmat.backends import default_backend
|
||||||
|
from fastapi import Request
|
||||||
|
from starlette.datastructures import URL
|
||||||
|
from litellm.proxy.proxy_server import user_api_key_auth, new_team
|
||||||
|
from litellm.proxy._types import NewTeamRequest, UserAPIKeyAuth
|
||||||
|
import litellm
|
||||||
|
import uuid
|
||||||
|
|
||||||
|
setattr(litellm.proxy.proxy_server, "prisma_client", prisma_client)
|
||||||
|
await litellm.proxy.proxy_server.prisma_client.connect()
|
||||||
|
|
||||||
|
os.environ.pop("JWT_AUDIENCE", None)
|
||||||
|
if audience:
|
||||||
|
os.environ["JWT_AUDIENCE"] = audience
|
||||||
|
|
||||||
|
# Generate a private / public key pair using RSA algorithm
|
||||||
|
key = rsa.generate_private_key(
|
||||||
|
public_exponent=65537, key_size=2048, backend=default_backend()
|
||||||
|
)
|
||||||
|
# Get private key in PEM format
|
||||||
|
private_key = key.private_bytes(
|
||||||
|
encoding=serialization.Encoding.PEM,
|
||||||
|
format=serialization.PrivateFormat.PKCS8,
|
||||||
|
encryption_algorithm=serialization.NoEncryption(),
|
||||||
|
)
|
||||||
|
|
||||||
|
# Get public key in PEM format
|
||||||
|
public_key = key.public_key().public_bytes(
|
||||||
|
encoding=serialization.Encoding.PEM,
|
||||||
|
format=serialization.PublicFormat.SubjectPublicKeyInfo,
|
||||||
|
)
|
||||||
|
|
||||||
|
public_key_obj = serialization.load_pem_public_key(
|
||||||
|
public_key, backend=default_backend()
|
||||||
|
)
|
||||||
|
|
||||||
|
# Convert RSA public key object to JWK (JSON Web Key)
|
||||||
|
public_jwk = json.loads(jwt.algorithms.RSAAlgorithm.to_jwk(public_key_obj))
|
||||||
|
|
||||||
|
assert isinstance(public_jwk, dict)
|
||||||
|
|
||||||
|
# set cache
|
||||||
|
cache = DualCache()
|
||||||
|
|
||||||
|
await cache.async_set_cache(key="litellm_jwt_auth_keys", value=[public_jwk])
|
||||||
|
|
||||||
|
jwt_handler = JWTHandler()
|
||||||
|
|
||||||
|
jwt_handler.user_api_key_cache = cache
|
||||||
|
|
||||||
|
jwt_handler.litellm_jwtauth = LiteLLM_JWTAuth(team_id_jwt_field="client_id")
|
||||||
|
|
||||||
|
# VALID TOKEN
|
||||||
|
## GENERATE A TOKEN
|
||||||
|
# Assuming the current time is in UTC
|
||||||
|
expiration_time = int((datetime.utcnow() + timedelta(minutes=10)).timestamp())
|
||||||
|
|
||||||
|
# Generate the JWT token
|
||||||
|
# But before, you should convert bytes to string
|
||||||
|
private_key_str = private_key.decode("utf-8")
|
||||||
|
|
||||||
|
## admin token
|
||||||
|
payload = {
|
||||||
|
"sub": "user123",
|
||||||
|
"exp": expiration_time, # set the token to expire in 10 minutes
|
||||||
|
"scope": "litellm_proxy_admin",
|
||||||
|
"aud": audience,
|
||||||
|
}
|
||||||
|
|
||||||
|
admin_token = jwt.encode(payload, private_key_str, algorithm="RS256")
|
||||||
|
|
||||||
|
# verify token
|
||||||
|
|
||||||
|
response = await jwt_handler.auth_jwt(token=admin_token)
|
||||||
|
|
||||||
|
## RUN IT THROUGH USER API KEY AUTH
|
||||||
|
|
||||||
|
"""
|
||||||
|
- 1. Initial call should fail -> team doesn't exist
|
||||||
|
- 2. Create team via admin token
|
||||||
|
- 3. 2nd call w/ same team -> call should succeed -> assert UserAPIKeyAuth object correctly formatted
|
||||||
|
"""
|
||||||
|
|
||||||
|
bearer_token = "Bearer " + admin_token
|
||||||
|
|
||||||
|
pseudo_routes = jwt_handler.litellm_jwtauth.admin_allowed_routes
|
||||||
|
|
||||||
|
actual_routes = []
|
||||||
|
for route in pseudo_routes:
|
||||||
|
if route in LiteLLMRoutes.__members__:
|
||||||
|
actual_routes.extend(LiteLLMRoutes[route].value)
|
||||||
|
|
||||||
|
for route in actual_routes:
|
||||||
|
request = Request(scope={"type": "http"})
|
||||||
|
|
||||||
|
request._url = URL(url=route)
|
||||||
|
|
||||||
|
## 1. INITIAL TEAM CALL - should fail
|
||||||
|
# use generated key to auth in
|
||||||
|
setattr(
|
||||||
|
litellm.proxy.proxy_server,
|
||||||
|
"general_settings",
|
||||||
|
{
|
||||||
|
"enable_jwt_auth": True,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
setattr(litellm.proxy.proxy_server, "jwt_handler", jwt_handler)
|
||||||
|
try:
|
||||||
|
result = await user_api_key_auth(request=request, api_key=bearer_token)
|
||||||
|
except Exception as e:
|
||||||
|
raise e
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue