Merge pull request #2129 from adrien-f/sso

update generic SSO login
This commit is contained in:
Ishaan Jaff 2024-02-22 13:21:53 -08:00 committed by GitHub
commit b603d9784d
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
2 changed files with 58 additions and 29 deletions

View file

@ -5097,7 +5097,15 @@ async def google_login(request: Request):
scope=generic_scope,
)
with generic_sso:
return await generic_sso.get_login_redirect()
# TODO: state should be a random string and added to the user session with cookie
# or a cryptographicly signed state that we can verify stateless
# For simplification we are using a static state, this is not perfect but some
# SSO providers do not allow stateless verification
redirect_params = {}
state = os.getenv("GENERIC_CLIENT_STATE", None)
if state:
redirect_params['state'] = state
return await generic_sso.get_login_redirect(**redirect_params)
elif ui_username is not None:
# No Google, Microsoft SSO
# Use UI Credentials set in .env
@ -5265,7 +5273,7 @@ async def auth_callback(request: Request):
result = await microsoft_sso.verify_and_process(request)
elif generic_client_id is not None:
# make generic sso provider
from fastapi_sso.sso.generic import create_provider, DiscoveryDocument
from fastapi_sso.sso.generic import create_provider, DiscoveryDocument, OpenID
generic_client_secret = os.getenv("GENERIC_CLIENT_SECRET", None)
generic_scope = os.getenv("GENERIC_SCOPE", "openid email profile").split(" ")
@ -5274,6 +5282,9 @@ async def auth_callback(request: Request):
)
generic_token_endpoint = os.getenv("GENERIC_TOKEN_ENDPOINT", None)
generic_userinfo_endpoint = os.getenv("GENERIC_USERINFO_ENDPOINT", None)
generic_include_client_id = (
os.getenv("GENERIC_INCLUDE_CLIENT_ID", "false").lower() == "true"
)
if generic_client_secret is None:
raise ProxyException(
message="GENERIC_CLIENT_SECRET not set. Set it in .env file",
@ -5308,12 +5319,36 @@ async def auth_callback(request: Request):
verbose_proxy_logger.debug(
f"GENERIC_REDIRECT_URI: {redirect_url}\nGENERIC_CLIENT_ID: {generic_client_id}\n"
)
generic_user_id_attribute_name = os.getenv("GENERIC_USER_ID_ATTRIBUTE", "email")
generic_user_email_attribute_name = os.getenv(
"GENERIC_USER_EMAIL_ATTRIBUTE", "email"
)
generic_user_role_attribute_name = os.getenv(
"GENERIC_USER_ROLE_ATTRIBUTE", "role"
)
verbose_proxy_logger.debug(
f" generic_user_id_attribute_name: {generic_user_id_attribute_name}\n generic_user_email_attribute_name: {generic_user_email_attribute_name}\n generic_user_role_attribute_name: {generic_user_role_attribute_name}"
)
discovery = DiscoveryDocument(
authorization_endpoint=generic_authorization_endpoint,
token_endpoint=generic_token_endpoint,
userinfo_endpoint=generic_userinfo_endpoint,
)
SSOProvider = create_provider(name="oidc", discovery_document=discovery)
def response_convertor(response, client):
return OpenID(
id=response.get(generic_user_email_attribute_name),
display_name=response.get(generic_user_email_attribute_name),
email=response.get(generic_user_email_attribute_name),
)
SSOProvider = create_provider(
name="oidc",
discovery_document=discovery,
response_convertor=response_convertor,
)
generic_sso = SSOProvider(
client_id=generic_client_id,
client_secret=generic_client_secret,
@ -5322,43 +5357,36 @@ async def auth_callback(request: Request):
scope=generic_scope,
)
verbose_proxy_logger.debug(f"calling generic_sso.verify_and_process")
request_body = await request.body()
request_query_params = request.query_params
# get "code" from query params
code = request_query_params.get("code")
result = await generic_sso.verify_and_process(request)
result = await generic_sso.verify_and_process(
request, params={"include_client_id": generic_include_client_id}
)
verbose_proxy_logger.debug(f"generic result: {result}")
# User is Authe'd in - generate key for the UI to access Proxy
user_email = getattr(result, "email", None)
user_id = getattr(result, "id", None)
# generic client id
if generic_client_id is not None:
generic_user_id_attribute_name = os.getenv("GENERIC_USER_ID_ATTRIBUTE", "email")
generic_user_email_attribute_name = os.getenv(
"GENERIC_USER_EMAIL_ATTRIBUTE", "email"
)
generic_user_role_attribute_name = os.getenv(
"GENERIC_USER_ROLE_ATTRIBUTE", "role"
)
verbose_proxy_logger.debug(
f" generic_user_id_attribute_name: {generic_user_id_attribute_name}\n generic_user_email_attribute_name: {generic_user_email_attribute_name}\n generic_user_role_attribute_name: {generic_user_role_attribute_name}"
)
user_id = getattr(result, generic_user_id_attribute_name, None)
user_email = getattr(result, generic_user_email_attribute_name, None)
user_id = result.id
user_email = result.email
user_role = getattr(result, generic_user_role_attribute_name, None)
if user_id is None:
user_id = getattr(result, "first_name", "") + getattr(result, "last_name", "")
# get user_info from litellm DB
user_info = None
if prisma_client is not None:
user_info = await prisma_client.get_data(user_id=user_id, table_name="user")
user_id_models: List = []
if user_info is not None:
user_id_models = getattr(user_info, "models", [])
user_id_models = []
# User might not be already created on first generation of key
# But if it is, we want its models preferences
try:
if prisma_client is not None:
user_info = await prisma_client.get_data(user_id=user_id, table_name="user")
if user_info is not None:
user_id_models = getattr(user_info, "models", [])
except Exception as e:
pass
response = await generate_key_helper_fn(
**{