[Feat UI sso] store 'provider' in user metadata (#5856)

* store sso provider in user metadata

* store user metadata

* store user auth_provider in user metadata

* add "metadata" for LiteLLM_UserTable

* fix sso test
This commit is contained in:
Ishaan Jaff 2024-09-23 17:49:36 -07:00 committed by GitHub
parent 922c8ac758
commit 391b107909
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
5 changed files with 16 additions and 0 deletions

View file

@ -883,6 +883,7 @@ async def generate_key_helper_fn(
"user_role": user_role, "user_role": user_role,
"spend": spend, "spend": spend,
"models": models, "models": models,
"metadata": metadata_json,
"max_parallel_requests": max_parallel_requests, "max_parallel_requests": max_parallel_requests,
"tpm_limit": tpm_limit, "tpm_limit": tpm_limit,
"rpm_limit": rpm_limit, "rpm_limit": rpm_limit,

View file

@ -12,6 +12,7 @@ from typing import List, Optional
from fastapi import APIRouter, Depends, HTTPException, Request, status from fastapi import APIRouter, Depends, HTTPException, Request, status
from fastapi.responses import RedirectResponse from fastapi.responses import RedirectResponse
from fastapi_sso.sso.base import OpenID
import litellm import litellm
from litellm._logging import verbose_proxy_logger from litellm._logging import verbose_proxy_logger
@ -498,6 +499,7 @@ async def auth_callback(request: Request):
else: else:
# user not in DB, insert User into LiteLLM DB # user not in DB, insert User into LiteLLM DB
user_role = await insert_sso_user( user_role = await insert_sso_user(
result_openid=result,
user_defined_values=user_defined_values, user_defined_values=user_defined_values,
) )
except Exception as e: except Exception as e:
@ -575,10 +577,15 @@ async def auth_callback(request: Request):
async def insert_sso_user( async def insert_sso_user(
result_openid: Optional[OpenID],
user_defined_values: Optional[SSOUserDefinedValues] = None, user_defined_values: Optional[SSOUserDefinedValues] = None,
) -> str: ) -> str:
""" """
Helper function to create a New User in LiteLLM DB after a successful SSO login Helper function to create a New User in LiteLLM DB after a successful SSO login
Args:
result_openid (OpenID): User information in OpenID format if the login was successful.
user_defined_values (Optional[SSOUserDefinedValues], optional): LiteLLM SSOValues / fields that were read
""" """
verbose_proxy_logger.debug( verbose_proxy_logger.debug(
f"Inserting SSO user into DB. User values: {user_defined_values}" f"Inserting SSO user into DB. User values: {user_defined_values}"
@ -610,6 +617,9 @@ async def insert_sso_user(
budget_duration=user_defined_values["budget_duration"], budget_duration=user_defined_values["budget_duration"],
) )
if result_openid:
new_user_request.metadata = {"auth_provider": result_openid.provider}
await new_user(data=new_user_request, user_api_key_dict=UserAPIKeyAuth()) await new_user(data=new_user_request, user_api_key_dict=UserAPIKeyAuth())
return user_defined_values["user_role"] or LitellmUserRoles.INTERNAL_USER_VIEW_ONLY return user_defined_values["user_role"] or LitellmUserRoles.INTERNAL_USER_VIEW_ONLY

View file

@ -109,6 +109,7 @@ model LiteLLM_UserTable {
spend Float @default(0.0) spend Float @default(0.0)
user_email String? user_email String?
models String[] models String[]
metadata Json @default("{}")
max_parallel_requests Int? max_parallel_requests Int?
tpm_limit BigInt? tpm_limit BigInt?
rpm_limit BigInt? rpm_limit BigInt?

View file

@ -109,6 +109,7 @@ model LiteLLM_UserTable {
spend Float @default(0.0) spend Float @default(0.0)
user_email String? user_email String?
models String[] models String[]
metadata Json @default("{}")
max_parallel_requests Int? max_parallel_requests Int?
tpm_limit BigInt? tpm_limit BigInt?
rpm_limit BigInt? rpm_limit BigInt?

View file

@ -78,6 +78,7 @@ async def test_auth_callback_new_user(mock_google_sso, mock_env_vars, prisma_cli
mock_sso_result = MagicMock() mock_sso_result = MagicMock()
mock_sso_result.email = "newuser@example.com" mock_sso_result.email = "newuser@example.com"
mock_sso_result.id = unique_user_id mock_sso_result.id = unique_user_id
mock_sso_result.provider = "google"
mock_google_sso.return_value.verify_and_process = AsyncMock( mock_google_sso.return_value.verify_and_process = AsyncMock(
return_value=mock_sso_result return_value=mock_sso_result
) )
@ -110,6 +111,7 @@ async def test_auth_callback_new_user(mock_google_sso, mock_env_vars, prisma_cli
assert user is not None assert user is not None
assert user.user_email == "newuser@example.com" assert user.user_email == "newuser@example.com"
assert user.user_role == LitellmUserRoles.INTERNAL_USER_VIEW_ONLY assert user.user_role == LitellmUserRoles.INTERNAL_USER_VIEW_ONLY
assert user.metadata == {"auth_provider": "google"}
finally: finally:
# Clean up: Delete the user from the database # Clean up: Delete the user from the database
@ -148,6 +150,7 @@ async def test_auth_callback_new_user_with_sso_default(
mock_sso_result = MagicMock() mock_sso_result = MagicMock()
mock_sso_result.email = "newuser@example.com" mock_sso_result.email = "newuser@example.com"
mock_sso_result.id = unique_user_id mock_sso_result.id = unique_user_id
mock_sso_result.provider = "google"
mock_google_sso.return_value.verify_and_process = AsyncMock( mock_google_sso.return_value.verify_and_process = AsyncMock(
return_value=mock_sso_result return_value=mock_sso_result
) )