From a688df79b136e861be51fb22a4c617aef0cc08cd Mon Sep 17 00:00:00 2001 From: ishaan-jaff Date: Sat, 25 Nov 2023 12:54:01 -0800 Subject: [PATCH] (feat) proxy: make chat/completions async --- litellm/proxy/proxy_server.py | 33 +++++++++++++++++++++++++++++---- 1 file changed, 29 insertions(+), 4 deletions(-) diff --git a/litellm/proxy/proxy_server.py b/litellm/proxy/proxy_server.py index cb73687fed..e43534c205 100644 --- a/litellm/proxy/proxy_server.py +++ b/litellm/proxy/proxy_server.py @@ -563,6 +563,16 @@ def data_generator(response): except: yield f"data: {json.dumps(chunk)}\n\n" + +async def async_data_generator(response): + print_verbose("inside generator") + async for chunk in response: + print_verbose(f"returned chunk: {chunk}") + try: + yield f"data: {json.dumps(chunk.dict())}\n\n" + except: + yield f"data: {json.dumps(chunk)}\n\n" + def litellm_completion(*args, **kwargs): global user_temperature, user_request_timeout, user_max_tokens, user_api_base call_type = kwargs.pop("call_type") @@ -701,11 +711,26 @@ async def chat_completion(request: Request, model: Optional[str] = None, user_ap or model # for azure deployments or data["model"] # default passed in http request ) - data["call_type"] = "chat_completion" data["metadata"] = {"user_api_key": user_api_key_dict["api_key"]} - return litellm_completion( - **data - ) + + global user_temperature, user_request_timeout, user_max_tokens, user_api_base + # override with user settings, these are params passed via cli + if user_temperature: + data["temperature"] = user_temperature + if user_request_timeout: + data["request_timeout"] = user_request_timeout + if user_max_tokens: + data["max_tokens"] = user_max_tokens + if user_api_base: + data["api_base"] = user_api_base + router_model_names = [m["model_name"] for m in llm_model_list] if llm_model_list is not None else [] + if llm_router is not None and data["model"] in router_model_names: # model in router model list + response = await llm_router.acompletion(**data) + else: + response = await litellm.acompletion(**data) + if 'stream' in data and data['stream'] == True: # use generate_responses to stream responses + return StreamingResponse(async_data_generator(response), media_type='text/event-stream') + return response except Exception as e: print(f"\033[1;31mAn error occurred: {e}\n\n Debug this by setting `--debug`, e.g. `litellm --model gpt-3.5-turbo --debug`") error_traceback = traceback.format_exc()