build(schema.prisma): add new sso_user_id to LiteLLM_UserTable (#8167)

* build(schema.prisma): add new `sso_user_id` to LiteLLM_UserTable

easier way to store sso id for existing user

Allows existing user added to team, to login via SSO

* test(test_auth_checks.py): add unit testing for fuzzy user object get

* fix(handle_jwt.py): fix merge conflicts
This commit is contained in:
Krish Dholakia 2025-01-31 23:04:05 -08:00 committed by GitHub
parent 2147cad307
commit 8d0db8b379
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
5 changed files with 144 additions and 3 deletions

View file

@ -1576,6 +1576,7 @@ class LiteLLM_UserTable(LiteLLMPydanticObjectBase):
user_role: Optional[str] = None user_role: Optional[str] = None
organization_memberships: Optional[List[LiteLLM_OrganizationMembershipTable]] = None organization_memberships: Optional[List[LiteLLM_OrganizationMembershipTable]] = None
teams: List[str] = [] teams: List[str] = []
sso_user_id: Optional[str] = None
@model_validator(mode="before") @model_validator(mode="before")
@classmethod @classmethod

View file

@ -28,6 +28,7 @@ from litellm.proxy._types import (
CallInfo, CallInfo,
LiteLLM_EndUserTable, LiteLLM_EndUserTable,
LiteLLM_JWTAuth, LiteLLM_JWTAuth,
LiteLLM_OrganizationMembershipTable,
LiteLLM_OrganizationTable, LiteLLM_OrganizationTable,
LiteLLM_TeamTable, LiteLLM_TeamTable,
LiteLLM_TeamTableCachedObj, LiteLLM_TeamTableCachedObj,
@ -425,14 +426,55 @@ def get_role_based_models(
return None return None
async def _get_fuzzy_user_object(
prisma_client: PrismaClient,
sso_user_id: Optional[str] = None,
user_email: Optional[str] = None,
) -> Optional[LiteLLM_UserTable]:
"""
Checks if sso user is in db.
Called when user id match is not found in db.
- Check if sso_user_id is user_id in db
- Check if sso_user_id is sso_user_id in db
- Check if user_email is user_email in db
- If not, create new user with user_email and sso_user_id and user_id = sso_user_id
"""
response = None
if sso_user_id is not None:
response = await prisma_client.db.litellm_usertable.find_unique(
where={"sso_user_id": sso_user_id},
include={"organization_memberships": True},
)
if response is None and user_email is not None:
response = await prisma_client.db.litellm_usertable.find_first(
where={"user_email": user_email},
include={"organization_memberships": True},
)
if response is not None and sso_user_id is not None: # update sso_user_id
asyncio.create_task( # background task to update user with sso id
prisma_client.db.litellm_usertable.update(
where={"user_id": response.user_id},
data={"sso_user_id": sso_user_id},
)
)
return response
@log_db_metrics @log_db_metrics
async def get_user_object( async def get_user_object(
user_id: str, user_id: Optional[str],
prisma_client: Optional[PrismaClient], prisma_client: Optional[PrismaClient],
user_api_key_cache: DualCache, user_api_key_cache: DualCache,
user_id_upsert: bool, user_id_upsert: bool,
parent_otel_span: Optional[Span] = None, parent_otel_span: Optional[Span] = None,
proxy_logging_obj: Optional[ProxyLogging] = None, proxy_logging_obj: Optional[ProxyLogging] = None,
sso_user_id: Optional[str] = None,
user_email: Optional[str] = None,
) -> Optional[LiteLLM_UserTable]: ) -> Optional[LiteLLM_UserTable]:
""" """
- Check if user id in proxy User Table - Check if user id in proxy User Table
@ -465,6 +507,14 @@ async def get_user_object(
response = await prisma_client.db.litellm_usertable.find_unique( response = await prisma_client.db.litellm_usertable.find_unique(
where={"user_id": user_id}, include={"organization_memberships": True} where={"user_id": user_id}, include={"organization_memberships": True}
) )
if response is None:
response = await _get_fuzzy_user_object(
prisma_client=prisma_client,
sso_user_id=sso_user_id,
user_email=user_email,
)
else: else:
response = None response = None
@ -483,7 +533,7 @@ async def get_user_object(
): ):
# dump each organization membership to type LiteLLM_OrganizationMembershipTable # dump each organization membership to type LiteLLM_OrganizationMembershipTable
_dumped_memberships = [ _dumped_memberships = [
membership.model_dump() LiteLLM_OrganizationMembershipTable(**membership.model_dump())
for membership in response.organization_memberships for membership in response.organization_memberships
if membership is not None if membership is not None
] ]

View file

@ -621,6 +621,7 @@ class JWTAuthManager:
@staticmethod @staticmethod
async def get_objects( async def get_objects(
user_id: Optional[str], user_id: Optional[str],
user_email: Optional[str],
org_id: Optional[str], org_id: Optional[str],
end_user_id: Optional[str], end_user_id: Optional[str],
valid_user_email: Optional[bool], valid_user_email: Optional[bool],
@ -661,6 +662,8 @@ class JWTAuthManager:
), ),
parent_otel_span=parent_otel_span, parent_otel_span=parent_otel_span,
proxy_logging_obj=proxy_logging_obj, proxy_logging_obj=proxy_logging_obj,
user_email=user_email,
sso_user_id=user_id,
) )
if user_id if user_id
else None else None
@ -704,7 +707,7 @@ class JWTAuthManager:
# Get basic user info # Get basic user info
scopes = jwt_handler.get_scopes(token=jwt_valid_token) scopes = jwt_handler.get_scopes(token=jwt_valid_token)
user_id, _, valid_user_email = await JWTAuthManager.get_user_info( user_id, user_email, valid_user_email = await JWTAuthManager.get_user_info(
jwt_handler, jwt_valid_token jwt_handler, jwt_valid_token
) )
@ -748,6 +751,7 @@ class JWTAuthManager:
# Get other objects # Get other objects
user_object, org_object, end_user_object = await JWTAuthManager.get_objects( user_object, org_object, end_user_object = await JWTAuthManager.get_objects(
user_id=user_id, user_id=user_id,
user_email=user_email,
org_id=org_id, org_id=org_id,
end_user_id=end_user_id, end_user_id=end_user_id,
valid_user_email=valid_user_email, valid_user_email=valid_user_email,

View file

@ -103,6 +103,7 @@ model LiteLLM_UserTable {
user_id String @id user_id String @id
user_alias String? user_alias String?
team_id String? team_id String?
sso_user_id String? @unique
organization_id String? organization_id String?
password String? password String?
teams String[] @default([]) teams String[] @default([])

View file

@ -548,3 +548,88 @@ async def test_can_user_call_model():
args["model"] = "gpt-3.5-turbo" args["model"] = "gpt-3.5-turbo"
await can_user_call_model(**args) await can_user_call_model(**args)
@pytest.mark.asyncio
async def test_get_fuzzy_user_object():
from litellm.proxy.auth.auth_checks import _get_fuzzy_user_object
from litellm.proxy.utils import PrismaClient
from unittest.mock import AsyncMock, MagicMock
# Setup mock Prisma client
mock_prisma = MagicMock()
mock_prisma.db = MagicMock()
mock_prisma.db.litellm_usertable = MagicMock()
# Mock user data
test_user = LiteLLM_UserTable(
user_id="test_123",
sso_user_id="sso_123",
user_email="test@example.com",
organization_memberships=[],
max_budget=None,
)
# Test 1: Find user by SSO ID
mock_prisma.db.litellm_usertable.find_unique = AsyncMock(return_value=test_user)
result = await _get_fuzzy_user_object(
prisma_client=mock_prisma, sso_user_id="sso_123", user_email="test@example.com"
)
assert result == test_user
mock_prisma.db.litellm_usertable.find_unique.assert_called_with(
where={"sso_user_id": "sso_123"}, include={"organization_memberships": True}
)
# Test 2: SSO ID not found, find by email
mock_prisma.db.litellm_usertable.find_unique = AsyncMock(return_value=None)
mock_prisma.db.litellm_usertable.find_first = AsyncMock(return_value=test_user)
mock_prisma.db.litellm_usertable.update = AsyncMock()
result = await _get_fuzzy_user_object(
prisma_client=mock_prisma,
sso_user_id="new_sso_456",
user_email="test@example.com",
)
assert result == test_user
mock_prisma.db.litellm_usertable.find_first.assert_called_with(
where={"user_email": "test@example.com"},
include={"organization_memberships": True},
)
# Test 3: Verify background SSO update task when user found by email
await asyncio.sleep(0.1) # Allow time for background task
mock_prisma.db.litellm_usertable.update.assert_called_with(
where={"user_id": "test_123"}, data={"sso_user_id": "new_sso_456"}
)
# Test 4: User not found by either method
mock_prisma.db.litellm_usertable.find_unique = AsyncMock(return_value=None)
mock_prisma.db.litellm_usertable.find_first = AsyncMock(return_value=None)
result = await _get_fuzzy_user_object(
prisma_client=mock_prisma,
sso_user_id="unknown_sso",
user_email="unknown@example.com",
)
assert result is None
# Test 5: Only email provided (no SSO ID)
mock_prisma.db.litellm_usertable.find_first = AsyncMock(return_value=test_user)
result = await _get_fuzzy_user_object(
prisma_client=mock_prisma, user_email="test@example.com"
)
assert result == test_user
mock_prisma.db.litellm_usertable.find_first.assert_called_with(
where={"user_email": "test@example.com"},
include={"organization_memberships": True},
)
# Test 6: Only SSO ID provided (no email)
mock_prisma.db.litellm_usertable.find_unique = AsyncMock(return_value=test_user)
result = await _get_fuzzy_user_object(
prisma_client=mock_prisma, sso_user_id="sso_123"
)
assert result == test_user
mock_prisma.db.litellm_usertable.find_unique.assert_called_with(
where={"sso_user_id": "sso_123"}, include={"organization_memberships": True}
)