diff --git a/litellm/proxy/_types.py b/litellm/proxy/_types.py index 67acf71e5..812b572fb 100644 --- a/litellm/proxy/_types.py +++ b/litellm/proxy/_types.py @@ -656,6 +656,10 @@ class UpdateKeyRequest(GenerateKeyRequest): metadata: Optional[dict] = None +class RegenerateKeyRequest(UpdateKeyRequest): + key: Optional[str] = None + + class KeyRequest(LiteLLMBase): keys: List[str] diff --git a/litellm/proxy/management_endpoints/key_management_endpoints.py b/litellm/proxy/management_endpoints/key_management_endpoints.py index 00e17400c..43f4c1fcb 100644 --- a/litellm/proxy/management_endpoints/key_management_endpoints.py +++ b/litellm/proxy/management_endpoints/key_management_endpoints.py @@ -280,6 +280,52 @@ async def generate_key_fn( ) +async def prepare_key_update_data(data: UpdateKeyRequest, existing_key_row): + data_json: dict = data.dict(exclude_unset=True) + key = data_json.pop("key", None) + + _metadata_fields = ["model_rpm_limit", "model_tpm_limit", "guardrails"] + non_default_values = {} + for k, v in data_json.items(): + if k in _metadata_fields: + continue + if v is not None and v not in ([], {}, 0): + non_default_values[k] = v + + if "duration" in non_default_values: + duration = non_default_values.pop("duration") + duration_s = _duration_in_seconds(duration=duration) + expires = datetime.now(timezone.utc) + timedelta(seconds=duration_s) + non_default_values["expires"] = expires + + if "budget_duration" in non_default_values: + duration_s = _duration_in_seconds( + duration=non_default_values["budget_duration"] + ) + key_reset_at = datetime.now(timezone.utc) + timedelta(seconds=duration_s) + non_default_values["budget_reset_at"] = key_reset_at + + _metadata = existing_key_row.metadata or {} + + if data.model_tpm_limit: + if "model_tpm_limit" not in _metadata: + _metadata["model_tpm_limit"] = {} + _metadata["model_tpm_limit"].update(data.model_tpm_limit) + non_default_values["metadata"] = _metadata + + if data.model_rpm_limit: + if "model_rpm_limit" not in _metadata: + _metadata["model_rpm_limit"] = {} + _metadata["model_rpm_limit"].update(data.model_rpm_limit) + non_default_values["metadata"] = _metadata + + if data.guardrails: + _metadata["guardrails"] = data.guardrails + non_default_values["metadata"] = _metadata + + return non_default_values + + @router.post( "/key/update", tags=["key management"], dependencies=[Depends(user_api_key_auth)] ) @@ -323,59 +369,9 @@ async def update_key_fn( detail={"error": f"Team not found, passed team_id={data.team_id}"}, ) - _metadata_fields = ["model_rpm_limit", "model_tpm_limit", "guardrails"] - # get non default values for key - non_default_values = {} - for k, v in data_json.items(): - # this field gets stored in metadata - if key in _metadata_fields: - continue - if v is not None and v not in ( - [], - {}, - 0, - ): # models default to [], spend defaults to 0, we should not reset these values - non_default_values[k] = v - - if "duration" in non_default_values: - duration = non_default_values.pop("duration") - duration_s = _duration_in_seconds(duration=duration) - expires = datetime.now(timezone.utc) + timedelta(seconds=duration_s) - non_default_values["expires"] = expires - - if "budget_duration" in non_default_values: - duration_s = _duration_in_seconds( - duration=non_default_values["budget_duration"] - ) - key_reset_at = datetime.now(timezone.utc) + timedelta(seconds=duration_s) - non_default_values["budget_reset_at"] = key_reset_at - - # Update metadata for virtual Key - if data.model_tpm_limit: - _metadata = existing_key_row.metadata or {} - if "model_tpm_limit" not in _metadata: - _metadata["model_tpm_limit"] = {} - - _metadata["model_tpm_limit"].update(data.model_tpm_limit) - non_default_values["metadata"] = _metadata - non_default_values.pop("model_tpm_limit", None) - - if data.model_rpm_limit: - _metadata = existing_key_row.metadata or {} - if "model_rpm_limit" not in _metadata: - _metadata["model_rpm_limit"] = {} - - _metadata["model_rpm_limit"].update(data.model_rpm_limit) - non_default_values["metadata"] = _metadata - non_default_values.pop("model_rpm_limit", None) - - if data.guardrails: - _metadata = existing_key_row.metadata or {} - _metadata["guardrails"] = data.guardrails - - # update values that will be written to the DB - non_default_values["metadata"] = _metadata - non_default_values.pop("guardrails", None) + non_default_values = await prepare_key_update_data( + data=data, existing_key_row=existing_key_row + ) response = await prisma_client.update_data( token=key, data={**non_default_values, "token": key} @@ -983,6 +979,7 @@ async def delete_verification_token(tokens: List, user_id: Optional[str] = None) @management_endpoint_wrapper async def regenerate_key_fn( key: str, + data: Optional[RegenerateKeyRequest] = None, user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth), litellm_changed_by: Optional[str] = Header( None, @@ -1041,14 +1038,26 @@ async def regenerate_key_fn( new_token_hash = hash_token(new_token) new_token_key_name = f"sk-...{new_token[-4:]}" - # update new token in DB + # Prepare the update data + update_data = { + "token": new_token_hash, + "key_name": new_token_key_name, + } + + non_default_values = {} + if data is not None: + # Update with any provided parameters from GenerateKeyRequest + non_default_values = await prepare_key_update_data( + data=data, existing_key_row=_key_in_db + ) + + update_data.update(non_default_values) + # Update the token in the database updated_token = await prisma_client.db.litellm_verificationtoken.update( where={"token": hashed_api_key}, - data={ - "token": new_token_hash, - "key_name": new_token_key_name, - }, + data=update_data, ) + updated_token_dict = {} if updated_token is not None: updated_token_dict = dict(updated_token) diff --git a/litellm/tests/test_key_generate_prisma.py b/litellm/tests/test_key_generate_prisma.py index adf0e8aea..995d5c0f7 100644 --- a/litellm/tests/test_key_generate_prisma.py +++ b/litellm/tests/test_key_generate_prisma.py @@ -2946,108 +2946,6 @@ async def test_team_access_groups(prisma_client): ) -################ Unit Tests for testing regeneration of keys ########### -@pytest.mark.asyncio() -async def test_regenerate_api_key(prisma_client): - litellm.set_verbose = True - setattr(litellm.proxy.proxy_server, "prisma_client", prisma_client) - setattr(litellm.proxy.proxy_server, "master_key", "sk-1234") - await litellm.proxy.proxy_server.prisma_client.connect() - import uuid - - # generate new key - key_alias = f"test_alias_regenerate_key-{uuid.uuid4()}" - spend = 100 - max_budget = 400 - models = ["fake-openai-endpoint"] - new_key = await generate_key_fn( - data=GenerateKeyRequest( - key_alias=key_alias, spend=spend, max_budget=max_budget, models=models - ), - user_api_key_dict=UserAPIKeyAuth( - user_role=LitellmUserRoles.PROXY_ADMIN, - api_key="sk-1234", - user_id="1234", - ), - ) - - generated_key = new_key.key - print(generated_key) - - # assert the new key works as expected - request = Request(scope={"type": "http"}) - request._url = URL(url="/chat/completions") - - async def return_body(): - return_string = f'{{"model": "fake-openai-endpoint"}}' - # return string as bytes - return return_string.encode() - - request.body = return_body - result = await user_api_key_auth(request=request, api_key=f"Bearer {generated_key}") - print(result) - - # regenerate the key - new_key = await regenerate_key_fn( - key=generated_key, - user_api_key_dict=UserAPIKeyAuth( - user_role=LitellmUserRoles.PROXY_ADMIN, - api_key="sk-1234", - user_id="1234", - ), - ) - print("response from regenerate_key_fn", new_key) - - # assert the new key works as expected - request = Request(scope={"type": "http"}) - request._url = URL(url="/chat/completions") - - async def return_body_2(): - return_string = f'{{"model": "fake-openai-endpoint"}}' - # return string as bytes - return return_string.encode() - - request.body = return_body_2 - result = await user_api_key_auth(request=request, api_key=f"Bearer {new_key.key}") - print(result) - - # assert the old key stops working - request = Request(scope={"type": "http"}) - request._url = URL(url="/chat/completions") - - async def return_body_3(): - return_string = f'{{"model": "fake-openai-endpoint"}}' - # return string as bytes - return return_string.encode() - - request.body = return_body_3 - try: - result = await user_api_key_auth( - request=request, api_key=f"Bearer {generated_key}" - ) - print(result) - pytest.fail(f"This should have failed!. the key has been regenerated") - except Exception as e: - print("got expected exception", e) - assert "Invalid proxy server token passed" in e.message - - # Check that the regenerated key has the same spend, max_budget, models and key_alias - assert new_key.spend == spend, f"Expected spend {spend} but got {new_key.spend}" - assert ( - new_key.max_budget == max_budget - ), f"Expected max_budget {max_budget} but got {new_key.max_budget}" - assert ( - new_key.key_alias == key_alias - ), f"Expected key_alias {key_alias} but got {new_key.key_alias}" - assert ( - new_key.models == models - ), f"Expected models {models} but got {new_key.models}" - - assert new_key.key_name == f"sk-...{new_key.key[-4:]}" - - pass - - @pytest.mark.asyncio() async def test_team_tags(prisma_client): """ diff --git a/tests/proxy_admin_ui_tests/test_key_management.py b/tests/proxy_admin_ui_tests/test_key_management.py new file mode 100644 index 000000000..ddc3adcc8 --- /dev/null +++ b/tests/proxy_admin_ui_tests/test_key_management.py @@ -0,0 +1,271 @@ +import os +import sys +import traceback +import uuid +import datetime as dt +from datetime import datetime + +from dotenv import load_dotenv +from fastapi import Request +from fastapi.routing import APIRoute + +load_dotenv() +import io +import os +import time + +# this file is to test litellm/proxy + +sys.path.insert( + 0, os.path.abspath("../..") +) # Adds the parent directory to the system path +import asyncio +import logging + +import pytest + +import litellm +from litellm._logging import verbose_proxy_logger +from litellm.proxy.management_endpoints.internal_user_endpoints import ( + new_user, + user_info, + user_update, +) +from litellm.proxy.management_endpoints.key_management_endpoints import ( + delete_key_fn, + generate_key_fn, + generate_key_helper_fn, + info_key_fn, + regenerate_key_fn, + update_key_fn, +) +from litellm.proxy.management_endpoints.team_endpoints import ( + new_team, + team_info, + update_team, +) +from litellm.proxy.proxy_server import ( + LitellmUserRoles, + audio_transcriptions, + chat_completion, + completion, + embeddings, + image_generation, + model_list, + moderations, + new_end_user, + user_api_key_auth, +) +from litellm.proxy.spend_tracking.spend_management_endpoints import ( + global_spend, + global_spend_logs, + global_spend_models, + global_spend_keys, + spend_key_fn, + spend_user_fn, + view_spend_logs, +) +from litellm.proxy.utils import PrismaClient, ProxyLogging, hash_token, update_spend + +verbose_proxy_logger.setLevel(level=logging.DEBUG) + +from starlette.datastructures import URL + +from litellm.caching import DualCache +from litellm.proxy._types import ( + DynamoDBArgs, + GenerateKeyRequest, + KeyRequest, + LiteLLM_UpperboundKeyGenerateParams, + NewCustomerRequest, + NewTeamRequest, + NewUserRequest, + ProxyErrorTypes, + ProxyException, + UpdateKeyRequest, + RegenerateKeyRequest, + UpdateTeamRequest, + UpdateUserRequest, + UserAPIKeyAuth, +) +from litellm.proxy.utils import DBClient + +proxy_logging_obj = ProxyLogging(user_api_key_cache=DualCache()) + + +@pytest.fixture +def prisma_client(): + from litellm.proxy.proxy_cli import append_query_params + + ### add connection pool + pool timeout args + params = {"connection_limit": 100, "pool_timeout": 60} + database_url = os.getenv("DATABASE_URL") + modified_url = append_query_params(database_url, params) + os.environ["DATABASE_URL"] = modified_url + + # Assuming DBClient is a class that needs to be instantiated + prisma_client = PrismaClient( + database_url=os.environ["DATABASE_URL"], proxy_logging_obj=proxy_logging_obj + ) + + # Reset litellm.proxy.proxy_server.prisma_client to None + litellm.proxy.proxy_server.custom_db_client = None + litellm.proxy.proxy_server.litellm_proxy_budget_name = ( + f"litellm-proxy-budget-{time.time()}" + ) + litellm.proxy.proxy_server.user_custom_key_generate = None + + return prisma_client + + +################ Unit Tests for testing regeneration of keys ########### +@pytest.mark.asyncio() +async def test_regenerate_api_key(prisma_client): + litellm.set_verbose = True + setattr(litellm.proxy.proxy_server, "prisma_client", prisma_client) + setattr(litellm.proxy.proxy_server, "master_key", "sk-1234") + await litellm.proxy.proxy_server.prisma_client.connect() + import uuid + + # generate new key + key_alias = f"test_alias_regenerate_key-{uuid.uuid4()}" + spend = 100 + max_budget = 400 + models = ["fake-openai-endpoint"] + new_key = await generate_key_fn( + data=GenerateKeyRequest( + key_alias=key_alias, spend=spend, max_budget=max_budget, models=models + ), + user_api_key_dict=UserAPIKeyAuth( + user_role=LitellmUserRoles.PROXY_ADMIN, + api_key="sk-1234", + user_id="1234", + ), + ) + + generated_key = new_key.key + print(generated_key) + + # assert the new key works as expected + request = Request(scope={"type": "http"}) + request._url = URL(url="/chat/completions") + + async def return_body(): + return_string = f'{{"model": "fake-openai-endpoint"}}' + # return string as bytes + return return_string.encode() + + request.body = return_body + result = await user_api_key_auth(request=request, api_key=f"Bearer {generated_key}") + print(result) + + # regenerate the key + new_key = await regenerate_key_fn( + key=generated_key, + user_api_key_dict=UserAPIKeyAuth( + user_role=LitellmUserRoles.PROXY_ADMIN, + api_key="sk-1234", + user_id="1234", + ), + ) + print("response from regenerate_key_fn", new_key) + + # assert the new key works as expected + request = Request(scope={"type": "http"}) + request._url = URL(url="/chat/completions") + + async def return_body_2(): + return_string = f'{{"model": "fake-openai-endpoint"}}' + # return string as bytes + return return_string.encode() + + request.body = return_body_2 + result = await user_api_key_auth(request=request, api_key=f"Bearer {new_key.key}") + print(result) + + # assert the old key stops working + request = Request(scope={"type": "http"}) + request._url = URL(url="/chat/completions") + + async def return_body_3(): + return_string = f'{{"model": "fake-openai-endpoint"}}' + # return string as bytes + return return_string.encode() + + request.body = return_body_3 + try: + result = await user_api_key_auth( + request=request, api_key=f"Bearer {generated_key}" + ) + print(result) + pytest.fail(f"This should have failed!. the key has been regenerated") + except Exception as e: + print("got expected exception", e) + assert "Invalid proxy server token passed" in e.message + + # Check that the regenerated key has the same spend, max_budget, models and key_alias + assert new_key.spend == spend, f"Expected spend {spend} but got {new_key.spend}" + assert ( + new_key.max_budget == max_budget + ), f"Expected max_budget {max_budget} but got {new_key.max_budget}" + assert ( + new_key.key_alias == key_alias + ), f"Expected key_alias {key_alias} but got {new_key.key_alias}" + assert ( + new_key.models == models + ), f"Expected models {models} but got {new_key.models}" + + assert new_key.key_name == f"sk-...{new_key.key[-4:]}" + + pass + + +@pytest.mark.asyncio() +async def test_regenerate_api_key_with_new_alias_and_expiration(prisma_client): + litellm.set_verbose = True + setattr(litellm.proxy.proxy_server, "prisma_client", prisma_client) + setattr(litellm.proxy.proxy_server, "master_key", "sk-1234") + await litellm.proxy.proxy_server.prisma_client.connect() + import uuid + + # generate new key + key_alias = f"test_alias_regenerate_key-{uuid.uuid4()}" + spend = 100 + max_budget = 400 + models = ["fake-openai-endpoint"] + new_key = await generate_key_fn( + data=GenerateKeyRequest( + key_alias=key_alias, spend=spend, max_budget=max_budget, models=models + ), + user_api_key_dict=UserAPIKeyAuth( + user_role=LitellmUserRoles.PROXY_ADMIN, + api_key="sk-1234", + user_id="1234", + ), + ) + + generated_key = new_key.key + print(generated_key) + + # regenerate the key with new alias and expiration + new_key = await regenerate_key_fn( + key=generated_key, + data=RegenerateKeyRequest( + key_alias="very_new_alias", + duration="30d", + ), + user_api_key_dict=UserAPIKeyAuth( + user_role=LitellmUserRoles.PROXY_ADMIN, + api_key="sk-1234", + user_id="1234", + ), + ) + print("response from regenerate_key_fn", new_key) + + # assert the alias and duration are updated + assert new_key.key_alias == "very_new_alias" + + # assert the new key expires 30 days from now + now = datetime.now(dt.timezone.utc) + assert new_key.expires > now + dt.timedelta(days=29) + assert new_key.expires < now + dt.timedelta(days=31)