forked from phoenix/litellm-mirror
fix(handle_jwt.py): support public key caching ttl param
This commit is contained in:
parent
d90f44fe8e
commit
752516df1b
6 changed files with 26 additions and 25 deletions
|
@ -124,7 +124,7 @@ general_settings:
|
||||||
### Allowed LiteLLM scopes
|
### Allowed LiteLLM scopes
|
||||||
|
|
||||||
```python
|
```python
|
||||||
class LiteLLMProxyRoles(LiteLLMBase):
|
class LiteLLM_JWTAuth(LiteLLMBase):
|
||||||
proxy_admin: str = "litellm_proxy_admin"
|
proxy_admin: str = "litellm_proxy_admin"
|
||||||
proxy_user: str = "litellm_user" # 👈 Not implemented yet, for JWT-Auth.
|
proxy_user: str = "litellm_user" # 👈 Not implemented yet, for JWT-Auth.
|
||||||
```
|
```
|
||||||
|
|
|
@ -90,7 +90,7 @@ class LiteLLMRoutes(enum.Enum):
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
class LiteLLMProxyRoles(LiteLLMBase):
|
class LiteLLM_JWTAuth(LiteLLMBase):
|
||||||
"""
|
"""
|
||||||
A class to define the roles and permissions for a LiteLLM Proxy w/ JWT Auth.
|
A class to define the roles and permissions for a LiteLLM Proxy w/ JWT Auth.
|
||||||
|
|
||||||
|
@ -115,6 +115,7 @@ class LiteLLMProxyRoles(LiteLLMBase):
|
||||||
Literal["openai_routes", "info_routes", "management_routes"]
|
Literal["openai_routes", "info_routes", "management_routes"]
|
||||||
] = ["openai_routes", "info_routes"]
|
] = ["openai_routes", "info_routes"]
|
||||||
end_user_id_jwt_field: Optional[str] = "sub"
|
end_user_id_jwt_field: Optional[str] = "sub"
|
||||||
|
public_key_ttl: float = 600
|
||||||
|
|
||||||
def __init__(self, **kwargs: Any) -> None:
|
def __init__(self, **kwargs: Any) -> None:
|
||||||
# get the attribute names for this Pydantic model
|
# get the attribute names for this Pydantic model
|
||||||
|
|
|
@ -11,7 +11,7 @@ Run checks for:
|
||||||
from litellm.proxy._types import (
|
from litellm.proxy._types import (
|
||||||
LiteLLM_UserTable,
|
LiteLLM_UserTable,
|
||||||
LiteLLM_EndUserTable,
|
LiteLLM_EndUserTable,
|
||||||
LiteLLMProxyRoles,
|
LiteLLM_JWTAuth,
|
||||||
LiteLLM_TeamTable,
|
LiteLLM_TeamTable,
|
||||||
LiteLLMRoutes,
|
LiteLLMRoutes,
|
||||||
)
|
)
|
||||||
|
@ -88,7 +88,7 @@ def _allowed_routes_check(user_route: str, allowed_routes: list) -> bool:
|
||||||
def allowed_routes_check(
|
def allowed_routes_check(
|
||||||
user_role: Literal["proxy_admin", "team"],
|
user_role: Literal["proxy_admin", "team"],
|
||||||
user_route: str,
|
user_route: str,
|
||||||
litellm_proxy_roles: LiteLLMProxyRoles,
|
litellm_proxy_roles: LiteLLM_JWTAuth,
|
||||||
) -> bool:
|
) -> bool:
|
||||||
"""
|
"""
|
||||||
Check if user -> not admin - allowed to access these routes
|
Check if user -> not admin - allowed to access these routes
|
||||||
|
|
|
@ -12,7 +12,7 @@ import json
|
||||||
import os
|
import os
|
||||||
from litellm.caching import DualCache
|
from litellm.caching import DualCache
|
||||||
from litellm._logging import verbose_proxy_logger
|
from litellm._logging import verbose_proxy_logger
|
||||||
from litellm.proxy._types import LiteLLMProxyRoles, LiteLLM_UserTable
|
from litellm.proxy._types import LiteLLM_JWTAuth, LiteLLM_UserTable
|
||||||
from litellm.proxy.utils import PrismaClient
|
from litellm.proxy.utils import PrismaClient
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
|
@ -70,30 +70,30 @@ class JWTHandler:
|
||||||
self,
|
self,
|
||||||
prisma_client: Optional[PrismaClient],
|
prisma_client: Optional[PrismaClient],
|
||||||
user_api_key_cache: DualCache,
|
user_api_key_cache: DualCache,
|
||||||
litellm_proxy_roles: LiteLLMProxyRoles,
|
litellm_jwtauth: LiteLLM_JWTAuth,
|
||||||
) -> None:
|
) -> None:
|
||||||
self.prisma_client = prisma_client
|
self.prisma_client = prisma_client
|
||||||
self.user_api_key_cache = user_api_key_cache
|
self.user_api_key_cache = user_api_key_cache
|
||||||
self.litellm_proxy_roles = litellm_proxy_roles
|
self.litellm_jwtauth = litellm_jwtauth
|
||||||
|
|
||||||
def is_jwt(self, token: str):
|
def is_jwt(self, token: str):
|
||||||
parts = token.split(".")
|
parts = token.split(".")
|
||||||
return len(parts) == 3
|
return len(parts) == 3
|
||||||
|
|
||||||
def is_admin(self, scopes: list) -> bool:
|
def is_admin(self, scopes: list) -> bool:
|
||||||
if self.litellm_proxy_roles.admin_jwt_scope in scopes:
|
if self.litellm_jwtauth.admin_jwt_scope in scopes:
|
||||||
return True
|
return True
|
||||||
return False
|
return False
|
||||||
|
|
||||||
def is_team(self, scopes: list) -> bool:
|
def is_team(self, scopes: list) -> bool:
|
||||||
if self.litellm_proxy_roles.team_jwt_scope in scopes:
|
if self.litellm_jwtauth.team_jwt_scope in scopes:
|
||||||
return True
|
return True
|
||||||
return False
|
return False
|
||||||
|
|
||||||
def get_end_user_id(self, token: dict, default_value: Optional[str]) -> str:
|
def get_end_user_id(self, token: dict, default_value: Optional[str]) -> str:
|
||||||
try:
|
try:
|
||||||
if self.litellm_proxy_roles.team_id_jwt_field is not None:
|
if self.litellm_jwtauth.end_user_id_jwt_field is not None:
|
||||||
user_id = token[self.litellm_proxy_roles.team_id_jwt_field]
|
user_id = token[self.litellm_jwtauth.end_user_id_jwt_field]
|
||||||
else:
|
else:
|
||||||
user_id = None
|
user_id = None
|
||||||
except KeyError:
|
except KeyError:
|
||||||
|
@ -102,7 +102,7 @@ class JWTHandler:
|
||||||
|
|
||||||
def get_team_id(self, token: dict, default_value: Optional[str]) -> Optional[str]:
|
def get_team_id(self, token: dict, default_value: Optional[str]) -> Optional[str]:
|
||||||
try:
|
try:
|
||||||
team_id = token[self.litellm_proxy_roles.team_id_jwt_field]
|
team_id = token[self.litellm_jwtauth.team_id_jwt_field]
|
||||||
except KeyError:
|
except KeyError:
|
||||||
team_id = default_value
|
team_id = default_value
|
||||||
return team_id
|
return team_id
|
||||||
|
@ -137,7 +137,9 @@ class JWTHandler:
|
||||||
keys = response.json()["keys"]
|
keys = response.json()["keys"]
|
||||||
|
|
||||||
await self.user_api_key_cache.async_set_cache(
|
await self.user_api_key_cache.async_set_cache(
|
||||||
key="litellm_jwt_auth_keys", value=keys, ttl=600 # cache for 10 mins
|
key="litellm_jwt_auth_keys",
|
||||||
|
value=keys,
|
||||||
|
ttl=self.litellm_jwtauth.public_key_ttl, # cache for 10 mins
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
keys = cached_keys
|
keys = cached_keys
|
||||||
|
|
|
@ -378,13 +378,13 @@ async def user_api_key_auth(
|
||||||
is_allowed = allowed_routes_check(
|
is_allowed = allowed_routes_check(
|
||||||
user_role="proxy_admin",
|
user_role="proxy_admin",
|
||||||
user_route=route,
|
user_route=route,
|
||||||
litellm_proxy_roles=jwt_handler.litellm_proxy_roles,
|
litellm_proxy_roles=jwt_handler.litellm_jwtauth,
|
||||||
)
|
)
|
||||||
if is_allowed:
|
if is_allowed:
|
||||||
return UserAPIKeyAuth()
|
return UserAPIKeyAuth()
|
||||||
else:
|
else:
|
||||||
allowed_routes = (
|
allowed_routes = (
|
||||||
jwt_handler.litellm_proxy_roles.admin_allowed_routes
|
jwt_handler.litellm_jwtauth.admin_allowed_routes
|
||||||
)
|
)
|
||||||
actual_routes = get_actual_routes(allowed_routes=allowed_routes)
|
actual_routes = get_actual_routes(allowed_routes=allowed_routes)
|
||||||
raise Exception(
|
raise Exception(
|
||||||
|
@ -394,23 +394,23 @@ async def user_api_key_auth(
|
||||||
is_team = jwt_handler.is_team(scopes=scopes)
|
is_team = jwt_handler.is_team(scopes=scopes)
|
||||||
if is_team == False:
|
if is_team == False:
|
||||||
raise Exception(
|
raise Exception(
|
||||||
f"Missing both Admin and Team scopes from token. Either is required. Admin Scope={jwt_handler.litellm_proxy_roles.admin_jwt_scope}, Team Scope={jwt_handler.litellm_proxy_roles.team_jwt_scope}"
|
f"Missing both Admin and Team scopes from token. Either is required. Admin Scope={jwt_handler.litellm_jwtauth.admin_jwt_scope}, Team Scope={jwt_handler.litellm_jwtauth.team_jwt_scope}"
|
||||||
)
|
)
|
||||||
# get team id
|
# get team id
|
||||||
team_id = jwt_handler.get_team_id(token=valid_token, default_value=None)
|
team_id = jwt_handler.get_team_id(token=valid_token, default_value=None)
|
||||||
|
|
||||||
if team_id is None:
|
if team_id is None:
|
||||||
raise Exception(
|
raise Exception(
|
||||||
f"No team id passed in. Field checked in jwt token - '{jwt_handler.litellm_proxy_roles.team_id_jwt_field}'"
|
f"No team id passed in. Field checked in jwt token - '{jwt_handler.litellm_jwtauth.team_id_jwt_field}'"
|
||||||
)
|
)
|
||||||
# check allowed team routes
|
# check allowed team routes
|
||||||
is_allowed = allowed_routes_check(
|
is_allowed = allowed_routes_check(
|
||||||
user_role="team",
|
user_role="team",
|
||||||
user_route=route,
|
user_route=route,
|
||||||
litellm_proxy_roles=jwt_handler.litellm_proxy_roles,
|
litellm_proxy_roles=jwt_handler.litellm_jwtauth,
|
||||||
)
|
)
|
||||||
if is_allowed == False:
|
if is_allowed == False:
|
||||||
allowed_routes = jwt_handler.litellm_proxy_roles.team_allowed_routes
|
allowed_routes = jwt_handler.litellm_jwtauth.team_allowed_routes
|
||||||
actual_routes = get_actual_routes(allowed_routes=allowed_routes)
|
actual_routes = get_actual_routes(allowed_routes=allowed_routes)
|
||||||
raise Exception(
|
raise Exception(
|
||||||
f"Team not allowed to access this route. Route={route}, Allowed Routes={actual_routes}"
|
f"Team not allowed to access this route. Route={route}, Allowed Routes={actual_routes}"
|
||||||
|
@ -2741,11 +2741,9 @@ async def startup_event():
|
||||||
|
|
||||||
## JWT AUTH ##
|
## JWT AUTH ##
|
||||||
if general_settings.get("litellm_proxy_roles", None) is not None:
|
if general_settings.get("litellm_proxy_roles", None) is not None:
|
||||||
litellm_proxy_roles = LiteLLMProxyRoles(
|
litellm_proxy_roles = LiteLLM_JWTAuth(**general_settings["litellm_proxy_roles"])
|
||||||
**general_settings["litellm_proxy_roles"]
|
|
||||||
)
|
|
||||||
else:
|
else:
|
||||||
litellm_proxy_roles = LiteLLMProxyRoles()
|
litellm_proxy_roles = LiteLLM_JWTAuth()
|
||||||
jwt_handler.update_environment(
|
jwt_handler.update_environment(
|
||||||
prisma_client=prisma_client,
|
prisma_client=prisma_client,
|
||||||
user_api_key_cache=user_api_key_cache,
|
user_api_key_cache=user_api_key_cache,
|
||||||
|
|
|
@ -12,7 +12,7 @@ sys.path.insert(
|
||||||
0, os.path.abspath("../..")
|
0, os.path.abspath("../..")
|
||||||
) # Adds the parent directory to the system path
|
) # Adds the parent directory to the system path
|
||||||
import pytest
|
import pytest
|
||||||
from litellm.proxy._types import LiteLLMProxyRoles
|
from litellm.proxy._types import LiteLLM_JWTAuth
|
||||||
from litellm.proxy.auth.handle_jwt import JWTHandler
|
from litellm.proxy.auth.handle_jwt import JWTHandler
|
||||||
from litellm.caching import DualCache
|
from litellm.caching import DualCache
|
||||||
from datetime import datetime, timedelta
|
from datetime import datetime, timedelta
|
||||||
|
@ -32,7 +32,7 @@ def test_load_config_with_custom_role_names():
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
proxy_roles = LiteLLMProxyRoles(
|
proxy_roles = LiteLLM_JWTAuth(
|
||||||
**config.get("general_settings", {}).get("litellm_proxy_roles", {})
|
**config.get("general_settings", {}).get("litellm_proxy_roles", {})
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue