diff --git a/litellm/proxy/_new_secret_config.yaml b/litellm/proxy/_new_secret_config.yaml index 1bc1d7119..f41a0bdcd 100644 --- a/litellm/proxy/_new_secret_config.yaml +++ b/litellm/proxy/_new_secret_config.yaml @@ -40,9 +40,10 @@ general_settings: allow_user_auth: true alerting: ["slack"] # store_model_in_db: True // set via environment variable - os.environ["STORE_MODEL_IN_DB"] = "True" - # proxy_batch_write_at: 60 # 👈 Frequency of batch writing logs to server (in seconds) + proxy_batch_write_at: 5 # 👈 Frequency of batch writing logs to server (in seconds) enable_jwt_auth: True alerting: ["slack"] litellm_jwtauth: admin_jwt_scope: "litellm_proxy_admin" - public_key_ttl: os.environ/LITELLM_PUBLIC_KEY_TTL \ No newline at end of file + public_key_ttl: os.environ/LITELLM_PUBLIC_KEY_TTL + user_id_jwt_field: "sub" \ No newline at end of file diff --git a/litellm/proxy/_types.py b/litellm/proxy/_types.py index a6e8b9bb7..904f930c3 100644 --- a/litellm/proxy/_types.py +++ b/litellm/proxy/_types.py @@ -124,7 +124,8 @@ class LiteLLM_JWTAuth(LiteLLMBase): - team_jwt_scope: The JWT scope required for proxy team roles. - team_id_jwt_field: The field in the JWT token that stores the team ID. Default - `client_id`. - team_allowed_routes: list of allowed routes for proxy team roles. - - end_user_id_jwt_field: Default - `sub`. The field in the JWT token that stores the end-user ID. Turn this off by setting to `None`. Enables end-user cost tracking. + - user_id_jwt_field: The field in the JWT token that stores the user id (maps to `LiteLLMUserTable`). Use this for internal employees. + - end_user_id_jwt_field: The field in the JWT token that stores the end-user ID (maps to `LiteLLMEndUserTable`). Turn this off by setting to `None`. Enables end-user cost tracking. Use this for external customers. - public_key_ttl: Default - 600s. TTL for caching public JWT keys. See `auth_checks.py` for the specific routes @@ -139,7 +140,8 @@ class LiteLLM_JWTAuth(LiteLLMBase): team_allowed_routes: List[ Literal["openai_routes", "info_routes", "management_routes"] ] = ["openai_routes", "info_routes"] - end_user_id_jwt_field: Optional[str] = "sub" + user_id_jwt_field: Optional[str] = None + end_user_id_jwt_field: Optional[str] = None public_key_ttl: float = 600 def __init__(self, **kwargs: Any) -> None: diff --git a/litellm/proxy/auth/auth_checks.py b/litellm/proxy/auth/auth_checks.py index 8cfa5587b..c56c48365 100644 --- a/litellm/proxy/auth/auth_checks.py +++ b/litellm/proxy/auth/auth_checks.py @@ -26,6 +26,7 @@ all_routes = LiteLLMRoutes.openai_routes.value + LiteLLMRoutes.management_routes def common_checks( request_body: dict, team_object: LiteLLM_TeamTable, + user_object: Optional[LiteLLM_UserTable], end_user_object: Optional[LiteLLM_EndUserTable], global_proxy_spend: Optional[float], general_settings: dict, @@ -37,7 +38,8 @@ def common_checks( 1. If team is blocked 2. If team can call model 3. If team is in budget - 4. If end_user ('user' passed to /chat/completions, /embeddings endpoint) is in budget + 5. If user passed in (JWT or key.user_id) - is in budget + 4. If end_user (either via JWT or 'user' passed to /chat/completions, /embeddings endpoint) is in budget 5. [OPTIONAL] If 'enforce_end_user' enabled - did developer pass in 'user' param for openai endpoints 6. [OPTIONAL] If 'litellm.max_budget' is set (>0), is proxy under budget """ @@ -69,14 +71,20 @@ def common_checks( raise Exception( f"Team={team_object.team_id} over budget. Spend={team_object.spend}, Budget={team_object.max_budget}" ) - # 4. If end_user ('user' passed to /chat/completions, /embeddings endpoint) is in budget + if user_object is not None and user_object.max_budget is not None: + user_budget = user_object.max_budget + if user_budget > user_object.spend: + raise Exception( + f"ExceededBudget: User={user_object.user_id} over budget. Spend={user_object.spend}, Budget={user_budget}" + ) + # 5. If end_user ('user' passed to /chat/completions, /embeddings endpoint) is in budget if end_user_object is not None and end_user_object.litellm_budget_table is not None: end_user_budget = end_user_object.litellm_budget_table.max_budget if end_user_budget is not None and end_user_object.spend > end_user_budget: raise Exception( f"ExceededBudget: End User={end_user_object.user_id} over budget. Spend={end_user_object.spend}, Budget={end_user_budget}" ) - # 5. [OPTIONAL] If 'enforce_user_param' enabled - did developer pass in 'user' param for openai endpoints + # 6. [OPTIONAL] If 'enforce_user_param' enabled - did developer pass in 'user' param for openai endpoints if ( general_settings.get("enforce_user_param", None) is not None and general_settings["enforce_user_param"] == True @@ -85,7 +93,7 @@ def common_checks( raise Exception( f"'user' param not passed in. 'enforce_user_param'={general_settings['enforce_user_param']}" ) - # 6. [OPTIONAL] If 'litellm.max_budget' is set (>0), is proxy under budget + # 7. [OPTIONAL] If 'litellm.max_budget' is set (>0), is proxy under budget if litellm.max_budget > 0 and global_proxy_spend is not None: if global_proxy_spend > litellm.max_budget: raise Exception( @@ -204,19 +212,24 @@ async def get_end_user_object( return None -async def get_user_object(self, user_id: str) -> LiteLLM_UserTable: +async def get_user_object( + user_id: str, + prisma_client: Optional[PrismaClient], + user_api_key_cache: DualCache, +) -> Optional[LiteLLM_UserTable]: """ - Check if user id in proxy User Table - if valid, return LiteLLM_UserTable object with defined limits - if not, then raise an error """ - if self.prisma_client is None: - raise Exception( - "No DB Connected. See - https://docs.litellm.ai/docs/proxy/virtual_keys" - ) + if prisma_client is None: + raise Exception("No db connected") + + if user_id is None: + return None # check if in cache - cached_user_obj = self.user_api_key_cache.async_get_cache(key=user_id) + cached_user_obj = user_api_key_cache.async_get_cache(key=user_id) if cached_user_obj is not None: if isinstance(cached_user_obj, dict): return LiteLLM_UserTable(**cached_user_obj) @@ -224,7 +237,7 @@ async def get_user_object(self, user_id: str) -> LiteLLM_UserTable: return cached_user_obj # else, check db try: - response = await self.prisma_client.db.litellm_usertable.find_unique( + response = await prisma_client.db.litellm_usertable.find_unique( where={"user_id": user_id} ) @@ -232,9 +245,9 @@ async def get_user_object(self, user_id: str) -> LiteLLM_UserTable: raise Exception return LiteLLM_UserTable(**response.dict()) - except Exception as e: + except Exception as e: # if end-user not in db raise Exception( - f"User doesn't exist in db. User={user_id}. Create user via `/user/new` call." + f"User doesn't exist in db. 'user_id'={user_id}. Create user via `/user/new` call." ) diff --git a/litellm/proxy/auth/handle_jwt.py b/litellm/proxy/auth/handle_jwt.py index 9758d52cc..76042ec68 100644 --- a/litellm/proxy/auth/handle_jwt.py +++ b/litellm/proxy/auth/handle_jwt.py @@ -74,6 +74,16 @@ class JWTHandler: team_id = default_value return team_id + def get_user_id(self, token: dict, default_value: Optional[str]) -> Optional[str]: + try: + if self.litellm_jwtauth.user_id_jwt_field is not None: + user_id = token[self.litellm_jwtauth.user_id_jwt_field] + else: + user_id = None + except KeyError: + user_id = default_value + return user_id + def get_scopes(self, token: dict) -> list: try: if isinstance(token["scope"], str): @@ -101,7 +111,11 @@ class JWTHandler: if cached_keys is None: response = await self.http_handler.get(keys_url) - keys = response.json()["keys"] + response_json = response.json() + if "keys" in response_json: + keys = response.json()["keys"] + else: + keys = response_json await self.user_api_key_cache.async_set_cache( key="litellm_jwt_auth_keys", diff --git a/litellm/proxy/proxy_server.py b/litellm/proxy/proxy_server.py index f7e06a63f..17ddb2f05 100644 --- a/litellm/proxy/proxy_server.py +++ b/litellm/proxy/proxy_server.py @@ -422,12 +422,21 @@ async def user_api_key_auth( user_api_key_cache=user_api_key_cache, ) - # common checks - # allow request - - # get the request body - request_data = await _read_request_body(request=request) - + # [OPTIONAL] track spend against an internal employee - `LiteLLM_UserTable` + user_object = None + user_id = jwt_handler.get_user_id(token=valid_token, default_value=None) + if user_id is not None: + # get the user object + user_object = await get_user_object( + user_id=user_id, + prisma_client=prisma_client, + user_api_key_cache=user_api_key_cache, + ) + # save the user object to cache + await user_api_key_cache.async_set_cache( + key=user_id, value=user_object + ) + # [OPTIONAL] track spend against an external user - `LiteLLM_EndUserTable` end_user_object = None end_user_id = jwt_handler.get_end_user_id( token=valid_token, default_value=None @@ -445,7 +454,6 @@ async def user_api_key_auth( ) global_proxy_spend = None - if litellm.max_budget > 0: # user set proxy max budget # check cache global_proxy_spend = await user_api_key_cache.async_get_cache( @@ -480,16 +488,20 @@ async def user_api_key_auth( ) ) + # get the request body + request_data = await _read_request_body(request=request) + # run through common checks _ = common_checks( request_body=request_data, team_object=team_object, + user_object=user_object, end_user_object=end_user_object, general_settings=general_settings, global_proxy_spend=global_proxy_spend, route=route, ) - # save user object in cache + # save team object in cache await user_api_key_cache.async_set_cache( key=team_object.team_id, value=team_object ) @@ -502,6 +514,7 @@ async def user_api_key_auth( team_rpm_limit=team_object.rpm_limit, team_models=team_object.models, user_role="app_owner", + user_id=user_id, ) #### ELSE #### if master_key is None: @@ -954,6 +967,7 @@ async def user_api_key_auth( _ = common_checks( request_body=request_data, team_object=_team_obj, + user_object=None, end_user_object=_end_user_object, general_settings=general_settings, global_proxy_spend=global_proxy_spend, @@ -1328,8 +1342,6 @@ async def update_database( existing_token_obj = await user_api_key_cache.async_get_cache( key=hashed_token ) - if existing_token_obj is None: - return existing_user_obj = await user_api_key_cache.async_get_cache(key=user_id) if existing_user_obj is not None and isinstance(existing_user_obj, dict): existing_user_obj = LiteLLM_UserTable(**existing_user_obj) @@ -1351,7 +1363,9 @@ async def update_database( if end_user_id is not None: prisma_client.end_user_list_transactons[end_user_id] = ( response_cost - + prisma_client.user_list_transactons.get(end_user_id, 0) + + prisma_client.end_user_list_transactons.get( + end_user_id, 0 + ) ) elif custom_db_client is not None: for id in user_ids: diff --git a/litellm/tests/test_jwt.py b/litellm/tests/test_jwt.py index 0699137cc..407814e84 100644 --- a/litellm/tests/test_jwt.py +++ b/litellm/tests/test_jwt.py @@ -345,3 +345,187 @@ async def test_team_token_output(prisma_client): assert team_result.team_tpm_limit == 100 assert team_result.team_rpm_limit == 99 assert team_result.team_models == ["gpt-3.5-turbo", "gpt-4"] + + +@pytest.mark.asyncio +async def test_user_token_output(prisma_client): + """ + - If user required, check if it exists + - fail initial request (when user doesn't exist) + - create user + - retry -> it should pass now + """ + import jwt, json + from cryptography.hazmat.primitives import serialization + from cryptography.hazmat.primitives.asymmetric import rsa + from cryptography.hazmat.backends import default_backend + from fastapi import Request + from starlette.datastructures import URL + from litellm.proxy.proxy_server import user_api_key_auth, new_team, new_user + from litellm.proxy._types import NewTeamRequest, UserAPIKeyAuth, NewUserRequest + import litellm + import uuid + + setattr(litellm.proxy.proxy_server, "prisma_client", prisma_client) + await litellm.proxy.proxy_server.prisma_client.connect() + + # Generate a private / public key pair using RSA algorithm + key = rsa.generate_private_key( + public_exponent=65537, key_size=2048, backend=default_backend() + ) + # Get private key in PEM format + private_key = key.private_bytes( + encoding=serialization.Encoding.PEM, + format=serialization.PrivateFormat.PKCS8, + encryption_algorithm=serialization.NoEncryption(), + ) + + # Get public key in PEM format + public_key = key.public_key().public_bytes( + encoding=serialization.Encoding.PEM, + format=serialization.PublicFormat.SubjectPublicKeyInfo, + ) + + public_key_obj = serialization.load_pem_public_key( + public_key, backend=default_backend() + ) + + # Convert RSA public key object to JWK (JSON Web Key) + public_jwk = json.loads(jwt.algorithms.RSAAlgorithm.to_jwk(public_key_obj)) + + assert isinstance(public_jwk, dict) + + # set cache + cache = DualCache() + + await cache.async_set_cache(key="litellm_jwt_auth_keys", value=[public_jwk]) + + jwt_handler = JWTHandler() + + jwt_handler.user_api_key_cache = cache + + jwt_handler.litellm_jwtauth = LiteLLM_JWTAuth() + + jwt_handler.litellm_jwtauth.user_id_jwt_field = "sub" + + # VALID TOKEN + ## GENERATE A TOKEN + # Assuming the current time is in UTC + expiration_time = int((datetime.utcnow() + timedelta(minutes=10)).timestamp()) + + team_id = f"team123_{uuid.uuid4()}" + user_id = f"user123_{uuid.uuid4()}" + payload = { + "sub": user_id, + "exp": expiration_time, # set the token to expire in 10 minutes + "scope": "litellm_team", + "client_id": team_id, + } + + # Generate the JWT token + # But before, you should convert bytes to string + private_key_str = private_key.decode("utf-8") + + ## team token + token = jwt.encode(payload, private_key_str, algorithm="RS256") + + ## admin token + payload = { + "sub": user_id, + "exp": expiration_time, # set the token to expire in 10 minutes + "scope": "litellm_proxy_admin", + } + + admin_token = jwt.encode(payload, private_key_str, algorithm="RS256") + + ## VERIFY IT WORKS + + # verify token + + response = await jwt_handler.auth_jwt(token=token) + + ## RUN IT THROUGH USER API KEY AUTH + + """ + - 1. Initial call should fail -> team doesn't exist + - 2. Create team via admin token + - 3. 2nd call w/ same team -> call should fail -> user doesn't exist + - 4. Create user via admin token + - 5. 3rd call w/ same team, same user -> call should succeed + - 6. assert user api key auth format + """ + + bearer_token = "Bearer " + token + + request = Request(scope={"type": "http"}) + request._url = URL(url="/chat/completions") + + ## 1. INITIAL TEAM CALL - should fail + # use generated key to auth in + setattr(litellm.proxy.proxy_server, "general_settings", {"enable_jwt_auth": True}) + setattr(litellm.proxy.proxy_server, "jwt_handler", jwt_handler) + try: + result = await user_api_key_auth(request=request, api_key=bearer_token) + pytest.fail("Team doesn't exist. This should fail") + except Exception as e: + pass + + ## 2. CREATE TEAM W/ ADMIN TOKEN - should succeed + try: + bearer_token = "Bearer " + admin_token + + request._url = URL(url="/team/new") + result = await user_api_key_auth(request=request, api_key=bearer_token) + await new_team( + data=NewTeamRequest( + team_id=team_id, + tpm_limit=100, + rpm_limit=99, + models=["gpt-3.5-turbo", "gpt-4"], + ), + user_api_key_dict=result, + ) + except Exception as e: + pytest.fail(f"This should not fail - {str(e)}") + + ## 3. 2nd CALL W/ TEAM TOKEN - should fail + bearer_token = "Bearer " + token + request._url = URL(url="/chat/completions") + try: + team_result: UserAPIKeyAuth = await user_api_key_auth( + request=request, api_key=bearer_token + ) + pytest.fail(f"User doesn't exist. this should fail") + except Exception as e: + pass + + ## 4. Create user + try: + bearer_token = "Bearer " + admin_token + + request._url = URL(url="/team/new") + result = await user_api_key_auth(request=request, api_key=bearer_token) + await new_user( + data=NewUserRequest( + user_id=user_id, + ), + ) + except Exception as e: + pytest.fail(f"This should not fail - {str(e)}") + + ## 5. 3rd call w/ same team, same user -> call should succeed + bearer_token = "Bearer " + token + request._url = URL(url="/chat/completions") + try: + team_result: UserAPIKeyAuth = await user_api_key_auth( + request=request, api_key=bearer_token + ) + except Exception as e: + pytest.fail(f"Team exists. This should not fail - {e}") + + ## 6. ASSERT USER_API_KEY_AUTH format (used for tpm/rpm limiting in parallel_request_limiter.py AND cost tracking) + + assert team_result.team_tpm_limit == 100 + assert team_result.team_rpm_limit == 99 + assert team_result.team_models == ["gpt-3.5-turbo", "gpt-4"] + assert team_result.user_id == user_id diff --git a/proxy_server_config.yaml b/proxy_server_config.yaml index 63c50953a..36c761ed3 100644 --- a/proxy_server_config.yaml +++ b/proxy_server_config.yaml @@ -64,12 +64,12 @@ litellm_settings: telemetry: False context_window_fallbacks: [{"gpt-3.5-turbo": ["gpt-3.5-turbo-large"]}] -router_settings: - routing_strategy: usage-based-routing-v2 - redis_host: os.environ/REDIS_HOST - redis_password: os.environ/REDIS_PASSWORD - redis_port: os.environ/REDIS_PORT - enable_pre_call_checks: true +# router_settings: +# routing_strategy: usage-based-routing-v2 +# redis_host: os.environ/REDIS_HOST +# redis_password: os.environ/REDIS_PASSWORD +# redis_port: os.environ/REDIS_PORT +# enable_pre_call_checks: true general_settings: master_key: sk-1234 # [OPTIONAL] Use to enforce auth on proxy. See - https://docs.litellm.ai/docs/proxy/virtual_keys