allow using oauth2 checks for logging into proxy

This commit is contained in:
Ishaan Jaff 2024-08-16 13:36:29 -07:00
parent 0c0b835c3f
commit d4b33cf87c

View file

@ -3,15 +3,11 @@ from typing import Literal
import httpx import httpx
from litellm.llms.custom_httpx.http_handler import ( from litellm.llms.custom_httpx.http_handler import _get_async_httpx_client
AsyncHTTPHandler, from litellm.proxy._types import UserAPIKeyAuth
HTTPHandler,
_get_async_httpx_client,
_get_httpx_client,
)
async def check_oauth2_token(token: str) -> Literal[True]: async def check_oauth2_token(token: str) -> UserAPIKeyAuth:
""" """
Makes a request to the token info endpoint to validate the OAuth2 token. Makes a request to the token info endpoint to validate the OAuth2 token.
@ -26,6 +22,9 @@ async def check_oauth2_token(token: str) -> Literal[True]:
""" """
# Get the token info endpoint from environment variable # Get the token info endpoint from environment variable
token_info_endpoint = os.getenv("OAUTH_TOKEN_INFO_ENDPOINT") token_info_endpoint = os.getenv("OAUTH_TOKEN_INFO_ENDPOINT")
user_id_field_name = os.environ.get("OAUTH_USER_ID_FIELD_NAME", "sub")
user_role_field_name = os.environ.get("OAUTH_USER_ROLE_FIELD_NAME", "role")
user_team_id_field_name = os.environ.get("OAUTH_USER_TEAM_ID_FIELD_NAME", "team_id")
if not token_info_endpoint: if not token_info_endpoint:
raise ValueError("OAUTH_TOKEN_INFO_ENDPOINT environment variable is not set") raise ValueError("OAUTH_TOKEN_INFO_ENDPOINT environment variable is not set")
@ -45,11 +44,19 @@ async def check_oauth2_token(token: str) -> Literal[True]:
# You might want to add additional checks here based on the response # You might want to add additional checks here based on the response
# For example, checking if the token is expired or has the correct scope # For example, checking if the token is expired or has the correct scope
user_id = data.get(user_id_field_name)
user_team_id = data.get(user_team_id_field_name)
user_role = data.get(user_role_field_name)
return True return UserAPIKeyAuth(
api_key=token,
team_id=user_team_id,
user_id=user_id,
user_role=user_role,
)
except httpx.HTTPStatusError as e: except httpx.HTTPStatusError as e:
# This will catch any 4xx or 5xx errors # This will catch any 4xx or 5xx errors
raise ValueError(f"Token validation failed: {e}") raise ValueError(f"Oauth 2.0 Token validation failed: {e}")
except Exception as e: except Exception as e:
# This will catch any other errors (like network issues) # This will catch any other errors (like network issues)
raise ValueError(f"An error occurred during token validation: {e}") raise ValueError(f"An error occurred during token validation: {e}")