refactor(openai.py): support aiohttp streaming

This commit is contained in:
Krrish Dholakia 2023-11-09 16:15:21 -08:00
parent bba62b56d3
commit c053782d96
5 changed files with 108 additions and 42 deletions

View file

@ -1,11 +1,10 @@
from typing import Optional, Union
import types, requests
from .base import BaseLLM
from litellm.utils import ModelResponse, Choices, Message
from litellm.utils import ModelResponse, Choices, Message, CustomStreamWrapper
from typing import Callable, Optional
import aiohttp
class OpenAIError(Exception):
def __init__(self, status_code, message):
self.status_code = status_code
@ -219,7 +218,12 @@ class OpenAIChatCompletion(BaseLLM):
)
try:
if "stream" in optional_params and optional_params["stream"] == True:
if acompletion is True:
if optional_params.get("stream", False):
return self.async_streaming(logging_obj=logging_obj, api_base=api_base, data=data, headers=headers, model_response=model_response, model=model)
else:
return self.acompletion(logging_obj=logging_obj, api_base=api_base, data=data, headers=headers, model_response=model_response, model=model)
elif "stream" in optional_params and optional_params["stream"] == True:
response = self._client_session.post(
url=api_base,
json=data,
@ -231,8 +235,7 @@ class OpenAIChatCompletion(BaseLLM):
## RESPONSE OBJECT
return response.iter_lines()
elif acompletion is True:
return self.acompletion(api_base=api_base, data=data, headers=headers, model_response=model_response)
else:
response = self._client_session.post(
url=api_base,
@ -273,7 +276,12 @@ class OpenAIChatCompletion(BaseLLM):
import traceback
raise OpenAIError(status_code=500, message=traceback.format_exc())
async def acompletion(self, api_base: str, data: dict, headers: dict, model_response: ModelResponse):
async def acompletion(self,
logging_obj,
api_base: str,
data: dict, headers: dict,
model_response: ModelResponse,
model: str):
async with aiohttp.ClientSession() as session:
async with session.post(api_base, json=data, headers=headers) as response:
response_json = await response.json()
@ -284,6 +292,25 @@ class OpenAIChatCompletion(BaseLLM):
## RESPONSE OBJECT
return self.convert_to_model_response_object(response_object=response_json, model_response_object=model_response)
async def async_streaming(self,
logging_obj,
api_base: str,
data: dict, headers: dict,
model_response: ModelResponse,
model: str):
async with aiohttp.ClientSession() as session:
async with session.post(api_base, json=data, headers=headers) as response:
# Check if the request was successful (status code 200)
if response.status != 200:
raise OpenAIError(status_code=response.status, message=await response.text())
# Handle the streamed response
# async for line in response.content:
# print(line)
streamwrapper = CustomStreamWrapper(completion_stream=response, model=model, custom_llm_provider="openai",logging_obj=logging_obj)
async for transformed_chunk in streamwrapper:
yield transformed_chunk
def embedding(self,
model: str,
input: list,