diff --git a/litellm/proxy/auth/oauth2_proxy_hook.py b/litellm/proxy/auth/oauth2_proxy_hook.py new file mode 100644 index 0000000000..a1db5d842c --- /dev/null +++ b/litellm/proxy/auth/oauth2_proxy_hook.py @@ -0,0 +1,45 @@ +from typing import Any, Dict + +from fastapi import Request + +from litellm._logging import verbose_proxy_logger +from litellm.proxy._types import UserAPIKeyAuth + + +async def handle_oauth2_proxy_request(request: Request) -> UserAPIKeyAuth: + """ + Handle request from oauth2 proxy. + """ + from litellm.proxy.proxy_server import general_settings + + verbose_proxy_logger.debug("Handling oauth2 proxy request") + # Define the OAuth2 config mappings + oauth2_config_mappings: Dict[str, str] = general_settings.get( + "oauth2_config_mappings", None + ) + verbose_proxy_logger.debug(f"Oauth2 config mappings: {oauth2_config_mappings}") + + if not oauth2_config_mappings: + raise ValueError("Oauth2 config mappings not found in general_settings") + # Initialize a dictionary to store the mapped values + auth_data: Dict[str, Any] = {} + + # Extract values from headers based on the mappings + for key, header in oauth2_config_mappings.items(): + value = request.headers.get(header) + if value: + # Convert max_budget to float if present + if key == "max_budget": + auth_data[key] = float(value) + # Convert models to list if present + elif key == "models": + auth_data[key] = [model.strip() for model in value.split(",")] + else: + auth_data[key] = value + verbose_proxy_logger.debug( + f"Auth data before creating UserAPIKeyAuth object: {auth_data}" + ) + user_api_key_auth = UserAPIKeyAuth(**auth_data) + verbose_proxy_logger.debug(f"UserAPIKeyAuth object created: {user_api_key_auth}") + # Create and return UserAPIKeyAuth object + return user_api_key_auth diff --git a/litellm/proxy/auth/user_api_key_auth.py b/litellm/proxy/auth/user_api_key_auth.py index 7b1eb56786..433480bc4e 100644 --- a/litellm/proxy/auth/user_api_key_auth.py +++ b/litellm/proxy/auth/user_api_key_auth.py @@ -64,6 +64,7 @@ from litellm.proxy.auth.auth_utils import ( route_in_additonal_public_routes, ) from litellm.proxy.auth.oauth2_check import check_oauth2_token +from litellm.proxy.auth.oauth2_proxy_hook import handle_oauth2_proxy_request from litellm.proxy.common_utils.http_parsing_utils import _read_request_body from litellm.proxy.utils import _to_ns @@ -217,6 +218,9 @@ async def user_api_key_auth( return await check_oauth2_token(token=api_key) + if general_settings.get("enable_oauth2_proxy_auth", False) is True: + return await handle_oauth2_proxy_request(request=request) + if general_settings.get("enable_jwt_auth", False) is True: is_jwt = jwt_handler.is_jwt(token=api_key) verbose_proxy_logger.debug("is_jwt: %s", is_jwt) diff --git a/litellm/proxy/proxy_config.yaml b/litellm/proxy/proxy_config.yaml index 536f6e2e57..ac17a1d84c 100644 --- a/litellm/proxy/proxy_config.yaml +++ b/litellm/proxy/proxy_config.yaml @@ -13,6 +13,15 @@ model_list: model: cohere/rerank-english-v3.0 api_key: os.environ/COHERE_API_KEY +general_settings: + enable_oauth2_proxy_auth: True + oauth2_config_mappings: + token: X-Auth-Token + user_id: X-Auth-Client-ID + team_id: X-Auth-Team-ID + max_budget: X-Auth-Max-Budget + models: X-Auth-Allowed-Models + # default off mode litellm_settings: set_verbose: True \ No newline at end of file