mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-24 18:24:20 +00:00
[Feat SSO] Debug route - allow admins to debug SSO JWT fields (#9835)
* refactor SSO handler * render sso JWT on ui * docs debug sso * fix sso login flow use await * fix ui sso debug JWT * test ui sso * remove redis vl * fix redisvl==0.5.1 * fix ml dtypes * fix redisvl * fix redis vl * fix debug_sso_callback * fix linting error * fix redis semantic caching dep
This commit is contained in:
parent
08a3620414
commit
6f7e9b9728
5 changed files with 900 additions and 175 deletions
|
@ -198,6 +198,7 @@ This budget does not apply to keys created under non-default teams.
|
|||
|
||||
### Auto-add SSO users to teams
|
||||
|
||||
|
||||
1. Specify the JWT field that contains the team ids, that the user belongs to.
|
||||
|
||||
```yaml
|
||||
|
@ -207,7 +208,8 @@ general_settings:
|
|||
team_ids_jwt_field: "groups" # 👈 CAN BE ANY FIELD
|
||||
```
|
||||
|
||||
This is assuming your SSO token looks like this:
|
||||
This is assuming your SSO token looks like this. **If you need to inspect the JWT fields received from your SSO provider by LiteLLM, follow these instructions [here](#debugging-sso-jwt-fields)**
|
||||
|
||||
```
|
||||
{
|
||||
...,
|
||||
|
@ -231,6 +233,39 @@ curl -X POST '<PROXY_BASE_URL>/team/new' \
|
|||
|
||||
Here's a walkthrough of [how it works](https://www.loom.com/share/8959be458edf41fd85937452c29a33f3?sid=7ebd6d37-569a-4023-866e-e0cde67cb23e)
|
||||
|
||||
### Debugging SSO JWT fields
|
||||
|
||||
If you need to inspect the JWT fields received from your SSO provider by LiteLLM, follow these instructions. This guide walks you through setting up a debug callback to view the JWT data during the SSO process.
|
||||
|
||||
|
||||
<Image img={require('../../img/debug_sso.png')} style={{ width: '500px', height: 'auto' }} />
|
||||
<br />
|
||||
|
||||
1. Add `/sso/debug/callback` as a redirect URL in your SSO provider
|
||||
|
||||
In your SSO provider's settings, add the following URL as a new redirect (callback) URL:
|
||||
|
||||
```bash showLineNumbers title="Redirect URL"
|
||||
http://<proxy_base_url>/sso/debug/callback
|
||||
```
|
||||
|
||||
|
||||
2. Navigate to the debug login page on your browser
|
||||
|
||||
Navigate to the following URL on your browser:
|
||||
|
||||
```bash showLineNumbers title="URL to navigate to"
|
||||
https://<proxy_base_url>/sso/debug/login
|
||||
```
|
||||
|
||||
This will initiate the standard SSO flow. You will be redirected to your SSO provider's login screen, and after successful authentication, you will be redirected back to LiteLLM's debug callback route.
|
||||
|
||||
|
||||
3. View the JWT fields
|
||||
|
||||
Once redirected, you should see a page called "SSO Debug Information". This page displays the JWT fields received from your SSO provider (as shown in the image above)
|
||||
|
||||
|
||||
### Restrict Users from creating personal keys
|
||||
|
||||
This is useful if you only want users to create keys under a specific team.
|
||||
|
|
BIN
docs/my-website/img/debug_sso.png
Normal file
BIN
docs/my-website/img/debug_sso.png
Normal file
Binary file not shown.
After Width: | Height: | Size: 167 KiB |
284
litellm/proxy/common_utils/html_forms/jwt_display_template.py
Normal file
284
litellm/proxy/common_utils/html_forms/jwt_display_template.py
Normal file
|
@ -0,0 +1,284 @@
|
|||
# JWT display template for SSO debug callback
|
||||
jwt_display_template = """
|
||||
<!DOCTYPE html>
|
||||
<html lang="en">
|
||||
<head>
|
||||
<meta charset="UTF-8">
|
||||
<title>LiteLLM SSO Debug - JWT Information</title>
|
||||
<meta name="viewport" content="width=device-width, initial-scale=1.0">
|
||||
<style>
|
||||
body {
|
||||
font-family: -apple-system, BlinkMacSystemFont, 'Segoe UI', Roboto, Oxygen, Ubuntu, Cantarell, sans-serif;
|
||||
background-color: #f8fafc;
|
||||
margin: 0;
|
||||
padding: 20px;
|
||||
display: flex;
|
||||
justify-content: center;
|
||||
align-items: center;
|
||||
min-height: 100vh;
|
||||
color: #333;
|
||||
}
|
||||
|
||||
.container {
|
||||
background-color: #fff;
|
||||
padding: 40px;
|
||||
border-radius: 8px;
|
||||
box-shadow: 0 4px 12px rgba(0, 0, 0, 0.1);
|
||||
width: 800px;
|
||||
max-width: 100%;
|
||||
}
|
||||
|
||||
.logo-container {
|
||||
text-align: center;
|
||||
margin-bottom: 30px;
|
||||
}
|
||||
|
||||
.logo {
|
||||
font-size: 24px;
|
||||
font-weight: 600;
|
||||
color: #1e293b;
|
||||
}
|
||||
|
||||
h2 {
|
||||
margin: 0 0 10px;
|
||||
color: #1e293b;
|
||||
font-size: 28px;
|
||||
font-weight: 600;
|
||||
text-align: center;
|
||||
}
|
||||
|
||||
.subtitle {
|
||||
color: #64748b;
|
||||
margin: 0 0 20px;
|
||||
font-size: 16px;
|
||||
text-align: center;
|
||||
}
|
||||
|
||||
.info-box {
|
||||
background-color: #f1f5f9;
|
||||
border-radius: 6px;
|
||||
padding: 20px;
|
||||
margin-bottom: 30px;
|
||||
border-left: 4px solid #2563eb;
|
||||
}
|
||||
|
||||
.success-box {
|
||||
background-color: #f0fdf4;
|
||||
border-radius: 6px;
|
||||
padding: 20px;
|
||||
margin-bottom: 30px;
|
||||
border-left: 4px solid #16a34a;
|
||||
}
|
||||
|
||||
.info-header {
|
||||
display: flex;
|
||||
align-items: center;
|
||||
margin-bottom: 12px;
|
||||
color: #1e40af;
|
||||
font-weight: 600;
|
||||
font-size: 16px;
|
||||
}
|
||||
|
||||
.success-header {
|
||||
display: flex;
|
||||
align-items: center;
|
||||
margin-bottom: 12px;
|
||||
color: #166534;
|
||||
font-weight: 600;
|
||||
font-size: 16px;
|
||||
}
|
||||
|
||||
.info-header svg, .success-header svg {
|
||||
margin-right: 8px;
|
||||
}
|
||||
|
||||
.data-container {
|
||||
margin-top: 20px;
|
||||
}
|
||||
|
||||
.data-row {
|
||||
display: flex;
|
||||
border-bottom: 1px solid #e2e8f0;
|
||||
padding: 12px 0;
|
||||
}
|
||||
|
||||
.data-row:last-child {
|
||||
border-bottom: none;
|
||||
}
|
||||
|
||||
.data-label {
|
||||
font-weight: 500;
|
||||
color: #334155;
|
||||
width: 180px;
|
||||
flex-shrink: 0;
|
||||
}
|
||||
|
||||
.data-value {
|
||||
color: #475569;
|
||||
word-break: break-all;
|
||||
}
|
||||
|
||||
.jwt-container {
|
||||
background-color: #f8fafc;
|
||||
border-radius: 6px;
|
||||
padding: 15px;
|
||||
margin-top: 20px;
|
||||
overflow-x: auto;
|
||||
border: 1px solid #e2e8f0;
|
||||
}
|
||||
|
||||
.jwt-text {
|
||||
font-family: monospace;
|
||||
white-space: pre-wrap;
|
||||
word-break: break-all;
|
||||
margin: 0;
|
||||
color: #334155;
|
||||
}
|
||||
|
||||
.back-button {
|
||||
display: inline-block;
|
||||
background-color: #6466E9;
|
||||
color: #fff;
|
||||
text-decoration: none;
|
||||
padding: 10px 16px;
|
||||
border-radius: 6px;
|
||||
font-weight: 500;
|
||||
margin-top: 20px;
|
||||
text-align: center;
|
||||
}
|
||||
|
||||
.back-button:hover {
|
||||
background-color: #4138C2;
|
||||
text-decoration: none;
|
||||
}
|
||||
|
||||
.buttons {
|
||||
display: flex;
|
||||
gap: 10px;
|
||||
margin-top: 20px;
|
||||
}
|
||||
|
||||
.copy-button {
|
||||
background-color: #e2e8f0;
|
||||
color: #334155;
|
||||
border: none;
|
||||
padding: 8px 12px;
|
||||
border-radius: 4px;
|
||||
cursor: pointer;
|
||||
font-size: 14px;
|
||||
display: flex;
|
||||
align-items: center;
|
||||
}
|
||||
|
||||
.copy-button:hover {
|
||||
background-color: #cbd5e1;
|
||||
}
|
||||
|
||||
.copy-button svg {
|
||||
margin-right: 6px;
|
||||
}
|
||||
</style>
|
||||
</head>
|
||||
<body>
|
||||
<div class="container">
|
||||
<div class="logo-container">
|
||||
<div class="logo">
|
||||
🚅 LiteLLM
|
||||
</div>
|
||||
</div>
|
||||
<h2>SSO Debug Information</h2>
|
||||
<p class="subtitle">Results from the SSO authentication process.</p>
|
||||
|
||||
<div class="success-box">
|
||||
<div class="success-header">
|
||||
<svg xmlns="http://www.w3.org/2000/svg" width="16" height="16" viewBox="0 0 24 24" fill="none" stroke="currentColor" stroke-width="2" stroke-linecap="round" stroke-linejoin="round">
|
||||
<path d="M22 11.08V12a10 10 0 1 1-5.93-9.14"></path>
|
||||
<polyline points="22 4 12 14.01 9 11.01"></polyline>
|
||||
</svg>
|
||||
Authentication Successful
|
||||
</div>
|
||||
<p>The SSO authentication completed successfully. Below is the information returned by the provider.</p>
|
||||
</div>
|
||||
|
||||
<div class="data-container" id="userData">
|
||||
<!-- Data will be inserted here by JavaScript -->
|
||||
</div>
|
||||
|
||||
<div class="info-box">
|
||||
<div class="info-header">
|
||||
<svg xmlns="http://www.w3.org/2000/svg" width="16" height="16" viewBox="0 0 24 24" fill="none" stroke="currentColor" stroke-width="2" stroke-linecap="round" stroke-linejoin="round">
|
||||
<circle cx="12" cy="12" r="10"></circle>
|
||||
<line x1="12" y1="16" x2="12" y2="12"></line>
|
||||
<line x1="12" y1="8" x2="12.01" y2="8"></line>
|
||||
</svg>
|
||||
JSON Representation
|
||||
</div>
|
||||
<div class="jwt-container">
|
||||
<pre class="jwt-text" id="jsonData">Loading...</pre>
|
||||
</div>
|
||||
<div class="buttons">
|
||||
<button class="copy-button" onclick="copyToClipboard('jsonData')">
|
||||
<svg xmlns="http://www.w3.org/2000/svg" width="14" height="14" viewBox="0 0 24 24" fill="none" stroke="currentColor" stroke-width="2" stroke-linecap="round" stroke-linejoin="round">
|
||||
<rect x="9" y="9" width="13" height="13" rx="2" ry="2"></rect>
|
||||
<path d="M5 15H4a2 2 0 0 1-2-2V4a2 2 0 0 1 2-2h9a2 2 0 0 1 2 2v1"></path>
|
||||
</svg>
|
||||
Copy to Clipboard
|
||||
</button>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<a href="/sso/debug/login" class="back-button">
|
||||
Try Another SSO Login
|
||||
</a>
|
||||
</div>
|
||||
|
||||
<script>
|
||||
// This will be populated with the actual data from the server
|
||||
const userData = SSO_DATA;
|
||||
|
||||
function renderUserData() {
|
||||
const container = document.getElementById('userData');
|
||||
const jsonDisplay = document.getElementById('jsonData');
|
||||
|
||||
// Format JSON with indentation for display
|
||||
jsonDisplay.textContent = JSON.stringify(userData, null, 2);
|
||||
|
||||
// Clear container
|
||||
container.innerHTML = '';
|
||||
|
||||
// Add each key-value pair to the UI
|
||||
for (const [key, value] of Object.entries(userData)) {
|
||||
if (typeof value !== 'object' || value === null) {
|
||||
const row = document.createElement('div');
|
||||
row.className = 'data-row';
|
||||
|
||||
const label = document.createElement('div');
|
||||
label.className = 'data-label';
|
||||
label.textContent = key;
|
||||
|
||||
const dataValue = document.createElement('div');
|
||||
dataValue.className = 'data-value';
|
||||
dataValue.textContent = value !== null ? value : 'null';
|
||||
|
||||
row.appendChild(label);
|
||||
row.appendChild(dataValue);
|
||||
container.appendChild(row);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
function copyToClipboard(elementId) {
|
||||
const text = document.getElementById(elementId).textContent;
|
||||
navigator.clipboard.writeText(text).then(() => {
|
||||
alert('Copied to clipboard!');
|
||||
}).catch(err => {
|
||||
console.error('Could not copy text: ', err);
|
||||
});
|
||||
}
|
||||
|
||||
// Render the data when the page loads
|
||||
document.addEventListener('DOMContentLoaded', renderUserData);
|
||||
</script>
|
||||
</body>
|
||||
</html>
|
||||
"""
|
|
@ -3,6 +3,9 @@ Has all /sso/* routes
|
|||
|
||||
/sso/key/generate - handles user signing in with SSO and redirects to /sso/callback
|
||||
/sso/callback - returns JWT Redirect Response that redirects to LiteLLM UI
|
||||
|
||||
/sso/debug/login - handles user signing in with SSO and redirects to /sso/debug/callback
|
||||
/sso/debug/callback - returns the OpenID object returned by the SSO provider
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
|
@ -36,6 +39,9 @@ from litellm.proxy.common_utils.admin_ui_utils import (
|
|||
admin_ui_disabled,
|
||||
show_missing_vars_in_env,
|
||||
)
|
||||
from litellm.proxy.common_utils.html_forms.jwt_display_template import (
|
||||
jwt_display_template,
|
||||
)
|
||||
from litellm.proxy.common_utils.html_forms.ui_login import html_form
|
||||
from litellm.proxy.management_endpoints.internal_user_endpoints import new_user
|
||||
from litellm.proxy.management_endpoints.sso_helper_utils import (
|
||||
|
@ -92,131 +98,29 @@ async def google_login(request: Request): # noqa: PLR0915
|
|||
missing_env_vars = show_missing_vars_in_env()
|
||||
if missing_env_vars is not None:
|
||||
return missing_env_vars
|
||||
ui_username = os.getenv("UI_USERNAME")
|
||||
|
||||
# get url from request
|
||||
redirect_url = os.getenv("PROXY_BASE_URL", str(request.base_url))
|
||||
ui_username = os.getenv("UI_USERNAME")
|
||||
if redirect_url.endswith("/"):
|
||||
redirect_url += "sso/callback"
|
||||
else:
|
||||
redirect_url += "/sso/callback"
|
||||
# Google SSO Auth
|
||||
if google_client_id is not None:
|
||||
from fastapi_sso.sso.google import GoogleSSO
|
||||
redirect_url = SSOAuthenticationHandler.get_redirect_url_for_sso(
|
||||
request=request,
|
||||
sso_callback_route="sso/callback",
|
||||
)
|
||||
|
||||
google_client_secret = os.getenv("GOOGLE_CLIENT_SECRET", None)
|
||||
if google_client_secret is None:
|
||||
raise ProxyException(
|
||||
message="GOOGLE_CLIENT_SECRET not set. Set it in .env file",
|
||||
type=ProxyErrorTypes.auth_error,
|
||||
param="GOOGLE_CLIENT_SECRET",
|
||||
code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
)
|
||||
google_sso = GoogleSSO(
|
||||
client_id=google_client_id,
|
||||
client_secret=google_client_secret,
|
||||
redirect_uri=redirect_url,
|
||||
# Check if we should use SSO handler
|
||||
if (
|
||||
SSOAuthenticationHandler.should_use_sso_handler(
|
||||
microsoft_client_id=microsoft_client_id,
|
||||
google_client_id=google_client_id,
|
||||
generic_client_id=generic_client_id,
|
||||
)
|
||||
verbose_proxy_logger.info(
|
||||
f"In /google-login/key/generate, \nGOOGLE_REDIRECT_URI: {redirect_url}\nGOOGLE_CLIENT_ID: {google_client_id}"
|
||||
is True
|
||||
):
|
||||
return await SSOAuthenticationHandler.get_sso_login_redirect(
|
||||
redirect_url=redirect_url,
|
||||
microsoft_client_id=microsoft_client_id,
|
||||
google_client_id=google_client_id,
|
||||
generic_client_id=generic_client_id,
|
||||
)
|
||||
with google_sso:
|
||||
return await google_sso.get_login_redirect()
|
||||
# Microsoft SSO Auth
|
||||
elif microsoft_client_id is not None:
|
||||
from fastapi_sso.sso.microsoft import MicrosoftSSO
|
||||
|
||||
microsoft_client_secret = os.getenv("MICROSOFT_CLIENT_SECRET", None)
|
||||
microsoft_tenant = os.getenv("MICROSOFT_TENANT", None)
|
||||
if microsoft_client_secret is None:
|
||||
raise ProxyException(
|
||||
message="MICROSOFT_CLIENT_SECRET not set. Set it in .env file",
|
||||
type=ProxyErrorTypes.auth_error,
|
||||
param="MICROSOFT_CLIENT_SECRET",
|
||||
code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
)
|
||||
microsoft_sso = MicrosoftSSO(
|
||||
client_id=microsoft_client_id,
|
||||
client_secret=microsoft_client_secret,
|
||||
tenant=microsoft_tenant,
|
||||
redirect_uri=redirect_url,
|
||||
allow_insecure_http=True,
|
||||
)
|
||||
with microsoft_sso:
|
||||
return await microsoft_sso.get_login_redirect()
|
||||
elif generic_client_id is not None:
|
||||
from fastapi_sso.sso.base import DiscoveryDocument
|
||||
from fastapi_sso.sso.generic import create_provider
|
||||
|
||||
generic_client_secret = os.getenv("GENERIC_CLIENT_SECRET", None)
|
||||
generic_scope = os.getenv("GENERIC_SCOPE", "openid email profile").split(" ")
|
||||
generic_authorization_endpoint = os.getenv(
|
||||
"GENERIC_AUTHORIZATION_ENDPOINT", None
|
||||
)
|
||||
generic_token_endpoint = os.getenv("GENERIC_TOKEN_ENDPOINT", None)
|
||||
generic_userinfo_endpoint = os.getenv("GENERIC_USERINFO_ENDPOINT", None)
|
||||
if generic_client_secret is None:
|
||||
raise ProxyException(
|
||||
message="GENERIC_CLIENT_SECRET not set. Set it in .env file",
|
||||
type=ProxyErrorTypes.auth_error,
|
||||
param="GENERIC_CLIENT_SECRET",
|
||||
code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
)
|
||||
if generic_authorization_endpoint is None:
|
||||
raise ProxyException(
|
||||
message="GENERIC_AUTHORIZATION_ENDPOINT not set. Set it in .env file",
|
||||
type=ProxyErrorTypes.auth_error,
|
||||
param="GENERIC_AUTHORIZATION_ENDPOINT",
|
||||
code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
)
|
||||
if generic_token_endpoint is None:
|
||||
raise ProxyException(
|
||||
message="GENERIC_TOKEN_ENDPOINT not set. Set it in .env file",
|
||||
type=ProxyErrorTypes.auth_error,
|
||||
param="GENERIC_TOKEN_ENDPOINT",
|
||||
code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
)
|
||||
if generic_userinfo_endpoint is None:
|
||||
raise ProxyException(
|
||||
message="GENERIC_USERINFO_ENDPOINT not set. Set it in .env file",
|
||||
type=ProxyErrorTypes.auth_error,
|
||||
param="GENERIC_USERINFO_ENDPOINT",
|
||||
code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
)
|
||||
verbose_proxy_logger.debug(
|
||||
f"authorization_endpoint: {generic_authorization_endpoint}\ntoken_endpoint: {generic_token_endpoint}\nuserinfo_endpoint: {generic_userinfo_endpoint}"
|
||||
)
|
||||
verbose_proxy_logger.debug(
|
||||
f"GENERIC_REDIRECT_URI: {redirect_url}\nGENERIC_CLIENT_ID: {generic_client_id}\n"
|
||||
)
|
||||
discovery = DiscoveryDocument(
|
||||
authorization_endpoint=generic_authorization_endpoint,
|
||||
token_endpoint=generic_token_endpoint,
|
||||
userinfo_endpoint=generic_userinfo_endpoint,
|
||||
)
|
||||
SSOProvider = create_provider(name="oidc", discovery_document=discovery)
|
||||
generic_sso = SSOProvider(
|
||||
client_id=generic_client_id,
|
||||
client_secret=generic_client_secret,
|
||||
redirect_uri=redirect_url,
|
||||
allow_insecure_http=True,
|
||||
scope=generic_scope,
|
||||
)
|
||||
with generic_sso:
|
||||
# TODO: state should be a random string and added to the user session with cookie
|
||||
# or a cryptographicly signed state that we can verify stateless
|
||||
# For simplification we are using a static state, this is not perfect but some
|
||||
# SSO providers do not allow stateless verification
|
||||
redirect_params = {}
|
||||
state = os.getenv("GENERIC_CLIENT_STATE", None)
|
||||
|
||||
if state:
|
||||
redirect_params["state"] = state
|
||||
elif "okta" in generic_authorization_endpoint:
|
||||
redirect_params[
|
||||
"state"
|
||||
] = uuid.uuid4().hex # set state param for okta - required
|
||||
return await generic_sso.get_login_redirect(**redirect_params) # type: ignore
|
||||
elif ui_username is not None:
|
||||
# No Google, Microsoft SSO
|
||||
# Use UI Credentials set in .env
|
||||
|
@ -271,7 +175,7 @@ async def get_generic_sso_response(
|
|||
jwt_handler: JWTHandler,
|
||||
generic_client_id: str,
|
||||
redirect_url: str,
|
||||
) -> Optional[OpenID]:
|
||||
) -> Union[OpenID, dict]:
|
||||
# make generic sso provider
|
||||
from fastapi_sso.sso.base import DiscoveryDocument
|
||||
from fastapi_sso.sso.generic import create_provider
|
||||
|
@ -348,7 +252,7 @@ async def get_generic_sso_response(
|
|||
request, params={"include_client_id": generic_include_client_id}
|
||||
)
|
||||
verbose_proxy_logger.debug("generic result: %s", result)
|
||||
return result
|
||||
return result or {}
|
||||
|
||||
|
||||
async def create_team_member_add_task(team_id, user_info):
|
||||
|
@ -443,54 +347,16 @@ async def auth_callback(request: Request): # noqa: PLR0915
|
|||
|
||||
result = None
|
||||
if google_client_id is not None:
|
||||
from fastapi_sso.sso.google import GoogleSSO
|
||||
|
||||
google_client_secret = os.getenv("GOOGLE_CLIENT_SECRET", None)
|
||||
if google_client_secret is None:
|
||||
raise ProxyException(
|
||||
message="GOOGLE_CLIENT_SECRET not set. Set it in .env file",
|
||||
type=ProxyErrorTypes.auth_error,
|
||||
param="GOOGLE_CLIENT_SECRET",
|
||||
code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
)
|
||||
google_sso = GoogleSSO(
|
||||
client_id=google_client_id,
|
||||
redirect_uri=redirect_url,
|
||||
client_secret=google_client_secret,
|
||||
)
|
||||
result = await google_sso.verify_and_process(request)
|
||||
elif microsoft_client_id is not None:
|
||||
from fastapi_sso.sso.microsoft import MicrosoftSSO
|
||||
|
||||
microsoft_client_secret = os.getenv("MICROSOFT_CLIENT_SECRET", None)
|
||||
microsoft_tenant = os.getenv("MICROSOFT_TENANT", None)
|
||||
if microsoft_client_secret is None:
|
||||
raise ProxyException(
|
||||
message="MICROSOFT_CLIENT_SECRET not set. Set it in .env file",
|
||||
type=ProxyErrorTypes.auth_error,
|
||||
param="MICROSOFT_CLIENT_SECRET",
|
||||
code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
)
|
||||
if microsoft_tenant is None:
|
||||
raise ProxyException(
|
||||
message="MICROSOFT_TENANT not set. Set it in .env file",
|
||||
type=ProxyErrorTypes.auth_error,
|
||||
param="MICROSOFT_TENANT",
|
||||
code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
)
|
||||
microsoft_sso = MicrosoftSSO(
|
||||
client_id=microsoft_client_id,
|
||||
client_secret=microsoft_client_secret,
|
||||
tenant=microsoft_tenant,
|
||||
redirect_uri=redirect_url,
|
||||
allow_insecure_http=True,
|
||||
)
|
||||
original_msft_result = await microsoft_sso.verify_and_process(
|
||||
result = await GoogleSSOHandler.get_google_callback_response(
|
||||
request=request,
|
||||
convert_response=False,
|
||||
google_client_id=google_client_id,
|
||||
redirect_url=redirect_url,
|
||||
)
|
||||
result = MicrosoftSSOHandler.openid_from_response(
|
||||
response=original_msft_result,
|
||||
elif microsoft_client_id is not None:
|
||||
result = await MicrosoftSSOHandler.get_microsoft_callback_response(
|
||||
request=request,
|
||||
microsoft_client_id=microsoft_client_id,
|
||||
redirect_url=redirect_url,
|
||||
jwt_handler=jwt_handler,
|
||||
)
|
||||
elif generic_client_id is not None:
|
||||
|
@ -705,7 +571,7 @@ async def auth_callback(request: Request): # noqa: PLR0915
|
|||
|
||||
|
||||
async def insert_sso_user(
|
||||
result_openid: Optional[OpenID],
|
||||
result_openid: Optional[Union[OpenID, dict]],
|
||||
user_defined_values: Optional[SSOUserDefinedValues] = None,
|
||||
) -> NewUserResponse:
|
||||
"""
|
||||
|
@ -721,6 +587,10 @@ async def insert_sso_user(
|
|||
verbose_proxy_logger.debug(
|
||||
f"Inserting SSO user into DB. User values: {user_defined_values}"
|
||||
)
|
||||
if result_openid is None:
|
||||
raise ValueError("result_openid is None")
|
||||
if isinstance(result_openid, dict):
|
||||
result_openid = OpenID(**result_openid)
|
||||
|
||||
if user_defined_values is None:
|
||||
raise ValueError("user_defined_values is None")
|
||||
|
@ -733,9 +603,9 @@ async def insert_sso_user(
|
|||
if user_defined_values.get("max_budget") is None:
|
||||
user_defined_values["max_budget"] = litellm.max_internal_user_budget
|
||||
if user_defined_values.get("budget_duration") is None:
|
||||
user_defined_values[
|
||||
"budget_duration"
|
||||
] = litellm.internal_user_budget_duration
|
||||
user_defined_values["budget_duration"] = (
|
||||
litellm.internal_user_budget_duration
|
||||
)
|
||||
|
||||
if user_defined_values["user_role"] is None:
|
||||
user_defined_values["user_role"] = LitellmUserRoles.INTERNAL_USER_VIEW_ONLY
|
||||
|
@ -789,11 +659,242 @@ async def get_ui_settings(request: Request):
|
|||
}
|
||||
|
||||
|
||||
class SSOAuthenticationHandler:
|
||||
"""
|
||||
Handler for SSO Authentication across all SSO providers
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
async def get_sso_login_redirect(
|
||||
redirect_url: str,
|
||||
google_client_id: Optional[str] = None,
|
||||
microsoft_client_id: Optional[str] = None,
|
||||
generic_client_id: Optional[str] = None,
|
||||
) -> Optional[RedirectResponse]:
|
||||
"""
|
||||
Step 1. Call Get Login Redirect for the SSO provider. Send the redirect response to `redirect_url`
|
||||
|
||||
Args:
|
||||
redirect_url (str): The URL to redirect the user to after login
|
||||
google_client_id (Optional[str], optional): The Google Client ID. Defaults to None.
|
||||
microsoft_client_id (Optional[str], optional): The Microsoft Client ID. Defaults to None.
|
||||
generic_client_id (Optional[str], optional): The Generic Client ID. Defaults to None.
|
||||
|
||||
Returns:
|
||||
RedirectResponse: The redirect response from the SSO provider
|
||||
"""
|
||||
# Google SSO Auth
|
||||
if google_client_id is not None:
|
||||
from fastapi_sso.sso.google import GoogleSSO
|
||||
|
||||
google_client_secret = os.getenv("GOOGLE_CLIENT_SECRET", None)
|
||||
if google_client_secret is None:
|
||||
raise ProxyException(
|
||||
message="GOOGLE_CLIENT_SECRET not set. Set it in .env file",
|
||||
type=ProxyErrorTypes.auth_error,
|
||||
param="GOOGLE_CLIENT_SECRET",
|
||||
code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
)
|
||||
google_sso = GoogleSSO(
|
||||
client_id=google_client_id,
|
||||
client_secret=google_client_secret,
|
||||
redirect_uri=redirect_url,
|
||||
)
|
||||
verbose_proxy_logger.info(
|
||||
f"In /google-login/key/generate, \nGOOGLE_REDIRECT_URI: {redirect_url}\nGOOGLE_CLIENT_ID: {google_client_id}"
|
||||
)
|
||||
with google_sso:
|
||||
return await google_sso.get_login_redirect()
|
||||
# Microsoft SSO Auth
|
||||
elif microsoft_client_id is not None:
|
||||
from fastapi_sso.sso.microsoft import MicrosoftSSO
|
||||
|
||||
microsoft_client_secret = os.getenv("MICROSOFT_CLIENT_SECRET", None)
|
||||
microsoft_tenant = os.getenv("MICROSOFT_TENANT", None)
|
||||
if microsoft_client_secret is None:
|
||||
raise ProxyException(
|
||||
message="MICROSOFT_CLIENT_SECRET not set. Set it in .env file",
|
||||
type=ProxyErrorTypes.auth_error,
|
||||
param="MICROSOFT_CLIENT_SECRET",
|
||||
code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
)
|
||||
microsoft_sso = MicrosoftSSO(
|
||||
client_id=microsoft_client_id,
|
||||
client_secret=microsoft_client_secret,
|
||||
tenant=microsoft_tenant,
|
||||
redirect_uri=redirect_url,
|
||||
allow_insecure_http=True,
|
||||
)
|
||||
with microsoft_sso:
|
||||
return await microsoft_sso.get_login_redirect()
|
||||
elif generic_client_id is not None:
|
||||
from fastapi_sso.sso.base import DiscoveryDocument
|
||||
from fastapi_sso.sso.generic import create_provider
|
||||
|
||||
generic_client_secret = os.getenv("GENERIC_CLIENT_SECRET", None)
|
||||
generic_scope = os.getenv("GENERIC_SCOPE", "openid email profile").split(
|
||||
" "
|
||||
)
|
||||
generic_authorization_endpoint = os.getenv(
|
||||
"GENERIC_AUTHORIZATION_ENDPOINT", None
|
||||
)
|
||||
generic_token_endpoint = os.getenv("GENERIC_TOKEN_ENDPOINT", None)
|
||||
generic_userinfo_endpoint = os.getenv("GENERIC_USERINFO_ENDPOINT", None)
|
||||
if generic_client_secret is None:
|
||||
raise ProxyException(
|
||||
message="GENERIC_CLIENT_SECRET not set. Set it in .env file",
|
||||
type=ProxyErrorTypes.auth_error,
|
||||
param="GENERIC_CLIENT_SECRET",
|
||||
code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
)
|
||||
if generic_authorization_endpoint is None:
|
||||
raise ProxyException(
|
||||
message="GENERIC_AUTHORIZATION_ENDPOINT not set. Set it in .env file",
|
||||
type=ProxyErrorTypes.auth_error,
|
||||
param="GENERIC_AUTHORIZATION_ENDPOINT",
|
||||
code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
)
|
||||
if generic_token_endpoint is None:
|
||||
raise ProxyException(
|
||||
message="GENERIC_TOKEN_ENDPOINT not set. Set it in .env file",
|
||||
type=ProxyErrorTypes.auth_error,
|
||||
param="GENERIC_TOKEN_ENDPOINT",
|
||||
code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
)
|
||||
if generic_userinfo_endpoint is None:
|
||||
raise ProxyException(
|
||||
message="GENERIC_USERINFO_ENDPOINT not set. Set it in .env file",
|
||||
type=ProxyErrorTypes.auth_error,
|
||||
param="GENERIC_USERINFO_ENDPOINT",
|
||||
code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
)
|
||||
verbose_proxy_logger.debug(
|
||||
f"authorization_endpoint: {generic_authorization_endpoint}\ntoken_endpoint: {generic_token_endpoint}\nuserinfo_endpoint: {generic_userinfo_endpoint}"
|
||||
)
|
||||
verbose_proxy_logger.debug(
|
||||
f"GENERIC_REDIRECT_URI: {redirect_url}\nGENERIC_CLIENT_ID: {generic_client_id}\n"
|
||||
)
|
||||
discovery = DiscoveryDocument(
|
||||
authorization_endpoint=generic_authorization_endpoint,
|
||||
token_endpoint=generic_token_endpoint,
|
||||
userinfo_endpoint=generic_userinfo_endpoint,
|
||||
)
|
||||
SSOProvider = create_provider(name="oidc", discovery_document=discovery)
|
||||
generic_sso = SSOProvider(
|
||||
client_id=generic_client_id,
|
||||
client_secret=generic_client_secret,
|
||||
redirect_uri=redirect_url,
|
||||
allow_insecure_http=True,
|
||||
scope=generic_scope,
|
||||
)
|
||||
with generic_sso:
|
||||
# TODO: state should be a random string and added to the user session with cookie
|
||||
# or a cryptographicly signed state that we can verify stateless
|
||||
# For simplification we are using a static state, this is not perfect but some
|
||||
# SSO providers do not allow stateless verification
|
||||
redirect_params = {}
|
||||
state = os.getenv("GENERIC_CLIENT_STATE", None)
|
||||
|
||||
if state:
|
||||
redirect_params["state"] = state
|
||||
elif "okta" in generic_authorization_endpoint:
|
||||
redirect_params["state"] = (
|
||||
uuid.uuid4().hex
|
||||
) # set state param for okta - required
|
||||
return await generic_sso.get_login_redirect(**redirect_params) # type: ignore
|
||||
raise ValueError(
|
||||
"Unknown SSO provider. Please setup SSO with client IDs https://docs.litellm.ai/docs/proxy/admin_ui_sso"
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def should_use_sso_handler(
|
||||
google_client_id: Optional[str] = None,
|
||||
microsoft_client_id: Optional[str] = None,
|
||||
generic_client_id: Optional[str] = None,
|
||||
) -> bool:
|
||||
if (
|
||||
google_client_id is not None
|
||||
or microsoft_client_id is not None
|
||||
or generic_client_id is not None
|
||||
):
|
||||
return True
|
||||
return False
|
||||
|
||||
@staticmethod
|
||||
def get_redirect_url_for_sso(
|
||||
request: Request,
|
||||
sso_callback_route: str,
|
||||
) -> str:
|
||||
"""
|
||||
Get the redirect URL for SSO
|
||||
"""
|
||||
redirect_url = os.getenv("PROXY_BASE_URL", str(request.base_url))
|
||||
if redirect_url.endswith("/"):
|
||||
redirect_url += sso_callback_route
|
||||
else:
|
||||
redirect_url += "/" + sso_callback_route
|
||||
return redirect_url
|
||||
|
||||
|
||||
class MicrosoftSSOHandler:
|
||||
"""
|
||||
Handles Microsoft SSO callback response and returns a CustomOpenID object
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
async def get_microsoft_callback_response(
|
||||
request: Request,
|
||||
microsoft_client_id: str,
|
||||
redirect_url: str,
|
||||
jwt_handler: JWTHandler,
|
||||
return_raw_sso_response: bool = False,
|
||||
) -> Union[CustomOpenID, OpenID, dict]:
|
||||
"""
|
||||
Get the Microsoft SSO callback response
|
||||
|
||||
Args:
|
||||
return_raw_sso_response: If True, return the raw SSO response
|
||||
"""
|
||||
from fastapi_sso.sso.microsoft import MicrosoftSSO
|
||||
|
||||
microsoft_client_secret = os.getenv("MICROSOFT_CLIENT_SECRET", None)
|
||||
microsoft_tenant = os.getenv("MICROSOFT_TENANT", None)
|
||||
if microsoft_client_secret is None:
|
||||
raise ProxyException(
|
||||
message="MICROSOFT_CLIENT_SECRET not set. Set it in .env file",
|
||||
type=ProxyErrorTypes.auth_error,
|
||||
param="MICROSOFT_CLIENT_SECRET",
|
||||
code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
)
|
||||
if microsoft_tenant is None:
|
||||
raise ProxyException(
|
||||
message="MICROSOFT_TENANT not set. Set it in .env file",
|
||||
type=ProxyErrorTypes.auth_error,
|
||||
param="MICROSOFT_TENANT",
|
||||
code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
)
|
||||
microsoft_sso = MicrosoftSSO(
|
||||
client_id=microsoft_client_id,
|
||||
client_secret=microsoft_client_secret,
|
||||
tenant=microsoft_tenant,
|
||||
redirect_uri=redirect_url,
|
||||
allow_insecure_http=True,
|
||||
)
|
||||
original_msft_result = await microsoft_sso.verify_and_process(
|
||||
request=request,
|
||||
convert_response=False,
|
||||
)
|
||||
|
||||
# if user is trying to get the raw sso response for debugging, return the raw sso response
|
||||
if return_raw_sso_response:
|
||||
return original_msft_result or {}
|
||||
|
||||
result = MicrosoftSSOHandler.openid_from_response(
|
||||
response=original_msft_result,
|
||||
jwt_handler=jwt_handler,
|
||||
)
|
||||
return result
|
||||
|
||||
@staticmethod
|
||||
def openid_from_response(
|
||||
response: Optional[dict], jwt_handler: JWTHandler
|
||||
|
@ -811,3 +912,181 @@ class MicrosoftSSOHandler:
|
|||
)
|
||||
verbose_proxy_logger.debug(f"Microsoft SSO OpenID Response: {openid_response}")
|
||||
return openid_response
|
||||
|
||||
|
||||
class GoogleSSOHandler:
|
||||
"""
|
||||
Handles Google SSO callback response and returns a CustomOpenID object
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
async def get_google_callback_response(
|
||||
request: Request,
|
||||
google_client_id: str,
|
||||
redirect_url: str,
|
||||
return_raw_sso_response: bool = False,
|
||||
) -> Union[OpenID, dict]:
|
||||
"""
|
||||
Get the Google SSO callback response
|
||||
|
||||
Args:
|
||||
return_raw_sso_response: If True, return the raw SSO response
|
||||
"""
|
||||
from fastapi_sso.sso.google import GoogleSSO
|
||||
|
||||
google_client_secret = os.getenv("GOOGLE_CLIENT_SECRET", None)
|
||||
if google_client_secret is None:
|
||||
raise ProxyException(
|
||||
message="GOOGLE_CLIENT_SECRET not set. Set it in .env file",
|
||||
type=ProxyErrorTypes.auth_error,
|
||||
param="GOOGLE_CLIENT_SECRET",
|
||||
code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
)
|
||||
google_sso = GoogleSSO(
|
||||
client_id=google_client_id,
|
||||
redirect_uri=redirect_url,
|
||||
client_secret=google_client_secret,
|
||||
)
|
||||
|
||||
# if user is trying to get the raw sso response for debugging, return the raw sso response
|
||||
if return_raw_sso_response:
|
||||
return (
|
||||
await google_sso.verify_and_process(
|
||||
request=request,
|
||||
convert_response=False,
|
||||
)
|
||||
or {}
|
||||
)
|
||||
|
||||
result = await google_sso.verify_and_process(request)
|
||||
return result or {}
|
||||
|
||||
|
||||
@router.get("/sso/debug/login", tags=["experimental"], include_in_schema=False)
|
||||
async def debug_sso_login(request: Request):
|
||||
"""
|
||||
Create Proxy API Keys using Google Workspace SSO. Requires setting PROXY_BASE_URL in .env
|
||||
PROXY_BASE_URL should be the your deployed proxy endpoint, e.g. PROXY_BASE_URL="https://litellm-production-7002.up.railway.app/"
|
||||
Example:
|
||||
"""
|
||||
from litellm.proxy.proxy_server import premium_user
|
||||
|
||||
microsoft_client_id = os.getenv("MICROSOFT_CLIENT_ID", None)
|
||||
google_client_id = os.getenv("GOOGLE_CLIENT_ID", None)
|
||||
generic_client_id = os.getenv("GENERIC_CLIENT_ID", None)
|
||||
|
||||
####### Check if user is a Enterprise / Premium User #######
|
||||
if (
|
||||
microsoft_client_id is not None
|
||||
or google_client_id is not None
|
||||
or generic_client_id is not None
|
||||
):
|
||||
if premium_user is not True:
|
||||
raise ProxyException(
|
||||
message="You must be a LiteLLM Enterprise user to use SSO. If you have a license please set `LITELLM_LICENSE` in your env. If you want to obtain a license meet with us here: https://calendly.com/d/4mp-gd3-k5k/litellm-1-1-onboarding-chat You are seeing this error message because You set one of `MICROSOFT_CLIENT_ID`, `GOOGLE_CLIENT_ID`, or `GENERIC_CLIENT_ID` in your env. Please unset this",
|
||||
type=ProxyErrorTypes.auth_error,
|
||||
param="premium_user",
|
||||
code=status.HTTP_403_FORBIDDEN,
|
||||
)
|
||||
|
||||
# get url from request
|
||||
redirect_url = SSOAuthenticationHandler.get_redirect_url_for_sso(
|
||||
request=request,
|
||||
sso_callback_route="sso/debug/callback",
|
||||
)
|
||||
|
||||
# Check if we should use SSO handler
|
||||
if (
|
||||
SSOAuthenticationHandler.should_use_sso_handler(
|
||||
microsoft_client_id=microsoft_client_id,
|
||||
google_client_id=google_client_id,
|
||||
generic_client_id=generic_client_id,
|
||||
)
|
||||
is True
|
||||
):
|
||||
return await SSOAuthenticationHandler.get_sso_login_redirect(
|
||||
redirect_url=redirect_url,
|
||||
microsoft_client_id=microsoft_client_id,
|
||||
google_client_id=google_client_id,
|
||||
generic_client_id=generic_client_id,
|
||||
)
|
||||
|
||||
|
||||
@router.get("/sso/debug/callback", tags=["experimental"], include_in_schema=False)
|
||||
async def debug_sso_callback(request: Request):
|
||||
"""
|
||||
Returns the OpenID object returned by the SSO provider
|
||||
"""
|
||||
import json
|
||||
|
||||
from fastapi.responses import HTMLResponse
|
||||
|
||||
from litellm.proxy.proxy_server import jwt_handler
|
||||
|
||||
microsoft_client_id = os.getenv("MICROSOFT_CLIENT_ID", None)
|
||||
google_client_id = os.getenv("GOOGLE_CLIENT_ID", None)
|
||||
generic_client_id = os.getenv("GENERIC_CLIENT_ID", None)
|
||||
|
||||
redirect_url = os.getenv("PROXY_BASE_URL", str(request.base_url))
|
||||
if redirect_url.endswith("/"):
|
||||
redirect_url += "sso/debug/callback"
|
||||
else:
|
||||
redirect_url += "/sso/debug/callback"
|
||||
|
||||
result = None
|
||||
if google_client_id is not None:
|
||||
result = await GoogleSSOHandler.get_google_callback_response(
|
||||
request=request,
|
||||
google_client_id=google_client_id,
|
||||
redirect_url=redirect_url,
|
||||
return_raw_sso_response=True,
|
||||
)
|
||||
elif microsoft_client_id is not None:
|
||||
result = await MicrosoftSSOHandler.get_microsoft_callback_response(
|
||||
request=request,
|
||||
microsoft_client_id=microsoft_client_id,
|
||||
redirect_url=redirect_url,
|
||||
jwt_handler=jwt_handler,
|
||||
return_raw_sso_response=True,
|
||||
)
|
||||
elif generic_client_id is not None:
|
||||
result = await get_generic_sso_response(
|
||||
request=request,
|
||||
jwt_handler=jwt_handler,
|
||||
generic_client_id=generic_client_id,
|
||||
redirect_url=redirect_url,
|
||||
)
|
||||
|
||||
# If result is None, return a basic error message
|
||||
if result is None:
|
||||
return HTMLResponse(
|
||||
content="<h1>SSO Authentication Failed</h1><p>No data was returned from the SSO provider.</p>",
|
||||
status_code=400,
|
||||
)
|
||||
|
||||
# Convert the OpenID object to a dictionary
|
||||
if hasattr(result, "__dict__"):
|
||||
result_dict = result.__dict__
|
||||
else:
|
||||
result_dict = dict(result)
|
||||
|
||||
# Filter out any None values and convert to JSON serializable format
|
||||
filtered_result = {}
|
||||
for key, value in result_dict.items():
|
||||
if value is not None and not key.startswith("_"):
|
||||
if isinstance(value, (str, int, float, bool)) or value is None:
|
||||
filtered_result[key] = value
|
||||
else:
|
||||
try:
|
||||
# Try to convert to string or another JSON serializable format
|
||||
filtered_result[key] = str(value)
|
||||
except Exception as e:
|
||||
filtered_result[key] = f"Complex value (not displayable): {str(e)}"
|
||||
|
||||
# Replace the placeholder in the template with the actual data
|
||||
html_content = jwt_display_template.replace(
|
||||
"const userData = SSO_DATA;",
|
||||
f"const userData = {json.dumps(filtered_result, indent=2)};",
|
||||
)
|
||||
|
||||
return HTMLResponse(content=html_content)
|
||||
|
|
|
@ -1,3 +1,4 @@
|
|||
import asyncio
|
||||
import json
|
||||
import os
|
||||
import sys
|
||||
|
@ -5,15 +6,19 @@ from typing import Optional, cast
|
|||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from fastapi import Request
|
||||
from fastapi.testclient import TestClient
|
||||
|
||||
sys.path.insert(
|
||||
0, os.path.abspath("../../..")
|
||||
0, os.path.abspath("../../../")
|
||||
) # Adds the parent directory to the system path
|
||||
|
||||
from litellm.proxy.auth.handle_jwt import JWTHandler
|
||||
from litellm.proxy.management_endpoints.types import CustomOpenID
|
||||
from litellm.proxy.management_endpoints.ui_sso import MicrosoftSSOHandler
|
||||
from litellm.proxy.management_endpoints.ui_sso import (
|
||||
GoogleSSOHandler,
|
||||
MicrosoftSSOHandler,
|
||||
)
|
||||
|
||||
|
||||
def test_microsoft_sso_handler_openid_from_response():
|
||||
|
@ -79,3 +84,125 @@ def test_microsoft_sso_handler_with_empty_response():
|
|||
|
||||
# Make sure the JWT handler was called with an empty dict
|
||||
mock_jwt_handler.get_team_ids_from_jwt.assert_called_once_with({})
|
||||
|
||||
|
||||
def test_get_microsoft_callback_response():
|
||||
# Arrange
|
||||
mock_request = MagicMock(spec=Request)
|
||||
mock_jwt_handler = MagicMock(spec=JWTHandler)
|
||||
mock_response = {
|
||||
"mail": "microsoft_user@example.com",
|
||||
"displayName": "Microsoft User",
|
||||
"id": "msft123",
|
||||
"givenName": "Microsoft",
|
||||
"surname": "User",
|
||||
}
|
||||
|
||||
future = asyncio.Future()
|
||||
future.set_result(mock_response)
|
||||
|
||||
with patch.dict(
|
||||
os.environ,
|
||||
{"MICROSOFT_CLIENT_SECRET": "mock_secret", "MICROSOFT_TENANT": "mock_tenant"},
|
||||
):
|
||||
with patch(
|
||||
"fastapi_sso.sso.microsoft.MicrosoftSSO.verify_and_process",
|
||||
return_value=future,
|
||||
):
|
||||
# Act
|
||||
result = asyncio.run(
|
||||
MicrosoftSSOHandler.get_microsoft_callback_response(
|
||||
request=mock_request,
|
||||
microsoft_client_id="mock_client_id",
|
||||
redirect_url="http://mock_redirect_url",
|
||||
jwt_handler=mock_jwt_handler,
|
||||
)
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert isinstance(result, CustomOpenID)
|
||||
assert result.email == "microsoft_user@example.com"
|
||||
assert result.display_name == "Microsoft User"
|
||||
assert result.provider == "microsoft"
|
||||
assert result.id == "msft123"
|
||||
assert result.first_name == "Microsoft"
|
||||
assert result.last_name == "User"
|
||||
|
||||
|
||||
def test_get_microsoft_callback_response_raw_sso_response():
|
||||
# Arrange
|
||||
mock_request = MagicMock(spec=Request)
|
||||
mock_jwt_handler = MagicMock(spec=JWTHandler)
|
||||
mock_response = {
|
||||
"mail": "microsoft_user@example.com",
|
||||
"displayName": "Microsoft User",
|
||||
"id": "msft123",
|
||||
"givenName": "Microsoft",
|
||||
"surname": "User",
|
||||
}
|
||||
|
||||
future = asyncio.Future()
|
||||
future.set_result(mock_response)
|
||||
with patch.dict(
|
||||
os.environ,
|
||||
{"MICROSOFT_CLIENT_SECRET": "mock_secret", "MICROSOFT_TENANT": "mock_tenant"},
|
||||
):
|
||||
with patch(
|
||||
"fastapi_sso.sso.microsoft.MicrosoftSSO.verify_and_process",
|
||||
return_value=future,
|
||||
):
|
||||
# Act
|
||||
result = asyncio.run(
|
||||
MicrosoftSSOHandler.get_microsoft_callback_response(
|
||||
request=mock_request,
|
||||
microsoft_client_id="mock_client_id",
|
||||
redirect_url="http://mock_redirect_url",
|
||||
jwt_handler=mock_jwt_handler,
|
||||
return_raw_sso_response=True,
|
||||
)
|
||||
)
|
||||
|
||||
# Assert
|
||||
print("result from verify_and_process", result)
|
||||
assert isinstance(result, dict)
|
||||
assert result["mail"] == "microsoft_user@example.com"
|
||||
assert result["displayName"] == "Microsoft User"
|
||||
assert result["id"] == "msft123"
|
||||
assert result["givenName"] == "Microsoft"
|
||||
assert result["surname"] == "User"
|
||||
|
||||
|
||||
def test_get_google_callback_response():
|
||||
# Arrange
|
||||
mock_request = MagicMock(spec=Request)
|
||||
mock_response = {
|
||||
"email": "google_user@example.com",
|
||||
"name": "Google User",
|
||||
"sub": "google123",
|
||||
"given_name": "Google",
|
||||
"family_name": "User",
|
||||
}
|
||||
|
||||
future = asyncio.Future()
|
||||
future.set_result(mock_response)
|
||||
|
||||
with patch.dict(os.environ, {"GOOGLE_CLIENT_SECRET": "mock_secret"}):
|
||||
with patch(
|
||||
"fastapi_sso.sso.google.GoogleSSO.verify_and_process", return_value=future
|
||||
):
|
||||
# Act
|
||||
result = asyncio.run(
|
||||
GoogleSSOHandler.get_google_callback_response(
|
||||
request=mock_request,
|
||||
google_client_id="mock_client_id",
|
||||
redirect_url="http://mock_redirect_url",
|
||||
)
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert isinstance(result, dict)
|
||||
assert result.get("email") == "google_user@example.com"
|
||||
assert result.get("name") == "Google User"
|
||||
assert result.get("sub") == "google123"
|
||||
assert result.get("given_name") == "Google"
|
||||
assert result.get("family_name") == "User"
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue