refactor(openai.py): support aiohttp streaming

This commit is contained in:
Krrish Dholakia 2023-11-09 16:15:21 -08:00
parent bba62b56d3
commit c053782d96
5 changed files with 108 additions and 42 deletions

View file

@ -137,9 +137,13 @@ async def acompletion(model: str, messages: List = [], *args, **kwargs):
_, custom_llm_provider, _, _ = get_llm_provider(model=model, api_base=kwargs.get("api_base", None))
if custom_llm_provider == "openai" or custom_llm_provider == "azure": # currently implemented aiohttp calls for just azure and openai, soon all.
# Await normally
response = await completion(*args, **kwargs)
if (custom_llm_provider == "openai" or custom_llm_provider == "azure"): # currently implemented aiohttp calls for just azure and openai, soon all.
if kwargs.get("stream", False):
response = completion(*args, **kwargs)
else:
# Await normally
response = await completion(*args, **kwargs)
else:
# Call the synchronous function using run_in_executor
response = await loop.run_in_executor(None, func_with_context)
@ -147,6 +151,7 @@ async def acompletion(model: str, messages: List = [], *args, **kwargs):
# do not change this
# for stream = True, always return an async generator
# See OpenAI acreate https://github.com/openai/openai-python/blob/5d50e9e3b39540af782ca24e65c290343d86e1a9/openai/api_resources/abstract/engine_api_resource.py#L193
# return response
return(
line
async for line in response
@ -515,7 +520,7 @@ def completion(
)
raise e
if "stream" in optional_params and optional_params["stream"] == True:
if optional_params.get("stream", False) and acompletion is False:
response = CustomStreamWrapper(response, model, custom_llm_provider=custom_llm_provider, logging_obj=logging)
return response
## LOGGING