diff --git a/litellm/tests/test_key_generate_prisma.py b/litellm/tests/test_key_generate_prisma.py index b81e2dbf6..5accd03c6 100644 --- a/litellm/tests/test_key_generate_prisma.py +++ b/litellm/tests/test_key_generate_prisma.py @@ -8,6 +8,8 @@ # 7. Make a call with an key that never expires, expect to pass # 8. Make a call with an expired key, expect to fail # 9. Delete a Key +# 10. Generate a key, call key/info. Assert info returned is the same as generated key info +# 11. Generate a Key, cal key/info, call key/update, call key/info # function to call to generate key - async def new_user(data: NewUserRequest): @@ -33,13 +35,20 @@ from litellm.proxy.proxy_server import ( user_api_key_auth, user_update, delete_key_fn, + info_key_fn, + update_key_fn, ) from litellm.proxy.utils import PrismaClient, ProxyLogging from litellm._logging import verbose_proxy_logger verbose_proxy_logger.setLevel(level=logging.DEBUG) -from litellm.proxy._types import NewUserRequest, DynamoDBArgs, DeleteKeyRequest +from litellm.proxy._types import ( + NewUserRequest, + DynamoDBArgs, + DeleteKeyRequest, + UpdateKeyRequest, +) from litellm.proxy.utils import DBClient from starlette.datastructures import URL from litellm.caching import DualCache @@ -477,3 +486,116 @@ def test_delete_key_auth(prisma_client): print(e.detail) assert "Authentication Error" in e.detail pass + + +def test_generate_and_call_key_info(prisma_client): + # 10. Generate a Key, cal key/info + + print("prisma client=", prisma_client) + + setattr(litellm.proxy.proxy_server, "prisma_client", prisma_client) + setattr(litellm.proxy.proxy_server, "master_key", "sk-1234") + try: + + async def test(): + await litellm.proxy.proxy_server.prisma_client.connect() + request = NewUserRequest( + metadata={"team": "litellm-team3", "project": "litellm-project3"} + ) + key = await new_user(request) + print(key) + + generated_key = key.key + + # use generated key to auth in + result = await info_key_fn(key=generated_key) + print("result from info_key_fn", result) + assert result["key"] == generated_key + print("\n info for key=", result["info"]) + assert result["info"].max_parallel_requests == None + assert result["info"].metadata == { + "team": "litellm-team3", + "project": "litellm-project3", + } + + # cleanup - delete key + delete_key_request = DeleteKeyRequest(keys=[generated_key]) + + # delete the key + await delete_key_fn(request=request, data=delete_key_request) + + asyncio.run(test()) + except Exception as e: + pytest.fail(f"An exception occurred - {str(e)}") + + +def test_generate_and_update_key(prisma_client): + # 11. Generate a Key, cal key/info, call key/update, call key/info + # Check if data gets updated + # Check if untouched data does not get updated + + print("prisma client=", prisma_client) + + setattr(litellm.proxy.proxy_server, "prisma_client", prisma_client) + setattr(litellm.proxy.proxy_server, "master_key", "sk-1234") + try: + + async def test(): + await litellm.proxy.proxy_server.prisma_client.connect() + request = NewUserRequest( + metadata={"team": "litellm-team3", "project": "litellm-project3"} + ) + key = await new_user(request) + print(key) + + generated_key = key.key + + # use generated key to auth in + result = await info_key_fn(key=generated_key) + print("result from info_key_fn", result) + assert result["key"] == generated_key + print("\n info for key=", result["info"]) + assert result["info"].max_parallel_requests == None + assert result["info"].metadata == { + "team": "litellm-team3", + "project": "litellm-project3", + } + + request = Request(scope={"type": "http"}) + request._url = URL(url="/update/key") + + # update the key + await update_key_fn( + request=Request, + data=UpdateKeyRequest( + key=generated_key, + models=["ada", "babbage", "curie", "davinci"], + ), + ) + + # get info on key after update + result = await info_key_fn(key=generated_key) + print("result from info_key_fn", result) + assert result["key"] == generated_key + print("\n info for key=", result["info"]) + assert result["info"].max_parallel_requests == None + assert result["info"].metadata == { + "team": "litellm-team3", + "project": "litellm-project3", + } + assert result["info"].models == ["ada", "babbage", "curie", "davinci"] + + # cleanup - delete key + delete_key_request = DeleteKeyRequest(keys=[generated_key]) + + request = Request(scope={"type": "http"}, receive=None) + request._url = URL(url="/chat/completions") + + # delete the key + await delete_key_fn(request=request, data=delete_key_request) + + asyncio.run(test()) + except Exception as e: + print("Got Exception", e) + print(e.detail) + pytest.fail(f"An exception occurred - {str(e)}")