From a1f6b9b531a43546d47f142aec1a269320abc107 Mon Sep 17 00:00:00 2001 From: Krrish Dholakia Date: Tue, 21 Nov 2023 13:47:00 -0800 Subject: [PATCH] refactor(proxy_server.py): refactoring background rq worker --- litellm/__init__.py | 2 +- litellm/proxy/proxy_server.py | 31 +++++++++++++++++++++---------- litellm/proxy/queue/rq.py | 32 -------------------------------- litellm/proxy/queue/rq_worker.py | 30 ++++++++++++++++++++++++++++++ 4 files changed, 52 insertions(+), 43 deletions(-) delete mode 100644 litellm/proxy/queue/rq.py create mode 100644 litellm/proxy/queue/rq_worker.py diff --git a/litellm/__init__.py b/litellm/__init__.py index 06101a522..6ada6480d 100644 --- a/litellm/__init__.py +++ b/litellm/__init__.py @@ -393,5 +393,5 @@ from .exceptions import ( ) from .budget_manager import BudgetManager from .proxy.proxy_cli import run_server -from .proxy.queue.rq import start_rq_worker_in_background from .router import Router +from .proxy.proxy_server import litellm_queue_completion diff --git a/litellm/proxy/proxy_server.py b/litellm/proxy/proxy_server.py index 9c3831732..383ee7cfb 100644 --- a/litellm/proxy/proxy_server.py +++ b/litellm/proxy/proxy_server.py @@ -98,7 +98,7 @@ from fastapi.middleware.cors import CORSMiddleware from fastapi.security import OAuth2PasswordBearer import json import logging -from litellm import start_rq_worker_in_background +# from litellm.proxy.queue import start_rq_worker_in_background app = FastAPI(docs_url="/", title="LiteLLM API") router = APIRouter() @@ -149,6 +149,12 @@ def print_verbose(print_statement): print(print_statement) +def litellm_queue_completion(*args, **kwargs): + call_type = kwargs.pop("call_type") + llm_router: litellm.Router = kwargs.pop("llm_router") + + return llm_router.completion(**kwargs) + def usage_telemetry( feature: str, ): # helps us know if people are using this feature. Set `litellm --telemetry False` to your cli call to turn this off @@ -215,7 +221,7 @@ def rq_setup(use_queue: bool): from rq.job import Job redis_job = Job - start_rq_worker_in_background() + # start_rq_worker_in_background() redis_connection = Redis(host=os.getenv("REDIS_HOST"), port=os.getenv("REDIS_PORT"), password=os.getenv("REDIS_PASSWORD")) request_queue = Queue(connection=redis_connection) @@ -387,10 +393,11 @@ def initialize( if experimental: pass if save: - save_params_to_config(dynamic_config) - with open(user_config_path) as f: - print(f.read()) - print("\033[1;32mDone successfully\033[0m") + pass + # save_params_to_config(dynamic_config) + # with open(user_config_path) as f: + # print(f.read()) + # print("\033[1;32mDone successfully\033[0m") user_telemetry = telemetry usage_telemetry(feature="local_proxy_server") @@ -435,7 +442,7 @@ def litellm_completion(*args, **kwargs): if 'stream' in kwargs and kwargs['stream'] == True: # use generate_responses to stream responses return StreamingResponse(data_generator(response), media_type='text/event-stream') return response - + @app.on_event("startup") async def startup_event(): global prisma_client @@ -575,7 +582,7 @@ async def generate_key_fn(request: Request): @router.post("/queue/request", dependencies=[Depends(user_api_key_auth)]) async def async_chat_completions(request: Request): - global request_queue + global request_queue, llm_router body = await request.body() body_str = body.decode() try: @@ -588,7 +595,8 @@ async def async_chat_completions(request: Request): or data["model"] # default passed in http request ) data["call_type"] = "chat_completion" - job = request_queue.enqueue(litellm_completion, **data) + data["llm_router"] = llm_router + job = request_queue.enqueue(litellm.litellm_queue_completion, **data) return {"id": job.id, "url": f"/queue/response/{job.id}", "eta": 5, "status": "queued"} pass @@ -597,8 +605,11 @@ async def async_chat_completions(request: Request, task_id: str): global redis_connection, redis_job try: job = redis_job.fetch(id=task_id, connection=redis_connection) - print(f"job status: {job.get_status()}") result = job.result + status = job.get_status() + print(f"job status: {status}; job result: {result}") + if status == "failed": + print(f"job: {job.exc_info}") status = "queued" if result is not None: status = "finished" diff --git a/litellm/proxy/queue/rq.py b/litellm/proxy/queue/rq.py deleted file mode 100644 index a76b279c1..000000000 --- a/litellm/proxy/queue/rq.py +++ /dev/null @@ -1,32 +0,0 @@ -import os -import subprocess -import sys -import multiprocessing -from dotenv import load_dotenv -load_dotenv() - -def run_rq_worker(redis_url): - command = ["rq", "worker", "--url", redis_url] - subprocess.run(command) - -def start_rq_worker_in_background(): - # Set OBJC_DISABLE_INITIALIZE_FORK_SAFETY to YES - os.environ["OBJC_DISABLE_INITIALIZE_FORK_SAFETY"] = "YES" - - # Check if required environment variables are set - required_vars = ["REDIS_USERNAME", "REDIS_PASSWORD", "REDIS_HOST", "REDIS_PORT"] - missing_vars = [var for var in required_vars if var not in os.environ] - - if missing_vars: - print(f"Error: Redis environment variables not set. Please set {', '.join(missing_vars)}.") - sys.exit(1) - - # Construct Redis URL - REDIS_URL = f"redis://{os.environ['REDIS_USERNAME']}:{os.environ['REDIS_PASSWORD']}@{os.environ['REDIS_HOST']}:{os.environ['REDIS_PORT']}" - - # Run rq worker in a separate process - worker_process = multiprocessing.Process(target=run_rq_worker, args=(REDIS_URL,)) - worker_process.start() - -if __name__ == "__main__": - start_rq_worker_in_background() diff --git a/litellm/proxy/queue/rq_worker.py b/litellm/proxy/queue/rq_worker.py new file mode 100644 index 000000000..86e1ebf76 --- /dev/null +++ b/litellm/proxy/queue/rq_worker.py @@ -0,0 +1,30 @@ +import sys, os +from rq import Worker, Queue, Connection +from redis import Redis +from dotenv import load_dotenv +load_dotenv() +# Add the path to the local folder to sys.path +sys.path.insert( + 0, os.path.abspath("../../..") +) # Adds the parent directory to the system path - for litellm local dev + + +# # Import your local module +# import litellm +# from litellm import litellm_queue_completion + +# Set up RQ connection +redis_conn = Redis(host=os.getenv("REDIS_HOST"), port=os.getenv("REDIS_PORT"), password=os.getenv("REDIS_PASSWORD")) +print(redis_conn.ping()) # Should print True if connected successfully +# Create a worker and add the queue +try: + queue = Queue(connection=redis_conn) + worker = Worker([queue], connection=redis_conn) +except Exception as e: + print(f"Error setting up worker: {e}") + exit() + +# Run the worker +if __name__ == '__main__': + with Connection(redis_conn): + worker.work() \ No newline at end of file