fix(proxy_server.py): enable jwt-auth for users

allow a user to auth into the proxy via jwt's and call allowed routes
This commit is contained in:
Krrish Dholakia 2024-03-22 17:08:10 -07:00
parent 9bf086386e
commit d06b9a5a47
4 changed files with 86 additions and 12 deletions

View file

@ -602,7 +602,8 @@ general_settings:
"completion_model": "string", "completion_model": "string",
"disable_spend_logs": "boolean", # turn off writing each transaction to the db "disable_spend_logs": "boolean", # turn off writing each transaction to the db
"disable_reset_budget": "boolean", # turn off reset budget scheduled task "disable_reset_budget": "boolean", # turn off reset budget scheduled task
"enable_jwt_auth": "boolean", # allow proxy admin to auth in via jwt tokens with 'litellm_proxy_admin' in claims "enable_jwt_auth": "boolean", # allow proxy admin to auth in via jwt tokens with 'litellm_proxy_admin' in claims
"allowed_routes": "list", # list of allowed proxy API routes - a user can access. (currently JWT-Auth only)
"key_management_system": "google_kms", # either google_kms or azure_kms "key_management_system": "google_kms", # either google_kms or azure_kms
"master_key": "string", "master_key": "string",
"database_url": "string", "database_url": "string",

View file

@ -531,6 +531,9 @@ class ConfigGeneralSettings(LiteLLMBase):
ui_access_mode: Optional[Literal["admin_only", "all"]] = Field( ui_access_mode: Optional[Literal["admin_only", "all"]] = Field(
"all", description="Control access to the Proxy UI" "all", description="Control access to the Proxy UI"
) )
allowed_routes: Optional[List] = Field(
None, description="Proxy API Endpoints you want users to be able to access"
)
class ConfigYAML(LiteLLMBase): class ConfigYAML(LiteLLMBase):

View file

@ -9,7 +9,7 @@ Run checks for:
3. If end_user ('user' passed to /chat/completions, /embeddings endpoint) is in budget 3. If end_user ('user' passed to /chat/completions, /embeddings endpoint) is in budget
""" """
from litellm.proxy._types import LiteLLM_UserTable, LiteLLM_EndUserTable from litellm.proxy._types import LiteLLM_UserTable, LiteLLM_EndUserTable
from typing import Optional from typing import Optional, Literal
from litellm.proxy.utils import PrismaClient from litellm.proxy.utils import PrismaClient
from litellm.caching import DualCache from litellm.caching import DualCache
@ -19,6 +19,13 @@ def common_checks(
user_object: LiteLLM_UserTable, user_object: LiteLLM_UserTable,
end_user_object: Optional[LiteLLM_EndUserTable], end_user_object: Optional[LiteLLM_EndUserTable],
) -> bool: ) -> bool:
"""
Common checks across jwt + key-based auth.
1. If user can call model
2. If user is in budget
3. If end_user ('user' passed to /chat/completions, /embeddings endpoint) is in budget
"""
_model = request_body.get("model", None) _model = request_body.get("model", None)
# 1. If user can call model # 1. If user can call model
if ( if (
@ -47,6 +54,52 @@ def common_checks(
return True return True
def allowed_routes_check(
user_role: Literal["proxy_admin", "app_owner"],
route: str,
allowed_routes: Optional[list] = None,
) -> bool:
"""
Check if user -> not admin - allowed to access these routes
"""
openai_routes = [
# chat completions
"/openai/deployments/{model}/chat/completions",
"/chat/completions",
"/v1/chat/completions",
# completions
# embeddings
"/openai/deployments/{model}/embeddings",
"/embeddings",
"/v1/embeddings",
# image generation
"/images/generations",
"/v1/images/generations",
# audio transcription
"/audio/transcriptions",
"/v1/audio/transcriptions",
# moderations
"/moderations",
"/v1/moderations",
# models
"/models",
"/v1/models",
]
info_routes = ["/key/info", "/team/info", "/user/info", "/model/info"]
default_routes = openai_routes + info_routes
if user_role == "proxy_admin":
return True
elif user_role == "app_owner":
if allowed_routes is None:
if route in default_routes: # check default routes
return True
elif route in allowed_routes:
return True
else:
return False
return False
async def get_end_user_object( async def get_end_user_object(
end_user_id: Optional[str], end_user_id: Optional[str],
prisma_client: Optional[PrismaClient], prisma_client: Optional[PrismaClient],

View file

@ -110,7 +110,11 @@ from litellm.proxy.auth.handle_jwt import JWTHandler
from litellm.proxy.hooks.prompt_injection_detection import ( from litellm.proxy.hooks.prompt_injection_detection import (
_OPTIONAL_PromptInjectionDetection, _OPTIONAL_PromptInjectionDetection,
) )
from litellm.proxy.auth.auth_checks import common_checks, get_end_user_object from litellm.proxy.auth.auth_checks import (
common_checks,
get_end_user_object,
allowed_routes_check,
)
try: try:
from litellm._version import version from litellm._version import version
@ -332,7 +336,7 @@ def _get_pydantic_json_dict(pydantic_obj: BaseModel) -> dict:
async def user_api_key_auth( async def user_api_key_auth(
request: Request, api_key: str = fastapi.Security(api_key_header) request: Request, api_key: str = fastapi.Security(api_key_header)
) -> UserAPIKeyAuth: ) -> UserAPIKeyAuth:
global master_key, prisma_client, llm_model_list, user_custom_auth, custom_db_client global master_key, prisma_client, llm_model_list, user_custom_auth, custom_db_client, general_settings
try: try:
if isinstance(api_key, str): if isinstance(api_key, str):
passed_in_key = api_key passed_in_key = api_key
@ -354,6 +358,7 @@ async def user_api_key_auth(
enable_jwt_auth: true enable_jwt_auth: true
``` ```
""" """
route: str = request.url.path
if general_settings.get("enable_jwt_auth", False) == True: if general_settings.get("enable_jwt_auth", False) == True:
is_jwt = jwt_handler.is_jwt(token=api_key) is_jwt = jwt_handler.is_jwt(token=api_key)
verbose_proxy_logger.debug(f"is_jwt: {is_jwt}") verbose_proxy_logger.debug(f"is_jwt: {is_jwt}")
@ -407,15 +412,28 @@ async def user_api_key_auth(
user_id=user_id, user_id=user_id,
) )
else: else:
# return UserAPIKeyAuth object is_allowed = allowed_routes_check(
return UserAPIKeyAuth(
api_key=None,
user_id=user_object.user_id,
tpm_limit=user_object.tpm_limit,
rpm_limit=user_object.rpm_limit,
models=user_object.models,
user_role="app_owner", user_role="app_owner",
route=route,
allowed_routes=general_settings.get("allowed_routes", None),
) )
if is_allowed:
# return UserAPIKeyAuth object
return UserAPIKeyAuth(
api_key=None,
user_id=user_object.user_id,
tpm_limit=user_object.tpm_limit,
rpm_limit=user_object.rpm_limit,
models=user_object.models,
user_role="app_owner",
)
else:
raise HTTPException(
status_code=401,
detail={
"error": f"User={user_object.user_id} not allowed to access this route={route}."
},
)
#### ELSE #### #### ELSE ####
if master_key is None: if master_key is None:
if isinstance(api_key, str): if isinstance(api_key, str):
@ -423,7 +441,6 @@ async def user_api_key_auth(
else: else:
return UserAPIKeyAuth() return UserAPIKeyAuth()
route: str = request.url.path
if route == "/user/auth": if route == "/user/auth":
if general_settings.get("allow_user_auth", False) == True: if general_settings.get("allow_user_auth", False) == True:
return UserAPIKeyAuth() return UserAPIKeyAuth()