diff --git a/litellm/proxy/management_endpoints/key_management_endpoints.py b/litellm/proxy/management_endpoints/key_management_endpoints.py index cf01597db..6c4a06efa 100644 --- a/litellm/proxy/management_endpoints/key_management_endpoints.py +++ b/litellm/proxy/management_endpoints/key_management_endpoints.py @@ -883,6 +883,7 @@ async def generate_key_helper_fn( "user_role": user_role, "spend": spend, "models": models, + "metadata": metadata_json, "max_parallel_requests": max_parallel_requests, "tpm_limit": tpm_limit, "rpm_limit": rpm_limit, diff --git a/litellm/proxy/management_endpoints/ui_sso.py b/litellm/proxy/management_endpoints/ui_sso.py index 7d21b2015..e93f2e46b 100644 --- a/litellm/proxy/management_endpoints/ui_sso.py +++ b/litellm/proxy/management_endpoints/ui_sso.py @@ -12,6 +12,7 @@ from typing import List, Optional from fastapi import APIRouter, Depends, HTTPException, Request, status from fastapi.responses import RedirectResponse +from fastapi_sso.sso.base import OpenID import litellm from litellm._logging import verbose_proxy_logger @@ -498,6 +499,7 @@ async def auth_callback(request: Request): else: # user not in DB, insert User into LiteLLM DB user_role = await insert_sso_user( + result_openid=result, user_defined_values=user_defined_values, ) except Exception as e: @@ -575,10 +577,15 @@ async def auth_callback(request: Request): async def insert_sso_user( + result_openid: Optional[OpenID], user_defined_values: Optional[SSOUserDefinedValues] = None, ) -> str: """ Helper function to create a New User in LiteLLM DB after a successful SSO login + + Args: + result_openid (OpenID): User information in OpenID format if the login was successful. + user_defined_values (Optional[SSOUserDefinedValues], optional): LiteLLM SSOValues / fields that were read """ verbose_proxy_logger.debug( f"Inserting SSO user into DB. User values: {user_defined_values}" @@ -610,6 +617,9 @@ async def insert_sso_user( budget_duration=user_defined_values["budget_duration"], ) + if result_openid: + new_user_request.metadata = {"auth_provider": result_openid.provider} + await new_user(data=new_user_request, user_api_key_dict=UserAPIKeyAuth()) return user_defined_values["user_role"] or LitellmUserRoles.INTERNAL_USER_VIEW_ONLY diff --git a/litellm/proxy/schema.prisma b/litellm/proxy/schema.prisma index 9f53a7efb..ff2fc68c4 100644 --- a/litellm/proxy/schema.prisma +++ b/litellm/proxy/schema.prisma @@ -109,6 +109,7 @@ model LiteLLM_UserTable { spend Float @default(0.0) user_email String? models String[] + metadata Json @default("{}") max_parallel_requests Int? tpm_limit BigInt? rpm_limit BigInt? diff --git a/schema.prisma b/schema.prisma index 9f53a7efb..ff2fc68c4 100644 --- a/schema.prisma +++ b/schema.prisma @@ -109,6 +109,7 @@ model LiteLLM_UserTable { spend Float @default(0.0) user_email String? models String[] + metadata Json @default("{}") max_parallel_requests Int? tpm_limit BigInt? rpm_limit BigInt? diff --git a/tests/proxy_admin_ui_tests/test_sso_sign_in.py b/tests/proxy_admin_ui_tests/test_sso_sign_in.py index e0cb57955..7ecee7879 100644 --- a/tests/proxy_admin_ui_tests/test_sso_sign_in.py +++ b/tests/proxy_admin_ui_tests/test_sso_sign_in.py @@ -78,6 +78,7 @@ async def test_auth_callback_new_user(mock_google_sso, mock_env_vars, prisma_cli mock_sso_result = MagicMock() mock_sso_result.email = "newuser@example.com" mock_sso_result.id = unique_user_id + mock_sso_result.provider = "google" mock_google_sso.return_value.verify_and_process = AsyncMock( return_value=mock_sso_result ) @@ -110,6 +111,7 @@ async def test_auth_callback_new_user(mock_google_sso, mock_env_vars, prisma_cli assert user is not None assert user.user_email == "newuser@example.com" assert user.user_role == LitellmUserRoles.INTERNAL_USER_VIEW_ONLY + assert user.metadata == {"auth_provider": "google"} finally: # Clean up: Delete the user from the database @@ -148,6 +150,7 @@ async def test_auth_callback_new_user_with_sso_default( mock_sso_result = MagicMock() mock_sso_result.email = "newuser@example.com" mock_sso_result.id = unique_user_id + mock_sso_result.provider = "google" mock_google_sso.return_value.verify_and_process = AsyncMock( return_value=mock_sso_result )