diff --git a/docs/my-website/docs/proxy/ui.md b/docs/my-website/docs/proxy/ui.md index 1c1931f8f8..bc669e322a 100644 --- a/docs/my-website/docs/proxy/ui.md +++ b/docs/my-website/docs/proxy/ui.md @@ -133,7 +133,8 @@ The following can be used to customize attribute names when interacting with the GENERIC_USER_ID_ATTRIBUTE = "given_name" GENERIC_USER_EMAIL_ATTRIBUTE = "family_name" GENERIC_USER_ROLE_ATTRIBUTE = "given_role" - +GENERIC_CLIENT_STATE = "some-state" # if the provider needs a state parameter +GENERIC_INCLUDE_CLIENT_ID = "false" # some providers enforce that the client_id is not in the body GENERIC_SCOPE = "openid profile email" # default scope openid is sometimes not enough to retrieve basic user info like first_name and last_name located in profile scope ``` diff --git a/litellm/proxy/proxy_server.py b/litellm/proxy/proxy_server.py index f615232b75..dbb17bbcd5 100644 --- a/litellm/proxy/proxy_server.py +++ b/litellm/proxy/proxy_server.py @@ -4959,7 +4959,15 @@ async def google_login(request: Request): scope=generic_scope, ) with generic_sso: - return await generic_sso.get_login_redirect() + # TODO: state should be a random string and added to the user session with cookie + # or a cryptographicly signed state that we can verify stateless + # For simplification we are using a static state, this is not perfect but some + # SSO providers do not allow stateless verification + redirect_params = {} + state = os.getenv("GENERIC_CLIENT_STATE", None) + if state: + redirect_params['state'] = state + return await generic_sso.get_login_redirect(**redirect_params) elif ui_username is not None: # No Google, Microsoft SSO # Use UI Credentials set in .env @@ -5104,7 +5112,7 @@ async def auth_callback(request: Request): result = await microsoft_sso.verify_and_process(request) elif generic_client_id is not None: # make generic sso provider - from fastapi_sso.sso.generic import create_provider, DiscoveryDocument + from fastapi_sso.sso.generic import create_provider, DiscoveryDocument, OpenID generic_client_secret = os.getenv("GENERIC_CLIENT_SECRET", None) generic_scope = os.getenv("GENERIC_SCOPE", "openid email profile").split(" ") @@ -5113,6 +5121,9 @@ async def auth_callback(request: Request): ) generic_token_endpoint = os.getenv("GENERIC_TOKEN_ENDPOINT", None) generic_userinfo_endpoint = os.getenv("GENERIC_USERINFO_ENDPOINT", None) + generic_include_client_id = ( + os.getenv("GENERIC_INCLUDE_CLIENT_ID", "false").lower() == "true" + ) if generic_client_secret is None: raise ProxyException( message="GENERIC_CLIENT_SECRET not set. Set it in .env file", @@ -5147,12 +5158,36 @@ async def auth_callback(request: Request): verbose_proxy_logger.debug( f"GENERIC_REDIRECT_URI: {redirect_url}\nGENERIC_CLIENT_ID: {generic_client_id}\n" ) + + generic_user_id_attribute_name = os.getenv("GENERIC_USER_ID_ATTRIBUTE", "email") + generic_user_email_attribute_name = os.getenv( + "GENERIC_USER_EMAIL_ATTRIBUTE", "email" + ) + generic_user_role_attribute_name = os.getenv( + "GENERIC_USER_ROLE_ATTRIBUTE", "role" + ) + verbose_proxy_logger.debug( + f" generic_user_id_attribute_name: {generic_user_id_attribute_name}\n generic_user_email_attribute_name: {generic_user_email_attribute_name}\n generic_user_role_attribute_name: {generic_user_role_attribute_name}" + ) + discovery = DiscoveryDocument( authorization_endpoint=generic_authorization_endpoint, token_endpoint=generic_token_endpoint, userinfo_endpoint=generic_userinfo_endpoint, ) - SSOProvider = create_provider(name="oidc", discovery_document=discovery) + + def response_convertor(response, client): + return OpenID( + id=response.get(generic_user_email_attribute_name), + display_name=response.get(generic_user_email_attribute_name), + email=response.get(generic_user_email_attribute_name), + ) + + SSOProvider = create_provider( + name="oidc", + discovery_document=discovery, + response_convertor=response_convertor, + ) generic_sso = SSOProvider( client_id=generic_client_id, client_secret=generic_client_secret, @@ -5161,43 +5196,36 @@ async def auth_callback(request: Request): scope=generic_scope, ) verbose_proxy_logger.debug(f"calling generic_sso.verify_and_process") - request_body = await request.body() - request_query_params = request.query_params - # get "code" from query params - code = request_query_params.get("code") - result = await generic_sso.verify_and_process(request) + result = await generic_sso.verify_and_process( + request, params={"include_client_id": generic_include_client_id} + ) verbose_proxy_logger.debug(f"generic result: {result}") + # User is Authe'd in - generate key for the UI to access Proxy user_email = getattr(result, "email", None) user_id = getattr(result, "id", None) # generic client id if generic_client_id is not None: - generic_user_id_attribute_name = os.getenv("GENERIC_USER_ID_ATTRIBUTE", "email") - generic_user_email_attribute_name = os.getenv( - "GENERIC_USER_EMAIL_ATTRIBUTE", "email" - ) - generic_user_role_attribute_name = os.getenv( - "GENERIC_USER_ROLE_ATTRIBUTE", "role" - ) - - verbose_proxy_logger.debug( - f" generic_user_id_attribute_name: {generic_user_id_attribute_name}\n generic_user_email_attribute_name: {generic_user_email_attribute_name}\n generic_user_role_attribute_name: {generic_user_role_attribute_name}" - ) - - user_id = getattr(result, generic_user_id_attribute_name, None) - user_email = getattr(result, generic_user_email_attribute_name, None) + user_id = result.id + user_email = result.email user_role = getattr(result, generic_user_role_attribute_name, None) if user_id is None: user_id = getattr(result, "first_name", "") + getattr(result, "last_name", "") - # get user_info from litellm DB + user_info = None - if prisma_client is not None: - user_info = await prisma_client.get_data(user_id=user_id, table_name="user") - user_id_models: List = [] - if user_info is not None: - user_id_models = getattr(user_info, "models", []) + user_id_models = [] + + # User might not be already created on first generation of key + # But if it is, we want its models preferences + try: + if prisma_client is not None: + user_info = await prisma_client.get_data(user_id=user_id, table_name="user") + if user_info is not None: + user_id_models = getattr(user_info, "models", []) + except Exception as e: + pass response = await generate_key_helper_fn( **{