diff --git a/litellm/proxy/auth/oauth2_check.py b/litellm/proxy/auth/oauth2_check.py index 37973e5237..92c18ad5f4 100644 --- a/litellm/proxy/auth/oauth2_check.py +++ b/litellm/proxy/auth/oauth2_check.py @@ -1,9 +1,3 @@ -import os -from typing import Literal - -import httpx - -from litellm.llms.custom_httpx.http_handler import _get_async_httpx_client from litellm.proxy._types import UserAPIKeyAuth @@ -20,6 +14,15 @@ async def check_oauth2_token(token: str) -> UserAPIKeyAuth: Raises: ValueError: If the token is invalid, the request fails, or the token info endpoint is not set. """ + import os + from typing import Literal + + import httpx + + from litellm._logging import verbose_proxy_logger + from litellm.llms.custom_httpx.http_handler import _get_async_httpx_client + + verbose_proxy_logger.debug("Oauth2 token validation for token=%s", token) # Get the token info endpoint from environment variable token_info_endpoint = os.getenv("OAUTH_TOKEN_INFO_ENDPOINT") user_id_field_name = os.environ.get("OAUTH_USER_ID_FIELD_NAME", "sub") @@ -30,7 +33,6 @@ async def check_oauth2_token(token: str) -> UserAPIKeyAuth: raise ValueError("OAUTH_TOKEN_INFO_ENDPOINT environment variable is not set") client = _get_async_httpx_client() - headers = {"Authorization": f"Bearer {token}", "Content-Type": "application/json"} try: @@ -42,6 +44,12 @@ async def check_oauth2_token(token: str) -> UserAPIKeyAuth: # If we get here, the request was successful data = response.json() + verbose_proxy_logger.debug( + "Oauth2 token validation for token=%s, response from /token/info=%s", + token, + data, + ) + # You might want to add additional checks here based on the response # For example, checking if the token is expired or has the correct scope user_id = data.get(user_id_field_name) diff --git a/litellm/proxy/auth/user_api_key_auth.py b/litellm/proxy/auth/user_api_key_auth.py index 00e78f64e6..f947b6fb71 100644 --- a/litellm/proxy/auth/user_api_key_auth.py +++ b/litellm/proxy/auth/user_api_key_auth.py @@ -62,6 +62,7 @@ from litellm.proxy.auth.auth_utils import ( is_llm_api_route, route_in_additonal_public_routes, ) +from litellm.proxy.auth.oauth2_check import check_oauth2_token from litellm.proxy.common_utils.http_parsing_utils import _read_request_body from litellm.proxy.utils import _to_ns @@ -197,6 +198,11 @@ async def user_api_key_auth( # check if public endpoint return UserAPIKeyAuth(user_role=LitellmUserRoles.INTERNAL_USER_VIEW_ONLY) + if general_settings.get("enable_oauth2_auth", False) is True: + # return UserAPIKeyAuth object + # helper to check if the api_key is a valid oauth2 token + return await check_oauth2_token(token=api_key) + if general_settings.get("enable_jwt_auth", False) is True: is_jwt = jwt_handler.is_jwt(token=api_key) verbose_proxy_logger.debug("is_jwt: %s", is_jwt)