diff --git a/litellm/__init__.py b/litellm/__init__.py index 6ada6480d..c53f1e43f 100644 --- a/litellm/__init__.py +++ b/litellm/__init__.py @@ -393,5 +393,4 @@ from .exceptions import ( ) from .budget_manager import BudgetManager from .proxy.proxy_cli import run_server -from .router import Router -from .proxy.proxy_server import litellm_queue_completion +from .router import Router \ No newline at end of file diff --git a/litellm/proxy/proxy_server.py b/litellm/proxy/proxy_server.py index 52f4febd4..114bebcd7 100644 --- a/litellm/proxy/proxy_server.py +++ b/litellm/proxy/proxy_server.py @@ -148,13 +148,6 @@ def print_verbose(print_statement): if user_debug: print(print_statement) - -def litellm_queue_completion(*args, **kwargs): - call_type = kwargs.pop("call_type") - llm_router: litellm.Router = kwargs.pop("llm_router") - - return llm_router.completion(**kwargs) - def usage_telemetry( feature: str, ): # helps us know if people are using this feature. Set `litellm --telemetry False` to your cli call to turn this off @@ -228,14 +221,11 @@ def rq_setup(use_queue: bool): global request_queue, redis_connection, redis_job print(f"value of use_queue: {use_queue}") if use_queue: - from redis import Redis - from rq import Queue - from rq.job import Job - - redis_job = Job - # start_rq_worker_in_background() - redis_connection = Redis(host=os.getenv("REDIS_HOST"), port=os.getenv("REDIS_PORT"), password=os.getenv("REDIS_PASSWORD")) - request_queue = Queue(connection=redis_connection) + from litellm.proxy.queue.celery_app import celery_app, process_job + from celery.result import AsyncResult + request_queue = process_job + redis_job = AsyncResult + redis_connection = celery_app def run_ollama_serve(): command = ['ollama', 'serve'] @@ -597,7 +587,7 @@ async def generate_key_fn(request: Request): @router.post("/queue/request", dependencies=[Depends(user_api_key_auth)]) async def async_chat_completions(request: Request): - global request_queue, llm_router + global request_queue, llm_model_list body = await request.body() body_str = body.decode() try: @@ -609,9 +599,9 @@ async def async_chat_completions(request: Request): or user_model # model name passed via cli args or data["model"] # default passed in http request ) - data["call_type"] = "chat_completion" - data["llm_router"] = llm_router # this is dynamic - we should load the llm_router from the user_api_key_auth - job = request_queue.enqueue(litellm.litellm_queue_completion, **data) + data["llm_model_list"] = llm_model_list + print(f"data: {data}") + job = request_queue.apply_async(kwargs=data) return {"id": job.id, "url": f"/queue/response/{job.id}", "eta": 5, "status": "queued"} pass @@ -619,16 +609,11 @@ async def async_chat_completions(request: Request): async def async_chat_completions(request: Request, task_id: str): global redis_connection, redis_job try: - job = redis_job.fetch(id=task_id, connection=redis_connection) - result = job.result - status = job.get_status() - print(f"job status: {status}; job result: {result}") - if status == "failed": - print(f"job: {job.exc_info}") - status = "queued" - if result is not None: - status = "finished" - return {"status": status, "result": result} + job = redis_job(task_id, app=redis_connection) + if job.ready(): + return job.result + else: + return {'status': 'processing'} except Exception as e: return {"status": "finished", "result": str(e)} diff --git a/litellm/proxy/queue/celery_app.py b/litellm/proxy/queue/celery_app.py new file mode 100644 index 000000000..f06ca15f9 --- /dev/null +++ b/litellm/proxy/queue/celery_app.py @@ -0,0 +1,39 @@ +from dotenv import load_dotenv +load_dotenv() +import json +import redis +from celery import Celery +import time +import sys, os +sys.path.insert( + 0, os.path.abspath("../../..") +) # Adds the parent directory to the system path - for litellm local dev +import litellm + +# Redis connection setup +pool = redis.ConnectionPool(host=os.getenv("REDIS_HOST"), port=os.getenv("REDIS_PORT"), password=os.getenv("REDIS_PASSWORD"), db=0, max_connections=10) +redis_client = redis.Redis(connection_pool=pool) + +# Celery setup +celery_app = Celery('tasks', broker=f"redis://default:{os.getenv('REDIS_PASSWORD')}@{os.getenv('REDIS_HOST')}:{os.getenv('REDIS_PORT')}", backend=f"redis://default:{os.getenv('REDIS_PASSWORD')}@{os.getenv('REDIS_HOST')}:{os.getenv('REDIS_PORT')}") +celery_app.conf.update( + broker_pool_limit = None, + broker_transport_options = {'connection_pool': pool}, + result_backend_transport_options = {'connection_pool': pool}, +) + + +# Celery task +@celery_app.task(name='process_job') +def process_job(*args, **kwargs): + try: + llm_router: litellm.Router = litellm.Router(model_list=kwargs.pop("llm_model_list")) + response = llm_router.completion(*args, **kwargs) + if isinstance(response, litellm.ModelResponse): + response = response.model_dump_json() + return json.loads(response) + return str(response) + except Exception as e: + print(e) + raise e + \ No newline at end of file diff --git a/litellm/proxy/queue/celery_task.py b/litellm/proxy/queue/celery_task.py new file mode 100644 index 000000000..ea2075db0 --- /dev/null +++ b/litellm/proxy/queue/celery_task.py @@ -0,0 +1,15 @@ +from dotenv import load_dotenv +load_dotenv() + +import sys, os +sys.path.insert( + 0, os.path.abspath("../../..") +) # Adds the parent directory to the system path - for litellm local dev +import litellm +from litellm.proxy.queue.celery_app import celery_app + +# Celery task +@celery_app.task(name='process_job') +def process_job(*args, **kwargs): + llm_router: litellm.Router = litellm.Router(model_list=kwargs.pop("llm_model_list")) + return llm_router.completion(*args, **kwargs) \ No newline at end of file diff --git a/litellm/proxy/queue/celery_worker.py b/litellm/proxy/queue/celery_worker.py new file mode 100644 index 000000000..97179bd84 --- /dev/null +++ b/litellm/proxy/queue/celery_worker.py @@ -0,0 +1,9 @@ +import os +from multiprocessing import Process + +def run_worker(): + os.system("celery worker -A your_project_name.celery_app --concurrency=10 --loglevel=info") + +if __name__ == "__main__": + worker_process = Process(target=run_worker) + worker_process.start() \ No newline at end of file diff --git a/litellm/proxy/queue/rq_worker.py b/litellm/proxy/queue/rq_worker.py index eae4e93b8..6e8ce29ae 100644 --- a/litellm/proxy/queue/rq_worker.py +++ b/litellm/proxy/queue/rq_worker.py @@ -18,4 +18,9 @@ def start_rq_worker(): worker = Worker([queue], connection=redis_conn) except Exception as e: print(f"Error setting up worker: {e}") - exit() \ No newline at end of file + exit() + + with Connection(redis_conn): + worker.work() + +start_rq_worker() \ No newline at end of file