fix(proxy_server.py): fix linting issues

This commit is contained in:
Krrish Dholakia 2023-11-24 11:38:53 -08:00
parent d3d8b86eaa
commit 2e8d582a34
5 changed files with 44 additions and 32 deletions

View file

@ -277,10 +277,10 @@ def load_router_config(router: Optional[litellm.Router], config_file_path: str):
if isinstance(litellm.success_callback, list):
import utils
print("setting litellm success callback to track cost")
if (utils.track_cost_callback) not in litellm.success_callback:
litellm.success_callback.append(utils.track_cost_callback)
if (utils.track_cost_callback) not in litellm.success_callback: # type: ignore
litellm.success_callback.append(utils.track_cost_callback) # type: ignore
else:
litellm.success_callback = utils.track_cost_callback
litellm.success_callback = utils.track_cost_callback # type: ignore
### START REDIS QUEUE ###
use_queue = general_settings.get("use_queue", False)
celery_setup(use_queue=use_queue)
@ -717,32 +717,40 @@ async def test_endpoint(request: Request):
@router.post("/queue/request", dependencies=[Depends(user_api_key_auth)])
async def async_queue_request(request: Request):
global celery_fn, llm_model_list
body = await request.body()
body_str = body.decode()
try:
data = ast.literal_eval(body_str)
except:
data = json.loads(body_str)
data["model"] = (
general_settings.get("completion_model", None) # server default
or user_model # model name passed via cli args
or data["model"] # default passed in http request
)
data["llm_model_list"] = llm_model_list
print(f"data: {data}")
job = celery_fn.apply_async(kwargs=data)
return {"id": job.id, "url": f"/queue/response/{job.id}", "eta": 5, "status": "queued"}
pass
if celery_fn is not None:
body = await request.body()
body_str = body.decode()
try:
data = ast.literal_eval(body_str)
except:
data = json.loads(body_str)
data["model"] = (
general_settings.get("completion_model", None) # server default
or user_model # model name passed via cli args
or data["model"] # default passed in http request
)
data["llm_model_list"] = llm_model_list
print(f"data: {data}")
job = celery_fn.apply_async(kwargs=data)
return {"id": job.id, "url": f"/queue/response/{job.id}", "eta": 5, "status": "queued"}
else:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail={"error": "Queue not initialized"},
)
@router.get("/queue/response/{task_id}", dependencies=[Depends(user_api_key_auth)])
async def async_queue_response(request: Request, task_id: str):
global celery_app_conn, async_result
try:
job = async_result(task_id, app=celery_app_conn)
if job.ready():
return {"status": "finished", "result": job.result}
if celery_app_conn is not None and async_result is not None:
job = async_result(task_id, app=celery_app_conn)
if job.ready():
return {"status": "finished", "result": job.result}
else:
return {'status': 'queued'}
else:
return {'status': 'queued'}
raise Exception()
except Exception as e:
return {"status": "finished", "result": str(e)}