From 922c8ac7589fecb805fe9965ff5ecb1a636acab1 Mon Sep 17 00:00:00 2001 From: Ishaan Jaff Date: Mon, 23 Sep 2024 16:37:02 -0700 Subject: [PATCH] [Feat-Proxy] add service accounts backend (#5852) * service_account_settings on config * add service account checks * call service_account_checks * add testing for service accounts --- litellm/proxy/auth/service_account_checks.py | 53 ++++++++++++++++++ litellm/proxy/auth/user_api_key_auth.py | 7 +++ litellm/proxy/proxy_config.yaml | 3 +- litellm/tests/test_key_generate_prisma.py | 58 ++++++++++++++++++++ 4 files changed, 119 insertions(+), 2 deletions(-) create mode 100644 litellm/proxy/auth/service_account_checks.py diff --git a/litellm/proxy/auth/service_account_checks.py b/litellm/proxy/auth/service_account_checks.py new file mode 100644 index 000000000..87d7d6685 --- /dev/null +++ b/litellm/proxy/auth/service_account_checks.py @@ -0,0 +1,53 @@ +""" +Checks for LiteLLM service account keys + +""" + +from litellm.proxy._types import ProxyErrorTypes, ProxyException, UserAPIKeyAuth + + +def check_if_token_is_service_account(valid_token: UserAPIKeyAuth) -> bool: + """ + Checks if the token is a service account + + Returns: + bool: True if token is a service account + + """ + if valid_token.metadata: + if "service_account_id" in valid_token.metadata: + return True + return False + + +async def service_account_checks( + valid_token: UserAPIKeyAuth, request_data: dict +) -> bool: + """ + If a virtual key is a service account, checks it's a valid service account + + A token is a service account if it has a service_account_id in its metadata + + Service Account Specific Checks: + - Check if required_params is set + """ + + if check_if_token_is_service_account(valid_token) is not True: + return True + + from litellm.proxy.proxy_server import general_settings + + if "service_account_settings" in general_settings: + service_account_settings = general_settings["service_account_settings"] + if "enforced_params" in service_account_settings: + _enforced_params = service_account_settings["enforced_params"] + for param in _enforced_params: + if param not in request_data: + raise ProxyException( + type=ProxyErrorTypes.bad_request_error.value, + code=400, + param=param, + message=f"BadRequest please pass param={param} in request body. This is a required param for service account", + ) + + return True diff --git a/litellm/proxy/auth/user_api_key_auth.py b/litellm/proxy/auth/user_api_key_auth.py index 85c252d5d..a5f84b3ac 100644 --- a/litellm/proxy/auth/user_api_key_auth.py +++ b/litellm/proxy/auth/user_api_key_auth.py @@ -70,6 +70,7 @@ from litellm.proxy.auth.auth_utils import ( from litellm.proxy.auth.oauth2_check import check_oauth2_token from litellm.proxy.auth.oauth2_proxy_hook import handle_oauth2_proxy_request from litellm.proxy.auth.route_checks import non_admin_allowed_routes_check +from litellm.proxy.auth.service_account_checks import service_account_checks from litellm.proxy.common_utils.http_parsing_utils import _read_request_body from litellm.proxy.utils import _to_ns @@ -965,6 +966,12 @@ async def user_api_key_auth( else: _team_obj = None + # Check 9: Check if key is a service account key + await service_account_checks( + valid_token=valid_token, + request_data=request_data, + ) + user_api_key_cache.set_cache( key=valid_token.team_id, value=_team_obj ) # save team table in cache - used for tpm/rpm limiting - tpm_rpm_limiter.py diff --git a/litellm/proxy/proxy_config.yaml b/litellm/proxy/proxy_config.yaml index fda8deadd..829131146 100644 --- a/litellm/proxy/proxy_config.yaml +++ b/litellm/proxy/proxy_config.yaml @@ -27,5 +27,4 @@ litellm_settings: general_settings: service_account_settings: - required_params: ["user"] - + enforced_params: ["user"] diff --git a/litellm/tests/test_key_generate_prisma.py b/litellm/tests/test_key_generate_prisma.py index a25444c69..6edfeca0f 100644 --- a/litellm/tests/test_key_generate_prisma.py +++ b/litellm/tests/test_key_generate_prisma.py @@ -3249,3 +3249,61 @@ async def test_auth_vertex_ai_route(prisma_client): assert "Invalid proxy server token passed" in error_str pass + + +@pytest.mark.asyncio +async def test_service_accounts(prisma_client): + """ + Do not delete + this is the Admin UI flow + """ + # Make a call to a key with model = `all-proxy-models` this is an Alias from LiteLLM Admin UI + setattr(litellm.proxy.proxy_server, "prisma_client", prisma_client) + setattr(litellm.proxy.proxy_server, "master_key", "sk-1234") + setattr( + litellm.proxy.proxy_server, + "general_settings", + {"service_account_settings": {"enforced_params": ["user"]}}, + ) + + await litellm.proxy.proxy_server.prisma_client.connect() + + request = GenerateKeyRequest( + metadata={"service_account_id": f"prod-service-{uuid.uuid4()}"}, + ) + response = await generate_key_fn( + data=request, + ) + + print("key generated=", response) + generated_key = response.key + bearer_token = "Bearer " + generated_key + # make a bad /chat/completions call expect it to fail + + request = Request(scope={"type": "http"}) + request._url = URL(url="/chat/completions") + + async def return_body(): + return b'{"model": "gemini-pro-vision"}' + + request.body = return_body + + # use generated key to auth in + print("Bearer token being sent to user_api_key_auth() - {}".format(bearer_token)) + try: + result = await user_api_key_auth(request=request, api_key=bearer_token) + pytest.fail("Expected this call to fail. Bad request using service account") + except Exception as e: + print("error str=", str(e.message)) + assert "This is a required param for service account" in str(e.message) + + # make a good /chat/completions call it should pass + async def good_return_body(): + return b'{"model": "gemini-pro-vision", "user": "foo"}' + + request.body = good_return_body + + result = await user_api_key_auth(request=request, api_key=bearer_token) + print("response from user_api_key_auth", result) + + setattr(litellm.proxy.proxy_server, "general_settings", {})