Merge pull request #2970 from BerriAI/litellm_keys

fix(handle_jwt.py): User cost tracking via JWT Auth
This commit is contained in:
Krish Dholakia 2024-04-11 21:44:15 -07:00 committed by GitHub
commit d89644d46c
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
7 changed files with 263 additions and 35 deletions

View file

@ -40,9 +40,10 @@ general_settings:
allow_user_auth: true allow_user_auth: true
alerting: ["slack"] alerting: ["slack"]
# store_model_in_db: True // set via environment variable - os.environ["STORE_MODEL_IN_DB"] = "True" # store_model_in_db: True // set via environment variable - os.environ["STORE_MODEL_IN_DB"] = "True"
# proxy_batch_write_at: 60 # 👈 Frequency of batch writing logs to server (in seconds) proxy_batch_write_at: 5 # 👈 Frequency of batch writing logs to server (in seconds)
enable_jwt_auth: True enable_jwt_auth: True
alerting: ["slack"] alerting: ["slack"]
litellm_jwtauth: litellm_jwtauth:
admin_jwt_scope: "litellm_proxy_admin" admin_jwt_scope: "litellm_proxy_admin"
public_key_ttl: os.environ/LITELLM_PUBLIC_KEY_TTL public_key_ttl: os.environ/LITELLM_PUBLIC_KEY_TTL
user_id_jwt_field: "sub"

View file

@ -124,7 +124,8 @@ class LiteLLM_JWTAuth(LiteLLMBase):
- team_jwt_scope: The JWT scope required for proxy team roles. - team_jwt_scope: The JWT scope required for proxy team roles.
- team_id_jwt_field: The field in the JWT token that stores the team ID. Default - `client_id`. - team_id_jwt_field: The field in the JWT token that stores the team ID. Default - `client_id`.
- team_allowed_routes: list of allowed routes for proxy team roles. - team_allowed_routes: list of allowed routes for proxy team roles.
- end_user_id_jwt_field: Default - `sub`. The field in the JWT token that stores the end-user ID. Turn this off by setting to `None`. Enables end-user cost tracking. - user_id_jwt_field: The field in the JWT token that stores the user id (maps to `LiteLLMUserTable`). Use this for internal employees.
- end_user_id_jwt_field: The field in the JWT token that stores the end-user ID (maps to `LiteLLMEndUserTable`). Turn this off by setting to `None`. Enables end-user cost tracking. Use this for external customers.
- public_key_ttl: Default - 600s. TTL for caching public JWT keys. - public_key_ttl: Default - 600s. TTL for caching public JWT keys.
See `auth_checks.py` for the specific routes See `auth_checks.py` for the specific routes
@ -139,7 +140,8 @@ class LiteLLM_JWTAuth(LiteLLMBase):
team_allowed_routes: List[ team_allowed_routes: List[
Literal["openai_routes", "info_routes", "management_routes"] Literal["openai_routes", "info_routes", "management_routes"]
] = ["openai_routes", "info_routes"] ] = ["openai_routes", "info_routes"]
end_user_id_jwt_field: Optional[str] = "sub" user_id_jwt_field: Optional[str] = None
end_user_id_jwt_field: Optional[str] = None
public_key_ttl: float = 600 public_key_ttl: float = 600
def __init__(self, **kwargs: Any) -> None: def __init__(self, **kwargs: Any) -> None:

View file

@ -26,6 +26,7 @@ all_routes = LiteLLMRoutes.openai_routes.value + LiteLLMRoutes.management_routes
def common_checks( def common_checks(
request_body: dict, request_body: dict,
team_object: LiteLLM_TeamTable, team_object: LiteLLM_TeamTable,
user_object: Optional[LiteLLM_UserTable],
end_user_object: Optional[LiteLLM_EndUserTable], end_user_object: Optional[LiteLLM_EndUserTable],
global_proxy_spend: Optional[float], global_proxy_spend: Optional[float],
general_settings: dict, general_settings: dict,
@ -37,7 +38,8 @@ def common_checks(
1. If team is blocked 1. If team is blocked
2. If team can call model 2. If team can call model
3. If team is in budget 3. If team is in budget
4. If end_user ('user' passed to /chat/completions, /embeddings endpoint) is in budget 5. If user passed in (JWT or key.user_id) - is in budget
4. If end_user (either via JWT or 'user' passed to /chat/completions, /embeddings endpoint) is in budget
5. [OPTIONAL] If 'enforce_end_user' enabled - did developer pass in 'user' param for openai endpoints 5. [OPTIONAL] If 'enforce_end_user' enabled - did developer pass in 'user' param for openai endpoints
6. [OPTIONAL] If 'litellm.max_budget' is set (>0), is proxy under budget 6. [OPTIONAL] If 'litellm.max_budget' is set (>0), is proxy under budget
""" """
@ -69,14 +71,20 @@ def common_checks(
raise Exception( raise Exception(
f"Team={team_object.team_id} over budget. Spend={team_object.spend}, Budget={team_object.max_budget}" f"Team={team_object.team_id} over budget. Spend={team_object.spend}, Budget={team_object.max_budget}"
) )
# 4. If end_user ('user' passed to /chat/completions, /embeddings endpoint) is in budget if user_object is not None and user_object.max_budget is not None:
user_budget = user_object.max_budget
if user_budget > user_object.spend:
raise Exception(
f"ExceededBudget: User={user_object.user_id} over budget. Spend={user_object.spend}, Budget={user_budget}"
)
# 5. If end_user ('user' passed to /chat/completions, /embeddings endpoint) is in budget
if end_user_object is not None and end_user_object.litellm_budget_table is not None: if end_user_object is not None and end_user_object.litellm_budget_table is not None:
end_user_budget = end_user_object.litellm_budget_table.max_budget end_user_budget = end_user_object.litellm_budget_table.max_budget
if end_user_budget is not None and end_user_object.spend > end_user_budget: if end_user_budget is not None and end_user_object.spend > end_user_budget:
raise Exception( raise Exception(
f"ExceededBudget: End User={end_user_object.user_id} over budget. Spend={end_user_object.spend}, Budget={end_user_budget}" f"ExceededBudget: End User={end_user_object.user_id} over budget. Spend={end_user_object.spend}, Budget={end_user_budget}"
) )
# 5. [OPTIONAL] If 'enforce_user_param' enabled - did developer pass in 'user' param for openai endpoints # 6. [OPTIONAL] If 'enforce_user_param' enabled - did developer pass in 'user' param for openai endpoints
if ( if (
general_settings.get("enforce_user_param", None) is not None general_settings.get("enforce_user_param", None) is not None
and general_settings["enforce_user_param"] == True and general_settings["enforce_user_param"] == True
@ -85,7 +93,7 @@ def common_checks(
raise Exception( raise Exception(
f"'user' param not passed in. 'enforce_user_param'={general_settings['enforce_user_param']}" f"'user' param not passed in. 'enforce_user_param'={general_settings['enforce_user_param']}"
) )
# 6. [OPTIONAL] If 'litellm.max_budget' is set (>0), is proxy under budget # 7. [OPTIONAL] If 'litellm.max_budget' is set (>0), is proxy under budget
if litellm.max_budget > 0 and global_proxy_spend is not None: if litellm.max_budget > 0 and global_proxy_spend is not None:
if global_proxy_spend > litellm.max_budget: if global_proxy_spend > litellm.max_budget:
raise Exception( raise Exception(
@ -204,19 +212,24 @@ async def get_end_user_object(
return None return None
async def get_user_object(self, user_id: str) -> LiteLLM_UserTable: async def get_user_object(
user_id: str,
prisma_client: Optional[PrismaClient],
user_api_key_cache: DualCache,
) -> Optional[LiteLLM_UserTable]:
""" """
- Check if user id in proxy User Table - Check if user id in proxy User Table
- if valid, return LiteLLM_UserTable object with defined limits - if valid, return LiteLLM_UserTable object with defined limits
- if not, then raise an error - if not, then raise an error
""" """
if self.prisma_client is None: if prisma_client is None:
raise Exception( raise Exception("No db connected")
"No DB Connected. See - https://docs.litellm.ai/docs/proxy/virtual_keys"
) if user_id is None:
return None
# check if in cache # check if in cache
cached_user_obj = self.user_api_key_cache.async_get_cache(key=user_id) cached_user_obj = user_api_key_cache.async_get_cache(key=user_id)
if cached_user_obj is not None: if cached_user_obj is not None:
if isinstance(cached_user_obj, dict): if isinstance(cached_user_obj, dict):
return LiteLLM_UserTable(**cached_user_obj) return LiteLLM_UserTable(**cached_user_obj)
@ -224,7 +237,7 @@ async def get_user_object(self, user_id: str) -> LiteLLM_UserTable:
return cached_user_obj return cached_user_obj
# else, check db # else, check db
try: try:
response = await self.prisma_client.db.litellm_usertable.find_unique( response = await prisma_client.db.litellm_usertable.find_unique(
where={"user_id": user_id} where={"user_id": user_id}
) )
@ -232,9 +245,9 @@ async def get_user_object(self, user_id: str) -> LiteLLM_UserTable:
raise Exception raise Exception
return LiteLLM_UserTable(**response.dict()) return LiteLLM_UserTable(**response.dict())
except Exception as e: except Exception as e: # if end-user not in db
raise Exception( raise Exception(
f"User doesn't exist in db. User={user_id}. Create user via `/user/new` call." f"User doesn't exist in db. 'user_id'={user_id}. Create user via `/user/new` call."
) )

View file

@ -74,6 +74,16 @@ class JWTHandler:
team_id = default_value team_id = default_value
return team_id return team_id
def get_user_id(self, token: dict, default_value: Optional[str]) -> Optional[str]:
try:
if self.litellm_jwtauth.user_id_jwt_field is not None:
user_id = token[self.litellm_jwtauth.user_id_jwt_field]
else:
user_id = None
except KeyError:
user_id = default_value
return user_id
def get_scopes(self, token: dict) -> list: def get_scopes(self, token: dict) -> list:
try: try:
if isinstance(token["scope"], str): if isinstance(token["scope"], str):
@ -101,7 +111,11 @@ class JWTHandler:
if cached_keys is None: if cached_keys is None:
response = await self.http_handler.get(keys_url) response = await self.http_handler.get(keys_url)
response_json = response.json()
if "keys" in response_json:
keys = response.json()["keys"] keys = response.json()["keys"]
else:
keys = response_json
await self.user_api_key_cache.async_set_cache( await self.user_api_key_cache.async_set_cache(
key="litellm_jwt_auth_keys", key="litellm_jwt_auth_keys",

View file

@ -422,12 +422,21 @@ async def user_api_key_auth(
user_api_key_cache=user_api_key_cache, user_api_key_cache=user_api_key_cache,
) )
# common checks # [OPTIONAL] track spend against an internal employee - `LiteLLM_UserTable`
# allow request user_object = None
user_id = jwt_handler.get_user_id(token=valid_token, default_value=None)
# get the request body if user_id is not None:
request_data = await _read_request_body(request=request) # get the user object
user_object = await get_user_object(
user_id=user_id,
prisma_client=prisma_client,
user_api_key_cache=user_api_key_cache,
)
# save the user object to cache
await user_api_key_cache.async_set_cache(
key=user_id, value=user_object
)
# [OPTIONAL] track spend against an external user - `LiteLLM_EndUserTable`
end_user_object = None end_user_object = None
end_user_id = jwt_handler.get_end_user_id( end_user_id = jwt_handler.get_end_user_id(
token=valid_token, default_value=None token=valid_token, default_value=None
@ -445,7 +454,6 @@ async def user_api_key_auth(
) )
global_proxy_spend = None global_proxy_spend = None
if litellm.max_budget > 0: # user set proxy max budget if litellm.max_budget > 0: # user set proxy max budget
# check cache # check cache
global_proxy_spend = await user_api_key_cache.async_get_cache( global_proxy_spend = await user_api_key_cache.async_get_cache(
@ -480,16 +488,20 @@ async def user_api_key_auth(
) )
) )
# get the request body
request_data = await _read_request_body(request=request)
# run through common checks # run through common checks
_ = common_checks( _ = common_checks(
request_body=request_data, request_body=request_data,
team_object=team_object, team_object=team_object,
user_object=user_object,
end_user_object=end_user_object, end_user_object=end_user_object,
general_settings=general_settings, general_settings=general_settings,
global_proxy_spend=global_proxy_spend, global_proxy_spend=global_proxy_spend,
route=route, route=route,
) )
# save user object in cache # save team object in cache
await user_api_key_cache.async_set_cache( await user_api_key_cache.async_set_cache(
key=team_object.team_id, value=team_object key=team_object.team_id, value=team_object
) )
@ -502,6 +514,7 @@ async def user_api_key_auth(
team_rpm_limit=team_object.rpm_limit, team_rpm_limit=team_object.rpm_limit,
team_models=team_object.models, team_models=team_object.models,
user_role="app_owner", user_role="app_owner",
user_id=user_id,
) )
#### ELSE #### #### ELSE ####
if master_key is None: if master_key is None:
@ -954,6 +967,7 @@ async def user_api_key_auth(
_ = common_checks( _ = common_checks(
request_body=request_data, request_body=request_data,
team_object=_team_obj, team_object=_team_obj,
user_object=None,
end_user_object=_end_user_object, end_user_object=_end_user_object,
general_settings=general_settings, general_settings=general_settings,
global_proxy_spend=global_proxy_spend, global_proxy_spend=global_proxy_spend,
@ -1328,8 +1342,6 @@ async def update_database(
existing_token_obj = await user_api_key_cache.async_get_cache( existing_token_obj = await user_api_key_cache.async_get_cache(
key=hashed_token key=hashed_token
) )
if existing_token_obj is None:
return
existing_user_obj = await user_api_key_cache.async_get_cache(key=user_id) existing_user_obj = await user_api_key_cache.async_get_cache(key=user_id)
if existing_user_obj is not None and isinstance(existing_user_obj, dict): if existing_user_obj is not None and isinstance(existing_user_obj, dict):
existing_user_obj = LiteLLM_UserTable(**existing_user_obj) existing_user_obj = LiteLLM_UserTable(**existing_user_obj)
@ -1351,7 +1363,9 @@ async def update_database(
if end_user_id is not None: if end_user_id is not None:
prisma_client.end_user_list_transactons[end_user_id] = ( prisma_client.end_user_list_transactons[end_user_id] = (
response_cost response_cost
+ prisma_client.user_list_transactons.get(end_user_id, 0) + prisma_client.end_user_list_transactons.get(
end_user_id, 0
)
) )
elif custom_db_client is not None: elif custom_db_client is not None:
for id in user_ids: for id in user_ids:

View file

@ -345,3 +345,187 @@ async def test_team_token_output(prisma_client):
assert team_result.team_tpm_limit == 100 assert team_result.team_tpm_limit == 100
assert team_result.team_rpm_limit == 99 assert team_result.team_rpm_limit == 99
assert team_result.team_models == ["gpt-3.5-turbo", "gpt-4"] assert team_result.team_models == ["gpt-3.5-turbo", "gpt-4"]
@pytest.mark.asyncio
async def test_user_token_output(prisma_client):
"""
- If user required, check if it exists
- fail initial request (when user doesn't exist)
- create user
- retry -> it should pass now
"""
import jwt, json
from cryptography.hazmat.primitives import serialization
from cryptography.hazmat.primitives.asymmetric import rsa
from cryptography.hazmat.backends import default_backend
from fastapi import Request
from starlette.datastructures import URL
from litellm.proxy.proxy_server import user_api_key_auth, new_team, new_user
from litellm.proxy._types import NewTeamRequest, UserAPIKeyAuth, NewUserRequest
import litellm
import uuid
setattr(litellm.proxy.proxy_server, "prisma_client", prisma_client)
await litellm.proxy.proxy_server.prisma_client.connect()
# Generate a private / public key pair using RSA algorithm
key = rsa.generate_private_key(
public_exponent=65537, key_size=2048, backend=default_backend()
)
# Get private key in PEM format
private_key = key.private_bytes(
encoding=serialization.Encoding.PEM,
format=serialization.PrivateFormat.PKCS8,
encryption_algorithm=serialization.NoEncryption(),
)
# Get public key in PEM format
public_key = key.public_key().public_bytes(
encoding=serialization.Encoding.PEM,
format=serialization.PublicFormat.SubjectPublicKeyInfo,
)
public_key_obj = serialization.load_pem_public_key(
public_key, backend=default_backend()
)
# Convert RSA public key object to JWK (JSON Web Key)
public_jwk = json.loads(jwt.algorithms.RSAAlgorithm.to_jwk(public_key_obj))
assert isinstance(public_jwk, dict)
# set cache
cache = DualCache()
await cache.async_set_cache(key="litellm_jwt_auth_keys", value=[public_jwk])
jwt_handler = JWTHandler()
jwt_handler.user_api_key_cache = cache
jwt_handler.litellm_jwtauth = LiteLLM_JWTAuth()
jwt_handler.litellm_jwtauth.user_id_jwt_field = "sub"
# VALID TOKEN
## GENERATE A TOKEN
# Assuming the current time is in UTC
expiration_time = int((datetime.utcnow() + timedelta(minutes=10)).timestamp())
team_id = f"team123_{uuid.uuid4()}"
user_id = f"user123_{uuid.uuid4()}"
payload = {
"sub": user_id,
"exp": expiration_time, # set the token to expire in 10 minutes
"scope": "litellm_team",
"client_id": team_id,
}
# Generate the JWT token
# But before, you should convert bytes to string
private_key_str = private_key.decode("utf-8")
## team token
token = jwt.encode(payload, private_key_str, algorithm="RS256")
## admin token
payload = {
"sub": user_id,
"exp": expiration_time, # set the token to expire in 10 minutes
"scope": "litellm_proxy_admin",
}
admin_token = jwt.encode(payload, private_key_str, algorithm="RS256")
## VERIFY IT WORKS
# verify token
response = await jwt_handler.auth_jwt(token=token)
## RUN IT THROUGH USER API KEY AUTH
"""
- 1. Initial call should fail -> team doesn't exist
- 2. Create team via admin token
- 3. 2nd call w/ same team -> call should fail -> user doesn't exist
- 4. Create user via admin token
- 5. 3rd call w/ same team, same user -> call should succeed
- 6. assert user api key auth format
"""
bearer_token = "Bearer " + token
request = Request(scope={"type": "http"})
request._url = URL(url="/chat/completions")
## 1. INITIAL TEAM CALL - should fail
# use generated key to auth in
setattr(litellm.proxy.proxy_server, "general_settings", {"enable_jwt_auth": True})
setattr(litellm.proxy.proxy_server, "jwt_handler", jwt_handler)
try:
result = await user_api_key_auth(request=request, api_key=bearer_token)
pytest.fail("Team doesn't exist. This should fail")
except Exception as e:
pass
## 2. CREATE TEAM W/ ADMIN TOKEN - should succeed
try:
bearer_token = "Bearer " + admin_token
request._url = URL(url="/team/new")
result = await user_api_key_auth(request=request, api_key=bearer_token)
await new_team(
data=NewTeamRequest(
team_id=team_id,
tpm_limit=100,
rpm_limit=99,
models=["gpt-3.5-turbo", "gpt-4"],
),
user_api_key_dict=result,
)
except Exception as e:
pytest.fail(f"This should not fail - {str(e)}")
## 3. 2nd CALL W/ TEAM TOKEN - should fail
bearer_token = "Bearer " + token
request._url = URL(url="/chat/completions")
try:
team_result: UserAPIKeyAuth = await user_api_key_auth(
request=request, api_key=bearer_token
)
pytest.fail(f"User doesn't exist. this should fail")
except Exception as e:
pass
## 4. Create user
try:
bearer_token = "Bearer " + admin_token
request._url = URL(url="/team/new")
result = await user_api_key_auth(request=request, api_key=bearer_token)
await new_user(
data=NewUserRequest(
user_id=user_id,
),
)
except Exception as e:
pytest.fail(f"This should not fail - {str(e)}")
## 5. 3rd call w/ same team, same user -> call should succeed
bearer_token = "Bearer " + token
request._url = URL(url="/chat/completions")
try:
team_result: UserAPIKeyAuth = await user_api_key_auth(
request=request, api_key=bearer_token
)
except Exception as e:
pytest.fail(f"Team exists. This should not fail - {e}")
## 6. ASSERT USER_API_KEY_AUTH format (used for tpm/rpm limiting in parallel_request_limiter.py AND cost tracking)
assert team_result.team_tpm_limit == 100
assert team_result.team_rpm_limit == 99
assert team_result.team_models == ["gpt-3.5-turbo", "gpt-4"]
assert team_result.user_id == user_id

View file

@ -64,12 +64,12 @@ litellm_settings:
telemetry: False telemetry: False
context_window_fallbacks: [{"gpt-3.5-turbo": ["gpt-3.5-turbo-large"]}] context_window_fallbacks: [{"gpt-3.5-turbo": ["gpt-3.5-turbo-large"]}]
router_settings: # router_settings:
routing_strategy: usage-based-routing-v2 # routing_strategy: usage-based-routing-v2
redis_host: os.environ/REDIS_HOST # redis_host: os.environ/REDIS_HOST
redis_password: os.environ/REDIS_PASSWORD # redis_password: os.environ/REDIS_PASSWORD
redis_port: os.environ/REDIS_PORT # redis_port: os.environ/REDIS_PORT
enable_pre_call_checks: true # enable_pre_call_checks: true
general_settings: general_settings:
master_key: sk-1234 # [OPTIONAL] Use to enforce auth on proxy. See - https://docs.litellm.ai/docs/proxy/virtual_keys master_key: sk-1234 # [OPTIONAL] Use to enforce auth on proxy. See - https://docs.litellm.ai/docs/proxy/virtual_keys