diff --git a/docs/my-website/docs/proxy/configs.md b/docs/my-website/docs/proxy/configs.md index dbb8979f9..e3de37881 100644 --- a/docs/my-website/docs/proxy/configs.md +++ b/docs/my-website/docs/proxy/configs.md @@ -602,7 +602,8 @@ general_settings: "completion_model": "string", "disable_spend_logs": "boolean", # turn off writing each transaction to the db "disable_reset_budget": "boolean", # turn off reset budget scheduled task - "enable_jwt_auth": "boolean", # allow proxy admin to auth in via jwt tokens with 'litellm_proxy_admin' in claims + "enable_jwt_auth": "boolean", # allow proxy admin to auth in via jwt tokens with 'litellm_proxy_admin' in claims + "allowed_routes": "list", # list of allowed proxy API routes - a user can access. (currently JWT-Auth only) "key_management_system": "google_kms", # either google_kms or azure_kms "master_key": "string", "database_url": "string", diff --git a/litellm/proxy/_types.py b/litellm/proxy/_types.py index b5c50b143..d23049056 100644 --- a/litellm/proxy/_types.py +++ b/litellm/proxy/_types.py @@ -531,6 +531,9 @@ class ConfigGeneralSettings(LiteLLMBase): ui_access_mode: Optional[Literal["admin_only", "all"]] = Field( "all", description="Control access to the Proxy UI" ) + allowed_routes: Optional[List] = Field( + None, description="Proxy API Endpoints you want users to be able to access" + ) class ConfigYAML(LiteLLMBase): diff --git a/litellm/proxy/auth/auth_checks.py b/litellm/proxy/auth/auth_checks.py index cd326cc6d..1c16381ad 100644 --- a/litellm/proxy/auth/auth_checks.py +++ b/litellm/proxy/auth/auth_checks.py @@ -9,7 +9,7 @@ Run checks for: 3. If end_user ('user' passed to /chat/completions, /embeddings endpoint) is in budget """ from litellm.proxy._types import LiteLLM_UserTable, LiteLLM_EndUserTable -from typing import Optional +from typing import Optional, Literal from litellm.proxy.utils import PrismaClient from litellm.caching import DualCache @@ -19,6 +19,13 @@ def common_checks( user_object: LiteLLM_UserTable, end_user_object: Optional[LiteLLM_EndUserTable], ) -> bool: + """ + Common checks across jwt + key-based auth. + + 1. If user can call model + 2. If user is in budget + 3. If end_user ('user' passed to /chat/completions, /embeddings endpoint) is in budget + """ _model = request_body.get("model", None) # 1. If user can call model if ( @@ -47,6 +54,52 @@ def common_checks( return True +def allowed_routes_check( + user_role: Literal["proxy_admin", "app_owner"], + route: str, + allowed_routes: Optional[list] = None, +) -> bool: + """ + Check if user -> not admin - allowed to access these routes + """ + openai_routes = [ + # chat completions + "/openai/deployments/{model}/chat/completions", + "/chat/completions", + "/v1/chat/completions", + # completions + # embeddings + "/openai/deployments/{model}/embeddings", + "/embeddings", + "/v1/embeddings", + # image generation + "/images/generations", + "/v1/images/generations", + # audio transcription + "/audio/transcriptions", + "/v1/audio/transcriptions", + # moderations + "/moderations", + "/v1/moderations", + # models + "/models", + "/v1/models", + ] + info_routes = ["/key/info", "/team/info", "/user/info", "/model/info"] + default_routes = openai_routes + info_routes + if user_role == "proxy_admin": + return True + elif user_role == "app_owner": + if allowed_routes is None: + if route in default_routes: # check default routes + return True + elif route in allowed_routes: + return True + else: + return False + return False + + async def get_end_user_object( end_user_id: Optional[str], prisma_client: Optional[PrismaClient], diff --git a/litellm/proxy/proxy_server.py b/litellm/proxy/proxy_server.py index 738fd2d63..3b15d6ba1 100644 --- a/litellm/proxy/proxy_server.py +++ b/litellm/proxy/proxy_server.py @@ -110,7 +110,11 @@ from litellm.proxy.auth.handle_jwt import JWTHandler from litellm.proxy.hooks.prompt_injection_detection import ( _OPTIONAL_PromptInjectionDetection, ) -from litellm.proxy.auth.auth_checks import common_checks, get_end_user_object +from litellm.proxy.auth.auth_checks import ( + common_checks, + get_end_user_object, + allowed_routes_check, +) try: from litellm._version import version @@ -332,7 +336,7 @@ def _get_pydantic_json_dict(pydantic_obj: BaseModel) -> dict: async def user_api_key_auth( request: Request, api_key: str = fastapi.Security(api_key_header) ) -> UserAPIKeyAuth: - global master_key, prisma_client, llm_model_list, user_custom_auth, custom_db_client + global master_key, prisma_client, llm_model_list, user_custom_auth, custom_db_client, general_settings try: if isinstance(api_key, str): passed_in_key = api_key @@ -354,6 +358,7 @@ async def user_api_key_auth( enable_jwt_auth: true ``` """ + route: str = request.url.path if general_settings.get("enable_jwt_auth", False) == True: is_jwt = jwt_handler.is_jwt(token=api_key) verbose_proxy_logger.debug(f"is_jwt: {is_jwt}") @@ -407,15 +412,28 @@ async def user_api_key_auth( user_id=user_id, ) else: - # return UserAPIKeyAuth object - return UserAPIKeyAuth( - api_key=None, - user_id=user_object.user_id, - tpm_limit=user_object.tpm_limit, - rpm_limit=user_object.rpm_limit, - models=user_object.models, + is_allowed = allowed_routes_check( user_role="app_owner", + route=route, + allowed_routes=general_settings.get("allowed_routes", None), ) + if is_allowed: + # return UserAPIKeyAuth object + return UserAPIKeyAuth( + api_key=None, + user_id=user_object.user_id, + tpm_limit=user_object.tpm_limit, + rpm_limit=user_object.rpm_limit, + models=user_object.models, + user_role="app_owner", + ) + else: + raise HTTPException( + status_code=401, + detail={ + "error": f"User={user_object.user_id} not allowed to access this route={route}." + }, + ) #### ELSE #### if master_key is None: if isinstance(api_key, str): @@ -423,7 +441,6 @@ async def user_api_key_auth( else: return UserAPIKeyAuth() - route: str = request.url.path if route == "/user/auth": if general_settings.get("allow_user_auth", False) == True: return UserAPIKeyAuth()