mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-26 03:04:13 +00:00
fix(main.py): support async streaming for text completions endpoint
This commit is contained in:
parent
7df9c8e4d8
commit
1608dd7e0b
7 changed files with 175 additions and 68 deletions
|
@ -797,37 +797,6 @@ async def async_data_generator(response, user_api_key_dict):
|
|||
except:
|
||||
yield f"data: {json.dumps(chunk)}\n\n"
|
||||
|
||||
def litellm_completion(*args, **kwargs):
|
||||
global user_temperature, user_request_timeout, user_max_tokens, user_api_base
|
||||
call_type = kwargs.pop("call_type")
|
||||
# override with user settings, these are params passed via cli
|
||||
if user_temperature:
|
||||
kwargs["temperature"] = user_temperature
|
||||
if user_request_timeout:
|
||||
kwargs["request_timeout"] = user_request_timeout
|
||||
if user_max_tokens:
|
||||
kwargs["max_tokens"] = user_max_tokens
|
||||
if user_api_base:
|
||||
kwargs["api_base"] = user_api_base
|
||||
## ROUTE TO CORRECT ENDPOINT ##
|
||||
router_model_names = [m["model_name"] for m in llm_model_list] if llm_model_list is not None else []
|
||||
try:
|
||||
if llm_router is not None and kwargs["model"] in router_model_names: # model in router model list
|
||||
if call_type == "chat_completion":
|
||||
response = llm_router.completion(*args, **kwargs)
|
||||
elif call_type == "text_completion":
|
||||
response = llm_router.text_completion(*args, **kwargs)
|
||||
else:
|
||||
if call_type == "chat_completion":
|
||||
response = litellm.completion(*args, **kwargs)
|
||||
elif call_type == "text_completion":
|
||||
response = litellm.text_completion(*args, **kwargs)
|
||||
except Exception as e:
|
||||
raise e
|
||||
if 'stream' in kwargs and kwargs['stream'] == True: # use generate_responses to stream responses
|
||||
return StreamingResponse(data_generator(response), media_type='text/event-stream')
|
||||
return response
|
||||
|
||||
def get_litellm_model_info(model: dict = {}):
|
||||
model_info = model.get("model_info", {})
|
||||
model_to_lookup = model.get("litellm_params", {}).get("model", None)
|
||||
|
@ -907,7 +876,8 @@ def model_list():
|
|||
@router.post("/v1/completions", dependencies=[Depends(user_api_key_auth)])
|
||||
@router.post("/completions", dependencies=[Depends(user_api_key_auth)])
|
||||
@router.post("/engines/{model:path}/completions", dependencies=[Depends(user_api_key_auth)])
|
||||
async def completion(request: Request, model: Optional[str] = None, user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth)):
|
||||
async def completion(request: Request, model: Optional[str] = None, user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth), background_tasks: BackgroundTasks = BackgroundTasks()):
|
||||
global user_temperature, user_request_timeout, user_max_tokens, user_api_base
|
||||
try:
|
||||
body = await request.body()
|
||||
body_str = body.decode()
|
||||
|
@ -925,17 +895,44 @@ async def completion(request: Request, model: Optional[str] = None, user_api_key
|
|||
)
|
||||
if user_model:
|
||||
data["model"] = user_model
|
||||
data["call_type"] = "text_completion"
|
||||
if "metadata" in data:
|
||||
data["metadata"]["user_api_key"] = user_api_key_dict.api_key
|
||||
else:
|
||||
data["metadata"] = {"user_api_key": user_api_key_dict.api_key}
|
||||
|
||||
return litellm_completion(
|
||||
**data
|
||||
)
|
||||
# override with user settings, these are params passed via cli
|
||||
if user_temperature:
|
||||
data["temperature"] = user_temperature
|
||||
if user_request_timeout:
|
||||
data["request_timeout"] = user_request_timeout
|
||||
if user_max_tokens:
|
||||
data["max_tokens"] = user_max_tokens
|
||||
if user_api_base:
|
||||
data["api_base"] = user_api_base
|
||||
|
||||
### CALL HOOKS ### - modify incoming data before calling the model
|
||||
data = await proxy_logging_obj.pre_call_hook(user_api_key_dict=user_api_key_dict, data=data, call_type="completion")
|
||||
|
||||
### ROUTE THE REQUEST ###
|
||||
router_model_names = [m["model_name"] for m in llm_model_list] if llm_model_list is not None else []
|
||||
if llm_router is not None and data["model"] in router_model_names: # model in router model list
|
||||
response = await llm_router.atext_completion(**data)
|
||||
elif llm_router is not None and data["model"] in llm_router.deployment_names: # model in router deployments, calling a specific deployment on the router
|
||||
response = await llm_router.atext_completion(**data, specific_deployment = True)
|
||||
elif llm_router is not None and llm_router.model_group_alias is not None and data["model"] in llm_router.model_group_alias: # model set in model_group_alias
|
||||
response = await llm_router.atext_completion(**data)
|
||||
else: # router is not set
|
||||
response = await litellm.atext_completion(**data)
|
||||
|
||||
print(f"final response: {response}")
|
||||
if 'stream' in data and data['stream'] == True: # use generate_responses to stream responses
|
||||
return StreamingResponse(async_data_generator(user_api_key_dict=user_api_key_dict, response=response), media_type='text/event-stream')
|
||||
|
||||
background_tasks.add_task(log_input_output, request, response) # background task for logging to OTEL
|
||||
return response
|
||||
except Exception as e:
|
||||
print(f"\033[1;31mAn error occurred: {e}\n\n Debug this by setting `--debug`, e.g. `litellm --model gpt-3.5-turbo --debug`")
|
||||
traceback.print_exc()
|
||||
error_traceback = traceback.format_exc()
|
||||
error_msg = f"{str(e)}\n\n{error_traceback}"
|
||||
try:
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue