fix: fix typing

This commit is contained in:
Krrish Dholakia 2024-11-23 02:36:09 +05:30
parent 31943bf2ad
commit f8a46b5950
2 changed files with 43 additions and 11 deletions

View file

@ -2,6 +2,7 @@ import enum
import json import json
import os import os
import sys import sys
import traceback
import uuid import uuid
from dataclasses import fields from dataclasses import fields
from datetime import datetime from datetime import datetime
@ -890,11 +891,7 @@ class DeleteCustomerRequest(LiteLLMBase):
user_ids: List[str] user_ids: List[str]
class Member(LiteLLMBase): class MemberBase(LiteLLMBase):
role: Literal[
"admin",
"user",
]
user_id: Optional[str] = None user_id: Optional[str] = None
user_email: Optional[str] = None user_email: Optional[str] = None
@ -908,6 +905,21 @@ class Member(LiteLLMBase):
return values return values
class Member(MemberBase):
role: Literal[
"admin",
"user",
]
class OrgMember(MemberBase):
role: Literal[
LitellmUserRoles.ORG_ADMIN,
LitellmUserRoles.INTERNAL_USER,
LitellmUserRoles.INTERNAL_USER_VIEW_ONLY,
]
class TeamBase(LiteLLMBase): class TeamBase(LiteLLMBase):
team_alias: Optional[str] = None team_alias: Optional[str] = None
team_id: Optional[str] = None team_id: Optional[str] = None
@ -1970,6 +1982,26 @@ class MemberAddRequest(LiteLLMBase):
# Replace member_data with the single Member object # Replace member_data with the single Member object
data["member"] = member data["member"] = member
# Call the superclass __init__ method to initialize the object # Call the superclass __init__ method to initialize the object
traceback.print_stack()
super().__init__(**data)
class OrgMemberAddRequest(LiteLLMBase):
member: Union[List[OrgMember], OrgMember]
def __init__(self, **data):
member_data = data.get("member")
if isinstance(member_data, list):
# If member is a list of dictionaries, convert each dictionary to a Member object
members = [OrgMember(**item) for item in member_data]
# Replace member_data with the list of Member objects
data["member"] = members
elif isinstance(member_data, dict):
# If member is a dictionary, convert it to a single Member object
member = OrgMember(**member_data)
# Replace member_data with the single Member object
data["member"] = member
# Call the superclass __init__ method to initialize the object
super().__init__(**data) super().__init__(**data)
@ -2021,7 +2053,7 @@ class TeamMemberUpdateResponse(MemberUpdateResponse):
# Organization Member Requests # Organization Member Requests
class OrganizationMemberAddRequest(MemberAddRequest): class OrganizationMemberAddRequest(OrgMemberAddRequest):
organization_id: str organization_id: str
max_budget_in_organization: Optional[float] = ( max_budget_in_organization: Optional[float] = (
None # Users max budget within the organization None # Users max budget within the organization

View file

@ -160,7 +160,7 @@ async def test_create_new_user_in_organization(prisma_client, user_role):
response = await organization_member_add( response = await organization_member_add(
data=OrganizationMemberAddRequest( data=OrganizationMemberAddRequest(
organization_id=org_id, organization_id=org_id,
member=Member(role=user_role, user_id=created_user_id), member=OrgMember(role=user_role, user_id=created_user_id),
), ),
http_request=None, http_request=None,
) )
@ -220,7 +220,7 @@ async def test_org_admin_create_team_permissions(prisma_client):
response = await organization_member_add( response = await organization_member_add(
data=OrganizationMemberAddRequest( data=OrganizationMemberAddRequest(
organization_id=org_id, organization_id=org_id,
member=Member(role=LitellmUserRoles.ORG_ADMIN, user_id=created_user_id), member=OrgMember(role=LitellmUserRoles.ORG_ADMIN, user_id=created_user_id),
), ),
http_request=None, http_request=None,
) )
@ -292,7 +292,7 @@ async def test_org_admin_create_user_permissions(prisma_client):
response = await organization_member_add( response = await organization_member_add(
data=OrganizationMemberAddRequest( data=OrganizationMemberAddRequest(
organization_id=org_id, organization_id=org_id,
member=Member(role=LitellmUserRoles.ORG_ADMIN, user_id=created_user_id), member=OrgMember(role=LitellmUserRoles.ORG_ADMIN, user_id=created_user_id),
), ),
http_request=None, http_request=None,
) )
@ -323,7 +323,7 @@ async def test_org_admin_create_user_permissions(prisma_client):
response = await organization_member_add( response = await organization_member_add(
data=OrganizationMemberAddRequest( data=OrganizationMemberAddRequest(
organization_id=org_id, organization_id=org_id,
member=Member( member=OrgMember(
role=LitellmUserRoles.INTERNAL_USER, user_id=new_internal_user_for_org role=LitellmUserRoles.INTERNAL_USER, user_id=new_internal_user_for_org
), ),
), ),
@ -375,7 +375,7 @@ async def test_org_admin_create_user_team_wrong_org_permissions(prisma_client):
response = await organization_member_add( response = await organization_member_add(
data=OrganizationMemberAddRequest( data=OrganizationMemberAddRequest(
organization_id=org1_id, organization_id=org1_id,
member=Member(role=LitellmUserRoles.ORG_ADMIN, user_id=created_user_id), member=OrgMember(role=LitellmUserRoles.ORG_ADMIN, user_id=created_user_id),
), ),
http_request=None, http_request=None,
) )