mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-26 11:14:04 +00:00
allow using oauth2 checks for logging into proxy
This commit is contained in:
parent
0c0b835c3f
commit
d4b33cf87c
1 changed files with 16 additions and 9 deletions
|
@ -3,15 +3,11 @@ from typing import Literal
|
||||||
|
|
||||||
import httpx
|
import httpx
|
||||||
|
|
||||||
from litellm.llms.custom_httpx.http_handler import (
|
from litellm.llms.custom_httpx.http_handler import _get_async_httpx_client
|
||||||
AsyncHTTPHandler,
|
from litellm.proxy._types import UserAPIKeyAuth
|
||||||
HTTPHandler,
|
|
||||||
_get_async_httpx_client,
|
|
||||||
_get_httpx_client,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
async def check_oauth2_token(token: str) -> Literal[True]:
|
async def check_oauth2_token(token: str) -> UserAPIKeyAuth:
|
||||||
"""
|
"""
|
||||||
Makes a request to the token info endpoint to validate the OAuth2 token.
|
Makes a request to the token info endpoint to validate the OAuth2 token.
|
||||||
|
|
||||||
|
@ -26,6 +22,9 @@ async def check_oauth2_token(token: str) -> Literal[True]:
|
||||||
"""
|
"""
|
||||||
# Get the token info endpoint from environment variable
|
# Get the token info endpoint from environment variable
|
||||||
token_info_endpoint = os.getenv("OAUTH_TOKEN_INFO_ENDPOINT")
|
token_info_endpoint = os.getenv("OAUTH_TOKEN_INFO_ENDPOINT")
|
||||||
|
user_id_field_name = os.environ.get("OAUTH_USER_ID_FIELD_NAME", "sub")
|
||||||
|
user_role_field_name = os.environ.get("OAUTH_USER_ROLE_FIELD_NAME", "role")
|
||||||
|
user_team_id_field_name = os.environ.get("OAUTH_USER_TEAM_ID_FIELD_NAME", "team_id")
|
||||||
|
|
||||||
if not token_info_endpoint:
|
if not token_info_endpoint:
|
||||||
raise ValueError("OAUTH_TOKEN_INFO_ENDPOINT environment variable is not set")
|
raise ValueError("OAUTH_TOKEN_INFO_ENDPOINT environment variable is not set")
|
||||||
|
@ -45,11 +44,19 @@ async def check_oauth2_token(token: str) -> Literal[True]:
|
||||||
|
|
||||||
# You might want to add additional checks here based on the response
|
# You might want to add additional checks here based on the response
|
||||||
# For example, checking if the token is expired or has the correct scope
|
# For example, checking if the token is expired or has the correct scope
|
||||||
|
user_id = data.get(user_id_field_name)
|
||||||
|
user_team_id = data.get(user_team_id_field_name)
|
||||||
|
user_role = data.get(user_role_field_name)
|
||||||
|
|
||||||
return True
|
return UserAPIKeyAuth(
|
||||||
|
api_key=token,
|
||||||
|
team_id=user_team_id,
|
||||||
|
user_id=user_id,
|
||||||
|
user_role=user_role,
|
||||||
|
)
|
||||||
except httpx.HTTPStatusError as e:
|
except httpx.HTTPStatusError as e:
|
||||||
# This will catch any 4xx or 5xx errors
|
# This will catch any 4xx or 5xx errors
|
||||||
raise ValueError(f"Token validation failed: {e}")
|
raise ValueError(f"Oauth 2.0 Token validation failed: {e}")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
# This will catch any other errors (like network issues)
|
# This will catch any other errors (like network issues)
|
||||||
raise ValueError(f"An error occurred during token validation: {e}")
|
raise ValueError(f"An error occurred during token validation: {e}")
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue