diff --git a/litellm/tests/test_key_generate_prisma.py b/litellm/tests/test_key_generate_prisma.py index 004aef63a..3f844335b 100644 --- a/litellm/tests/test_key_generate_prisma.py +++ b/litellm/tests/test_key_generate_prisma.py @@ -33,6 +33,7 @@ import pytest, logging, asyncio import litellm, asyncio from litellm.proxy.proxy_server import ( new_user, + generate_key_fn, user_api_key_auth, user_update, delete_key_fn, @@ -49,6 +50,7 @@ from litellm.proxy._types import ( DynamoDBArgs, DeleteKeyRequest, UpdateKeyRequest, + GenerateKeyRequest, ) from litellm.proxy.utils import DBClient from starlette.datastructures import URL @@ -593,3 +595,82 @@ def test_generate_and_update_key(prisma_client): print("Got Exception", e) print(e.detail) pytest.fail(f"An exception occurred - {str(e)}") + + +def test_key_generate_with_custom_auth(prisma_client): + # custom - generate key function + async def custom_generate_key_fn(data: GenerateKeyRequest): + """ + Asynchronously decides if a key should be generated or not based on the provided data. + + Args: + data (GenerateKeyRequest): The data to be used for decision making. + + Returns: + bool: True if a key should be generated, False otherwise. + """ + # decide if a key should be generated or not + print("using custom auth function!") + data_json = data.json() # type: ignore + + # Unpacking variables + team_id = data_json.get("team_id") + duration = data_json.get("duration") + models = data_json.get("models") + aliases = data_json.get("aliases") + config = data_json.get("config") + spend = data_json.get("spend") + user_id = data_json.get("user_id") + max_parallel_requests = data_json.get("max_parallel_requests") + metadata = data_json.get("metadata") + tpm_limit = data_json.get("tpm_limit") + rpm_limit = data_json.get("rpm_limit") + + if team_id is not None and team_id == "litellm-core-infra@gmail.com": + # only team_id="litellm-core-infra@gmail.com" can make keys + return { + "decision": True, + } + else: + print("Failed custom auth") + return { + "decision": False, + "message": "This violates LiteLLM Proxy Rules. No team id provided.", + } + + setattr(litellm.proxy.proxy_server, "prisma_client", prisma_client) + setattr(litellm.proxy.proxy_server, "master_key", "sk-1234") + setattr( + litellm.proxy.proxy_server, "user_custom_key_generate", custom_generate_key_fn + ) + try: + + async def test(): + await litellm.proxy.proxy_server.prisma_client.connect() + try: + request = GenerateKeyRequest() + key = await generate_key_fn(request) + pytest.fail(f"Expected an exception. Got {key}") + except Exception as e: + # this should fail + print("Got Exception", e) + print(e.detail) + print("First request failed!. This is expected") + assert ( + "This violates LiteLLM Proxy Rules. No team id provided." + in e.detail + ) + + request_2 = GenerateKeyRequest( + team_id="litellm-core-infra@gmail.com", + ) + + key = await generate_key_fn(request_2) + print(key) + generated_key = key.key + + asyncio.run(test()) + except Exception as e: + print("Got Exception", e) + print(e.detail) + pytest.fail(f"An exception occurred - {str(e)}")