From 8a9e8838e24a2cc957383a64cbf29817c20aa7ab Mon Sep 17 00:00:00 2001 From: Krish Dholakia Date: Thu, 6 Feb 2025 23:15:33 -0800 Subject: [PATCH] Litellm dev 02 06 2025 p3 (#8343) * feat(handle_jwt.py): initial commit to allow scope based model access * feat(handle_jwt.py): allow model access based on token scopes allow admin to control model access from IDP * test(test_jwt.py): add unit testing for scope based model access * docs(token_auth.md): add scope based model access to docs * docs(token_auth.md): update docs * docs(token_auth.md): update docs * build: add gemini commercial rate limits * fix: fix linting error --- docs/my-website/docs/proxy/token_auth.md | 66 +++++++++++++++++++++++- litellm/proxy/_new_secret_config.yaml | 14 +++++ litellm/proxy/_types.py | 32 +++++++++++- litellm/proxy/auth/auth_checks.py | 18 +++++-- litellm/proxy/auth/handle_jwt.py | 51 +++++++++++++++++- tests/proxy_unit_tests/test_jwt.py | 64 +++++++++++++++++++++++ 6 files changed, 238 insertions(+), 7 deletions(-) diff --git a/docs/my-website/docs/proxy/token_auth.md b/docs/my-website/docs/proxy/token_auth.md index 0e65900b28..9df0462281 100644 --- a/docs/my-website/docs/proxy/token_auth.md +++ b/docs/my-website/docs/proxy/token_auth.md @@ -370,4 +370,68 @@ Supported internal roles: - `internal_user`: User object will be used for RBAC spend tracking. Use this for tracking spend for an 'individual user'. - `proxy_admin`: Proxy admin will be used for RBAC spend tracking. Use this for granting admin access to a token. -### [Architecture Diagram (Control Model Access)](./jwt_auth_arch) \ No newline at end of file +### [Architecture Diagram (Control Model Access)](./jwt_auth_arch) + +## [BETA] Control Model Access with Scopes + +Control which models a JWT can access. Set `enforce_scope_based_access: true` to enforce scope-based access control. + +### 1. Setup config.yaml with scope mappings. + + +```yaml +model_list: + - model_name: anthropic-claude + litellm_params: + model: anthropic/claude-3-5-sonnet + api_key: os.environ/ANTHROPIC_API_KEY + - model_name: gpt-3.5-turbo-testing + litellm_params: + model: gpt-3.5-turbo + api_key: os.environ/OPENAI_API_KEY + +general_settings: + enable_jwt_auth: True + litellm_jwtauth: + team_id_jwt_field: "client_id" # 👈 set the field in the JWT token that contains the team id + team_id_upsert: true # 👈 upsert the team to db, if team id is not found in db + scope_mappings: + - scope: litellm.api.consumer + models: ["anthropic-claude"] + - scope: litellm.api.gpt_3_5_turbo + models: ["gpt-3.5-turbo-testing"] + enforce_scope_based_access: true # 👈 enforce scope-based access control + enforce_rbac: true # 👈 enforces only a Team/User/ProxyAdmin can access the proxy. +``` + +#### Scope Mapping Spec + +- `scope`: The scope to be used for the JWT token. +- `models`: The models that the JWT token can access. Value is the `model_name` in `model_list`. Note: Wildcard routes are not currently supported. + +### 2. Create a JWT with the correct scopes. + +Expected Token: + +``` +{ + "scope": ["litellm.api.consumer", "litellm.api.gpt_3_5_turbo"] +} +``` + +### 3. Test the flow. + +```bash +curl -L -X POST 'http://0.0.0.0:4000/v1/chat/completions' \ +-H 'Content-Type: application/json' \ +-H 'Authorization: Bearer eyJhbGci...' \ +-d '{ + "model": "gpt-3.5-turbo-testing", + "messages": [ + { + "role": "user", + "content": "Hey, how'\''s it going 1234?" + } + ] +}' +``` \ No newline at end of file diff --git a/litellm/proxy/_new_secret_config.yaml b/litellm/proxy/_new_secret_config.yaml index 6e0850af50..0a738bc29d 100644 --- a/litellm/proxy/_new_secret_config.yaml +++ b/litellm/proxy/_new_secret_config.yaml @@ -27,3 +27,17 @@ model_list: model: openai/fake api_key: fake-key api_base: https://exampleopenaiendpoint-production.up.railway.app/ + +general_settings: + enable_jwt_auth: True + litellm_jwtauth: + team_id_jwt_field: "client_id" + team_id_upsert: true + scope_mappings: + - scope: litellm.api.consumer + models: ["anthropic-claude"] + routes: ["/v1/chat/completions"] + - scope: litellm.api.gpt_3_5_turbo + models: ["gpt-3.5-turbo-testing"] + enforce_scope_based_access: true + enforce_rbac: true diff --git a/litellm/proxy/_types.py b/litellm/proxy/_types.py index a131e6ce85..3e43692158 100644 --- a/litellm/proxy/_types.py +++ b/litellm/proxy/_types.py @@ -1040,6 +1040,13 @@ class LiteLLM_TeamTable(TeamBase): "model_max_budget", "model_aliases", ] + + if ( + isinstance(values.get("members_with_roles"), dict) + and not values["members_with_roles"] + ): + values["members_with_roles"] = [] + for field in dict_fields: value = values.get(field) if value is not None and isinstance(value, str): @@ -2279,11 +2286,14 @@ RBAC_ROLES = Literal[ ] -class RoleBasedPermissions(LiteLLMPydanticObjectBase): - role: RBAC_ROLES +class OIDCPermissions(LiteLLMPydanticObjectBase): models: Optional[List[str]] = None routes: Optional[List[str]] = None + +class RoleBasedPermissions(OIDCPermissions): + role: RBAC_ROLES + model_config = { "extra": "forbid", } @@ -2294,6 +2304,14 @@ class RoleMapping(BaseModel): internal_role: RBAC_ROLES +class ScopeMapping(OIDCPermissions): + scope: str + + model_config = { + "extra": "forbid", + } + + class LiteLLM_JWTAuth(LiteLLMPydanticObjectBase): """ A class to define the roles and permissions for a LiteLLM Proxy w/ JWT Auth. @@ -2323,6 +2341,7 @@ class LiteLLM_JWTAuth(LiteLLMPydanticObjectBase): "info_routes", ] team_id_jwt_field: Optional[str] = None + team_id_upsert: bool = False team_ids_jwt_field: Optional[str] = None upsert_sso_user_to_team: bool = False team_allowed_routes: List[ @@ -2351,6 +2370,8 @@ class LiteLLM_JWTAuth(LiteLLMPydanticObjectBase): object_id_jwt_field: Optional[str] = ( None # can be either user / team, inferred from the role mapping ) + scope_mappings: Optional[List[ScopeMapping]] = None + enforce_scope_based_access: bool = False def __init__(self, **kwargs: Any) -> None: # get the attribute names for this Pydantic model @@ -2361,6 +2382,8 @@ class LiteLLM_JWTAuth(LiteLLMPydanticObjectBase): user_allowed_roles = kwargs.get("user_allowed_roles") object_id_jwt_field = kwargs.get("object_id_jwt_field") role_mappings = kwargs.get("role_mappings") + scope_mappings = kwargs.get("scope_mappings") + enforce_scope_based_access = kwargs.get("enforce_scope_based_access") if invalid_keys: raise ValueError( @@ -2378,4 +2401,9 @@ class LiteLLM_JWTAuth(LiteLLMPydanticObjectBase): "if object_id_jwt_field is set, role_mappings must also be set. Needed to infer if the caller is a user or team." ) + if scope_mappings is not None and not enforce_scope_based_access: + raise ValueError( + "scope_mappings must be set if enforce_scope_based_access is true." + ) + super().__init__(**kwargs) diff --git a/litellm/proxy/auth/auth_checks.py b/litellm/proxy/auth/auth_checks.py index 43815c357b..517cc7c73b 100644 --- a/litellm/proxy/auth/auth_checks.py +++ b/litellm/proxy/auth/auth_checks.py @@ -655,11 +655,20 @@ async def _delete_cache_key_object( @log_db_metrics -async def _get_team_db_check(team_id: str, prisma_client: PrismaClient): - return await prisma_client.db.litellm_teamtable.find_unique( +async def _get_team_db_check( + team_id: str, prisma_client: PrismaClient, team_id_upsert: Optional[bool] = None +): + response = await prisma_client.db.litellm_teamtable.find_unique( where={"team_id": team_id} ) + if response is None and team_id_upsert: + response = await prisma_client.db.litellm_teamtable.create( + data={"team_id": team_id} + ) + + return response + async def _get_team_object_from_db(team_id: str, prisma_client: PrismaClient): return await prisma_client.db.litellm_teamtable.find_unique( @@ -675,6 +684,7 @@ async def _get_team_object_from_user_api_key_cache( db_cache_expiry: int, proxy_logging_obj: Optional[ProxyLogging], key: str, + team_id_upsert: Optional[bool] = None, ) -> LiteLLM_TeamTableCachedObj: db_access_time_key = key should_check_db = _should_check_db( @@ -684,7 +694,7 @@ async def _get_team_object_from_user_api_key_cache( ) if should_check_db: response = await _get_team_db_check( - team_id=team_id, prisma_client=prisma_client + team_id=team_id, prisma_client=prisma_client, team_id_upsert=team_id_upsert ) else: response = None @@ -752,6 +762,7 @@ async def get_team_object( proxy_logging_obj: Optional[ProxyLogging] = None, check_cache_only: Optional[bool] = None, check_db_only: Optional[bool] = None, + team_id_upsert: Optional[bool] = None, ) -> LiteLLM_TeamTableCachedObj: """ - Check if team id in proxy Team Table @@ -795,6 +806,7 @@ async def get_team_object( last_db_access_time=last_db_access_time, db_cache_expiry=db_cache_expiry, key=key, + team_id_upsert=team_id_upsert, ) except Exception: raise Exception( diff --git a/litellm/proxy/auth/handle_jwt.py b/litellm/proxy/auth/handle_jwt.py index 1d2a6fe5cd..c60d41faee 100644 --- a/litellm/proxy/auth/handle_jwt.py +++ b/litellm/proxy/auth/handle_jwt.py @@ -30,6 +30,7 @@ from litellm.proxy._types import ( LiteLLM_TeamTable, LiteLLM_UserTable, LitellmUserRoles, + ScopeMapping, Span, ) from litellm.proxy.utils import PrismaClient, ProxyLogging @@ -318,7 +319,7 @@ class JWTHandler: org_id = default_value return org_id - def get_scopes(self, token: dict) -> list: + def get_scopes(self, token: dict) -> List[str]: try: if isinstance(token["scope"], str): # Assuming the scopes are stored in 'scope' claim and are space-separated @@ -543,6 +544,40 @@ class JWTAuthManager: return True + @staticmethod + def check_scope_based_access( + scope_mappings: List[ScopeMapping], + scopes: List[str], + request_data: dict, + general_settings: dict, + ) -> None: + """ + Check if scope allows access to the requested model + """ + if not scope_mappings: + return None + + allowed_models = [] + for sm in scope_mappings: + if sm.scope in scopes and sm.models: + allowed_models.extend(sm.models) + + requested_model = request_data.get("model") + + if not requested_model: + return None + + if requested_model not in allowed_models: + raise HTTPException( + status_code=403, + detail={ + "error": "model={} not allowed. Allowed_models={}".format( + requested_model, allowed_models + ) + }, + ) + return None + @staticmethod async def check_rbac_role( jwt_handler: JWTHandler, @@ -636,6 +671,7 @@ class JWTAuthManager: user_api_key_cache=user_api_key_cache, parent_otel_span=parent_otel_span, proxy_logging_obj=proxy_logging_obj, + team_id_upsert=jwt_handler.litellm_jwtauth.team_id_upsert, ) return individual_team_id, team_object @@ -829,6 +865,19 @@ class JWTAuthManager: rbac_role, ) + # Check Scope Based Access + scopes = jwt_handler.get_scopes(token=jwt_valid_token) + if ( + jwt_handler.litellm_jwtauth.enforce_scope_based_access + and jwt_handler.litellm_jwtauth.scope_mappings + ): + JWTAuthManager.check_scope_based_access( + scope_mappings=jwt_handler.litellm_jwtauth.scope_mappings, + scopes=scopes, + request_data=request_data, + general_settings=general_settings, + ) + object_id = jwt_handler.get_object_id(token=jwt_valid_token, default_value=None) # Get basic user info diff --git a/tests/proxy_unit_tests/test_jwt.py b/tests/proxy_unit_tests/test_jwt.py index f35091afa6..a168a91c12 100644 --- a/tests/proxy_unit_tests/test_jwt.py +++ b/tests/proxy_unit_tests/test_jwt.py @@ -1183,3 +1183,67 @@ def test_can_rbac_role_call_route(): }, route="/v1/embeddings", ) + + +@pytest.mark.parametrize( + "requested_model, should_work", + [ + ("gpt-3.5-turbo-testing", True), + ("gpt-4o", False), + ], +) +def test_check_scope_based_access(requested_model, should_work): + from litellm.proxy.auth.handle_jwt import JWTAuthManager + from litellm.proxy._types import ScopeMapping + + args = { + "scope_mappings": [ + ScopeMapping( + models=["anthropic-claude"], + routes=["/v1/chat/completions"], + scope="litellm.api.consumer", + ), + ScopeMapping( + models=["gpt-3.5-turbo-testing"], + routes=None, + scope="litellm.api.gpt_3_5_turbo", + ), + ], + "scopes": [ + "profile", + "groups-scope", + "email", + "litellm.api.gpt_3_5_turbo", + "litellm.api.consumer", + ], + "request_data": { + "model": requested_model, + "messages": [{"role": "user", "content": "Hey, how's it going 1234?"}], + }, + "general_settings": { + "enable_jwt_auth": True, + "litellm_jwtauth": { + "team_id_jwt_field": "client_id", + "team_id_upsert": True, + "scope_mappings": [ + { + "scope": "litellm.api.consumer", + "models": ["anthropic-claude"], + "routes": ["/v1/chat/completions"], + }, + { + "scope": "litellm.api.gpt_3_5_turbo", + "models": ["gpt-3.5-turbo-testing"], + }, + ], + "enforce_scope_based_access": True, + "enforce_rbac": True, + }, + }, + } + + if should_work: + JWTAuthManager.check_scope_based_access(**args) + else: + with pytest.raises(HTTPException): + JWTAuthManager.check_scope_based_access(**args)