diff --git a/litellm/__init__.py b/litellm/__init__.py index 78b40bcc5..06101a522 100644 --- a/litellm/__init__.py +++ b/litellm/__init__.py @@ -393,4 +393,5 @@ from .exceptions import ( ) from .budget_manager import BudgetManager from .proxy.proxy_cli import run_server +from .proxy.queue.rq import start_rq_worker_in_background from .router import Router diff --git a/litellm/proxy/proxy_server.py b/litellm/proxy/proxy_server.py index a7bf2690b..7f6acfe44 100644 --- a/litellm/proxy/proxy_server.py +++ b/litellm/proxy/proxy_server.py @@ -96,6 +96,7 @@ from fastapi.middleware.cors import CORSMiddleware from fastapi.security import OAuth2PasswordBearer import json import logging +from litellm import start_rq_worker_in_background app = FastAPI(docs_url="/", title="LiteLLM API") router = APIRouter() @@ -135,6 +136,10 @@ log_file = "api_log.json" worker_config = None master_key = None prisma_client = None +### REDIS QUEUE ### +redis_job = None +redis_connection = None +request_queue = None # Redis Queue for handling requests #### HELPER FUNCTIONS #### def print_verbose(print_statement): global user_debug @@ -199,6 +204,19 @@ def prisma_setup(database_url: Optional[str]): from prisma import Client prisma_client = Client() +def rq_setup(use_queue: bool): + global request_queue, redis_connection, redis_job + print(f"value of use_queue: {use_queue}") + if use_queue: + from redis import Redis + from rq import Queue + from rq.job import Job + + redis_job = Job + start_rq_worker_in_background() + redis_connection = Redis(host=os.getenv("REDIS_HOST"), port=os.getenv("REDIS_PORT"), password=os.getenv("REDIS_PASSWORD")) + request_queue = Queue(connection=redis_connection) + def run_ollama_serve(): command = ['ollama', 'serve'] @@ -234,7 +252,9 @@ def load_router_config(router: Optional[litellm.Router], config_file_path: str): ### CONNECT TO DATABASE ### database_url = general_settings.get("database_url", None) prisma_setup(database_url=database_url) - + ### START REDIS QUEUE ### + use_queue = general_settings.get("use_queue", False) + rq_setup(use_queue=use_queue) ## LITELLM MODULE SETTINGS (e.g. litellm.drop_params=True,..) litellm_settings = config.get('litellm_settings', None) @@ -551,6 +571,33 @@ async def generate_key_fn(request: Request): detail={"error": "models param must be a list"}, ) +@router.post("/queue/chat/completions", dependencies=[Depends(user_api_key_auth)]) +async def async_chat_completions(request: Request): + global request_queue + body = await request.body() + body_str = body.decode() + try: + data = ast.literal_eval(body_str) + except: + data = json.loads(body_str) + data["model"] = ( + server_settings.get("completion_model", None) # server default + or user_model # model name passed via cli args + or data["model"] # default passed in http request + ) + data["call_type"] = "chat_completion" + job = request_queue.enqueue(litellm_completion, **data) + return {"id": job.id, "url": f"/queue/chat/completions/{job.id}", "eta": 5, "status": "queued"} + pass + +@router.get("/queue/response/{task_id}", dependencies=[Depends(user_api_key_auth)]) +async def async_chat_completions(request: Request, task_id: str): + global redis_connection, redis_job + job = redis_job.fetch(id=task_id, connection=redis_connection) + print(f"job status: {job.get_status()}") + result = job.result + return {"status": job.get_status(), "result": result} + @router.get("/ollama_logs", dependencies=[Depends(user_api_key_auth)]) async def retrieve_server_log(request: Request): diff --git a/litellm/proxy/queue/rq.py b/litellm/proxy/queue/rq.py new file mode 100644 index 000000000..a76b279c1 --- /dev/null +++ b/litellm/proxy/queue/rq.py @@ -0,0 +1,32 @@ +import os +import subprocess +import sys +import multiprocessing +from dotenv import load_dotenv +load_dotenv() + +def run_rq_worker(redis_url): + command = ["rq", "worker", "--url", redis_url] + subprocess.run(command) + +def start_rq_worker_in_background(): + # Set OBJC_DISABLE_INITIALIZE_FORK_SAFETY to YES + os.environ["OBJC_DISABLE_INITIALIZE_FORK_SAFETY"] = "YES" + + # Check if required environment variables are set + required_vars = ["REDIS_USERNAME", "REDIS_PASSWORD", "REDIS_HOST", "REDIS_PORT"] + missing_vars = [var for var in required_vars if var not in os.environ] + + if missing_vars: + print(f"Error: Redis environment variables not set. Please set {', '.join(missing_vars)}.") + sys.exit(1) + + # Construct Redis URL + REDIS_URL = f"redis://{os.environ['REDIS_USERNAME']}:{os.environ['REDIS_PASSWORD']}@{os.environ['REDIS_HOST']}:{os.environ['REDIS_PORT']}" + + # Run rq worker in a separate process + worker_process = multiprocessing.Process(target=run_rq_worker, args=(REDIS_URL,)) + worker_process.start() + +if __name__ == "__main__": + start_rq_worker_in_background()