forked from phoenix/litellm-mirror
refactor(proxy_server.py): refactoring background rq worker
This commit is contained in:
parent
7c79d10e9f
commit
a1f6b9b531
4 changed files with 52 additions and 43 deletions
|
@ -393,5 +393,5 @@ from .exceptions import (
|
||||||
)
|
)
|
||||||
from .budget_manager import BudgetManager
|
from .budget_manager import BudgetManager
|
||||||
from .proxy.proxy_cli import run_server
|
from .proxy.proxy_cli import run_server
|
||||||
from .proxy.queue.rq import start_rq_worker_in_background
|
|
||||||
from .router import Router
|
from .router import Router
|
||||||
|
from .proxy.proxy_server import litellm_queue_completion
|
||||||
|
|
|
@ -98,7 +98,7 @@ from fastapi.middleware.cors import CORSMiddleware
|
||||||
from fastapi.security import OAuth2PasswordBearer
|
from fastapi.security import OAuth2PasswordBearer
|
||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
from litellm import start_rq_worker_in_background
|
# from litellm.proxy.queue import start_rq_worker_in_background
|
||||||
|
|
||||||
app = FastAPI(docs_url="/", title="LiteLLM API")
|
app = FastAPI(docs_url="/", title="LiteLLM API")
|
||||||
router = APIRouter()
|
router = APIRouter()
|
||||||
|
@ -149,6 +149,12 @@ def print_verbose(print_statement):
|
||||||
print(print_statement)
|
print(print_statement)
|
||||||
|
|
||||||
|
|
||||||
|
def litellm_queue_completion(*args, **kwargs):
|
||||||
|
call_type = kwargs.pop("call_type")
|
||||||
|
llm_router: litellm.Router = kwargs.pop("llm_router")
|
||||||
|
|
||||||
|
return llm_router.completion(**kwargs)
|
||||||
|
|
||||||
def usage_telemetry(
|
def usage_telemetry(
|
||||||
feature: str,
|
feature: str,
|
||||||
): # helps us know if people are using this feature. Set `litellm --telemetry False` to your cli call to turn this off
|
): # helps us know if people are using this feature. Set `litellm --telemetry False` to your cli call to turn this off
|
||||||
|
@ -215,7 +221,7 @@ def rq_setup(use_queue: bool):
|
||||||
from rq.job import Job
|
from rq.job import Job
|
||||||
|
|
||||||
redis_job = Job
|
redis_job = Job
|
||||||
start_rq_worker_in_background()
|
# start_rq_worker_in_background()
|
||||||
redis_connection = Redis(host=os.getenv("REDIS_HOST"), port=os.getenv("REDIS_PORT"), password=os.getenv("REDIS_PASSWORD"))
|
redis_connection = Redis(host=os.getenv("REDIS_HOST"), port=os.getenv("REDIS_PORT"), password=os.getenv("REDIS_PASSWORD"))
|
||||||
request_queue = Queue(connection=redis_connection)
|
request_queue = Queue(connection=redis_connection)
|
||||||
|
|
||||||
|
@ -387,10 +393,11 @@ def initialize(
|
||||||
if experimental:
|
if experimental:
|
||||||
pass
|
pass
|
||||||
if save:
|
if save:
|
||||||
save_params_to_config(dynamic_config)
|
pass
|
||||||
with open(user_config_path) as f:
|
# save_params_to_config(dynamic_config)
|
||||||
print(f.read())
|
# with open(user_config_path) as f:
|
||||||
print("\033[1;32mDone successfully\033[0m")
|
# print(f.read())
|
||||||
|
# print("\033[1;32mDone successfully\033[0m")
|
||||||
user_telemetry = telemetry
|
user_telemetry = telemetry
|
||||||
usage_telemetry(feature="local_proxy_server")
|
usage_telemetry(feature="local_proxy_server")
|
||||||
|
|
||||||
|
@ -575,7 +582,7 @@ async def generate_key_fn(request: Request):
|
||||||
|
|
||||||
@router.post("/queue/request", dependencies=[Depends(user_api_key_auth)])
|
@router.post("/queue/request", dependencies=[Depends(user_api_key_auth)])
|
||||||
async def async_chat_completions(request: Request):
|
async def async_chat_completions(request: Request):
|
||||||
global request_queue
|
global request_queue, llm_router
|
||||||
body = await request.body()
|
body = await request.body()
|
||||||
body_str = body.decode()
|
body_str = body.decode()
|
||||||
try:
|
try:
|
||||||
|
@ -588,7 +595,8 @@ async def async_chat_completions(request: Request):
|
||||||
or data["model"] # default passed in http request
|
or data["model"] # default passed in http request
|
||||||
)
|
)
|
||||||
data["call_type"] = "chat_completion"
|
data["call_type"] = "chat_completion"
|
||||||
job = request_queue.enqueue(litellm_completion, **data)
|
data["llm_router"] = llm_router
|
||||||
|
job = request_queue.enqueue(litellm.litellm_queue_completion, **data)
|
||||||
return {"id": job.id, "url": f"/queue/response/{job.id}", "eta": 5, "status": "queued"}
|
return {"id": job.id, "url": f"/queue/response/{job.id}", "eta": 5, "status": "queued"}
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
@ -597,8 +605,11 @@ async def async_chat_completions(request: Request, task_id: str):
|
||||||
global redis_connection, redis_job
|
global redis_connection, redis_job
|
||||||
try:
|
try:
|
||||||
job = redis_job.fetch(id=task_id, connection=redis_connection)
|
job = redis_job.fetch(id=task_id, connection=redis_connection)
|
||||||
print(f"job status: {job.get_status()}")
|
|
||||||
result = job.result
|
result = job.result
|
||||||
|
status = job.get_status()
|
||||||
|
print(f"job status: {status}; job result: {result}")
|
||||||
|
if status == "failed":
|
||||||
|
print(f"job: {job.exc_info}")
|
||||||
status = "queued"
|
status = "queued"
|
||||||
if result is not None:
|
if result is not None:
|
||||||
status = "finished"
|
status = "finished"
|
||||||
|
|
|
@ -1,32 +0,0 @@
|
||||||
import os
|
|
||||||
import subprocess
|
|
||||||
import sys
|
|
||||||
import multiprocessing
|
|
||||||
from dotenv import load_dotenv
|
|
||||||
load_dotenv()
|
|
||||||
|
|
||||||
def run_rq_worker(redis_url):
|
|
||||||
command = ["rq", "worker", "--url", redis_url]
|
|
||||||
subprocess.run(command)
|
|
||||||
|
|
||||||
def start_rq_worker_in_background():
|
|
||||||
# Set OBJC_DISABLE_INITIALIZE_FORK_SAFETY to YES
|
|
||||||
os.environ["OBJC_DISABLE_INITIALIZE_FORK_SAFETY"] = "YES"
|
|
||||||
|
|
||||||
# Check if required environment variables are set
|
|
||||||
required_vars = ["REDIS_USERNAME", "REDIS_PASSWORD", "REDIS_HOST", "REDIS_PORT"]
|
|
||||||
missing_vars = [var for var in required_vars if var not in os.environ]
|
|
||||||
|
|
||||||
if missing_vars:
|
|
||||||
print(f"Error: Redis environment variables not set. Please set {', '.join(missing_vars)}.")
|
|
||||||
sys.exit(1)
|
|
||||||
|
|
||||||
# Construct Redis URL
|
|
||||||
REDIS_URL = f"redis://{os.environ['REDIS_USERNAME']}:{os.environ['REDIS_PASSWORD']}@{os.environ['REDIS_HOST']}:{os.environ['REDIS_PORT']}"
|
|
||||||
|
|
||||||
# Run rq worker in a separate process
|
|
||||||
worker_process = multiprocessing.Process(target=run_rq_worker, args=(REDIS_URL,))
|
|
||||||
worker_process.start()
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
start_rq_worker_in_background()
|
|
30
litellm/proxy/queue/rq_worker.py
Normal file
30
litellm/proxy/queue/rq_worker.py
Normal file
|
@ -0,0 +1,30 @@
|
||||||
|
import sys, os
|
||||||
|
from rq import Worker, Queue, Connection
|
||||||
|
from redis import Redis
|
||||||
|
from dotenv import load_dotenv
|
||||||
|
load_dotenv()
|
||||||
|
# Add the path to the local folder to sys.path
|
||||||
|
sys.path.insert(
|
||||||
|
0, os.path.abspath("../../..")
|
||||||
|
) # Adds the parent directory to the system path - for litellm local dev
|
||||||
|
|
||||||
|
|
||||||
|
# # Import your local module
|
||||||
|
# import litellm
|
||||||
|
# from litellm import litellm_queue_completion
|
||||||
|
|
||||||
|
# Set up RQ connection
|
||||||
|
redis_conn = Redis(host=os.getenv("REDIS_HOST"), port=os.getenv("REDIS_PORT"), password=os.getenv("REDIS_PASSWORD"))
|
||||||
|
print(redis_conn.ping()) # Should print True if connected successfully
|
||||||
|
# Create a worker and add the queue
|
||||||
|
try:
|
||||||
|
queue = Queue(connection=redis_conn)
|
||||||
|
worker = Worker([queue], connection=redis_conn)
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Error setting up worker: {e}")
|
||||||
|
exit()
|
||||||
|
|
||||||
|
# Run the worker
|
||||||
|
if __name__ == '__main__':
|
||||||
|
with Connection(redis_conn):
|
||||||
|
worker.work()
|
Loading…
Add table
Add a link
Reference in a new issue