mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-26 11:14:04 +00:00
refactor(proxy_server.py): using celery workers instead of rq for concurrency
This commit is contained in:
parent
8c98a2c899
commit
b16646e584
6 changed files with 84 additions and 32 deletions
|
@ -394,4 +394,3 @@ from .exceptions import (
|
||||||
from .budget_manager import BudgetManager
|
from .budget_manager import BudgetManager
|
||||||
from .proxy.proxy_cli import run_server
|
from .proxy.proxy_cli import run_server
|
||||||
from .router import Router
|
from .router import Router
|
||||||
from .proxy.proxy_server import litellm_queue_completion
|
|
||||||
|
|
|
@ -148,13 +148,6 @@ def print_verbose(print_statement):
|
||||||
if user_debug:
|
if user_debug:
|
||||||
print(print_statement)
|
print(print_statement)
|
||||||
|
|
||||||
|
|
||||||
def litellm_queue_completion(*args, **kwargs):
|
|
||||||
call_type = kwargs.pop("call_type")
|
|
||||||
llm_router: litellm.Router = kwargs.pop("llm_router")
|
|
||||||
|
|
||||||
return llm_router.completion(**kwargs)
|
|
||||||
|
|
||||||
def usage_telemetry(
|
def usage_telemetry(
|
||||||
feature: str,
|
feature: str,
|
||||||
): # helps us know if people are using this feature. Set `litellm --telemetry False` to your cli call to turn this off
|
): # helps us know if people are using this feature. Set `litellm --telemetry False` to your cli call to turn this off
|
||||||
|
@ -228,14 +221,11 @@ def rq_setup(use_queue: bool):
|
||||||
global request_queue, redis_connection, redis_job
|
global request_queue, redis_connection, redis_job
|
||||||
print(f"value of use_queue: {use_queue}")
|
print(f"value of use_queue: {use_queue}")
|
||||||
if use_queue:
|
if use_queue:
|
||||||
from redis import Redis
|
from litellm.proxy.queue.celery_app import celery_app, process_job
|
||||||
from rq import Queue
|
from celery.result import AsyncResult
|
||||||
from rq.job import Job
|
request_queue = process_job
|
||||||
|
redis_job = AsyncResult
|
||||||
redis_job = Job
|
redis_connection = celery_app
|
||||||
# start_rq_worker_in_background()
|
|
||||||
redis_connection = Redis(host=os.getenv("REDIS_HOST"), port=os.getenv("REDIS_PORT"), password=os.getenv("REDIS_PASSWORD"))
|
|
||||||
request_queue = Queue(connection=redis_connection)
|
|
||||||
|
|
||||||
def run_ollama_serve():
|
def run_ollama_serve():
|
||||||
command = ['ollama', 'serve']
|
command = ['ollama', 'serve']
|
||||||
|
@ -597,7 +587,7 @@ async def generate_key_fn(request: Request):
|
||||||
|
|
||||||
@router.post("/queue/request", dependencies=[Depends(user_api_key_auth)])
|
@router.post("/queue/request", dependencies=[Depends(user_api_key_auth)])
|
||||||
async def async_chat_completions(request: Request):
|
async def async_chat_completions(request: Request):
|
||||||
global request_queue, llm_router
|
global request_queue, llm_model_list
|
||||||
body = await request.body()
|
body = await request.body()
|
||||||
body_str = body.decode()
|
body_str = body.decode()
|
||||||
try:
|
try:
|
||||||
|
@ -609,9 +599,9 @@ async def async_chat_completions(request: Request):
|
||||||
or user_model # model name passed via cli args
|
or user_model # model name passed via cli args
|
||||||
or data["model"] # default passed in http request
|
or data["model"] # default passed in http request
|
||||||
)
|
)
|
||||||
data["call_type"] = "chat_completion"
|
data["llm_model_list"] = llm_model_list
|
||||||
data["llm_router"] = llm_router # this is dynamic - we should load the llm_router from the user_api_key_auth
|
print(f"data: {data}")
|
||||||
job = request_queue.enqueue(litellm.litellm_queue_completion, **data)
|
job = request_queue.apply_async(kwargs=data)
|
||||||
return {"id": job.id, "url": f"/queue/response/{job.id}", "eta": 5, "status": "queued"}
|
return {"id": job.id, "url": f"/queue/response/{job.id}", "eta": 5, "status": "queued"}
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
@ -619,16 +609,11 @@ async def async_chat_completions(request: Request):
|
||||||
async def async_chat_completions(request: Request, task_id: str):
|
async def async_chat_completions(request: Request, task_id: str):
|
||||||
global redis_connection, redis_job
|
global redis_connection, redis_job
|
||||||
try:
|
try:
|
||||||
job = redis_job.fetch(id=task_id, connection=redis_connection)
|
job = redis_job(task_id, app=redis_connection)
|
||||||
result = job.result
|
if job.ready():
|
||||||
status = job.get_status()
|
return job.result
|
||||||
print(f"job status: {status}; job result: {result}")
|
else:
|
||||||
if status == "failed":
|
return {'status': 'processing'}
|
||||||
print(f"job: {job.exc_info}")
|
|
||||||
status = "queued"
|
|
||||||
if result is not None:
|
|
||||||
status = "finished"
|
|
||||||
return {"status": status, "result": result}
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
return {"status": "finished", "result": str(e)}
|
return {"status": "finished", "result": str(e)}
|
||||||
|
|
||||||
|
|
39
litellm/proxy/queue/celery_app.py
Normal file
39
litellm/proxy/queue/celery_app.py
Normal file
|
@ -0,0 +1,39 @@
|
||||||
|
from dotenv import load_dotenv
|
||||||
|
load_dotenv()
|
||||||
|
import json
|
||||||
|
import redis
|
||||||
|
from celery import Celery
|
||||||
|
import time
|
||||||
|
import sys, os
|
||||||
|
sys.path.insert(
|
||||||
|
0, os.path.abspath("../../..")
|
||||||
|
) # Adds the parent directory to the system path - for litellm local dev
|
||||||
|
import litellm
|
||||||
|
|
||||||
|
# Redis connection setup
|
||||||
|
pool = redis.ConnectionPool(host=os.getenv("REDIS_HOST"), port=os.getenv("REDIS_PORT"), password=os.getenv("REDIS_PASSWORD"), db=0, max_connections=10)
|
||||||
|
redis_client = redis.Redis(connection_pool=pool)
|
||||||
|
|
||||||
|
# Celery setup
|
||||||
|
celery_app = Celery('tasks', broker=f"redis://default:{os.getenv('REDIS_PASSWORD')}@{os.getenv('REDIS_HOST')}:{os.getenv('REDIS_PORT')}", backend=f"redis://default:{os.getenv('REDIS_PASSWORD')}@{os.getenv('REDIS_HOST')}:{os.getenv('REDIS_PORT')}")
|
||||||
|
celery_app.conf.update(
|
||||||
|
broker_pool_limit = None,
|
||||||
|
broker_transport_options = {'connection_pool': pool},
|
||||||
|
result_backend_transport_options = {'connection_pool': pool},
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# Celery task
|
||||||
|
@celery_app.task(name='process_job')
|
||||||
|
def process_job(*args, **kwargs):
|
||||||
|
try:
|
||||||
|
llm_router: litellm.Router = litellm.Router(model_list=kwargs.pop("llm_model_list"))
|
||||||
|
response = llm_router.completion(*args, **kwargs)
|
||||||
|
if isinstance(response, litellm.ModelResponse):
|
||||||
|
response = response.model_dump_json()
|
||||||
|
return json.loads(response)
|
||||||
|
return str(response)
|
||||||
|
except Exception as e:
|
||||||
|
print(e)
|
||||||
|
raise e
|
||||||
|
|
15
litellm/proxy/queue/celery_task.py
Normal file
15
litellm/proxy/queue/celery_task.py
Normal file
|
@ -0,0 +1,15 @@
|
||||||
|
from dotenv import load_dotenv
|
||||||
|
load_dotenv()
|
||||||
|
|
||||||
|
import sys, os
|
||||||
|
sys.path.insert(
|
||||||
|
0, os.path.abspath("../../..")
|
||||||
|
) # Adds the parent directory to the system path - for litellm local dev
|
||||||
|
import litellm
|
||||||
|
from litellm.proxy.queue.celery_app import celery_app
|
||||||
|
|
||||||
|
# Celery task
|
||||||
|
@celery_app.task(name='process_job')
|
||||||
|
def process_job(*args, **kwargs):
|
||||||
|
llm_router: litellm.Router = litellm.Router(model_list=kwargs.pop("llm_model_list"))
|
||||||
|
return llm_router.completion(*args, **kwargs)
|
9
litellm/proxy/queue/celery_worker.py
Normal file
9
litellm/proxy/queue/celery_worker.py
Normal file
|
@ -0,0 +1,9 @@
|
||||||
|
import os
|
||||||
|
from multiprocessing import Process
|
||||||
|
|
||||||
|
def run_worker():
|
||||||
|
os.system("celery worker -A your_project_name.celery_app --concurrency=10 --loglevel=info")
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
worker_process = Process(target=run_worker)
|
||||||
|
worker_process.start()
|
|
@ -19,3 +19,8 @@ def start_rq_worker():
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"Error setting up worker: {e}")
|
print(f"Error setting up worker: {e}")
|
||||||
exit()
|
exit()
|
||||||
|
|
||||||
|
with Connection(redis_conn):
|
||||||
|
worker.work()
|
||||||
|
|
||||||
|
start_rq_worker()
|
Loading…
Add table
Add a link
Reference in a new issue