mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-25 18:54:30 +00:00
refactor(proxy_server.py): refactoring background rq worker
This commit is contained in:
parent
7c79d10e9f
commit
a1f6b9b531
4 changed files with 52 additions and 43 deletions
|
@ -98,7 +98,7 @@ from fastapi.middleware.cors import CORSMiddleware
|
|||
from fastapi.security import OAuth2PasswordBearer
|
||||
import json
|
||||
import logging
|
||||
from litellm import start_rq_worker_in_background
|
||||
# from litellm.proxy.queue import start_rq_worker_in_background
|
||||
|
||||
app = FastAPI(docs_url="/", title="LiteLLM API")
|
||||
router = APIRouter()
|
||||
|
@ -149,6 +149,12 @@ def print_verbose(print_statement):
|
|||
print(print_statement)
|
||||
|
||||
|
||||
def litellm_queue_completion(*args, **kwargs):
|
||||
call_type = kwargs.pop("call_type")
|
||||
llm_router: litellm.Router = kwargs.pop("llm_router")
|
||||
|
||||
return llm_router.completion(**kwargs)
|
||||
|
||||
def usage_telemetry(
|
||||
feature: str,
|
||||
): # helps us know if people are using this feature. Set `litellm --telemetry False` to your cli call to turn this off
|
||||
|
@ -215,7 +221,7 @@ def rq_setup(use_queue: bool):
|
|||
from rq.job import Job
|
||||
|
||||
redis_job = Job
|
||||
start_rq_worker_in_background()
|
||||
# start_rq_worker_in_background()
|
||||
redis_connection = Redis(host=os.getenv("REDIS_HOST"), port=os.getenv("REDIS_PORT"), password=os.getenv("REDIS_PASSWORD"))
|
||||
request_queue = Queue(connection=redis_connection)
|
||||
|
||||
|
@ -387,10 +393,11 @@ def initialize(
|
|||
if experimental:
|
||||
pass
|
||||
if save:
|
||||
save_params_to_config(dynamic_config)
|
||||
with open(user_config_path) as f:
|
||||
print(f.read())
|
||||
print("\033[1;32mDone successfully\033[0m")
|
||||
pass
|
||||
# save_params_to_config(dynamic_config)
|
||||
# with open(user_config_path) as f:
|
||||
# print(f.read())
|
||||
# print("\033[1;32mDone successfully\033[0m")
|
||||
user_telemetry = telemetry
|
||||
usage_telemetry(feature="local_proxy_server")
|
||||
|
||||
|
@ -435,7 +442,7 @@ def litellm_completion(*args, **kwargs):
|
|||
if 'stream' in kwargs and kwargs['stream'] == True: # use generate_responses to stream responses
|
||||
return StreamingResponse(data_generator(response), media_type='text/event-stream')
|
||||
return response
|
||||
|
||||
|
||||
@app.on_event("startup")
|
||||
async def startup_event():
|
||||
global prisma_client
|
||||
|
@ -575,7 +582,7 @@ async def generate_key_fn(request: Request):
|
|||
|
||||
@router.post("/queue/request", dependencies=[Depends(user_api_key_auth)])
|
||||
async def async_chat_completions(request: Request):
|
||||
global request_queue
|
||||
global request_queue, llm_router
|
||||
body = await request.body()
|
||||
body_str = body.decode()
|
||||
try:
|
||||
|
@ -588,7 +595,8 @@ async def async_chat_completions(request: Request):
|
|||
or data["model"] # default passed in http request
|
||||
)
|
||||
data["call_type"] = "chat_completion"
|
||||
job = request_queue.enqueue(litellm_completion, **data)
|
||||
data["llm_router"] = llm_router
|
||||
job = request_queue.enqueue(litellm.litellm_queue_completion, **data)
|
||||
return {"id": job.id, "url": f"/queue/response/{job.id}", "eta": 5, "status": "queued"}
|
||||
pass
|
||||
|
||||
|
@ -597,8 +605,11 @@ async def async_chat_completions(request: Request, task_id: str):
|
|||
global redis_connection, redis_job
|
||||
try:
|
||||
job = redis_job.fetch(id=task_id, connection=redis_connection)
|
||||
print(f"job status: {job.get_status()}")
|
||||
result = job.result
|
||||
status = job.get_status()
|
||||
print(f"job status: {status}; job result: {result}")
|
||||
if status == "failed":
|
||||
print(f"job: {job.exc_info}")
|
||||
status = "queued"
|
||||
if result is not None:
|
||||
status = "finished"
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue