Merge pull request #2704 from BerriAI/litellm_jwt_auth_improvements_3

fix(handle_jwt.py): enable team-based jwt-auth access
This commit is contained in:
Krish Dholakia 2024-03-26 16:06:56 -07:00 committed by GitHub
commit 0ab708e6f1
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
7 changed files with 358 additions and 145 deletions

View file

@ -124,7 +124,7 @@ general_settings:
### Allowed LiteLLM scopes
```python
class LiteLLMProxyRoles(LiteLLMBase):
class LiteLLM_JWTAuth(LiteLLMBase):
proxy_admin: str = "litellm_proxy_admin"
proxy_user: str = "litellm_user" # 👈 Not implemented yet, for JWT-Auth.
```

View file

@ -1,4 +1,5 @@
from pydantic import BaseModel, Extra, Field, root_validator, Json, validator
from dataclasses import fields
import enum
from typing import Optional, List, Union, Dict, Literal, Any
from datetime import datetime
@ -37,9 +38,97 @@ class LiteLLMBase(BaseModel):
protected_namespaces = ()
class LiteLLMProxyRoles(LiteLLMBase):
proxy_admin: str = "litellm_proxy_admin"
proxy_user: str = "litellm_user"
class LiteLLMRoutes(enum.Enum):
openai_routes: List = [ # chat completions
"/openai/deployments/{model}/chat/completions",
"/chat/completions",
"/v1/chat/completions",
# completions
"/openai/deployments/{model}/completions",
"/completions",
"/v1/completions",
# embeddings
"/openai/deployments/{model}/embeddings",
"/embeddings",
"/v1/embeddings",
# image generation
"/images/generations",
"/v1/images/generations",
# audio transcription
"/audio/transcriptions",
"/v1/audio/transcriptions",
# moderations
"/moderations",
"/v1/moderations",
# models
"/models",
"/v1/models",
]
info_routes: List = ["/key/info", "/team/info", "/user/info", "/model/info"]
management_routes: List = [ # key
"/key/generate",
"/key/update",
"/key/delete",
"/key/info",
# user
"/user/new",
"/user/update",
"/user/delete",
"/user/info",
# team
"/team/new",
"/team/update",
"/team/delete",
"/team/info",
# model
"/model/new",
"/model/update",
"/model/delete",
"/model/info",
]
class LiteLLM_JWTAuth(LiteLLMBase):
"""
A class to define the roles and permissions for a LiteLLM Proxy w/ JWT Auth.
Attributes:
- admin_jwt_scope: The JWT scope required for proxy admin roles.
- admin_allowed_routes: list of allowed routes for proxy admin roles.
- team_jwt_scope: The JWT scope required for proxy team roles.
- team_id_jwt_field: The field in the JWT token that stores the team ID. Default - `client_id`.
- team_allowed_routes: list of allowed routes for proxy team roles.
- end_user_id_jwt_field: Default - `sub`. The field in the JWT token that stores the end-user ID. Turn this off by setting to `None`. Enables end-user cost tracking.
See `auth_checks.py` for the specific routes
"""
admin_jwt_scope: str = "litellm_proxy_admin"
admin_allowed_routes: List[
Literal["openai_routes", "info_routes", "management_routes"]
] = ["management_routes"]
team_jwt_scope: str = "litellm_team"
team_id_jwt_field: str = "client_id"
team_allowed_routes: List[
Literal["openai_routes", "info_routes", "management_routes"]
] = ["openai_routes", "info_routes"]
end_user_id_jwt_field: Optional[str] = "sub"
public_key_ttl: float = 600
def __init__(self, **kwargs: Any) -> None:
# get the attribute names for this Pydantic model
allowed_keys = self.__annotations__.keys()
invalid_keys = set(kwargs.keys()) - allowed_keys
if invalid_keys:
raise ValueError(
f"Invalid arguments provided: {', '.join(invalid_keys)}. Allowed arguments are: {', '.join(allowed_keys)}."
)
super().__init__(**kwargs)
class LiteLLMPromptInjectionParams(LiteLLMBase):

View file

@ -8,15 +8,23 @@ Run checks for:
2. If user is in budget
3. If end_user ('user' passed to /chat/completions, /embeddings endpoint) is in budget
"""
from litellm.proxy._types import LiteLLM_UserTable, LiteLLM_EndUserTable
from litellm.proxy._types import (
LiteLLM_UserTable,
LiteLLM_EndUserTable,
LiteLLM_JWTAuth,
LiteLLM_TeamTable,
LiteLLMRoutes,
)
from typing import Optional, Literal
from litellm.proxy.utils import PrismaClient
from litellm.caching import DualCache
all_routes = LiteLLMRoutes.openai_routes.value + LiteLLMRoutes.management_routes.value
def common_checks(
request_body: dict,
user_object: LiteLLM_UserTable,
team_object: LiteLLM_TeamTable,
end_user_object: Optional[LiteLLM_EndUserTable],
) -> bool:
"""
@ -30,19 +38,20 @@ def common_checks(
# 1. If user can call model
if (
_model is not None
and len(user_object.models) > 0
and _model not in user_object.models
and len(team_object.models) > 0
and _model not in team_object.models
):
raise Exception(
f"User={user_object.user_id} not allowed to call model={_model}. Allowed user models = {user_object.models}"
f"Team={team_object.team_id} not allowed to call model={_model}. Allowed team models = {team_object.models}"
)
# 2. If user is in budget
# 2. If team is in budget
if (
user_object.max_budget is not None
and user_object.spend > user_object.max_budget
team_object.max_budget is not None
and team_object.spend is not None
and team_object.spend > team_object.max_budget
):
raise Exception(
f"User={user_object.user_id} over budget. Spend={user_object.spend}, Budget={user_object.max_budget}"
f"Team={team_object.team_id} over budget. Spend={team_object.spend}, Budget={team_object.max_budget}"
)
# 3. If end_user ('user' passed to /chat/completions, /embeddings endpoint) is in budget
if end_user_object is not None and end_user_object.litellm_budget_table is not None:
@ -54,52 +63,79 @@ def common_checks(
return True
def _allowed_routes_check(user_route: str, allowed_routes: list) -> bool:
for allowed_route in allowed_routes:
if (
allowed_route == LiteLLMRoutes.openai_routes.name
and user_route in LiteLLMRoutes.openai_routes.value
):
return True
elif (
allowed_route == LiteLLMRoutes.info_routes.name
and user_route in LiteLLMRoutes.info_routes.value
):
return True
elif (
allowed_route == LiteLLMRoutes.management_routes.name
and user_route in LiteLLMRoutes.management_routes.value
):
return True
elif allowed_route == user_route:
return True
return False
def allowed_routes_check(
user_role: Literal["proxy_admin", "app_owner"],
route: str,
allowed_routes: Optional[list] = None,
user_role: Literal["proxy_admin", "team"],
user_route: str,
litellm_proxy_roles: LiteLLM_JWTAuth,
) -> bool:
"""
Check if user -> not admin - allowed to access these routes
"""
openai_routes = [
# chat completions
"/openai/deployments/{model}/chat/completions",
"/chat/completions",
"/v1/chat/completions",
# completions
# embeddings
"/openai/deployments/{model}/embeddings",
"/embeddings",
"/v1/embeddings",
# image generation
"/images/generations",
"/v1/images/generations",
# audio transcription
"/audio/transcriptions",
"/v1/audio/transcriptions",
# moderations
"/moderations",
"/v1/moderations",
# models
"/models",
"/v1/models",
]
info_routes = ["/key/info", "/team/info", "/user/info", "/model/info"]
default_routes = openai_routes + info_routes
if user_role == "proxy_admin":
return True
elif user_role == "app_owner":
if allowed_routes is None:
if route in default_routes: # check default routes
return True
elif route in allowed_routes:
return True
else:
return False
if litellm_proxy_roles.admin_allowed_routes is None:
is_allowed = _allowed_routes_check(
user_route=user_route, allowed_routes=["management_routes"]
)
return is_allowed
elif litellm_proxy_roles.admin_allowed_routes is not None:
is_allowed = _allowed_routes_check(
user_route=user_route,
allowed_routes=litellm_proxy_roles.admin_allowed_routes,
)
return is_allowed
elif user_role == "team":
if litellm_proxy_roles.team_allowed_routes is None:
"""
By default allow a team to call openai + info routes
"""
is_allowed = _allowed_routes_check(
user_route=user_route, allowed_routes=["openai_routes", "info_routes"]
)
return is_allowed
elif litellm_proxy_roles.team_allowed_routes is not None:
is_allowed = _allowed_routes_check(
user_route=user_route,
allowed_routes=litellm_proxy_roles.team_allowed_routes,
)
return is_allowed
return False
def get_actual_routes(allowed_routes: list) -> list:
actual_routes: list = []
for route_name in allowed_routes:
try:
route_value = LiteLLMRoutes[route_name].value
actual_routes = actual_routes + route_value
except KeyError:
actual_routes.append(route_name)
return actual_routes
async def get_end_user_object(
end_user_id: Optional[str],
prisma_client: Optional[PrismaClient],
@ -135,3 +171,75 @@ async def get_end_user_object(
return LiteLLM_EndUserTable(**response.dict())
except Exception as e: # if end-user not in db
return None
async def get_user_object(self, user_id: str) -> LiteLLM_UserTable:
"""
- Check if user id in proxy User Table
- if valid, return LiteLLM_UserTable object with defined limits
- if not, then raise an error
"""
if self.prisma_client is None:
raise Exception(
"No DB Connected. See - https://docs.litellm.ai/docs/proxy/virtual_keys"
)
# check if in cache
cached_user_obj = self.user_api_key_cache.async_get_cache(key=user_id)
if cached_user_obj is not None:
if isinstance(cached_user_obj, dict):
return LiteLLM_UserTable(**cached_user_obj)
elif isinstance(cached_user_obj, LiteLLM_UserTable):
return cached_user_obj
# else, check db
try:
response = await self.prisma_client.db.litellm_usertable.find_unique(
where={"user_id": user_id}
)
if response is None:
raise Exception
return LiteLLM_UserTable(**response.dict())
except Exception as e:
raise Exception(
f"User doesn't exist in db. User={user_id}. Create user via `/user/new` call."
)
async def get_team_object(
team_id: str,
prisma_client: Optional[PrismaClient],
user_api_key_cache: DualCache,
) -> LiteLLM_TeamTable:
"""
- Check if team id in proxy Team Table
- if valid, return LiteLLM_TeamTable object with defined limits
- if not, then raise an error
"""
if prisma_client is None:
raise Exception(
"No DB Connected. See - https://docs.litellm.ai/docs/proxy/virtual_keys"
)
# check if in cache
cached_team_obj = user_api_key_cache.async_get_cache(key=team_id)
if cached_team_obj is not None:
if isinstance(cached_team_obj, dict):
return LiteLLM_TeamTable(**cached_team_obj)
elif isinstance(cached_team_obj, LiteLLM_TeamTable):
return cached_team_obj
# else, check db
try:
response = await prisma_client.db.litellm_teamtable.find_unique(
where={"team_id": team_id}
)
if response is None:
raise Exception
return LiteLLM_TeamTable(**response.dict())
except Exception as e:
raise Exception(
f"Team doesn't exist in db. Team={team_id}. Create team via `/team/new` call."
)

View file

@ -12,7 +12,7 @@ import json
import os
from litellm.caching import DualCache
from litellm._logging import verbose_proxy_logger
from litellm.proxy._types import LiteLLMProxyRoles, LiteLLM_UserTable
from litellm.proxy._types import LiteLLM_JWTAuth, LiteLLM_UserTable
from litellm.proxy.utils import PrismaClient
from typing import Optional
@ -70,68 +70,43 @@ class JWTHandler:
self,
prisma_client: Optional[PrismaClient],
user_api_key_cache: DualCache,
litellm_proxy_roles: LiteLLMProxyRoles,
litellm_jwtauth: LiteLLM_JWTAuth,
) -> None:
self.prisma_client = prisma_client
self.user_api_key_cache = user_api_key_cache
self.litellm_proxy_roles = litellm_proxy_roles
self.litellm_jwtauth = litellm_jwtauth
def is_jwt(self, token: str):
parts = token.split(".")
return len(parts) == 3
def is_admin(self, scopes: list) -> bool:
if self.litellm_proxy_roles.proxy_admin in scopes:
if self.litellm_jwtauth.admin_jwt_scope in scopes:
return True
return False
def get_user_id(self, token: dict, default_value: str) -> str:
def is_team(self, scopes: list) -> bool:
if self.litellm_jwtauth.team_jwt_scope in scopes:
return True
return False
def get_end_user_id(self, token: dict, default_value: Optional[str]) -> str:
try:
user_id = token["sub"]
if self.litellm_jwtauth.end_user_id_jwt_field is not None:
user_id = token[self.litellm_jwtauth.end_user_id_jwt_field]
else:
user_id = None
except KeyError:
user_id = default_value
return user_id
def get_team_id(self, token: dict, default_value: Optional[str]) -> Optional[str]:
try:
team_id = token["client_id"]
team_id = token[self.litellm_jwtauth.team_id_jwt_field]
except KeyError:
team_id = default_value
return team_id
async def get_user_object(self, user_id: str) -> LiteLLM_UserTable:
"""
- Check if user id in proxy User Table
- if valid, return LiteLLM_UserTable object with defined limits
- if not, then raise an error
"""
if self.prisma_client is None:
raise Exception(
"No DB Connected. See - https://docs.litellm.ai/docs/proxy/virtual_keys"
)
# check if in cache
cached_user_obj = self.user_api_key_cache.async_get_cache(key=user_id)
if cached_user_obj is not None:
if isinstance(cached_user_obj, dict):
return LiteLLM_UserTable(**cached_user_obj)
elif isinstance(cached_user_obj, LiteLLM_UserTable):
return cached_user_obj
# else, check db
try:
response = await self.prisma_client.db.litellm_usertable.find_unique(
where={"user_id": user_id}
)
if response is None:
raise Exception
return LiteLLM_UserTable(**response.dict())
except Exception as e:
raise Exception(
f"User doesn't exist in db. User={user_id}. Create user via `/user/new` call."
)
def get_scopes(self, token: dict) -> list:
try:
if isinstance(token["scope"], str):
@ -162,7 +137,9 @@ class JWTHandler:
keys = response.json()["keys"]
await self.user_api_key_cache.async_set_cache(
key="litellm_jwt_auth_keys", value=keys, ttl=600 # cache for 10 mins
key="litellm_jwt_auth_keys",
value=keys,
ttl=self.litellm_jwtauth.public_key_ttl, # cache for 10 mins
)
else:
keys = cached_keys

View file

@ -113,7 +113,10 @@ from litellm.proxy.hooks.prompt_injection_detection import (
from litellm.proxy.auth.auth_checks import (
common_checks,
get_end_user_object,
get_team_object,
get_user_object,
allowed_routes_check,
get_actual_routes,
)
try:
@ -359,71 +362,99 @@ async def user_api_key_auth(
scopes = jwt_handler.get_scopes(token=valid_token)
# check if admin
is_admin = jwt_handler.is_admin(scopes=scopes)
# get user id
user_id = jwt_handler.get_user_id(
token=valid_token, default_value=litellm_proxy_admin_name
# if admin return
if is_admin:
# check allowed admin routes
is_allowed = allowed_routes_check(
user_role="proxy_admin",
user_route=route,
litellm_proxy_roles=jwt_handler.litellm_jwtauth,
)
if is_allowed:
return UserAPIKeyAuth()
else:
allowed_routes = (
jwt_handler.litellm_jwtauth.admin_allowed_routes
)
actual_routes = get_actual_routes(allowed_routes=allowed_routes)
raise Exception(
f"Admin not allowed to access this route. Route={route}, Allowed Routes={actual_routes}"
)
# check if team in scopes
is_team = jwt_handler.is_team(scopes=scopes)
if is_team == False:
raise Exception(
f"Missing both Admin and Team scopes from token. Either is required. Admin Scope={jwt_handler.litellm_jwtauth.admin_jwt_scope}, Team Scope={jwt_handler.litellm_jwtauth.team_jwt_scope}"
)
# get team id
team_id = jwt_handler.get_team_id(token=valid_token, default_value=None)
if team_id is None:
raise Exception(
f"No team id passed in. Field checked in jwt token - '{jwt_handler.litellm_jwtauth.team_id_jwt_field}'"
)
# check allowed team routes
is_allowed = allowed_routes_check(
user_role="team",
user_route=route,
litellm_proxy_roles=jwt_handler.litellm_jwtauth,
)
if is_allowed == False:
allowed_routes = jwt_handler.litellm_jwtauth.team_allowed_routes
actual_routes = get_actual_routes(allowed_routes=allowed_routes)
raise Exception(
f"Team not allowed to access this route. Route={route}, Allowed Routes={actual_routes}"
)
# check if team in db
team_object = await get_team_object(
team_id=team_id,
prisma_client=prisma_client,
user_api_key_cache=user_api_key_cache,
)
end_user_object = None
# common checks
# allow request
# get the request body
request_data = await _read_request_body(request=request)
# get user obj from cache/db -> run for admin too. Ensures, admin client id in db.
user_object = await jwt_handler.get_user_object(user_id=user_id)
if (
request_data.get("user", None)
and request_data["user"] != user_object.user_id
):
end_user_object = None
end_user_id = jwt_handler.get_end_user_id(
token=valid_token, default_value=None
)
if end_user_id is not None:
# get the end-user object
end_user_object = await get_end_user_object(
end_user_id=request_data["user"],
end_user_id=end_user_id,
prisma_client=prisma_client,
user_api_key_cache=user_api_key_cache,
)
# save the end-user object to cache
await user_api_key_cache.async_set_cache(
key=request_data["user"], value=end_user_object
key=end_user_id, value=end_user_object
)
# run through common checks
_ = common_checks(
request_body=request_data,
user_object=user_object,
team_object=team_object,
end_user_object=end_user_object,
)
# save user object in cache
await user_api_key_cache.async_set_cache(
key=user_object.user_id, value=user_object
key=team_object.team_id, value=team_object
)
# return UserAPIKeyAuth object
return UserAPIKeyAuth(
api_key=None,
team_id=team_object.team_id,
tpm_limit=team_object.tpm_limit,
rpm_limit=team_object.rpm_limit,
models=team_object.models,
user_role="app_owner",
)
# if admin return
if is_admin:
return UserAPIKeyAuth(
api_key=api_key,
user_role="proxy_admin",
user_id=user_id,
)
else:
is_allowed = allowed_routes_check(
user_role="app_owner",
route=route,
allowed_routes=general_settings.get("allowed_routes", None),
)
if is_allowed:
# return UserAPIKeyAuth object
return UserAPIKeyAuth(
api_key=None,
user_id=user_object.user_id,
tpm_limit=user_object.tpm_limit,
rpm_limit=user_object.rpm_limit,
models=user_object.models,
user_role="app_owner",
)
else:
raise HTTPException(
status_code=401,
detail={
"error": f"User={user_object.user_id} not allowed to access this route={route}."
},
)
#### ELSE ####
if master_key is None:
if isinstance(api_key, str):
@ -2698,12 +2729,14 @@ async def startup_event():
proxy_logging_obj._init_litellm_callbacks() # INITIALIZE LITELLM CALLBACKS ON SERVER STARTUP <- do this to catch any logging errors on startup, not when calls are being made
## JWT AUTH ##
if general_settings.get("litellm_jwtauth", None) is not None:
litellm_jwtauth = LiteLLM_JWTAuth(**general_settings["litellm_jwtauth"])
else:
litellm_jwtauth = LiteLLM_JWTAuth()
jwt_handler.update_environment(
prisma_client=prisma_client,
user_api_key_cache=user_api_key_cache,
litellm_proxy_roles=LiteLLMProxyRoles(
**general_settings.get("litellm_proxy_roles", {})
),
litellm_jwtauth=litellm_jwtauth,
)
if use_background_health_checks:

View file

@ -12,7 +12,7 @@ sys.path.insert(
0, os.path.abspath("../..")
) # Adds the parent directory to the system path
import pytest
from litellm.proxy._types import LiteLLMProxyRoles
from litellm.proxy._types import LiteLLM_JWTAuth
from litellm.proxy.auth.handle_jwt import JWTHandler
from litellm.caching import DualCache
from datetime import datetime, timedelta
@ -28,17 +28,17 @@ public_key = {
def test_load_config_with_custom_role_names():
config = {
"general_settings": {
"litellm_proxy_roles": {"proxy_admin": "litellm-proxy-admin"}
"litellm_proxy_roles": {"admin_jwt_scope": "litellm-proxy-admin"}
}
}
proxy_roles = LiteLLMProxyRoles(
proxy_roles = LiteLLM_JWTAuth(
**config.get("general_settings", {}).get("litellm_proxy_roles", {})
)
print(f"proxy_roles: {proxy_roles}")
assert proxy_roles.proxy_admin == "litellm-proxy-admin"
assert proxy_roles.admin_jwt_scope == "litellm-proxy-admin"
# test_load_config_with_custom_role_names()

View file

@ -500,7 +500,10 @@ class ModelResponse(OpenAIObject):
if choices is not None and isinstance(choices, list):
new_choices = []
for choice in choices:
_new_choice = StreamingChoices(**choice)
if isinstance(choice, StreamingChoices):
_new_choice = choice
elif isinstance(choice, dict):
_new_choice = StreamingChoices(**choice)
new_choices.append(_new_choice)
choices = new_choices
else:
@ -513,7 +516,10 @@ class ModelResponse(OpenAIObject):
if choices is not None and isinstance(choices, list):
new_choices = []
for choice in choices:
_new_choice = Choices(**choice)
if isinstance(choice, Choices):
_new_choice = choice
elif isinstance(choice, dict):
_new_choice = Choices(**choice)
new_choices.append(_new_choice)
choices = new_choices
else:
@ -7231,7 +7237,7 @@ def exception_type(
exception_mapping_worked = True
raise APIError(
status_code=original_exception.status_code,
message=f"AnthropicException - {original_exception.message}",
message=f"AnthropicException - {original_exception.message}. Handle with `litellm.APIError`.",
llm_provider="anthropic",
model=model,
request=original_exception.request,