From 204ddfc785ec9845d9bfbf3af0ed62ccef5a9db4 Mon Sep 17 00:00:00 2001 From: Ishaan Jaff Date: Wed, 13 Nov 2024 11:02:11 -0800 Subject: [PATCH] us event hooks for key management endpoints --- .../proxy/hooks/key_management_event_hooks.py | 111 +++++++++++++++++- .../key_management_endpoints.py | 77 +++--------- 2 files changed, 124 insertions(+), 64 deletions(-) diff --git a/litellm/proxy/hooks/key_management_event_hooks.py b/litellm/proxy/hooks/key_management_event_hooks.py index ae0a80162..0977b6dd5 100644 --- a/litellm/proxy/hooks/key_management_event_hooks.py +++ b/litellm/proxy/hooks/key_management_event_hooks.py @@ -2,13 +2,20 @@ import asyncio import json import uuid from datetime import datetime, timezone -from typing import Optional +from re import A +from typing import Any, Optional + +from fastapi import status import litellm from litellm.proxy._types import ( GenerateKeyRequest, + KeyRequest, LiteLLM_AuditLogs, LitellmTableNames, + ProxyErrorTypes, + ProxyException, + UpdateKeyRequest, UserAPIKeyAuth, WebhookEvent, ) @@ -64,6 +71,108 @@ class KeyManagementEventHooks: ) ) + @staticmethod + async def async_key_updated_hook( + data: UpdateKeyRequest, + existing_key_row: Any, + response: Any, + user_api_key_dict: UserAPIKeyAuth, + litellm_changed_by: Optional[str] = None, + ): + """ + Post /key/update processing hook + + Handles the following: + - Storing Audit Logs for key update + """ + from litellm.proxy.management_helpers.audit_logs import ( + create_audit_log_for_update, + ) + from litellm.proxy.proxy_server import litellm_proxy_admin_name + + # Enterprise Feature - Audit Logging. Enable with litellm.store_audit_logs = True + if litellm.store_audit_logs is True: + _updated_values = json.dumps(data.json(exclude_none=True), default=str) + + _before_value = existing_key_row.json(exclude_none=True) + _before_value = json.dumps(_before_value, default=str) + + asyncio.create_task( + create_audit_log_for_update( + request_data=LiteLLM_AuditLogs( + id=str(uuid.uuid4()), + updated_at=datetime.now(timezone.utc), + changed_by=litellm_changed_by + or user_api_key_dict.user_id + or litellm_proxy_admin_name, + changed_by_api_key=user_api_key_dict.api_key, + table_name=LitellmTableNames.KEY_TABLE_NAME, + object_id=data.key, + action="updated", + updated_values=_updated_values, + before_value=_before_value, + ) + ) + ) + pass + + @staticmethod + async def async_key_deleted_hook( + data: KeyRequest, + response: dict, + user_api_key_dict: UserAPIKeyAuth, + litellm_changed_by: Optional[str] = None, + ): + """ + Post /key/delete processing hook + + Handles the following: + - Storing Audit Logs for key deletion + """ + from litellm.proxy.management_helpers.audit_logs import ( + create_audit_log_for_update, + ) + from litellm.proxy.proxy_server import litellm_proxy_admin_name, prisma_client + + # Enterprise Feature - Audit Logging. Enable with litellm.store_audit_logs = True + # we do this after the first for loop, since first for loop is for validation. we only want this inserted after validation passes + if litellm.store_audit_logs is True: + # make an audit log for each team deleted + for key in data.keys: + key_row = await prisma_client.get_data( # type: ignore + token=key, table_name="key", query_type="find_unique" + ) + + if key_row is None: + raise ProxyException( + message=f"Key {key} not found", + type=ProxyErrorTypes.bad_request_error, + param="key", + code=status.HTTP_404_NOT_FOUND, + ) + + key_row = key_row.json(exclude_none=True) + _key_row = json.dumps(key_row, default=str) + + asyncio.create_task( + create_audit_log_for_update( + request_data=LiteLLM_AuditLogs( + id=str(uuid.uuid4()), + updated_at=datetime.now(timezone.utc), + changed_by=litellm_changed_by + or user_api_key_dict.user_id + or litellm_proxy_admin_name, + changed_by_api_key=user_api_key_dict.api_key, + table_name=LitellmTableNames.KEY_TABLE_NAME, + object_id=key, + action="deleted", + updated_values="{}", + before_value=_key_row, + ) + ) + ) + pass + @staticmethod async def _send_key_created_email(response: dict): from litellm.proxy.proxy_server import general_settings, proxy_logging_obj diff --git a/litellm/proxy/management_endpoints/key_management_endpoints.py b/litellm/proxy/management_endpoints/key_management_endpoints.py index 4f4bb38b3..1dcd9e800 100644 --- a/litellm/proxy/management_endpoints/key_management_endpoints.py +++ b/litellm/proxy/management_endpoints/key_management_endpoints.py @@ -374,30 +374,13 @@ async def update_key_fn( proxy_logging_obj=proxy_logging_obj, ) - # Enterprise Feature - Audit Logging. Enable with litellm.store_audit_logs = True - if litellm.store_audit_logs is True: - _updated_values = json.dumps(data_json, default=str) - - _before_value = existing_key_row.json(exclude_none=True) - _before_value = json.dumps(_before_value, default=str) - - asyncio.create_task( - create_audit_log_for_update( - request_data=LiteLLM_AuditLogs( - id=str(uuid.uuid4()), - updated_at=datetime.now(timezone.utc), - changed_by=litellm_changed_by - or user_api_key_dict.user_id - or litellm_proxy_admin_name, - changed_by_api_key=user_api_key_dict.api_key, - table_name=LitellmTableNames.KEY_TABLE_NAME, - object_id=data.key, - action="updated", - updated_values=_updated_values, - before_value=_before_value, - ) - ) - ) + await KeyManagementEventHooks.async_key_updated_hook( + data=data, + existing_key_row=existing_key_row, + response=response, + user_api_key_dict=user_api_key_dict, + litellm_changed_by=litellm_changed_by, + ) if response is None: raise ValueError("Failed to update key got response = None") @@ -482,45 +465,6 @@ async def delete_key_fn( and user_api_key_dict.user_role == LitellmUserRoles.PROXY_ADMIN ): user_id = None # unless they're admin - - # Enterprise Feature - Audit Logging. Enable with litellm.store_audit_logs = True - # we do this after the first for loop, since first for loop is for validation. we only want this inserted after validation passes - if litellm.store_audit_logs is True: - # make an audit log for each team deleted - for key in data.keys: - key_row = await prisma_client.get_data( # type: ignore - token=key, table_name="key", query_type="find_unique" - ) - - if key_row is None: - raise ProxyException( - message=f"Key {key} not found", - type=ProxyErrorTypes.bad_request_error, - param="key", - code=status.HTTP_404_NOT_FOUND, - ) - - key_row = key_row.json(exclude_none=True) - _key_row = json.dumps(key_row, default=str) - - asyncio.create_task( - create_audit_log_for_update( - request_data=LiteLLM_AuditLogs( - id=str(uuid.uuid4()), - updated_at=datetime.now(timezone.utc), - changed_by=litellm_changed_by - or user_api_key_dict.user_id - or litellm_proxy_admin_name, - changed_by_api_key=user_api_key_dict.api_key, - table_name=LitellmTableNames.KEY_TABLE_NAME, - object_id=key, - action="deleted", - updated_values="{}", - before_value=_key_row, - ) - ) - ) - number_deleted_keys = await delete_verification_token( tokens=keys, user_id=user_id ) @@ -555,6 +499,13 @@ async def delete_key_fn( f"/keys/delete - cache after delete: {user_api_key_cache.in_memory_cache.cache_dict}" ) + await KeyManagementEventHooks.async_key_deleted_hook( + data=data, + user_api_key_dict=user_api_key_dict, + litellm_changed_by=litellm_changed_by, + response=number_deleted_keys, + ) + return {"deleted_keys": keys} except Exception as e: if isinstance(e, HTTPException):