mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-26 11:14:04 +00:00
build(schema.prisma): add new sso_user_id
to LiteLLM_UserTable (#8167)
* build(schema.prisma): add new `sso_user_id` to LiteLLM_UserTable easier way to store sso id for existing user Allows existing user added to team, to login via SSO * test(test_auth_checks.py): add unit testing for fuzzy user object get * fix(handle_jwt.py): fix merge conflicts
This commit is contained in:
parent
2147cad307
commit
8d0db8b379
5 changed files with 144 additions and 3 deletions
|
@ -1576,6 +1576,7 @@ class LiteLLM_UserTable(LiteLLMPydanticObjectBase):
|
|||
user_role: Optional[str] = None
|
||||
organization_memberships: Optional[List[LiteLLM_OrganizationMembershipTable]] = None
|
||||
teams: List[str] = []
|
||||
sso_user_id: Optional[str] = None
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
|
|
|
@ -28,6 +28,7 @@ from litellm.proxy._types import (
|
|||
CallInfo,
|
||||
LiteLLM_EndUserTable,
|
||||
LiteLLM_JWTAuth,
|
||||
LiteLLM_OrganizationMembershipTable,
|
||||
LiteLLM_OrganizationTable,
|
||||
LiteLLM_TeamTable,
|
||||
LiteLLM_TeamTableCachedObj,
|
||||
|
@ -425,14 +426,55 @@ def get_role_based_models(
|
|||
return None
|
||||
|
||||
|
||||
async def _get_fuzzy_user_object(
|
||||
prisma_client: PrismaClient,
|
||||
sso_user_id: Optional[str] = None,
|
||||
user_email: Optional[str] = None,
|
||||
) -> Optional[LiteLLM_UserTable]:
|
||||
"""
|
||||
Checks if sso user is in db.
|
||||
|
||||
Called when user id match is not found in db.
|
||||
|
||||
- Check if sso_user_id is user_id in db
|
||||
- Check if sso_user_id is sso_user_id in db
|
||||
- Check if user_email is user_email in db
|
||||
- If not, create new user with user_email and sso_user_id and user_id = sso_user_id
|
||||
"""
|
||||
response = None
|
||||
if sso_user_id is not None:
|
||||
response = await prisma_client.db.litellm_usertable.find_unique(
|
||||
where={"sso_user_id": sso_user_id},
|
||||
include={"organization_memberships": True},
|
||||
)
|
||||
|
||||
if response is None and user_email is not None:
|
||||
response = await prisma_client.db.litellm_usertable.find_first(
|
||||
where={"user_email": user_email},
|
||||
include={"organization_memberships": True},
|
||||
)
|
||||
|
||||
if response is not None and sso_user_id is not None: # update sso_user_id
|
||||
asyncio.create_task( # background task to update user with sso id
|
||||
prisma_client.db.litellm_usertable.update(
|
||||
where={"user_id": response.user_id},
|
||||
data={"sso_user_id": sso_user_id},
|
||||
)
|
||||
)
|
||||
|
||||
return response
|
||||
|
||||
|
||||
@log_db_metrics
|
||||
async def get_user_object(
|
||||
user_id: str,
|
||||
user_id: Optional[str],
|
||||
prisma_client: Optional[PrismaClient],
|
||||
user_api_key_cache: DualCache,
|
||||
user_id_upsert: bool,
|
||||
parent_otel_span: Optional[Span] = None,
|
||||
proxy_logging_obj: Optional[ProxyLogging] = None,
|
||||
sso_user_id: Optional[str] = None,
|
||||
user_email: Optional[str] = None,
|
||||
) -> Optional[LiteLLM_UserTable]:
|
||||
"""
|
||||
- Check if user id in proxy User Table
|
||||
|
@ -465,6 +507,14 @@ async def get_user_object(
|
|||
response = await prisma_client.db.litellm_usertable.find_unique(
|
||||
where={"user_id": user_id}, include={"organization_memberships": True}
|
||||
)
|
||||
|
||||
if response is None:
|
||||
response = await _get_fuzzy_user_object(
|
||||
prisma_client=prisma_client,
|
||||
sso_user_id=sso_user_id,
|
||||
user_email=user_email,
|
||||
)
|
||||
|
||||
else:
|
||||
response = None
|
||||
|
||||
|
@ -483,7 +533,7 @@ async def get_user_object(
|
|||
):
|
||||
# dump each organization membership to type LiteLLM_OrganizationMembershipTable
|
||||
_dumped_memberships = [
|
||||
membership.model_dump()
|
||||
LiteLLM_OrganizationMembershipTable(**membership.model_dump())
|
||||
for membership in response.organization_memberships
|
||||
if membership is not None
|
||||
]
|
||||
|
|
|
@ -621,6 +621,7 @@ class JWTAuthManager:
|
|||
@staticmethod
|
||||
async def get_objects(
|
||||
user_id: Optional[str],
|
||||
user_email: Optional[str],
|
||||
org_id: Optional[str],
|
||||
end_user_id: Optional[str],
|
||||
valid_user_email: Optional[bool],
|
||||
|
@ -661,6 +662,8 @@ class JWTAuthManager:
|
|||
),
|
||||
parent_otel_span=parent_otel_span,
|
||||
proxy_logging_obj=proxy_logging_obj,
|
||||
user_email=user_email,
|
||||
sso_user_id=user_id,
|
||||
)
|
||||
if user_id
|
||||
else None
|
||||
|
@ -704,7 +707,7 @@ class JWTAuthManager:
|
|||
|
||||
# Get basic user info
|
||||
scopes = jwt_handler.get_scopes(token=jwt_valid_token)
|
||||
user_id, _, valid_user_email = await JWTAuthManager.get_user_info(
|
||||
user_id, user_email, valid_user_email = await JWTAuthManager.get_user_info(
|
||||
jwt_handler, jwt_valid_token
|
||||
)
|
||||
|
||||
|
@ -748,6 +751,7 @@ class JWTAuthManager:
|
|||
# Get other objects
|
||||
user_object, org_object, end_user_object = await JWTAuthManager.get_objects(
|
||||
user_id=user_id,
|
||||
user_email=user_email,
|
||||
org_id=org_id,
|
||||
end_user_id=end_user_id,
|
||||
valid_user_email=valid_user_email,
|
||||
|
|
|
@ -103,6 +103,7 @@ model LiteLLM_UserTable {
|
|||
user_id String @id
|
||||
user_alias String?
|
||||
team_id String?
|
||||
sso_user_id String? @unique
|
||||
organization_id String?
|
||||
password String?
|
||||
teams String[] @default([])
|
||||
|
|
|
@ -548,3 +548,88 @@ async def test_can_user_call_model():
|
|||
|
||||
args["model"] = "gpt-3.5-turbo"
|
||||
await can_user_call_model(**args)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_fuzzy_user_object():
|
||||
from litellm.proxy.auth.auth_checks import _get_fuzzy_user_object
|
||||
from litellm.proxy.utils import PrismaClient
|
||||
from unittest.mock import AsyncMock, MagicMock
|
||||
|
||||
# Setup mock Prisma client
|
||||
mock_prisma = MagicMock()
|
||||
mock_prisma.db = MagicMock()
|
||||
mock_prisma.db.litellm_usertable = MagicMock()
|
||||
|
||||
# Mock user data
|
||||
test_user = LiteLLM_UserTable(
|
||||
user_id="test_123",
|
||||
sso_user_id="sso_123",
|
||||
user_email="test@example.com",
|
||||
organization_memberships=[],
|
||||
max_budget=None,
|
||||
)
|
||||
|
||||
# Test 1: Find user by SSO ID
|
||||
mock_prisma.db.litellm_usertable.find_unique = AsyncMock(return_value=test_user)
|
||||
result = await _get_fuzzy_user_object(
|
||||
prisma_client=mock_prisma, sso_user_id="sso_123", user_email="test@example.com"
|
||||
)
|
||||
assert result == test_user
|
||||
mock_prisma.db.litellm_usertable.find_unique.assert_called_with(
|
||||
where={"sso_user_id": "sso_123"}, include={"organization_memberships": True}
|
||||
)
|
||||
|
||||
# Test 2: SSO ID not found, find by email
|
||||
mock_prisma.db.litellm_usertable.find_unique = AsyncMock(return_value=None)
|
||||
mock_prisma.db.litellm_usertable.find_first = AsyncMock(return_value=test_user)
|
||||
mock_prisma.db.litellm_usertable.update = AsyncMock()
|
||||
|
||||
result = await _get_fuzzy_user_object(
|
||||
prisma_client=mock_prisma,
|
||||
sso_user_id="new_sso_456",
|
||||
user_email="test@example.com",
|
||||
)
|
||||
assert result == test_user
|
||||
mock_prisma.db.litellm_usertable.find_first.assert_called_with(
|
||||
where={"user_email": "test@example.com"},
|
||||
include={"organization_memberships": True},
|
||||
)
|
||||
|
||||
# Test 3: Verify background SSO update task when user found by email
|
||||
await asyncio.sleep(0.1) # Allow time for background task
|
||||
mock_prisma.db.litellm_usertable.update.assert_called_with(
|
||||
where={"user_id": "test_123"}, data={"sso_user_id": "new_sso_456"}
|
||||
)
|
||||
|
||||
# Test 4: User not found by either method
|
||||
mock_prisma.db.litellm_usertable.find_unique = AsyncMock(return_value=None)
|
||||
mock_prisma.db.litellm_usertable.find_first = AsyncMock(return_value=None)
|
||||
|
||||
result = await _get_fuzzy_user_object(
|
||||
prisma_client=mock_prisma,
|
||||
sso_user_id="unknown_sso",
|
||||
user_email="unknown@example.com",
|
||||
)
|
||||
assert result is None
|
||||
|
||||
# Test 5: Only email provided (no SSO ID)
|
||||
mock_prisma.db.litellm_usertable.find_first = AsyncMock(return_value=test_user)
|
||||
result = await _get_fuzzy_user_object(
|
||||
prisma_client=mock_prisma, user_email="test@example.com"
|
||||
)
|
||||
assert result == test_user
|
||||
mock_prisma.db.litellm_usertable.find_first.assert_called_with(
|
||||
where={"user_email": "test@example.com"},
|
||||
include={"organization_memberships": True},
|
||||
)
|
||||
|
||||
# Test 6: Only SSO ID provided (no email)
|
||||
mock_prisma.db.litellm_usertable.find_unique = AsyncMock(return_value=test_user)
|
||||
result = await _get_fuzzy_user_object(
|
||||
prisma_client=mock_prisma, sso_user_id="sso_123"
|
||||
)
|
||||
assert result == test_user
|
||||
mock_prisma.db.litellm_usertable.find_unique.assert_called_with(
|
||||
where={"sso_user_id": "sso_123"}, include={"organization_memberships": True}
|
||||
)
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue