diff --git a/litellm/proxy/proxy_cli.py b/litellm/proxy/proxy_cli.py index 890cf5294..498290094 100644 --- a/litellm/proxy/proxy_cli.py +++ b/litellm/proxy/proxy_cli.py @@ -5,6 +5,7 @@ import random from datetime import datetime import importlib from dotenv import load_dotenv +import urllib.parse as urlparse sys.path.append(os.getcwd()) @@ -17,6 +18,15 @@ import shutil telemetry = None +def append_query_params(url, params): + parsed_url = urlparse.urlparse(url) + parsed_query = urlparse.parse_qs(parsed_url.query) + parsed_query.update(params) + encoded_query = urlparse.urlencode(parsed_query, doseq=True) + modified_url = urlparse.urlunparse(parsed_url._replace(query=encoded_query)) + return modified_url + + def run_ollama_serve(): try: command = ["ollama", "serve"] @@ -416,6 +426,12 @@ def run_server( if os.getenv("DATABASE_URL", None) is not None: try: + ### add connection pool + pool timeout args + params = {"connection_limit": 500, "pool_timeout": 60} + database_url = os.getenv("DATABASE_URL") + modified_url = append_query_params(database_url, params) + os.environ["DATABASE_URL"] = modified_url + ### subprocess.run(["prisma"], capture_output=True) is_prisma_runnable = True except FileNotFoundError: @@ -522,6 +538,5 @@ def run_server( ).run() # Run gunicorn - if __name__ == "__main__": run_server() diff --git a/litellm/proxy/proxy_server.py b/litellm/proxy/proxy_server.py index 5e016c46b..ca1c26f7c 100644 --- a/litellm/proxy/proxy_server.py +++ b/litellm/proxy/proxy_server.py @@ -841,6 +841,10 @@ async def update_database( f"Enters prisma db call, response_cost: {response_cost}, token: {token}; user_id: {user_id}" ) + ### [TODO] STEP 1: GET KEY + USER SPEND ### (key, user) + + ### [TODO] STEP 2: UPDATE SPEND ### (key, user, spend logs) + ### UPDATE USER SPEND ### async def _update_user_db(): """ @@ -1922,7 +1926,7 @@ async def startup_event(): ### START BUDGET SCHEDULER ### scheduler = AsyncIOScheduler() interval = random.randint( - 7, 14 + 597, 605 ) # random interval, so multiple workers avoid resetting budget at the same time scheduler.add_job(reset_budget, "interval", seconds=interval, args=[prisma_client]) scheduler.start() diff --git a/litellm/proxy/utils.py b/litellm/proxy/utils.py index 84b09d726..178403e44 100644 --- a/litellm/proxy/utils.py +++ b/litellm/proxy/utils.py @@ -391,13 +391,7 @@ class PrismaClient: # Now you can import the Prisma Client from prisma import Prisma # type: ignore - self.db = Prisma( - http={ - "limits": httpx.Limits( - max_connections=1000, max_keepalive_connections=100 - ) - } - ) # Client to connect to Prisma db + self.db = Prisma() # Client to connect to Prisma db def hash_token(self, token: str): # Hash the string using SHA-256 diff --git a/tests/test_keys.py b/tests/test_keys.py index da7dc19b4..28ce02511 100644 --- a/tests/test_keys.py +++ b/tests/test_keys.py @@ -431,9 +431,9 @@ async def test_key_info_spend_values_image_generation(): @pytest.mark.asyncio async def test_key_with_budgets(): """ - - Create key with budget and 5s duration + - Create key with budget and 5min duration - Get 'reset_at' value - - wait 5s + - wait 10min (budget reset runs every 10mins.) - Check if value updated """ from litellm.proxy.utils import hash_token @@ -449,8 +449,8 @@ async def test_key_with_budgets(): reset_at_init_value = key_info["info"]["budget_reset_at"] reset_at_new_value = None i = 0 + await asyncio.sleep(610) while i < 3: - await asyncio.sleep(30) key_info = await get_key_info(session=session, get_key=key, call_key=key) reset_at_new_value = key_info["info"]["budget_reset_at"] try: @@ -458,6 +458,7 @@ async def test_key_with_budgets(): break except: i + 1 + await asyncio.sleep(5) assert reset_at_init_value != reset_at_new_value @@ -481,7 +482,7 @@ async def test_key_crossing_budget(): response = await chat_completion(session=session, key=key) print("response 1: ", response) - await asyncio.sleep(2) + await asyncio.sleep(10) try: response = await chat_completion(session=session, key=key) pytest.fail("Should have failed - Key crossed it's budget") diff --git a/tests/test_spend_logs.py b/tests/test_spend_logs.py index 6db0e301e..4d7ad175f 100644 --- a/tests/test_spend_logs.py +++ b/tests/test_spend_logs.py @@ -113,6 +113,7 @@ async def test_spend_logs(): await get_spend_logs(session=session, request_id=response["id"]) +@pytest.mark.skip(reason="High traffic load test, meant to be run locally") @pytest.mark.asyncio async def test_spend_logs_high_traffic(): """ @@ -155,9 +156,12 @@ async def test_spend_logs_high_traffic(): successful_completions = [c for c in chat_completions if c is not None] print(f"Num successful completions: {len(successful_completions)}") await asyncio.sleep(10) - response = await get_spend_logs(session=session, api_key=key) - print(f"response: {response}") - print(f"len responses: {len(response)}") - assert len(response) == n - print(n, time.time() - start, len(response)) + try: + response = await retry_request(get_spend_logs, session=session, api_key=key) + print(f"response: {response}") + print(f"len responses: {len(response)}") + assert len(response) == n + print(n, time.time() - start, len(response)) + except: + print(n, time.time() - start, 0) raise Exception("it worked!")