feat(proxy_server.py): EXPERIMENTAL: adding queuing endpoints to openai proxy server

This commit is contained in:
Krrish Dholakia 2023-11-21 12:06:15 -08:00
parent 5dabcc21c9
commit c6a4744947
3 changed files with 81 additions and 1 deletions

View file

@ -96,6 +96,7 @@ from fastapi.middleware.cors import CORSMiddleware
from fastapi.security import OAuth2PasswordBearer
import json
import logging
from litellm import start_rq_worker_in_background
app = FastAPI(docs_url="/", title="LiteLLM API")
router = APIRouter()
@ -135,6 +136,10 @@ log_file = "api_log.json"
worker_config = None
master_key = None
prisma_client = None
### REDIS QUEUE ###
redis_job = None
redis_connection = None
request_queue = None # Redis Queue for handling requests
#### HELPER FUNCTIONS ####
def print_verbose(print_statement):
global user_debug
@ -199,6 +204,19 @@ def prisma_setup(database_url: Optional[str]):
from prisma import Client
prisma_client = Client()
def rq_setup(use_queue: bool):
global request_queue, redis_connection, redis_job
print(f"value of use_queue: {use_queue}")
if use_queue:
from redis import Redis
from rq import Queue
from rq.job import Job
redis_job = Job
start_rq_worker_in_background()
redis_connection = Redis(host=os.getenv("REDIS_HOST"), port=os.getenv("REDIS_PORT"), password=os.getenv("REDIS_PASSWORD"))
request_queue = Queue(connection=redis_connection)
def run_ollama_serve():
command = ['ollama', 'serve']
@ -234,7 +252,9 @@ def load_router_config(router: Optional[litellm.Router], config_file_path: str):
### CONNECT TO DATABASE ###
database_url = general_settings.get("database_url", None)
prisma_setup(database_url=database_url)
### START REDIS QUEUE ###
use_queue = general_settings.get("use_queue", False)
rq_setup(use_queue=use_queue)
## LITELLM MODULE SETTINGS (e.g. litellm.drop_params=True,..)
litellm_settings = config.get('litellm_settings', None)
@ -551,6 +571,33 @@ async def generate_key_fn(request: Request):
detail={"error": "models param must be a list"},
)
@router.post("/queue/chat/completions", dependencies=[Depends(user_api_key_auth)])
async def async_chat_completions(request: Request):
global request_queue
body = await request.body()
body_str = body.decode()
try:
data = ast.literal_eval(body_str)
except:
data = json.loads(body_str)
data["model"] = (
server_settings.get("completion_model", None) # server default
or user_model # model name passed via cli args
or data["model"] # default passed in http request
)
data["call_type"] = "chat_completion"
job = request_queue.enqueue(litellm_completion, **data)
return {"id": job.id, "url": f"/queue/chat/completions/{job.id}", "eta": 5, "status": "queued"}
pass
@router.get("/queue/response/{task_id}", dependencies=[Depends(user_api_key_auth)])
async def async_chat_completions(request: Request, task_id: str):
global redis_connection, redis_job
job = redis_job.fetch(id=task_id, connection=redis_connection)
print(f"job status: {job.get_status()}")
result = job.result
return {"status": job.get_status(), "result": result}
@router.get("/ollama_logs", dependencies=[Depends(user_api_key_auth)])
async def retrieve_server_log(request: Request):