forked from phoenix/litellm-mirror
Merge pull request #3666 from BerriAI/litellm_jwt_fix
feat(proxy_server.py): JWT-Auth improvements
This commit is contained in:
commit
57d425aed7
6 changed files with 192 additions and 86 deletions
|
@ -36,6 +36,10 @@ litellm_settings:
|
||||||
|
|
||||||
general_settings:
|
general_settings:
|
||||||
enable_jwt_auth: True
|
enable_jwt_auth: True
|
||||||
|
litellm_jwtauth:
|
||||||
|
team_id_default: "1234"
|
||||||
|
user_id_jwt_field:
|
||||||
|
user_id_upsert: True
|
||||||
disable_reset_budget: True
|
disable_reset_budget: True
|
||||||
proxy_batch_write_at: 60 # 👈 Frequency of batch writing logs to server (in seconds)
|
proxy_batch_write_at: 60 # 👈 Frequency of batch writing logs to server (in seconds)
|
||||||
routing_strategy: simple-shuffle # Literal["simple-shuffle", "least-busy", "usage-based-routing","latency-based-routing"], default="simple-shuffle"
|
routing_strategy: simple-shuffle # Literal["simple-shuffle", "least-busy", "usage-based-routing","latency-based-routing"], default="simple-shuffle"
|
||||||
|
|
|
@ -202,13 +202,19 @@ class LiteLLM_JWTAuth(LiteLLMBase):
|
||||||
"global_spend_tracking_routes",
|
"global_spend_tracking_routes",
|
||||||
"info_routes",
|
"info_routes",
|
||||||
]
|
]
|
||||||
team_jwt_scope: str = "litellm_team"
|
team_id_jwt_field: Optional[str] = None
|
||||||
team_id_jwt_field: str = "client_id"
|
|
||||||
team_allowed_routes: List[
|
team_allowed_routes: List[
|
||||||
Literal["openai_routes", "info_routes", "management_routes"]
|
Literal["openai_routes", "info_routes", "management_routes"]
|
||||||
] = ["openai_routes", "info_routes"]
|
] = ["openai_routes", "info_routes"]
|
||||||
|
team_id_default: Optional[str] = Field(
|
||||||
|
default=None,
|
||||||
|
description="If no team_id given, default permissions/spend-tracking to this team.s",
|
||||||
|
)
|
||||||
org_id_jwt_field: Optional[str] = None
|
org_id_jwt_field: Optional[str] = None
|
||||||
user_id_jwt_field: Optional[str] = None
|
user_id_jwt_field: Optional[str] = None
|
||||||
|
user_id_upsert: bool = Field(
|
||||||
|
default=False, description="If user doesn't exist, upsert them into the db."
|
||||||
|
)
|
||||||
end_user_id_jwt_field: Optional[str] = None
|
end_user_id_jwt_field: Optional[str] = None
|
||||||
public_key_ttl: float = 600
|
public_key_ttl: float = 600
|
||||||
|
|
||||||
|
|
|
@ -26,7 +26,7 @@ all_routes = LiteLLMRoutes.openai_routes.value + LiteLLMRoutes.management_routes
|
||||||
|
|
||||||
def common_checks(
|
def common_checks(
|
||||||
request_body: dict,
|
request_body: dict,
|
||||||
team_object: LiteLLM_TeamTable,
|
team_object: Optional[LiteLLM_TeamTable],
|
||||||
user_object: Optional[LiteLLM_UserTable],
|
user_object: Optional[LiteLLM_UserTable],
|
||||||
end_user_object: Optional[LiteLLM_EndUserTable],
|
end_user_object: Optional[LiteLLM_EndUserTable],
|
||||||
global_proxy_spend: Optional[float],
|
global_proxy_spend: Optional[float],
|
||||||
|
@ -45,13 +45,14 @@ def common_checks(
|
||||||
6. [OPTIONAL] If 'litellm.max_budget' is set (>0), is proxy under budget
|
6. [OPTIONAL] If 'litellm.max_budget' is set (>0), is proxy under budget
|
||||||
"""
|
"""
|
||||||
_model = request_body.get("model", None)
|
_model = request_body.get("model", None)
|
||||||
if team_object.blocked == True:
|
if team_object is not None and team_object.blocked == True:
|
||||||
raise Exception(
|
raise Exception(
|
||||||
f"Team={team_object.team_id} is blocked. Update via `/team/unblock` if your admin."
|
f"Team={team_object.team_id} is blocked. Update via `/team/unblock` if your admin."
|
||||||
)
|
)
|
||||||
# 2. If user can call model
|
# 2. If user can call model
|
||||||
if (
|
if (
|
||||||
_model is not None
|
_model is not None
|
||||||
|
and team_object is not None
|
||||||
and len(team_object.models) > 0
|
and len(team_object.models) > 0
|
||||||
and _model not in team_object.models
|
and _model not in team_object.models
|
||||||
):
|
):
|
||||||
|
@ -65,7 +66,8 @@ def common_checks(
|
||||||
)
|
)
|
||||||
# 3. If team is in budget
|
# 3. If team is in budget
|
||||||
if (
|
if (
|
||||||
team_object.max_budget is not None
|
team_object is not None
|
||||||
|
and team_object.max_budget is not None
|
||||||
and team_object.spend is not None
|
and team_object.spend is not None
|
||||||
and team_object.spend > team_object.max_budget
|
and team_object.spend > team_object.max_budget
|
||||||
):
|
):
|
||||||
|
@ -239,6 +241,7 @@ async def get_user_object(
|
||||||
user_id: str,
|
user_id: str,
|
||||||
prisma_client: Optional[PrismaClient],
|
prisma_client: Optional[PrismaClient],
|
||||||
user_api_key_cache: DualCache,
|
user_api_key_cache: DualCache,
|
||||||
|
user_id_upsert: bool,
|
||||||
) -> Optional[LiteLLM_UserTable]:
|
) -> Optional[LiteLLM_UserTable]:
|
||||||
"""
|
"""
|
||||||
- Check if user id in proxy User Table
|
- Check if user id in proxy User Table
|
||||||
|
@ -252,7 +255,7 @@ async def get_user_object(
|
||||||
return None
|
return None
|
||||||
|
|
||||||
# check if in cache
|
# check if in cache
|
||||||
cached_user_obj = user_api_key_cache.async_get_cache(key=user_id)
|
cached_user_obj = await user_api_key_cache.async_get_cache(key=user_id)
|
||||||
if cached_user_obj is not None:
|
if cached_user_obj is not None:
|
||||||
if isinstance(cached_user_obj, dict):
|
if isinstance(cached_user_obj, dict):
|
||||||
return LiteLLM_UserTable(**cached_user_obj)
|
return LiteLLM_UserTable(**cached_user_obj)
|
||||||
|
@ -260,16 +263,27 @@ async def get_user_object(
|
||||||
return cached_user_obj
|
return cached_user_obj
|
||||||
# else, check db
|
# else, check db
|
||||||
try:
|
try:
|
||||||
|
|
||||||
response = await prisma_client.db.litellm_usertable.find_unique(
|
response = await prisma_client.db.litellm_usertable.find_unique(
|
||||||
where={"user_id": user_id}
|
where={"user_id": user_id}
|
||||||
)
|
)
|
||||||
|
|
||||||
if response is None:
|
if response is None:
|
||||||
raise Exception
|
if user_id_upsert:
|
||||||
|
response = await prisma_client.db.litellm_usertable.create(
|
||||||
|
data={"user_id": user_id}
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
raise Exception
|
||||||
|
|
||||||
return LiteLLM_UserTable(**response.dict())
|
_response = LiteLLM_UserTable(**dict(response))
|
||||||
except Exception as e: # if end-user not in db
|
|
||||||
raise Exception(
|
# save the user object to cache
|
||||||
|
await user_api_key_cache.async_set_cache(key=user_id, value=_response)
|
||||||
|
|
||||||
|
return _response
|
||||||
|
except Exception as e: # if user not in db
|
||||||
|
raise ValueError(
|
||||||
f"User doesn't exist in db. 'user_id'={user_id}. Create user via `/user/new` call."
|
f"User doesn't exist in db. 'user_id'={user_id}. Create user via `/user/new` call."
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -290,7 +304,7 @@ async def get_team_object(
|
||||||
)
|
)
|
||||||
|
|
||||||
# check if in cache
|
# check if in cache
|
||||||
cached_team_obj = user_api_key_cache.async_get_cache(key=team_id)
|
cached_team_obj = await user_api_key_cache.async_get_cache(key=team_id)
|
||||||
if cached_team_obj is not None:
|
if cached_team_obj is not None:
|
||||||
if isinstance(cached_team_obj, dict):
|
if isinstance(cached_team_obj, dict):
|
||||||
return LiteLLM_TeamTable(**cached_team_obj)
|
return LiteLLM_TeamTable(**cached_team_obj)
|
||||||
|
@ -305,7 +319,11 @@ async def get_team_object(
|
||||||
if response is None:
|
if response is None:
|
||||||
raise Exception
|
raise Exception
|
||||||
|
|
||||||
return LiteLLM_TeamTable(**response.dict())
|
_response = LiteLLM_TeamTable(**response.dict())
|
||||||
|
# save the team object to cache
|
||||||
|
await user_api_key_cache.async_set_cache(key=response.team_id, value=_response)
|
||||||
|
|
||||||
|
return _response
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise Exception(
|
raise Exception(
|
||||||
f"Team doesn't exist in db. Team={team_id}. Create team via `/team/new` call."
|
f"Team doesn't exist in db. Team={team_id}. Create team via `/team/new` call."
|
||||||
|
|
|
@ -55,12 +55,9 @@ class JWTHandler:
|
||||||
return True
|
return True
|
||||||
return False
|
return False
|
||||||
|
|
||||||
def is_team(self, scopes: list) -> bool:
|
def get_end_user_id(
|
||||||
if self.litellm_jwtauth.team_jwt_scope in scopes:
|
self, token: dict, default_value: Optional[str]
|
||||||
return True
|
) -> Optional[str]:
|
||||||
return False
|
|
||||||
|
|
||||||
def get_end_user_id(self, token: dict, default_value: Optional[str]) -> str:
|
|
||||||
try:
|
try:
|
||||||
if self.litellm_jwtauth.end_user_id_jwt_field is not None:
|
if self.litellm_jwtauth.end_user_id_jwt_field is not None:
|
||||||
user_id = token[self.litellm_jwtauth.end_user_id_jwt_field]
|
user_id = token[self.litellm_jwtauth.end_user_id_jwt_field]
|
||||||
|
@ -70,13 +67,36 @@ class JWTHandler:
|
||||||
user_id = default_value
|
user_id = default_value
|
||||||
return user_id
|
return user_id
|
||||||
|
|
||||||
|
def is_required_team_id(self) -> bool:
|
||||||
|
"""
|
||||||
|
Returns:
|
||||||
|
- True: if 'team_id_jwt_field' is set
|
||||||
|
- False: if not
|
||||||
|
"""
|
||||||
|
if self.litellm_jwtauth.team_id_jwt_field is None:
|
||||||
|
return False
|
||||||
|
return True
|
||||||
|
|
||||||
def get_team_id(self, token: dict, default_value: Optional[str]) -> Optional[str]:
|
def get_team_id(self, token: dict, default_value: Optional[str]) -> Optional[str]:
|
||||||
try:
|
try:
|
||||||
team_id = token[self.litellm_jwtauth.team_id_jwt_field]
|
if self.litellm_jwtauth.team_id_jwt_field is not None:
|
||||||
|
team_id = token[self.litellm_jwtauth.team_id_jwt_field]
|
||||||
|
elif self.litellm_jwtauth.team_id_default is not None:
|
||||||
|
team_id = self.litellm_jwtauth.team_id_default
|
||||||
|
else:
|
||||||
|
team_id = None
|
||||||
except KeyError:
|
except KeyError:
|
||||||
team_id = default_value
|
team_id = default_value
|
||||||
return team_id
|
return team_id
|
||||||
|
|
||||||
|
def is_upsert_user_id(self) -> bool:
|
||||||
|
"""
|
||||||
|
Returns:
|
||||||
|
- True: if 'user_id_upsert' is set
|
||||||
|
- False: if not
|
||||||
|
"""
|
||||||
|
return self.litellm_jwtauth.user_id_upsert
|
||||||
|
|
||||||
def get_user_id(self, token: dict, default_value: Optional[str]) -> Optional[str]:
|
def get_user_id(self, token: dict, default_value: Optional[str]) -> Optional[str]:
|
||||||
try:
|
try:
|
||||||
if self.litellm_jwtauth.user_id_jwt_field is not None:
|
if self.litellm_jwtauth.user_id_jwt_field is not None:
|
||||||
|
@ -165,7 +185,7 @@ class JWTHandler:
|
||||||
decode_options = None
|
decode_options = None
|
||||||
if audience is None:
|
if audience is None:
|
||||||
decode_options = {"verify_aud": False}
|
decode_options = {"verify_aud": False}
|
||||||
|
|
||||||
from jwt.algorithms import RSAAlgorithm
|
from jwt.algorithms import RSAAlgorithm
|
||||||
|
|
||||||
header = jwt.get_unverified_header(token)
|
header = jwt.get_unverified_header(token)
|
||||||
|
@ -207,12 +227,14 @@ class JWTHandler:
|
||||||
raise Exception(f"Validation fails: {str(e)}")
|
raise Exception(f"Validation fails: {str(e)}")
|
||||||
elif public_key is not None and isinstance(public_key, str):
|
elif public_key is not None and isinstance(public_key, str):
|
||||||
try:
|
try:
|
||||||
cert = x509.load_pem_x509_certificate(public_key.encode(), default_backend())
|
cert = x509.load_pem_x509_certificate(
|
||||||
|
public_key.encode(), default_backend()
|
||||||
|
)
|
||||||
|
|
||||||
# Extract public key
|
# Extract public key
|
||||||
key = cert.public_key().public_bytes(
|
key = cert.public_key().public_bytes(
|
||||||
serialization.Encoding.PEM,
|
serialization.Encoding.PEM,
|
||||||
serialization.PublicFormat.SubjectPublicKeyInfo
|
serialization.PublicFormat.SubjectPublicKeyInfo,
|
||||||
)
|
)
|
||||||
|
|
||||||
# decode the token using the public key
|
# decode the token using the public key
|
||||||
|
@ -221,7 +243,7 @@ class JWTHandler:
|
||||||
key,
|
key,
|
||||||
algorithms=algorithms,
|
algorithms=algorithms,
|
||||||
audience=audience,
|
audience=audience,
|
||||||
options=decode_options
|
options=decode_options,
|
||||||
)
|
)
|
||||||
return payload
|
return payload
|
||||||
|
|
||||||
|
|
|
@ -441,29 +441,32 @@ async def user_api_key_auth(
|
||||||
# get team id
|
# get team id
|
||||||
team_id = jwt_handler.get_team_id(token=valid_token, default_value=None)
|
team_id = jwt_handler.get_team_id(token=valid_token, default_value=None)
|
||||||
|
|
||||||
if team_id is None:
|
if team_id is None and jwt_handler.is_required_team_id() == True:
|
||||||
raise Exception(
|
raise Exception(
|
||||||
f"No team id passed in. Field checked in jwt token - '{jwt_handler.litellm_jwtauth.team_id_jwt_field}'"
|
f"No team id passed in. Field checked in jwt token - '{jwt_handler.litellm_jwtauth.team_id_jwt_field}'"
|
||||||
)
|
)
|
||||||
# check allowed team routes
|
|
||||||
is_allowed = allowed_routes_check(
|
|
||||||
user_role="team",
|
|
||||||
user_route=route,
|
|
||||||
litellm_proxy_roles=jwt_handler.litellm_jwtauth,
|
|
||||||
)
|
|
||||||
if is_allowed == False:
|
|
||||||
allowed_routes = jwt_handler.litellm_jwtauth.team_allowed_routes # type: ignore
|
|
||||||
actual_routes = get_actual_routes(allowed_routes=allowed_routes)
|
|
||||||
raise Exception(
|
|
||||||
f"Team not allowed to access this route. Route={route}, Allowed Routes={actual_routes}"
|
|
||||||
)
|
|
||||||
|
|
||||||
# check if team in db
|
team_object: Optional[LiteLLM_TeamTable] = None
|
||||||
team_object = await get_team_object(
|
if team_id is not None:
|
||||||
team_id=team_id,
|
# check allowed team routes
|
||||||
prisma_client=prisma_client,
|
is_allowed = allowed_routes_check(
|
||||||
user_api_key_cache=user_api_key_cache,
|
user_role="team",
|
||||||
)
|
user_route=route,
|
||||||
|
litellm_proxy_roles=jwt_handler.litellm_jwtauth,
|
||||||
|
)
|
||||||
|
if is_allowed == False:
|
||||||
|
allowed_routes = jwt_handler.litellm_jwtauth.team_allowed_routes # type: ignore
|
||||||
|
actual_routes = get_actual_routes(allowed_routes=allowed_routes)
|
||||||
|
raise Exception(
|
||||||
|
f"Team not allowed to access this route. Route={route}, Allowed Routes={actual_routes}"
|
||||||
|
)
|
||||||
|
|
||||||
|
# check if team in db
|
||||||
|
team_object = await get_team_object(
|
||||||
|
team_id=team_id,
|
||||||
|
prisma_client=prisma_client,
|
||||||
|
user_api_key_cache=user_api_key_cache,
|
||||||
|
)
|
||||||
|
|
||||||
# [OPTIONAL] track spend for an org id - `LiteLLM_OrganizationTable`
|
# [OPTIONAL] track spend for an org id - `LiteLLM_OrganizationTable`
|
||||||
org_id = jwt_handler.get_org_id(token=valid_token, default_value=None)
|
org_id = jwt_handler.get_org_id(token=valid_token, default_value=None)
|
||||||
|
@ -482,11 +485,9 @@ async def user_api_key_auth(
|
||||||
user_id=user_id,
|
user_id=user_id,
|
||||||
prisma_client=prisma_client,
|
prisma_client=prisma_client,
|
||||||
user_api_key_cache=user_api_key_cache,
|
user_api_key_cache=user_api_key_cache,
|
||||||
|
user_id_upsert=jwt_handler.is_upsert_user_id(),
|
||||||
)
|
)
|
||||||
# save the user object to cache
|
|
||||||
await user_api_key_cache.async_set_cache(
|
|
||||||
key=user_id, value=user_object
|
|
||||||
)
|
|
||||||
# [OPTIONAL] track spend against an external user - `LiteLLM_EndUserTable`
|
# [OPTIONAL] track spend against an external user - `LiteLLM_EndUserTable`
|
||||||
end_user_object = None
|
end_user_object = None
|
||||||
end_user_id = jwt_handler.get_end_user_id(
|
end_user_id = jwt_handler.get_end_user_id(
|
||||||
|
@ -548,18 +549,18 @@ async def user_api_key_auth(
|
||||||
global_proxy_spend=global_proxy_spend,
|
global_proxy_spend=global_proxy_spend,
|
||||||
route=route,
|
route=route,
|
||||||
)
|
)
|
||||||
# save team object in cache
|
|
||||||
await user_api_key_cache.async_set_cache(
|
|
||||||
key=team_object.team_id, value=team_object
|
|
||||||
)
|
|
||||||
|
|
||||||
# return UserAPIKeyAuth object
|
# return UserAPIKeyAuth object
|
||||||
return UserAPIKeyAuth(
|
return UserAPIKeyAuth(
|
||||||
api_key=None,
|
api_key=None,
|
||||||
team_id=team_object.team_id,
|
team_id=team_object.team_id if team_object is not None else None,
|
||||||
team_tpm_limit=team_object.tpm_limit,
|
team_tpm_limit=(
|
||||||
team_rpm_limit=team_object.rpm_limit,
|
team_object.tpm_limit if team_object is not None else None
|
||||||
team_models=team_object.models,
|
),
|
||||||
|
team_rpm_limit=(
|
||||||
|
team_object.rpm_limit if team_object is not None else None
|
||||||
|
),
|
||||||
|
team_models=team_object.models if team_object is not None else [],
|
||||||
user_role="app_owner",
|
user_role="app_owner",
|
||||||
user_id=user_id,
|
user_id=user_id,
|
||||||
org_id=org_id,
|
org_id=org_id,
|
||||||
|
|
|
@ -1,7 +1,7 @@
|
||||||
#### What this tests ####
|
#### What this tests ####
|
||||||
# Unit tests for JWT-Auth
|
# Unit tests for JWT-Auth
|
||||||
|
|
||||||
import sys, os, asyncio, time, random
|
import sys, os, asyncio, time, random, uuid
|
||||||
import traceback
|
import traceback
|
||||||
from dotenv import load_dotenv
|
from dotenv import load_dotenv
|
||||||
|
|
||||||
|
@ -24,6 +24,7 @@ public_key = {
|
||||||
"alg": "RS256",
|
"alg": "RS256",
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
def test_load_config_with_custom_role_names():
|
def test_load_config_with_custom_role_names():
|
||||||
config = {
|
config = {
|
||||||
"general_settings": {
|
"general_settings": {
|
||||||
|
@ -77,7 +78,8 @@ async def test_token_single_public_key():
|
||||||
== "qIgOQfEVrrErJC0E7gsHXi6rs_V0nyFY5qPFui2-tv0o4CwpwDzgfBtLO7o_wLiguq0lnu54sMT2eLNoRiiPuLvv6bg7Iy1H9yc5_4Jf5oYEOrqN5o9ZBOoYp1q68Pv0oNJYyZdGu5ZJfd7V4y953vB2XfEKgXCsAkhVhlvIUMiDNKWoMDWsyb2xela5tRURZ2mJAXcHfSC_sYdZxIA2YYrIHfoevq_vTlaz0qVSe_uOKjEpgOAS08UUrgda4CQL11nzICiIQzc6qmjIQt2cjzB2D_9zb4BYndzEtfl0kwAT0z_I85S3mkwTqHU-1BvKe_4MG4VG3dAAeffLPXJyXQ"
|
== "qIgOQfEVrrErJC0E7gsHXi6rs_V0nyFY5qPFui2-tv0o4CwpwDzgfBtLO7o_wLiguq0lnu54sMT2eLNoRiiPuLvv6bg7Iy1H9yc5_4Jf5oYEOrqN5o9ZBOoYp1q68Pv0oNJYyZdGu5ZJfd7V4y953vB2XfEKgXCsAkhVhlvIUMiDNKWoMDWsyb2xela5tRURZ2mJAXcHfSC_sYdZxIA2YYrIHfoevq_vTlaz0qVSe_uOKjEpgOAS08UUrgda4CQL11nzICiIQzc6qmjIQt2cjzB2D_9zb4BYndzEtfl0kwAT0z_I85S3mkwTqHU-1BvKe_4MG4VG3dAAeffLPXJyXQ"
|
||||||
)
|
)
|
||||||
|
|
||||||
@pytest.mark.parametrize('audience', [None, "litellm-proxy"])
|
|
||||||
|
@pytest.mark.parametrize("audience", [None, "litellm-proxy"])
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_valid_invalid_token(audience):
|
async def test_valid_invalid_token(audience):
|
||||||
"""
|
"""
|
||||||
|
@ -90,7 +92,7 @@ async def test_valid_invalid_token(audience):
|
||||||
from cryptography.hazmat.primitives.asymmetric import rsa
|
from cryptography.hazmat.primitives.asymmetric import rsa
|
||||||
from cryptography.hazmat.backends import default_backend
|
from cryptography.hazmat.backends import default_backend
|
||||||
|
|
||||||
os.environ.pop('JWT_AUDIENCE', None)
|
os.environ.pop("JWT_AUDIENCE", None)
|
||||||
if audience:
|
if audience:
|
||||||
os.environ["JWT_AUDIENCE"] = audience
|
os.environ["JWT_AUDIENCE"] = audience
|
||||||
|
|
||||||
|
@ -138,7 +140,7 @@ async def test_valid_invalid_token(audience):
|
||||||
"sub": "user123",
|
"sub": "user123",
|
||||||
"exp": expiration_time, # set the token to expire in 10 minutes
|
"exp": expiration_time, # set the token to expire in 10 minutes
|
||||||
"scope": "litellm-proxy-admin",
|
"scope": "litellm-proxy-admin",
|
||||||
"aud": audience
|
"aud": audience,
|
||||||
}
|
}
|
||||||
|
|
||||||
# Generate the JWT token
|
# Generate the JWT token
|
||||||
|
@ -166,7 +168,7 @@ async def test_valid_invalid_token(audience):
|
||||||
"sub": "user123",
|
"sub": "user123",
|
||||||
"exp": expiration_time, # set the token to expire in 10 minutes
|
"exp": expiration_time, # set the token to expire in 10 minutes
|
||||||
"scope": "litellm-NO-SCOPE",
|
"scope": "litellm-NO-SCOPE",
|
||||||
"aud": audience
|
"aud": audience,
|
||||||
}
|
}
|
||||||
|
|
||||||
# Generate the JWT token
|
# Generate the JWT token
|
||||||
|
@ -183,6 +185,7 @@ async def test_valid_invalid_token(audience):
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
pytest.fail(f"An exception occurred - {str(e)}")
|
pytest.fail(f"An exception occurred - {str(e)}")
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def prisma_client():
|
def prisma_client():
|
||||||
import litellm
|
import litellm
|
||||||
|
@ -205,7 +208,7 @@ def prisma_client():
|
||||||
return prisma_client
|
return prisma_client
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize('audience', [None, "litellm-proxy"])
|
@pytest.mark.parametrize("audience", [None, "litellm-proxy"])
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_team_token_output(prisma_client, audience):
|
async def test_team_token_output(prisma_client, audience):
|
||||||
import jwt, json
|
import jwt, json
|
||||||
|
@ -222,7 +225,7 @@ async def test_team_token_output(prisma_client, audience):
|
||||||
setattr(litellm.proxy.proxy_server, "prisma_client", prisma_client)
|
setattr(litellm.proxy.proxy_server, "prisma_client", prisma_client)
|
||||||
await litellm.proxy.proxy_server.prisma_client.connect()
|
await litellm.proxy.proxy_server.prisma_client.connect()
|
||||||
|
|
||||||
os.environ.pop('JWT_AUDIENCE', None)
|
os.environ.pop("JWT_AUDIENCE", None)
|
||||||
if audience:
|
if audience:
|
||||||
os.environ["JWT_AUDIENCE"] = audience
|
os.environ["JWT_AUDIENCE"] = audience
|
||||||
|
|
||||||
|
@ -261,7 +264,7 @@ async def test_team_token_output(prisma_client, audience):
|
||||||
|
|
||||||
jwt_handler.user_api_key_cache = cache
|
jwt_handler.user_api_key_cache = cache
|
||||||
|
|
||||||
jwt_handler.litellm_jwtauth = LiteLLM_JWTAuth()
|
jwt_handler.litellm_jwtauth = LiteLLM_JWTAuth(team_id_jwt_field="client_id")
|
||||||
|
|
||||||
# VALID TOKEN
|
# VALID TOKEN
|
||||||
## GENERATE A TOKEN
|
## GENERATE A TOKEN
|
||||||
|
@ -274,7 +277,7 @@ async def test_team_token_output(prisma_client, audience):
|
||||||
"exp": expiration_time, # set the token to expire in 10 minutes
|
"exp": expiration_time, # set the token to expire in 10 minutes
|
||||||
"scope": "litellm_team",
|
"scope": "litellm_team",
|
||||||
"client_id": team_id,
|
"client_id": team_id,
|
||||||
"aud": audience
|
"aud": audience,
|
||||||
}
|
}
|
||||||
|
|
||||||
# Generate the JWT token
|
# Generate the JWT token
|
||||||
|
@ -289,7 +292,7 @@ async def test_team_token_output(prisma_client, audience):
|
||||||
"sub": "user123",
|
"sub": "user123",
|
||||||
"exp": expiration_time, # set the token to expire in 10 minutes
|
"exp": expiration_time, # set the token to expire in 10 minutes
|
||||||
"scope": "litellm_proxy_admin",
|
"scope": "litellm_proxy_admin",
|
||||||
"aud": audience
|
"aud": audience,
|
||||||
}
|
}
|
||||||
|
|
||||||
admin_token = jwt.encode(payload, private_key_str, algorithm="RS256")
|
admin_token = jwt.encode(payload, private_key_str, algorithm="RS256")
|
||||||
|
@ -315,7 +318,13 @@ async def test_team_token_output(prisma_client, audience):
|
||||||
|
|
||||||
## 1. INITIAL TEAM CALL - should fail
|
## 1. INITIAL TEAM CALL - should fail
|
||||||
# use generated key to auth in
|
# use generated key to auth in
|
||||||
setattr(litellm.proxy.proxy_server, "general_settings", {"enable_jwt_auth": True})
|
setattr(
|
||||||
|
litellm.proxy.proxy_server,
|
||||||
|
"general_settings",
|
||||||
|
{
|
||||||
|
"enable_jwt_auth": True,
|
||||||
|
},
|
||||||
|
)
|
||||||
setattr(litellm.proxy.proxy_server, "jwt_handler", jwt_handler)
|
setattr(litellm.proxy.proxy_server, "jwt_handler", jwt_handler)
|
||||||
try:
|
try:
|
||||||
result = await user_api_key_auth(request=request, api_key=bearer_token)
|
result = await user_api_key_auth(request=request, api_key=bearer_token)
|
||||||
|
@ -358,9 +367,22 @@ async def test_team_token_output(prisma_client, audience):
|
||||||
assert team_result.team_models == ["gpt-3.5-turbo", "gpt-4"]
|
assert team_result.team_models == ["gpt-3.5-turbo", "gpt-4"]
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize('audience', [None, "litellm-proxy"])
|
@pytest.mark.parametrize("audience", [None, "litellm-proxy"])
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"team_id_set, default_team_id",
|
||||||
|
[(True, False), (False, True)],
|
||||||
|
)
|
||||||
|
@pytest.mark.parametrize("user_id_upsert", [True, False])
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_user_token_output(prisma_client, audience):
|
async def test_user_token_output(
|
||||||
|
prisma_client, audience, team_id_set, default_team_id, user_id_upsert
|
||||||
|
):
|
||||||
|
import uuid
|
||||||
|
|
||||||
|
args = locals()
|
||||||
|
print(f"received args - {args}")
|
||||||
|
if default_team_id:
|
||||||
|
default_team_id = "team_id_12344_{}".format(uuid.uuid4())
|
||||||
"""
|
"""
|
||||||
- If user required, check if it exists
|
- If user required, check if it exists
|
||||||
- fail initial request (when user doesn't exist)
|
- fail initial request (when user doesn't exist)
|
||||||
|
@ -373,7 +395,12 @@ async def test_user_token_output(prisma_client, audience):
|
||||||
from cryptography.hazmat.backends import default_backend
|
from cryptography.hazmat.backends import default_backend
|
||||||
from fastapi import Request
|
from fastapi import Request
|
||||||
from starlette.datastructures import URL
|
from starlette.datastructures import URL
|
||||||
from litellm.proxy.proxy_server import user_api_key_auth, new_team, new_user
|
from litellm.proxy.proxy_server import (
|
||||||
|
user_api_key_auth,
|
||||||
|
new_team,
|
||||||
|
new_user,
|
||||||
|
user_info,
|
||||||
|
)
|
||||||
from litellm.proxy._types import NewTeamRequest, UserAPIKeyAuth, NewUserRequest
|
from litellm.proxy._types import NewTeamRequest, UserAPIKeyAuth, NewUserRequest
|
||||||
import litellm
|
import litellm
|
||||||
import uuid
|
import uuid
|
||||||
|
@ -381,7 +408,7 @@ async def test_user_token_output(prisma_client, audience):
|
||||||
setattr(litellm.proxy.proxy_server, "prisma_client", prisma_client)
|
setattr(litellm.proxy.proxy_server, "prisma_client", prisma_client)
|
||||||
await litellm.proxy.proxy_server.prisma_client.connect()
|
await litellm.proxy.proxy_server.prisma_client.connect()
|
||||||
|
|
||||||
os.environ.pop('JWT_AUDIENCE', None)
|
os.environ.pop("JWT_AUDIENCE", None)
|
||||||
if audience:
|
if audience:
|
||||||
os.environ["JWT_AUDIENCE"] = audience
|
os.environ["JWT_AUDIENCE"] = audience
|
||||||
|
|
||||||
|
@ -423,6 +450,11 @@ async def test_user_token_output(prisma_client, audience):
|
||||||
jwt_handler.litellm_jwtauth = LiteLLM_JWTAuth()
|
jwt_handler.litellm_jwtauth = LiteLLM_JWTAuth()
|
||||||
|
|
||||||
jwt_handler.litellm_jwtauth.user_id_jwt_field = "sub"
|
jwt_handler.litellm_jwtauth.user_id_jwt_field = "sub"
|
||||||
|
jwt_handler.litellm_jwtauth.team_id_default = default_team_id
|
||||||
|
jwt_handler.litellm_jwtauth.user_id_upsert = user_id_upsert
|
||||||
|
|
||||||
|
if team_id_set:
|
||||||
|
jwt_handler.litellm_jwtauth.team_id_jwt_field = "client_id"
|
||||||
|
|
||||||
# VALID TOKEN
|
# VALID TOKEN
|
||||||
## GENERATE A TOKEN
|
## GENERATE A TOKEN
|
||||||
|
@ -436,7 +468,7 @@ async def test_user_token_output(prisma_client, audience):
|
||||||
"exp": expiration_time, # set the token to expire in 10 minutes
|
"exp": expiration_time, # set the token to expire in 10 minutes
|
||||||
"scope": "litellm_team",
|
"scope": "litellm_team",
|
||||||
"client_id": team_id,
|
"client_id": team_id,
|
||||||
"aud": audience
|
"aud": audience,
|
||||||
}
|
}
|
||||||
|
|
||||||
# Generate the JWT token
|
# Generate the JWT token
|
||||||
|
@ -451,7 +483,7 @@ async def test_user_token_output(prisma_client, audience):
|
||||||
"sub": user_id,
|
"sub": user_id,
|
||||||
"exp": expiration_time, # set the token to expire in 10 minutes
|
"exp": expiration_time, # set the token to expire in 10 minutes
|
||||||
"scope": "litellm_proxy_admin",
|
"scope": "litellm_proxy_admin",
|
||||||
"aud": audience
|
"aud": audience,
|
||||||
}
|
}
|
||||||
|
|
||||||
admin_token = jwt.encode(payload, private_key_str, algorithm="RS256")
|
admin_token = jwt.encode(payload, private_key_str, algorithm="RS256")
|
||||||
|
@ -503,6 +535,16 @@ async def test_user_token_output(prisma_client, audience):
|
||||||
),
|
),
|
||||||
user_api_key_dict=result,
|
user_api_key_dict=result,
|
||||||
)
|
)
|
||||||
|
if default_team_id:
|
||||||
|
await new_team(
|
||||||
|
data=NewTeamRequest(
|
||||||
|
team_id=default_team_id,
|
||||||
|
tpm_limit=100,
|
||||||
|
rpm_limit=99,
|
||||||
|
models=["gpt-3.5-turbo", "gpt-4"],
|
||||||
|
),
|
||||||
|
user_api_key_dict=result,
|
||||||
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
pytest.fail(f"This should not fail - {str(e)}")
|
pytest.fail(f"This should not fail - {str(e)}")
|
||||||
|
|
||||||
|
@ -513,23 +555,35 @@ async def test_user_token_output(prisma_client, audience):
|
||||||
team_result: UserAPIKeyAuth = await user_api_key_auth(
|
team_result: UserAPIKeyAuth = await user_api_key_auth(
|
||||||
request=request, api_key=bearer_token
|
request=request, api_key=bearer_token
|
||||||
)
|
)
|
||||||
pytest.fail(f"User doesn't exist. this should fail")
|
if user_id_upsert == False:
|
||||||
|
pytest.fail(f"User doesn't exist. this should fail")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
## 4. Create user
|
## 4. Create user
|
||||||
try:
|
if user_id_upsert:
|
||||||
bearer_token = "Bearer " + admin_token
|
## check if user already exists
|
||||||
|
try:
|
||||||
|
bearer_token = "Bearer " + admin_token
|
||||||
|
|
||||||
request._url = URL(url="/team/new")
|
request._url = URL(url="/team/new")
|
||||||
result = await user_api_key_auth(request=request, api_key=bearer_token)
|
result = await user_api_key_auth(request=request, api_key=bearer_token)
|
||||||
await new_user(
|
await user_info(user_id=user_id)
|
||||||
data=NewUserRequest(
|
except Exception as e:
|
||||||
user_id=user_id,
|
pytest.fail(f"This should not fail - {str(e)}")
|
||||||
),
|
else:
|
||||||
)
|
try:
|
||||||
except Exception as e:
|
bearer_token = "Bearer " + admin_token
|
||||||
pytest.fail(f"This should not fail - {str(e)}")
|
|
||||||
|
request._url = URL(url="/team/new")
|
||||||
|
result = await user_api_key_auth(request=request, api_key=bearer_token)
|
||||||
|
await new_user(
|
||||||
|
data=NewUserRequest(
|
||||||
|
user_id=user_id,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
pytest.fail(f"This should not fail - {str(e)}")
|
||||||
|
|
||||||
## 5. 3rd call w/ same team, same user -> call should succeed
|
## 5. 3rd call w/ same team, same user -> call should succeed
|
||||||
bearer_token = "Bearer " + token
|
bearer_token = "Bearer " + token
|
||||||
|
@ -543,7 +597,8 @@ async def test_user_token_output(prisma_client, audience):
|
||||||
|
|
||||||
## 6. ASSERT USER_API_KEY_AUTH format (used for tpm/rpm limiting in parallel_request_limiter.py AND cost tracking)
|
## 6. ASSERT USER_API_KEY_AUTH format (used for tpm/rpm limiting in parallel_request_limiter.py AND cost tracking)
|
||||||
|
|
||||||
assert team_result.team_tpm_limit == 100
|
if team_id_set or default_team_id is not None:
|
||||||
assert team_result.team_rpm_limit == 99
|
assert team_result.team_tpm_limit == 100
|
||||||
assert team_result.team_models == ["gpt-3.5-turbo", "gpt-4"]
|
assert team_result.team_rpm_limit == 99
|
||||||
|
assert team_result.team_models == ["gpt-3.5-turbo", "gpt-4"]
|
||||||
assert team_result.user_id == user_id
|
assert team_result.user_id == user_id
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue