diff --git a/litellm/tests/test_key_generate_dynamodb.py b/litellm/tests/test_key_generate_dynamodb.py index 5363bcd09a..ff2e9940b4 100644 --- a/litellm/tests/test_key_generate_dynamodb.py +++ b/litellm/tests/test_key_generate_dynamodb.py @@ -33,6 +33,7 @@ from litellm.proxy._types import NewUserRequest, DynamoDBArgs from litellm.proxy.utils import DBClient from starlette.datastructures import URL + db_args = { "ssl_verify": False, "billing_mode": "PAY_PER_REQUEST", @@ -87,7 +88,7 @@ def test_call_with_invalid_key(): generated_key = "bad-key" bearer_token = "Bearer " + generated_key - request = Request(scope={"type": "http"}) + request = Request(scope={"type": "http"}, receive=None) request._url = URL(url="/chat/completions") # use generated key to auth in @@ -102,18 +103,70 @@ def test_call_with_invalid_key(): pass -# def test_call_with_invalid_model(): -# # 3. Make a call to a key with an invalid model - expect to fail -# key = new_user(ValidNewUserRequest()) -# result = user_auth(InvalidModelRequest(key)) -# assert result is False +def test_call_with_invalid_model(): + # 3. Make a call to a key with an invalid model - expect to fail + setattr(litellm.proxy.proxy_server, "custom_db_client", custom_db_client) + setattr(litellm.proxy.proxy_server, "master_key", "sk-1234") + try: + + async def test(): + request = NewUserRequest(models=["mistral"]) + key = await new_user(request) + print(key) + + generated_key = key.key + bearer_token = "Bearer " + generated_key + + request = Request(scope={"type": "http"}) + request._url = URL(url="/chat/completions") + + async def return_body(): + return b'{"model": "gemini-pro-vision"}' + + request.body = return_body + + # use generated key to auth in + result = await user_api_key_auth(request=request, api_key=bearer_token) + pytest.fail(f"This should have failed!. IT's an invalid model") + + asyncio.run(test()) + except Exception as e: + assert ( + e.detail + == "Authentication Error, API Key not allowed to access model. This token can only access models=['mistral']. Tried to access gemini-pro-vision" + ) + pass -# def test_call_with_valid_model(): -# # 4. Make a call to a key with a valid model - expect to pass -# key = new_user(ValidNewUserRequest()) -# result = user_auth(ValidModelRequest(key)) -# assert result is True +def test_call_with_valid_model(): + # 4. Make a call to a key with a valid model - expect to pass + setattr(litellm.proxy.proxy_server, "custom_db_client", custom_db_client) + setattr(litellm.proxy.proxy_server, "master_key", "sk-1234") + try: + + async def test(): + request = NewUserRequest(models=["mistral"]) + key = await new_user(request) + print(key) + + generated_key = key.key + bearer_token = "Bearer " + generated_key + + request = Request(scope={"type": "http"}) + request._url = URL(url="/chat/completions") + + async def return_body(): + return b'{"model": "mistral"}' + + request.body = return_body + + # use generated key to auth in + result = await user_api_key_auth(request=request, api_key=bearer_token) + print("result from user auth with new key", result) + + asyncio.run(test()) + except Exception as e: + pytest.fail(f"An exception occurred - {str(e)}") # def test_call_with_expired_key():