forked from phoenix/litellm-mirror
[Feat UI sso] store 'provider' in user metadata (#5856)
* store sso provider in user metadata * store user metadata * store user auth_provider in user metadata * add "metadata" for LiteLLM_UserTable * fix sso test
This commit is contained in:
parent
922c8ac758
commit
391b107909
5 changed files with 16 additions and 0 deletions
|
@ -883,6 +883,7 @@ async def generate_key_helper_fn(
|
||||||
"user_role": user_role,
|
"user_role": user_role,
|
||||||
"spend": spend,
|
"spend": spend,
|
||||||
"models": models,
|
"models": models,
|
||||||
|
"metadata": metadata_json,
|
||||||
"max_parallel_requests": max_parallel_requests,
|
"max_parallel_requests": max_parallel_requests,
|
||||||
"tpm_limit": tpm_limit,
|
"tpm_limit": tpm_limit,
|
||||||
"rpm_limit": rpm_limit,
|
"rpm_limit": rpm_limit,
|
||||||
|
|
|
@ -12,6 +12,7 @@ from typing import List, Optional
|
||||||
|
|
||||||
from fastapi import APIRouter, Depends, HTTPException, Request, status
|
from fastapi import APIRouter, Depends, HTTPException, Request, status
|
||||||
from fastapi.responses import RedirectResponse
|
from fastapi.responses import RedirectResponse
|
||||||
|
from fastapi_sso.sso.base import OpenID
|
||||||
|
|
||||||
import litellm
|
import litellm
|
||||||
from litellm._logging import verbose_proxy_logger
|
from litellm._logging import verbose_proxy_logger
|
||||||
|
@ -498,6 +499,7 @@ async def auth_callback(request: Request):
|
||||||
else:
|
else:
|
||||||
# user not in DB, insert User into LiteLLM DB
|
# user not in DB, insert User into LiteLLM DB
|
||||||
user_role = await insert_sso_user(
|
user_role = await insert_sso_user(
|
||||||
|
result_openid=result,
|
||||||
user_defined_values=user_defined_values,
|
user_defined_values=user_defined_values,
|
||||||
)
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
@ -575,10 +577,15 @@ async def auth_callback(request: Request):
|
||||||
|
|
||||||
|
|
||||||
async def insert_sso_user(
|
async def insert_sso_user(
|
||||||
|
result_openid: Optional[OpenID],
|
||||||
user_defined_values: Optional[SSOUserDefinedValues] = None,
|
user_defined_values: Optional[SSOUserDefinedValues] = None,
|
||||||
) -> str:
|
) -> str:
|
||||||
"""
|
"""
|
||||||
Helper function to create a New User in LiteLLM DB after a successful SSO login
|
Helper function to create a New User in LiteLLM DB after a successful SSO login
|
||||||
|
|
||||||
|
Args:
|
||||||
|
result_openid (OpenID): User information in OpenID format if the login was successful.
|
||||||
|
user_defined_values (Optional[SSOUserDefinedValues], optional): LiteLLM SSOValues / fields that were read
|
||||||
"""
|
"""
|
||||||
verbose_proxy_logger.debug(
|
verbose_proxy_logger.debug(
|
||||||
f"Inserting SSO user into DB. User values: {user_defined_values}"
|
f"Inserting SSO user into DB. User values: {user_defined_values}"
|
||||||
|
@ -610,6 +617,9 @@ async def insert_sso_user(
|
||||||
budget_duration=user_defined_values["budget_duration"],
|
budget_duration=user_defined_values["budget_duration"],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if result_openid:
|
||||||
|
new_user_request.metadata = {"auth_provider": result_openid.provider}
|
||||||
|
|
||||||
await new_user(data=new_user_request, user_api_key_dict=UserAPIKeyAuth())
|
await new_user(data=new_user_request, user_api_key_dict=UserAPIKeyAuth())
|
||||||
|
|
||||||
return user_defined_values["user_role"] or LitellmUserRoles.INTERNAL_USER_VIEW_ONLY
|
return user_defined_values["user_role"] or LitellmUserRoles.INTERNAL_USER_VIEW_ONLY
|
||||||
|
|
|
@ -109,6 +109,7 @@ model LiteLLM_UserTable {
|
||||||
spend Float @default(0.0)
|
spend Float @default(0.0)
|
||||||
user_email String?
|
user_email String?
|
||||||
models String[]
|
models String[]
|
||||||
|
metadata Json @default("{}")
|
||||||
max_parallel_requests Int?
|
max_parallel_requests Int?
|
||||||
tpm_limit BigInt?
|
tpm_limit BigInt?
|
||||||
rpm_limit BigInt?
|
rpm_limit BigInt?
|
||||||
|
|
|
@ -109,6 +109,7 @@ model LiteLLM_UserTable {
|
||||||
spend Float @default(0.0)
|
spend Float @default(0.0)
|
||||||
user_email String?
|
user_email String?
|
||||||
models String[]
|
models String[]
|
||||||
|
metadata Json @default("{}")
|
||||||
max_parallel_requests Int?
|
max_parallel_requests Int?
|
||||||
tpm_limit BigInt?
|
tpm_limit BigInt?
|
||||||
rpm_limit BigInt?
|
rpm_limit BigInt?
|
||||||
|
|
|
@ -78,6 +78,7 @@ async def test_auth_callback_new_user(mock_google_sso, mock_env_vars, prisma_cli
|
||||||
mock_sso_result = MagicMock()
|
mock_sso_result = MagicMock()
|
||||||
mock_sso_result.email = "newuser@example.com"
|
mock_sso_result.email = "newuser@example.com"
|
||||||
mock_sso_result.id = unique_user_id
|
mock_sso_result.id = unique_user_id
|
||||||
|
mock_sso_result.provider = "google"
|
||||||
mock_google_sso.return_value.verify_and_process = AsyncMock(
|
mock_google_sso.return_value.verify_and_process = AsyncMock(
|
||||||
return_value=mock_sso_result
|
return_value=mock_sso_result
|
||||||
)
|
)
|
||||||
|
@ -110,6 +111,7 @@ async def test_auth_callback_new_user(mock_google_sso, mock_env_vars, prisma_cli
|
||||||
assert user is not None
|
assert user is not None
|
||||||
assert user.user_email == "newuser@example.com"
|
assert user.user_email == "newuser@example.com"
|
||||||
assert user.user_role == LitellmUserRoles.INTERNAL_USER_VIEW_ONLY
|
assert user.user_role == LitellmUserRoles.INTERNAL_USER_VIEW_ONLY
|
||||||
|
assert user.metadata == {"auth_provider": "google"}
|
||||||
|
|
||||||
finally:
|
finally:
|
||||||
# Clean up: Delete the user from the database
|
# Clean up: Delete the user from the database
|
||||||
|
@ -148,6 +150,7 @@ async def test_auth_callback_new_user_with_sso_default(
|
||||||
mock_sso_result = MagicMock()
|
mock_sso_result = MagicMock()
|
||||||
mock_sso_result.email = "newuser@example.com"
|
mock_sso_result.email = "newuser@example.com"
|
||||||
mock_sso_result.id = unique_user_id
|
mock_sso_result.id = unique_user_id
|
||||||
|
mock_sso_result.provider = "google"
|
||||||
mock_google_sso.return_value.verify_and_process = AsyncMock(
|
mock_google_sso.return_value.verify_and_process = AsyncMock(
|
||||||
return_value=mock_sso_result
|
return_value=mock_sso_result
|
||||||
)
|
)
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue