mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-27 11:43:54 +00:00
feat(handle_jwt.py): add support for 'team_id_default
allows admin to set a default team id for spend-tracking + permissions
This commit is contained in:
parent
da2ea0ba04
commit
ed4315af38
4 changed files with 36 additions and 10 deletions
|
@ -36,6 +36,9 @@ litellm_settings:
|
||||||
|
|
||||||
general_settings:
|
general_settings:
|
||||||
enable_jwt_auth: True
|
enable_jwt_auth: True
|
||||||
|
litellm_jwtauth:
|
||||||
|
team_id_default: "1234"
|
||||||
|
upsert_users: True
|
||||||
disable_reset_budget: True
|
disable_reset_budget: True
|
||||||
proxy_batch_write_at: 60 # 👈 Frequency of batch writing logs to server (in seconds)
|
proxy_batch_write_at: 60 # 👈 Frequency of batch writing logs to server (in seconds)
|
||||||
routing_strategy: simple-shuffle # Literal["simple-shuffle", "least-busy", "usage-based-routing","latency-based-routing"], default="simple-shuffle"
|
routing_strategy: simple-shuffle # Literal["simple-shuffle", "least-busy", "usage-based-routing","latency-based-routing"], default="simple-shuffle"
|
||||||
|
|
|
@ -227,13 +227,19 @@ class LiteLLM_JWTAuth(LiteLLMBase):
|
||||||
"global_spend_tracking_routes",
|
"global_spend_tracking_routes",
|
||||||
"info_routes",
|
"info_routes",
|
||||||
]
|
]
|
||||||
team_jwt_scope: str = "litellm_team"
|
|
||||||
team_id_jwt_field: Optional[str] = None
|
team_id_jwt_field: Optional[str] = None
|
||||||
team_allowed_routes: List[
|
team_allowed_routes: List[
|
||||||
Literal["openai_routes", "info_routes", "management_routes"]
|
Literal["openai_routes", "info_routes", "management_routes"]
|
||||||
] = ["openai_routes", "info_routes"]
|
] = ["openai_routes", "info_routes"]
|
||||||
|
team_id_default: Optional[str] = Field(
|
||||||
|
default=None,
|
||||||
|
description="If no team_id given, default permissions/spend-tracking to this team.s",
|
||||||
|
)
|
||||||
org_id_jwt_field: Optional[str] = None
|
org_id_jwt_field: Optional[str] = None
|
||||||
user_id_jwt_field: Optional[str] = None
|
user_id_jwt_field: Optional[str] = None
|
||||||
|
user_id_upsert: bool = Field(
|
||||||
|
default=False, description="If user doesn't exist, upsert them into the db."
|
||||||
|
)
|
||||||
end_user_id_jwt_field: Optional[str] = None
|
end_user_id_jwt_field: Optional[str] = None
|
||||||
public_key_ttl: float = 600
|
public_key_ttl: float = 600
|
||||||
|
|
||||||
|
|
|
@ -55,11 +55,6 @@ class JWTHandler:
|
||||||
return True
|
return True
|
||||||
return False
|
return False
|
||||||
|
|
||||||
def is_team(self, scopes: list) -> bool:
|
|
||||||
if self.litellm_jwtauth.team_jwt_scope in scopes:
|
|
||||||
return True
|
|
||||||
return False
|
|
||||||
|
|
||||||
def get_end_user_id(
|
def get_end_user_id(
|
||||||
self, token: dict, default_value: Optional[str]
|
self, token: dict, default_value: Optional[str]
|
||||||
) -> Optional[str]:
|
) -> Optional[str]:
|
||||||
|
@ -84,7 +79,12 @@ class JWTHandler:
|
||||||
|
|
||||||
def get_team_id(self, token: dict, default_value: Optional[str]) -> Optional[str]:
|
def get_team_id(self, token: dict, default_value: Optional[str]) -> Optional[str]:
|
||||||
try:
|
try:
|
||||||
team_id = token[self.litellm_jwtauth.team_id_jwt_field]
|
if self.litellm_jwtauth.team_id_jwt_field is not None:
|
||||||
|
team_id = token[self.litellm_jwtauth.team_id_jwt_field]
|
||||||
|
elif self.litellm_jwtauth.team_id_default is not None:
|
||||||
|
team_id = self.litellm_jwtauth.team_id_default
|
||||||
|
else:
|
||||||
|
team_id = None
|
||||||
except KeyError:
|
except KeyError:
|
||||||
team_id = default_value
|
team_id = default_value
|
||||||
return team_id
|
return team_id
|
||||||
|
|
|
@ -368,9 +368,14 @@ async def test_team_token_output(prisma_client, audience):
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("audience", [None, "litellm-proxy"])
|
@pytest.mark.parametrize("audience", [None, "litellm-proxy"])
|
||||||
@pytest.mark.parametrize("team_id_set", [True, False])
|
@pytest.mark.parametrize(
|
||||||
|
"team_id_set, default_team_id", [(True, None), (False, "1234")]
|
||||||
|
)
|
||||||
|
@pytest.mark.parametrize("user_id_upsert", [True, False])
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_user_token_output(prisma_client, audience, team_id_set):
|
async def test_user_token_output(
|
||||||
|
prisma_client, audience, team_id_set, default_team_id, user_id_upsert
|
||||||
|
):
|
||||||
"""
|
"""
|
||||||
- If user required, check if it exists
|
- If user required, check if it exists
|
||||||
- fail initial request (when user doesn't exist)
|
- fail initial request (when user doesn't exist)
|
||||||
|
@ -433,6 +438,8 @@ async def test_user_token_output(prisma_client, audience, team_id_set):
|
||||||
jwt_handler.litellm_jwtauth = LiteLLM_JWTAuth()
|
jwt_handler.litellm_jwtauth = LiteLLM_JWTAuth()
|
||||||
|
|
||||||
jwt_handler.litellm_jwtauth.user_id_jwt_field = "sub"
|
jwt_handler.litellm_jwtauth.user_id_jwt_field = "sub"
|
||||||
|
jwt_handler.litellm_jwtauth.team_id_default = default_team_id
|
||||||
|
|
||||||
if team_id_set:
|
if team_id_set:
|
||||||
jwt_handler.litellm_jwtauth.team_id_jwt_field = "client_id"
|
jwt_handler.litellm_jwtauth.team_id_jwt_field = "client_id"
|
||||||
|
|
||||||
|
@ -515,6 +522,16 @@ async def test_user_token_output(prisma_client, audience, team_id_set):
|
||||||
),
|
),
|
||||||
user_api_key_dict=result,
|
user_api_key_dict=result,
|
||||||
)
|
)
|
||||||
|
if default_team_id is not None:
|
||||||
|
await new_team(
|
||||||
|
data=NewTeamRequest(
|
||||||
|
team_id=default_team_id,
|
||||||
|
tpm_limit=100,
|
||||||
|
rpm_limit=99,
|
||||||
|
models=["gpt-3.5-turbo", "gpt-4"],
|
||||||
|
),
|
||||||
|
user_api_key_dict=result,
|
||||||
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
pytest.fail(f"This should not fail - {str(e)}")
|
pytest.fail(f"This should not fail - {str(e)}")
|
||||||
|
|
||||||
|
@ -555,7 +572,7 @@ async def test_user_token_output(prisma_client, audience, team_id_set):
|
||||||
|
|
||||||
## 6. ASSERT USER_API_KEY_AUTH format (used for tpm/rpm limiting in parallel_request_limiter.py AND cost tracking)
|
## 6. ASSERT USER_API_KEY_AUTH format (used for tpm/rpm limiting in parallel_request_limiter.py AND cost tracking)
|
||||||
|
|
||||||
if team_id_set:
|
if team_id_set or default_team_id is not None:
|
||||||
assert team_result.team_tpm_limit == 100
|
assert team_result.team_tpm_limit == 100
|
||||||
assert team_result.team_rpm_limit == 99
|
assert team_result.team_rpm_limit == 99
|
||||||
assert team_result.team_models == ["gpt-3.5-turbo", "gpt-4"]
|
assert team_result.team_models == ["gpt-3.5-turbo", "gpt-4"]
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue