refactor(proxy_server.py): using celery workers instead of rq for concurrency

This commit is contained in:
Krrish Dholakia 2023-11-21 16:31:20 -08:00
parent 8c98a2c899
commit b16646e584
6 changed files with 84 additions and 32 deletions

View file

@ -148,13 +148,6 @@ def print_verbose(print_statement):
if user_debug:
print(print_statement)
def litellm_queue_completion(*args, **kwargs):
call_type = kwargs.pop("call_type")
llm_router: litellm.Router = kwargs.pop("llm_router")
return llm_router.completion(**kwargs)
def usage_telemetry(
feature: str,
): # helps us know if people are using this feature. Set `litellm --telemetry False` to your cli call to turn this off
@ -228,14 +221,11 @@ def rq_setup(use_queue: bool):
global request_queue, redis_connection, redis_job
print(f"value of use_queue: {use_queue}")
if use_queue:
from redis import Redis
from rq import Queue
from rq.job import Job
redis_job = Job
# start_rq_worker_in_background()
redis_connection = Redis(host=os.getenv("REDIS_HOST"), port=os.getenv("REDIS_PORT"), password=os.getenv("REDIS_PASSWORD"))
request_queue = Queue(connection=redis_connection)
from litellm.proxy.queue.celery_app import celery_app, process_job
from celery.result import AsyncResult
request_queue = process_job
redis_job = AsyncResult
redis_connection = celery_app
def run_ollama_serve():
command = ['ollama', 'serve']
@ -597,7 +587,7 @@ async def generate_key_fn(request: Request):
@router.post("/queue/request", dependencies=[Depends(user_api_key_auth)])
async def async_chat_completions(request: Request):
global request_queue, llm_router
global request_queue, llm_model_list
body = await request.body()
body_str = body.decode()
try:
@ -609,9 +599,9 @@ async def async_chat_completions(request: Request):
or user_model # model name passed via cli args
or data["model"] # default passed in http request
)
data["call_type"] = "chat_completion"
data["llm_router"] = llm_router # this is dynamic - we should load the llm_router from the user_api_key_auth
job = request_queue.enqueue(litellm.litellm_queue_completion, **data)
data["llm_model_list"] = llm_model_list
print(f"data: {data}")
job = request_queue.apply_async(kwargs=data)
return {"id": job.id, "url": f"/queue/response/{job.id}", "eta": 5, "status": "queued"}
pass
@ -619,16 +609,11 @@ async def async_chat_completions(request: Request):
async def async_chat_completions(request: Request, task_id: str):
global redis_connection, redis_job
try:
job = redis_job.fetch(id=task_id, connection=redis_connection)
result = job.result
status = job.get_status()
print(f"job status: {status}; job result: {result}")
if status == "failed":
print(f"job: {job.exc_info}")
status = "queued"
if result is not None:
status = "finished"
return {"status": status, "result": result}
job = redis_job(task_id, app=redis_connection)
if job.ready():
return job.result
else:
return {'status': 'processing'}
except Exception as e:
return {"status": "finished", "result": str(e)}