mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-26 19:24:27 +00:00
build(schema.prisma): add new sso_user_id
to LiteLLM_UserTable (#8167)
* build(schema.prisma): add new `sso_user_id` to LiteLLM_UserTable easier way to store sso id for existing user Allows existing user added to team, to login via SSO * test(test_auth_checks.py): add unit testing for fuzzy user object get * fix(handle_jwt.py): fix merge conflicts
This commit is contained in:
parent
2147cad307
commit
8d0db8b379
5 changed files with 144 additions and 3 deletions
|
@ -1576,6 +1576,7 @@ class LiteLLM_UserTable(LiteLLMPydanticObjectBase):
|
||||||
user_role: Optional[str] = None
|
user_role: Optional[str] = None
|
||||||
organization_memberships: Optional[List[LiteLLM_OrganizationMembershipTable]] = None
|
organization_memberships: Optional[List[LiteLLM_OrganizationMembershipTable]] = None
|
||||||
teams: List[str] = []
|
teams: List[str] = []
|
||||||
|
sso_user_id: Optional[str] = None
|
||||||
|
|
||||||
@model_validator(mode="before")
|
@model_validator(mode="before")
|
||||||
@classmethod
|
@classmethod
|
||||||
|
|
|
@ -28,6 +28,7 @@ from litellm.proxy._types import (
|
||||||
CallInfo,
|
CallInfo,
|
||||||
LiteLLM_EndUserTable,
|
LiteLLM_EndUserTable,
|
||||||
LiteLLM_JWTAuth,
|
LiteLLM_JWTAuth,
|
||||||
|
LiteLLM_OrganizationMembershipTable,
|
||||||
LiteLLM_OrganizationTable,
|
LiteLLM_OrganizationTable,
|
||||||
LiteLLM_TeamTable,
|
LiteLLM_TeamTable,
|
||||||
LiteLLM_TeamTableCachedObj,
|
LiteLLM_TeamTableCachedObj,
|
||||||
|
@ -425,14 +426,55 @@ def get_role_based_models(
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
async def _get_fuzzy_user_object(
|
||||||
|
prisma_client: PrismaClient,
|
||||||
|
sso_user_id: Optional[str] = None,
|
||||||
|
user_email: Optional[str] = None,
|
||||||
|
) -> Optional[LiteLLM_UserTable]:
|
||||||
|
"""
|
||||||
|
Checks if sso user is in db.
|
||||||
|
|
||||||
|
Called when user id match is not found in db.
|
||||||
|
|
||||||
|
- Check if sso_user_id is user_id in db
|
||||||
|
- Check if sso_user_id is sso_user_id in db
|
||||||
|
- Check if user_email is user_email in db
|
||||||
|
- If not, create new user with user_email and sso_user_id and user_id = sso_user_id
|
||||||
|
"""
|
||||||
|
response = None
|
||||||
|
if sso_user_id is not None:
|
||||||
|
response = await prisma_client.db.litellm_usertable.find_unique(
|
||||||
|
where={"sso_user_id": sso_user_id},
|
||||||
|
include={"organization_memberships": True},
|
||||||
|
)
|
||||||
|
|
||||||
|
if response is None and user_email is not None:
|
||||||
|
response = await prisma_client.db.litellm_usertable.find_first(
|
||||||
|
where={"user_email": user_email},
|
||||||
|
include={"organization_memberships": True},
|
||||||
|
)
|
||||||
|
|
||||||
|
if response is not None and sso_user_id is not None: # update sso_user_id
|
||||||
|
asyncio.create_task( # background task to update user with sso id
|
||||||
|
prisma_client.db.litellm_usertable.update(
|
||||||
|
where={"user_id": response.user_id},
|
||||||
|
data={"sso_user_id": sso_user_id},
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
return response
|
||||||
|
|
||||||
|
|
||||||
@log_db_metrics
|
@log_db_metrics
|
||||||
async def get_user_object(
|
async def get_user_object(
|
||||||
user_id: str,
|
user_id: Optional[str],
|
||||||
prisma_client: Optional[PrismaClient],
|
prisma_client: Optional[PrismaClient],
|
||||||
user_api_key_cache: DualCache,
|
user_api_key_cache: DualCache,
|
||||||
user_id_upsert: bool,
|
user_id_upsert: bool,
|
||||||
parent_otel_span: Optional[Span] = None,
|
parent_otel_span: Optional[Span] = None,
|
||||||
proxy_logging_obj: Optional[ProxyLogging] = None,
|
proxy_logging_obj: Optional[ProxyLogging] = None,
|
||||||
|
sso_user_id: Optional[str] = None,
|
||||||
|
user_email: Optional[str] = None,
|
||||||
) -> Optional[LiteLLM_UserTable]:
|
) -> Optional[LiteLLM_UserTable]:
|
||||||
"""
|
"""
|
||||||
- Check if user id in proxy User Table
|
- Check if user id in proxy User Table
|
||||||
|
@ -465,6 +507,14 @@ async def get_user_object(
|
||||||
response = await prisma_client.db.litellm_usertable.find_unique(
|
response = await prisma_client.db.litellm_usertable.find_unique(
|
||||||
where={"user_id": user_id}, include={"organization_memberships": True}
|
where={"user_id": user_id}, include={"organization_memberships": True}
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if response is None:
|
||||||
|
response = await _get_fuzzy_user_object(
|
||||||
|
prisma_client=prisma_client,
|
||||||
|
sso_user_id=sso_user_id,
|
||||||
|
user_email=user_email,
|
||||||
|
)
|
||||||
|
|
||||||
else:
|
else:
|
||||||
response = None
|
response = None
|
||||||
|
|
||||||
|
@ -483,7 +533,7 @@ async def get_user_object(
|
||||||
):
|
):
|
||||||
# dump each organization membership to type LiteLLM_OrganizationMembershipTable
|
# dump each organization membership to type LiteLLM_OrganizationMembershipTable
|
||||||
_dumped_memberships = [
|
_dumped_memberships = [
|
||||||
membership.model_dump()
|
LiteLLM_OrganizationMembershipTable(**membership.model_dump())
|
||||||
for membership in response.organization_memberships
|
for membership in response.organization_memberships
|
||||||
if membership is not None
|
if membership is not None
|
||||||
]
|
]
|
||||||
|
|
|
@ -621,6 +621,7 @@ class JWTAuthManager:
|
||||||
@staticmethod
|
@staticmethod
|
||||||
async def get_objects(
|
async def get_objects(
|
||||||
user_id: Optional[str],
|
user_id: Optional[str],
|
||||||
|
user_email: Optional[str],
|
||||||
org_id: Optional[str],
|
org_id: Optional[str],
|
||||||
end_user_id: Optional[str],
|
end_user_id: Optional[str],
|
||||||
valid_user_email: Optional[bool],
|
valid_user_email: Optional[bool],
|
||||||
|
@ -661,6 +662,8 @@ class JWTAuthManager:
|
||||||
),
|
),
|
||||||
parent_otel_span=parent_otel_span,
|
parent_otel_span=parent_otel_span,
|
||||||
proxy_logging_obj=proxy_logging_obj,
|
proxy_logging_obj=proxy_logging_obj,
|
||||||
|
user_email=user_email,
|
||||||
|
sso_user_id=user_id,
|
||||||
)
|
)
|
||||||
if user_id
|
if user_id
|
||||||
else None
|
else None
|
||||||
|
@ -704,7 +707,7 @@ class JWTAuthManager:
|
||||||
|
|
||||||
# Get basic user info
|
# Get basic user info
|
||||||
scopes = jwt_handler.get_scopes(token=jwt_valid_token)
|
scopes = jwt_handler.get_scopes(token=jwt_valid_token)
|
||||||
user_id, _, valid_user_email = await JWTAuthManager.get_user_info(
|
user_id, user_email, valid_user_email = await JWTAuthManager.get_user_info(
|
||||||
jwt_handler, jwt_valid_token
|
jwt_handler, jwt_valid_token
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -748,6 +751,7 @@ class JWTAuthManager:
|
||||||
# Get other objects
|
# Get other objects
|
||||||
user_object, org_object, end_user_object = await JWTAuthManager.get_objects(
|
user_object, org_object, end_user_object = await JWTAuthManager.get_objects(
|
||||||
user_id=user_id,
|
user_id=user_id,
|
||||||
|
user_email=user_email,
|
||||||
org_id=org_id,
|
org_id=org_id,
|
||||||
end_user_id=end_user_id,
|
end_user_id=end_user_id,
|
||||||
valid_user_email=valid_user_email,
|
valid_user_email=valid_user_email,
|
||||||
|
|
|
@ -103,6 +103,7 @@ model LiteLLM_UserTable {
|
||||||
user_id String @id
|
user_id String @id
|
||||||
user_alias String?
|
user_alias String?
|
||||||
team_id String?
|
team_id String?
|
||||||
|
sso_user_id String? @unique
|
||||||
organization_id String?
|
organization_id String?
|
||||||
password String?
|
password String?
|
||||||
teams String[] @default([])
|
teams String[] @default([])
|
||||||
|
|
|
@ -548,3 +548,88 @@ async def test_can_user_call_model():
|
||||||
|
|
||||||
args["model"] = "gpt-3.5-turbo"
|
args["model"] = "gpt-3.5-turbo"
|
||||||
await can_user_call_model(**args)
|
await can_user_call_model(**args)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_get_fuzzy_user_object():
|
||||||
|
from litellm.proxy.auth.auth_checks import _get_fuzzy_user_object
|
||||||
|
from litellm.proxy.utils import PrismaClient
|
||||||
|
from unittest.mock import AsyncMock, MagicMock
|
||||||
|
|
||||||
|
# Setup mock Prisma client
|
||||||
|
mock_prisma = MagicMock()
|
||||||
|
mock_prisma.db = MagicMock()
|
||||||
|
mock_prisma.db.litellm_usertable = MagicMock()
|
||||||
|
|
||||||
|
# Mock user data
|
||||||
|
test_user = LiteLLM_UserTable(
|
||||||
|
user_id="test_123",
|
||||||
|
sso_user_id="sso_123",
|
||||||
|
user_email="test@example.com",
|
||||||
|
organization_memberships=[],
|
||||||
|
max_budget=None,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Test 1: Find user by SSO ID
|
||||||
|
mock_prisma.db.litellm_usertable.find_unique = AsyncMock(return_value=test_user)
|
||||||
|
result = await _get_fuzzy_user_object(
|
||||||
|
prisma_client=mock_prisma, sso_user_id="sso_123", user_email="test@example.com"
|
||||||
|
)
|
||||||
|
assert result == test_user
|
||||||
|
mock_prisma.db.litellm_usertable.find_unique.assert_called_with(
|
||||||
|
where={"sso_user_id": "sso_123"}, include={"organization_memberships": True}
|
||||||
|
)
|
||||||
|
|
||||||
|
# Test 2: SSO ID not found, find by email
|
||||||
|
mock_prisma.db.litellm_usertable.find_unique = AsyncMock(return_value=None)
|
||||||
|
mock_prisma.db.litellm_usertable.find_first = AsyncMock(return_value=test_user)
|
||||||
|
mock_prisma.db.litellm_usertable.update = AsyncMock()
|
||||||
|
|
||||||
|
result = await _get_fuzzy_user_object(
|
||||||
|
prisma_client=mock_prisma,
|
||||||
|
sso_user_id="new_sso_456",
|
||||||
|
user_email="test@example.com",
|
||||||
|
)
|
||||||
|
assert result == test_user
|
||||||
|
mock_prisma.db.litellm_usertable.find_first.assert_called_with(
|
||||||
|
where={"user_email": "test@example.com"},
|
||||||
|
include={"organization_memberships": True},
|
||||||
|
)
|
||||||
|
|
||||||
|
# Test 3: Verify background SSO update task when user found by email
|
||||||
|
await asyncio.sleep(0.1) # Allow time for background task
|
||||||
|
mock_prisma.db.litellm_usertable.update.assert_called_with(
|
||||||
|
where={"user_id": "test_123"}, data={"sso_user_id": "new_sso_456"}
|
||||||
|
)
|
||||||
|
|
||||||
|
# Test 4: User not found by either method
|
||||||
|
mock_prisma.db.litellm_usertable.find_unique = AsyncMock(return_value=None)
|
||||||
|
mock_prisma.db.litellm_usertable.find_first = AsyncMock(return_value=None)
|
||||||
|
|
||||||
|
result = await _get_fuzzy_user_object(
|
||||||
|
prisma_client=mock_prisma,
|
||||||
|
sso_user_id="unknown_sso",
|
||||||
|
user_email="unknown@example.com",
|
||||||
|
)
|
||||||
|
assert result is None
|
||||||
|
|
||||||
|
# Test 5: Only email provided (no SSO ID)
|
||||||
|
mock_prisma.db.litellm_usertable.find_first = AsyncMock(return_value=test_user)
|
||||||
|
result = await _get_fuzzy_user_object(
|
||||||
|
prisma_client=mock_prisma, user_email="test@example.com"
|
||||||
|
)
|
||||||
|
assert result == test_user
|
||||||
|
mock_prisma.db.litellm_usertable.find_first.assert_called_with(
|
||||||
|
where={"user_email": "test@example.com"},
|
||||||
|
include={"organization_memberships": True},
|
||||||
|
)
|
||||||
|
|
||||||
|
# Test 6: Only SSO ID provided (no email)
|
||||||
|
mock_prisma.db.litellm_usertable.find_unique = AsyncMock(return_value=test_user)
|
||||||
|
result = await _get_fuzzy_user_object(
|
||||||
|
prisma_client=mock_prisma, sso_user_id="sso_123"
|
||||||
|
)
|
||||||
|
assert result == test_user
|
||||||
|
mock_prisma.db.litellm_usertable.find_unique.assert_called_with(
|
||||||
|
where={"sso_user_id": "sso_123"}, include={"organization_memberships": True}
|
||||||
|
)
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue