forked from phoenix/litellm-mirror
feat(proxy_server.py): EXPERIMENTAL: adding queuing endpoints to openai proxy server
This commit is contained in:
parent
5dabcc21c9
commit
c6a4744947
3 changed files with 81 additions and 1 deletions
|
@ -393,4 +393,5 @@ from .exceptions import (
|
||||||
)
|
)
|
||||||
from .budget_manager import BudgetManager
|
from .budget_manager import BudgetManager
|
||||||
from .proxy.proxy_cli import run_server
|
from .proxy.proxy_cli import run_server
|
||||||
|
from .proxy.queue.rq import start_rq_worker_in_background
|
||||||
from .router import Router
|
from .router import Router
|
||||||
|
|
|
@ -96,6 +96,7 @@ from fastapi.middleware.cors import CORSMiddleware
|
||||||
from fastapi.security import OAuth2PasswordBearer
|
from fastapi.security import OAuth2PasswordBearer
|
||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
|
from litellm import start_rq_worker_in_background
|
||||||
|
|
||||||
app = FastAPI(docs_url="/", title="LiteLLM API")
|
app = FastAPI(docs_url="/", title="LiteLLM API")
|
||||||
router = APIRouter()
|
router = APIRouter()
|
||||||
|
@ -135,6 +136,10 @@ log_file = "api_log.json"
|
||||||
worker_config = None
|
worker_config = None
|
||||||
master_key = None
|
master_key = None
|
||||||
prisma_client = None
|
prisma_client = None
|
||||||
|
### REDIS QUEUE ###
|
||||||
|
redis_job = None
|
||||||
|
redis_connection = None
|
||||||
|
request_queue = None # Redis Queue for handling requests
|
||||||
#### HELPER FUNCTIONS ####
|
#### HELPER FUNCTIONS ####
|
||||||
def print_verbose(print_statement):
|
def print_verbose(print_statement):
|
||||||
global user_debug
|
global user_debug
|
||||||
|
@ -199,6 +204,19 @@ def prisma_setup(database_url: Optional[str]):
|
||||||
from prisma import Client
|
from prisma import Client
|
||||||
prisma_client = Client()
|
prisma_client = Client()
|
||||||
|
|
||||||
|
def rq_setup(use_queue: bool):
|
||||||
|
global request_queue, redis_connection, redis_job
|
||||||
|
print(f"value of use_queue: {use_queue}")
|
||||||
|
if use_queue:
|
||||||
|
from redis import Redis
|
||||||
|
from rq import Queue
|
||||||
|
from rq.job import Job
|
||||||
|
|
||||||
|
redis_job = Job
|
||||||
|
start_rq_worker_in_background()
|
||||||
|
redis_connection = Redis(host=os.getenv("REDIS_HOST"), port=os.getenv("REDIS_PORT"), password=os.getenv("REDIS_PASSWORD"))
|
||||||
|
request_queue = Queue(connection=redis_connection)
|
||||||
|
|
||||||
def run_ollama_serve():
|
def run_ollama_serve():
|
||||||
command = ['ollama', 'serve']
|
command = ['ollama', 'serve']
|
||||||
|
|
||||||
|
@ -234,7 +252,9 @@ def load_router_config(router: Optional[litellm.Router], config_file_path: str):
|
||||||
### CONNECT TO DATABASE ###
|
### CONNECT TO DATABASE ###
|
||||||
database_url = general_settings.get("database_url", None)
|
database_url = general_settings.get("database_url", None)
|
||||||
prisma_setup(database_url=database_url)
|
prisma_setup(database_url=database_url)
|
||||||
|
### START REDIS QUEUE ###
|
||||||
|
use_queue = general_settings.get("use_queue", False)
|
||||||
|
rq_setup(use_queue=use_queue)
|
||||||
|
|
||||||
## LITELLM MODULE SETTINGS (e.g. litellm.drop_params=True,..)
|
## LITELLM MODULE SETTINGS (e.g. litellm.drop_params=True,..)
|
||||||
litellm_settings = config.get('litellm_settings', None)
|
litellm_settings = config.get('litellm_settings', None)
|
||||||
|
@ -551,6 +571,33 @@ async def generate_key_fn(request: Request):
|
||||||
detail={"error": "models param must be a list"},
|
detail={"error": "models param must be a list"},
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@router.post("/queue/chat/completions", dependencies=[Depends(user_api_key_auth)])
|
||||||
|
async def async_chat_completions(request: Request):
|
||||||
|
global request_queue
|
||||||
|
body = await request.body()
|
||||||
|
body_str = body.decode()
|
||||||
|
try:
|
||||||
|
data = ast.literal_eval(body_str)
|
||||||
|
except:
|
||||||
|
data = json.loads(body_str)
|
||||||
|
data["model"] = (
|
||||||
|
server_settings.get("completion_model", None) # server default
|
||||||
|
or user_model # model name passed via cli args
|
||||||
|
or data["model"] # default passed in http request
|
||||||
|
)
|
||||||
|
data["call_type"] = "chat_completion"
|
||||||
|
job = request_queue.enqueue(litellm_completion, **data)
|
||||||
|
return {"id": job.id, "url": f"/queue/chat/completions/{job.id}", "eta": 5, "status": "queued"}
|
||||||
|
pass
|
||||||
|
|
||||||
|
@router.get("/queue/response/{task_id}", dependencies=[Depends(user_api_key_auth)])
|
||||||
|
async def async_chat_completions(request: Request, task_id: str):
|
||||||
|
global redis_connection, redis_job
|
||||||
|
job = redis_job.fetch(id=task_id, connection=redis_connection)
|
||||||
|
print(f"job status: {job.get_status()}")
|
||||||
|
result = job.result
|
||||||
|
return {"status": job.get_status(), "result": result}
|
||||||
|
|
||||||
|
|
||||||
@router.get("/ollama_logs", dependencies=[Depends(user_api_key_auth)])
|
@router.get("/ollama_logs", dependencies=[Depends(user_api_key_auth)])
|
||||||
async def retrieve_server_log(request: Request):
|
async def retrieve_server_log(request: Request):
|
||||||
|
|
32
litellm/proxy/queue/rq.py
Normal file
32
litellm/proxy/queue/rq.py
Normal file
|
@ -0,0 +1,32 @@
|
||||||
|
import os
|
||||||
|
import subprocess
|
||||||
|
import sys
|
||||||
|
import multiprocessing
|
||||||
|
from dotenv import load_dotenv
|
||||||
|
load_dotenv()
|
||||||
|
|
||||||
|
def run_rq_worker(redis_url):
|
||||||
|
command = ["rq", "worker", "--url", redis_url]
|
||||||
|
subprocess.run(command)
|
||||||
|
|
||||||
|
def start_rq_worker_in_background():
|
||||||
|
# Set OBJC_DISABLE_INITIALIZE_FORK_SAFETY to YES
|
||||||
|
os.environ["OBJC_DISABLE_INITIALIZE_FORK_SAFETY"] = "YES"
|
||||||
|
|
||||||
|
# Check if required environment variables are set
|
||||||
|
required_vars = ["REDIS_USERNAME", "REDIS_PASSWORD", "REDIS_HOST", "REDIS_PORT"]
|
||||||
|
missing_vars = [var for var in required_vars if var not in os.environ]
|
||||||
|
|
||||||
|
if missing_vars:
|
||||||
|
print(f"Error: Redis environment variables not set. Please set {', '.join(missing_vars)}.")
|
||||||
|
sys.exit(1)
|
||||||
|
|
||||||
|
# Construct Redis URL
|
||||||
|
REDIS_URL = f"redis://{os.environ['REDIS_USERNAME']}:{os.environ['REDIS_PASSWORD']}@{os.environ['REDIS_HOST']}:{os.environ['REDIS_PORT']}"
|
||||||
|
|
||||||
|
# Run rq worker in a separate process
|
||||||
|
worker_process = multiprocessing.Process(target=run_rq_worker, args=(REDIS_URL,))
|
||||||
|
worker_process.start()
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
start_rq_worker_in_background()
|
Loading…
Add table
Add a link
Reference in a new issue