import os import sys import traceback import uuid from datetime import datetime from dotenv import load_dotenv from fastapi import Request from fastapi.routing import APIRoute import io import os import time # this file is to test litellm/proxy sys.path.insert( 0, os.path.abspath("../..") ) # Adds the parent directory to the system path import asyncio import logging load_dotenv() import pytest import uuid import litellm from litellm._logging import verbose_proxy_logger from litellm.proxy.proxy_server import ( LitellmUserRoles, audio_transcriptions, chat_completion, completion, embeddings, image_generation, model_list, moderations, user_api_key_auth, ) from litellm.proxy.utils import PrismaClient, ProxyLogging, hash_token, update_spend verbose_proxy_logger.setLevel(level=logging.DEBUG) from starlette.datastructures import URL from litellm.proxy.management_helpers.audit_logs import create_audit_log_for_update from litellm.proxy._types import LiteLLM_AuditLogs, LitellmTableNames from litellm.caching.caching import DualCache from unittest.mock import patch, AsyncMock proxy_logging_obj = ProxyLogging(user_api_key_cache=DualCache()) import json @pytest.mark.asyncio async def test_create_audit_log_for_update_premium_user(): """ Basic unit test for create_audit_log_for_update Test that the audit log is created when a premium user updates a team """ with patch("litellm.proxy.proxy_server.premium_user", True), patch( "litellm.store_audit_logs", True ), patch("litellm.proxy.proxy_server.prisma_client") as mock_prisma: mock_prisma.db.litellm_auditlog.create = AsyncMock() request_data = LiteLLM_AuditLogs( id="test_id", updated_at=datetime.now(), changed_by="test_changed_by", action="updated", table_name=LitellmTableNames.TEAM_TABLE_NAME, object_id="test_object_id", updated_values=json.dumps({"key": "value"}), before_value=json.dumps({"old_key": "old_value"}), ) await create_audit_log_for_update(request_data) mock_prisma.db.litellm_auditlog.create.assert_called_once_with( data={ "id": "test_id", "updated_at": request_data.updated_at, "changed_by": request_data.changed_by, "action": request_data.action, "table_name": request_data.table_name, "object_id": request_data.object_id, "updated_values": request_data.updated_values, "before_value": request_data.before_value, } ) @pytest.fixture def prisma_client(): from litellm.proxy.proxy_cli import append_query_params ### add connection pool + pool timeout args params = {"connection_limit": 100, "pool_timeout": 60} database_url = os.getenv("DATABASE_URL") modified_url = append_query_params(database_url, params) os.environ["DATABASE_URL"] = modified_url # Assuming PrismaClient is a class that needs to be instantiated prisma_client = PrismaClient( database_url=os.environ["DATABASE_URL"], proxy_logging_obj=proxy_logging_obj ) return prisma_client @pytest.mark.asyncio() async def test_create_audit_log_in_db(prisma_client): print("prisma client=", prisma_client) setattr(litellm.proxy.proxy_server, "prisma_client", prisma_client) setattr(litellm.proxy.proxy_server, "master_key", "sk-1234") setattr(litellm.proxy.proxy_server, "premium_user", True) setattr(litellm, "store_audit_logs", True) await litellm.proxy.proxy_server.prisma_client.connect() audit_log_id = f"audit_log_id_{uuid.uuid4()}" # create a audit log for /key/generate request_data = LiteLLM_AuditLogs( id=audit_log_id, updated_at=datetime.now(), changed_by="test_changed_by", action="updated", table_name=LitellmTableNames.TEAM_TABLE_NAME, object_id="test_object_id", updated_values=json.dumps({"key": "value"}), before_value=json.dumps({"old_key": "old_value"}), ) await create_audit_log_for_update(request_data) await asyncio.sleep(1) # now read the last log from the db last_log = await prisma_client.db.litellm_auditlog.find_first( where={"id": audit_log_id} ) assert last_log.id == audit_log_id setattr(litellm, "store_audit_logs", False)