update generic SSO login

During implementation for Okta, noticed a few things:

- Some providers require a state parameter to be sent
- Some providers require that the client_id is not included in the body

Moreover, the OpenID response converter was not implemented which
was returning an empty response.

Finally, there was an order where there's a fetch of user information
but on first usage, it is not created yet.
This commit is contained in:
Adrien Fillon 2024-02-22 14:34:16 +01:00
parent 0e26c59623
commit 4e29e5460b
2 changed files with 58 additions and 29 deletions

View file

@ -133,7 +133,8 @@ The following can be used to customize attribute names when interacting with the
GENERIC_USER_ID_ATTRIBUTE = "given_name" GENERIC_USER_ID_ATTRIBUTE = "given_name"
GENERIC_USER_EMAIL_ATTRIBUTE = "family_name" GENERIC_USER_EMAIL_ATTRIBUTE = "family_name"
GENERIC_USER_ROLE_ATTRIBUTE = "given_role" GENERIC_USER_ROLE_ATTRIBUTE = "given_role"
GENERIC_CLIENT_STATE = "some-state" # if the provider needs a state parameter
GENERIC_INCLUDE_CLIENT_ID = "false" # some providers enforce that the client_id is not in the body
GENERIC_SCOPE = "openid profile email" # default scope openid is sometimes not enough to retrieve basic user info like first_name and last_name located in profile scope GENERIC_SCOPE = "openid profile email" # default scope openid is sometimes not enough to retrieve basic user info like first_name and last_name located in profile scope
``` ```

View file

@ -4959,7 +4959,15 @@ async def google_login(request: Request):
scope=generic_scope, scope=generic_scope,
) )
with generic_sso: with generic_sso:
return await generic_sso.get_login_redirect() # TODO: state should be a random string and added to the user session with cookie
# or a cryptographicly signed state that we can verify stateless
# For simplification we are using a static state, this is not perfect but some
# SSO providers do not allow stateless verification
redirect_params = {}
state = os.getenv("GENERIC_CLIENT_STATE", None)
if state:
redirect_params['state'] = state
return await generic_sso.get_login_redirect(**redirect_params)
elif ui_username is not None: elif ui_username is not None:
# No Google, Microsoft SSO # No Google, Microsoft SSO
# Use UI Credentials set in .env # Use UI Credentials set in .env
@ -5104,7 +5112,7 @@ async def auth_callback(request: Request):
result = await microsoft_sso.verify_and_process(request) result = await microsoft_sso.verify_and_process(request)
elif generic_client_id is not None: elif generic_client_id is not None:
# make generic sso provider # make generic sso provider
from fastapi_sso.sso.generic import create_provider, DiscoveryDocument from fastapi_sso.sso.generic import create_provider, DiscoveryDocument, OpenID
generic_client_secret = os.getenv("GENERIC_CLIENT_SECRET", None) generic_client_secret = os.getenv("GENERIC_CLIENT_SECRET", None)
generic_scope = os.getenv("GENERIC_SCOPE", "openid email profile").split(" ") generic_scope = os.getenv("GENERIC_SCOPE", "openid email profile").split(" ")
@ -5113,6 +5121,9 @@ async def auth_callback(request: Request):
) )
generic_token_endpoint = os.getenv("GENERIC_TOKEN_ENDPOINT", None) generic_token_endpoint = os.getenv("GENERIC_TOKEN_ENDPOINT", None)
generic_userinfo_endpoint = os.getenv("GENERIC_USERINFO_ENDPOINT", None) generic_userinfo_endpoint = os.getenv("GENERIC_USERINFO_ENDPOINT", None)
generic_include_client_id = (
os.getenv("GENERIC_INCLUDE_CLIENT_ID", "false").lower() == "true"
)
if generic_client_secret is None: if generic_client_secret is None:
raise ProxyException( raise ProxyException(
message="GENERIC_CLIENT_SECRET not set. Set it in .env file", message="GENERIC_CLIENT_SECRET not set. Set it in .env file",
@ -5147,12 +5158,36 @@ async def auth_callback(request: Request):
verbose_proxy_logger.debug( verbose_proxy_logger.debug(
f"GENERIC_REDIRECT_URI: {redirect_url}\nGENERIC_CLIENT_ID: {generic_client_id}\n" f"GENERIC_REDIRECT_URI: {redirect_url}\nGENERIC_CLIENT_ID: {generic_client_id}\n"
) )
generic_user_id_attribute_name = os.getenv("GENERIC_USER_ID_ATTRIBUTE", "email")
generic_user_email_attribute_name = os.getenv(
"GENERIC_USER_EMAIL_ATTRIBUTE", "email"
)
generic_user_role_attribute_name = os.getenv(
"GENERIC_USER_ROLE_ATTRIBUTE", "role"
)
verbose_proxy_logger.debug(
f" generic_user_id_attribute_name: {generic_user_id_attribute_name}\n generic_user_email_attribute_name: {generic_user_email_attribute_name}\n generic_user_role_attribute_name: {generic_user_role_attribute_name}"
)
discovery = DiscoveryDocument( discovery = DiscoveryDocument(
authorization_endpoint=generic_authorization_endpoint, authorization_endpoint=generic_authorization_endpoint,
token_endpoint=generic_token_endpoint, token_endpoint=generic_token_endpoint,
userinfo_endpoint=generic_userinfo_endpoint, userinfo_endpoint=generic_userinfo_endpoint,
) )
SSOProvider = create_provider(name="oidc", discovery_document=discovery)
def response_convertor(response, client):
return OpenID(
id=response.get(generic_user_email_attribute_name),
display_name=response.get(generic_user_email_attribute_name),
email=response.get(generic_user_email_attribute_name),
)
SSOProvider = create_provider(
name="oidc",
discovery_document=discovery,
response_convertor=response_convertor,
)
generic_sso = SSOProvider( generic_sso = SSOProvider(
client_id=generic_client_id, client_id=generic_client_id,
client_secret=generic_client_secret, client_secret=generic_client_secret,
@ -5161,43 +5196,36 @@ async def auth_callback(request: Request):
scope=generic_scope, scope=generic_scope,
) )
verbose_proxy_logger.debug(f"calling generic_sso.verify_and_process") verbose_proxy_logger.debug(f"calling generic_sso.verify_and_process")
request_body = await request.body() result = await generic_sso.verify_and_process(
request_query_params = request.query_params request, params={"include_client_id": generic_include_client_id}
# get "code" from query params )
code = request_query_params.get("code")
result = await generic_sso.verify_and_process(request)
verbose_proxy_logger.debug(f"generic result: {result}") verbose_proxy_logger.debug(f"generic result: {result}")
# User is Authe'd in - generate key for the UI to access Proxy # User is Authe'd in - generate key for the UI to access Proxy
user_email = getattr(result, "email", None) user_email = getattr(result, "email", None)
user_id = getattr(result, "id", None) user_id = getattr(result, "id", None)
# generic client id # generic client id
if generic_client_id is not None: if generic_client_id is not None:
generic_user_id_attribute_name = os.getenv("GENERIC_USER_ID_ATTRIBUTE", "email") user_id = result.id
generic_user_email_attribute_name = os.getenv( user_email = result.email
"GENERIC_USER_EMAIL_ATTRIBUTE", "email"
)
generic_user_role_attribute_name = os.getenv(
"GENERIC_USER_ROLE_ATTRIBUTE", "role"
)
verbose_proxy_logger.debug(
f" generic_user_id_attribute_name: {generic_user_id_attribute_name}\n generic_user_email_attribute_name: {generic_user_email_attribute_name}\n generic_user_role_attribute_name: {generic_user_role_attribute_name}"
)
user_id = getattr(result, generic_user_id_attribute_name, None)
user_email = getattr(result, generic_user_email_attribute_name, None)
user_role = getattr(result, generic_user_role_attribute_name, None) user_role = getattr(result, generic_user_role_attribute_name, None)
if user_id is None: if user_id is None:
user_id = getattr(result, "first_name", "") + getattr(result, "last_name", "") user_id = getattr(result, "first_name", "") + getattr(result, "last_name", "")
# get user_info from litellm DB
user_info = None user_info = None
if prisma_client is not None: user_id_models = []
user_info = await prisma_client.get_data(user_id=user_id, table_name="user")
user_id_models: List = [] # User might not be already created on first generation of key
if user_info is not None: # But if it is, we want its models preferences
user_id_models = getattr(user_info, "models", []) try:
if prisma_client is not None:
user_info = await prisma_client.get_data(user_id=user_id, table_name="user")
if user_info is not None:
user_id_models = getattr(user_info, "models", [])
except Exception as e:
pass
response = await generate_key_helper_fn( response = await generate_key_helper_fn(
**{ **{