forked from phoenix/litellm-mirror
(feat) proxy: make chat/completions async
This commit is contained in:
parent
30f47d3169
commit
a688df79b1
1 changed files with 29 additions and 4 deletions
|
@ -563,6 +563,16 @@ def data_generator(response):
|
||||||
except:
|
except:
|
||||||
yield f"data: {json.dumps(chunk)}\n\n"
|
yield f"data: {json.dumps(chunk)}\n\n"
|
||||||
|
|
||||||
|
|
||||||
|
async def async_data_generator(response):
|
||||||
|
print_verbose("inside generator")
|
||||||
|
async for chunk in response:
|
||||||
|
print_verbose(f"returned chunk: {chunk}")
|
||||||
|
try:
|
||||||
|
yield f"data: {json.dumps(chunk.dict())}\n\n"
|
||||||
|
except:
|
||||||
|
yield f"data: {json.dumps(chunk)}\n\n"
|
||||||
|
|
||||||
def litellm_completion(*args, **kwargs):
|
def litellm_completion(*args, **kwargs):
|
||||||
global user_temperature, user_request_timeout, user_max_tokens, user_api_base
|
global user_temperature, user_request_timeout, user_max_tokens, user_api_base
|
||||||
call_type = kwargs.pop("call_type")
|
call_type = kwargs.pop("call_type")
|
||||||
|
@ -701,11 +711,26 @@ async def chat_completion(request: Request, model: Optional[str] = None, user_ap
|
||||||
or model # for azure deployments
|
or model # for azure deployments
|
||||||
or data["model"] # default passed in http request
|
or data["model"] # default passed in http request
|
||||||
)
|
)
|
||||||
data["call_type"] = "chat_completion"
|
|
||||||
data["metadata"] = {"user_api_key": user_api_key_dict["api_key"]}
|
data["metadata"] = {"user_api_key": user_api_key_dict["api_key"]}
|
||||||
return litellm_completion(
|
|
||||||
**data
|
global user_temperature, user_request_timeout, user_max_tokens, user_api_base
|
||||||
)
|
# override with user settings, these are params passed via cli
|
||||||
|
if user_temperature:
|
||||||
|
data["temperature"] = user_temperature
|
||||||
|
if user_request_timeout:
|
||||||
|
data["request_timeout"] = user_request_timeout
|
||||||
|
if user_max_tokens:
|
||||||
|
data["max_tokens"] = user_max_tokens
|
||||||
|
if user_api_base:
|
||||||
|
data["api_base"] = user_api_base
|
||||||
|
router_model_names = [m["model_name"] for m in llm_model_list] if llm_model_list is not None else []
|
||||||
|
if llm_router is not None and data["model"] in router_model_names: # model in router model list
|
||||||
|
response = await llm_router.acompletion(**data)
|
||||||
|
else:
|
||||||
|
response = await litellm.acompletion(**data)
|
||||||
|
if 'stream' in data and data['stream'] == True: # use generate_responses to stream responses
|
||||||
|
return StreamingResponse(async_data_generator(response), media_type='text/event-stream')
|
||||||
|
return response
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"\033[1;31mAn error occurred: {e}\n\n Debug this by setting `--debug`, e.g. `litellm --model gpt-3.5-turbo --debug`")
|
print(f"\033[1;31mAn error occurred: {e}\n\n Debug this by setting `--debug`, e.g. `litellm --model gpt-3.5-turbo --debug`")
|
||||||
error_traceback = traceback.format_exc()
|
error_traceback = traceback.format_exc()
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue