refactor(azure.py): enabling async streaming with aiohttp

This commit is contained in:
Krrish Dholakia 2023-11-09 16:41:06 -08:00
parent c053782d96
commit e12bff6d7f
5 changed files with 35 additions and 15 deletions

View file

@ -1,7 +1,7 @@
from typing import Optional, Union from typing import Optional, Union
import types, requests import types, requests
from .base import BaseLLM from .base import BaseLLM
from litellm.utils import ModelResponse, Choices, Message from litellm.utils import ModelResponse, Choices, Message, CustomStreamWrapper
from typing import Callable, Optional from typing import Callable, Optional
from litellm import OpenAIConfig from litellm import OpenAIConfig
import aiohttp import aiohttp
@ -143,8 +143,12 @@ class AzureChatCompletion(BaseLLM):
"api_base": api_base, "api_base": api_base,
}, },
) )
if acompletion is True:
if "stream" in optional_params and optional_params["stream"] == True: if optional_params.get("stream", False):
return self.async_streaming(logging_obj=logging_obj, api_base=api_base, data=data, headers=headers, model_response=model_response, model=model)
else:
return self.acompletion(logging_obj=logging_obj, api_base=api_base, data=data, headers=headers, model_response=model_response, model=model)
elif "stream" in optional_params and optional_params["stream"] == True:
response = self._client_session.post( response = self._client_session.post(
url=api_base, url=api_base,
json=data, json=data,
@ -156,8 +160,6 @@ class AzureChatCompletion(BaseLLM):
## RESPONSE OBJECT ## RESPONSE OBJECT
return response.iter_lines() return response.iter_lines()
elif acompletion is True:
return self.acompletion(api_base=api_base, data=data, headers=headers, model_response=model_response)
else: else:
response = self._client_session.post( response = self._client_session.post(
url=api_base, url=api_base,
@ -190,6 +192,22 @@ class AzureChatCompletion(BaseLLM):
## RESPONSE OBJECT ## RESPONSE OBJECT
return self.convert_to_model_response_object(response_object=response_json, model_response_object=model_response) return self.convert_to_model_response_object(response_object=response_json, model_response_object=model_response)
async def async_streaming(self,
logging_obj,
api_base: str,
data: dict, headers: dict,
model_response: ModelResponse,
model: str):
async with aiohttp.ClientSession() as session:
async with session.post(api_base, json=data, headers=headers) as response:
# Check if the request was successful (status code 200)
if response.status != 200:
raise AzureOpenAIError(status_code=response.status, message=await response.text())
# Handle the streamed response
stream_wrapper = CustomStreamWrapper(completion_stream=response, model=model, custom_llm_provider="azure",logging_obj=logging_obj)
async for transformed_chunk in stream_wrapper:
yield transformed_chunk
def embedding(self, def embedding(self,
model: str, model: str,

View file

@ -235,7 +235,6 @@ class OpenAIChatCompletion(BaseLLM):
## RESPONSE OBJECT ## RESPONSE OBJECT
return response.iter_lines() return response.iter_lines()
else: else:
response = self._client_session.post( response = self._client_session.post(
url=api_base, url=api_base,
@ -304,9 +303,6 @@ class OpenAIChatCompletion(BaseLLM):
if response.status != 200: if response.status != 200:
raise OpenAIError(status_code=response.status, message=await response.text()) raise OpenAIError(status_code=response.status, message=await response.text())
# Handle the streamed response
# async for line in response.content:
# print(line)
streamwrapper = CustomStreamWrapper(completion_stream=response, model=model, custom_llm_provider="openai",logging_obj=logging_obj) streamwrapper = CustomStreamWrapper(completion_stream=response, model=model, custom_llm_provider="openai",logging_obj=logging_obj)
async for transformed_chunk in streamwrapper: async for transformed_chunk in streamwrapper:
yield transformed_chunk yield transformed_chunk

View file

@ -442,7 +442,7 @@ def completion(
logging_obj=logging, logging_obj=logging,
acompletion=acompletion acompletion=acompletion
) )
if "stream" in optional_params and optional_params["stream"] == True: if optional_params.get("stream", False) and acompletion is False:
response = CustomStreamWrapper(response, model, custom_llm_provider=custom_llm_provider, logging_obj=logging) response = CustomStreamWrapper(response, model, custom_llm_provider=custom_llm_provider, logging_obj=logging)
return response return response
## LOGGING ## LOGGING

View file

@ -45,7 +45,8 @@ def test_get_response_streaming():
user_message = "Hello, how are you?" user_message = "Hello, how are you?"
messages = [{"content": user_message, "role": "user"}] messages = [{"content": user_message, "role": "user"}]
try: try:
response = await acompletion(model="gpt-3.5-turbo", messages=messages, stream=True) response = await acompletion(model="azure/chatgpt-v-2", messages=messages, stream=True)
# response = await acompletion(model="gpt-3.5-turbo", messages=messages, stream=True)
print(type(response)) print(type(response))
import inspect import inspect
@ -69,7 +70,7 @@ def test_get_response_streaming():
asyncio.run(test_async_call()) asyncio.run(test_async_call())
# test_get_response_streaming() test_get_response_streaming()
def test_get_response_non_openai_streaming(): def test_get_response_non_openai_streaming():
import asyncio import asyncio
@ -100,4 +101,4 @@ def test_get_response_non_openai_streaming():
return response return response
asyncio.run(test_async_call()) asyncio.run(test_async_call())
test_get_response_non_openai_streaming() # test_get_response_non_openai_streaming()

View file

@ -3960,7 +3960,12 @@ class CustomStreamWrapper:
is_finished = False is_finished = False
finish_reason = "" finish_reason = ""
text = "" text = ""
if chunk.startswith("data:"): if "data: [DONE]" in chunk:
text = ""
is_finished = True
finish_reason = "stop"
return {"text": text, "is_finished": is_finished, "finish_reason": finish_reason}
elif chunk.startswith("data:"):
data_json = json.loads(chunk[5:]) # chunk.startswith("data:"): data_json = json.loads(chunk[5:]) # chunk.startswith("data:"):
try: try:
text = data_json["choices"][0]["delta"].get("content", "") text = data_json["choices"][0]["delta"].get("content", "")
@ -4285,7 +4290,7 @@ class CustomStreamWrapper:
async def __anext__(self): async def __anext__(self):
try: try:
if self.custom_llm_provider == "openai": if self.custom_llm_provider == "openai" or self.custom_llm_provider == "azure":
async for chunk in self.completion_stream.content: async for chunk in self.completion_stream.content:
if chunk == "None" or chunk is None: if chunk == "None" or chunk is None:
raise Exception raise Exception