refactor(azure.py): enabling async streaming with aiohttp

This commit is contained in:
Krrish Dholakia 2023-11-09 16:41:06 -08:00
parent c053782d96
commit e12bff6d7f
5 changed files with 35 additions and 15 deletions

View file

@ -1,7 +1,7 @@
from typing import Optional, Union
import types, requests
from .base import BaseLLM
from litellm.utils import ModelResponse, Choices, Message
from litellm.utils import ModelResponse, Choices, Message, CustomStreamWrapper
from typing import Callable, Optional
from litellm import OpenAIConfig
import aiohttp
@ -143,8 +143,12 @@ class AzureChatCompletion(BaseLLM):
"api_base": api_base,
},
)
if "stream" in optional_params and optional_params["stream"] == True:
if acompletion is True:
if optional_params.get("stream", False):
return self.async_streaming(logging_obj=logging_obj, api_base=api_base, data=data, headers=headers, model_response=model_response, model=model)
else:
return self.acompletion(logging_obj=logging_obj, api_base=api_base, data=data, headers=headers, model_response=model_response, model=model)
elif "stream" in optional_params and optional_params["stream"] == True:
response = self._client_session.post(
url=api_base,
json=data,
@ -156,8 +160,6 @@ class AzureChatCompletion(BaseLLM):
## RESPONSE OBJECT
return response.iter_lines()
elif acompletion is True:
return self.acompletion(api_base=api_base, data=data, headers=headers, model_response=model_response)
else:
response = self._client_session.post(
url=api_base,
@ -190,6 +192,22 @@ class AzureChatCompletion(BaseLLM):
## RESPONSE OBJECT
return self.convert_to_model_response_object(response_object=response_json, model_response_object=model_response)
async def async_streaming(self,
logging_obj,
api_base: str,
data: dict, headers: dict,
model_response: ModelResponse,
model: str):
async with aiohttp.ClientSession() as session:
async with session.post(api_base, json=data, headers=headers) as response:
# Check if the request was successful (status code 200)
if response.status != 200:
raise AzureOpenAIError(status_code=response.status, message=await response.text())
# Handle the streamed response
stream_wrapper = CustomStreamWrapper(completion_stream=response, model=model, custom_llm_provider="azure",logging_obj=logging_obj)
async for transformed_chunk in stream_wrapper:
yield transformed_chunk
def embedding(self,
model: str,

View file

@ -235,7 +235,6 @@ class OpenAIChatCompletion(BaseLLM):
## RESPONSE OBJECT
return response.iter_lines()
else:
response = self._client_session.post(
url=api_base,
@ -304,9 +303,6 @@ class OpenAIChatCompletion(BaseLLM):
if response.status != 200:
raise OpenAIError(status_code=response.status, message=await response.text())
# Handle the streamed response
# async for line in response.content:
# print(line)
streamwrapper = CustomStreamWrapper(completion_stream=response, model=model, custom_llm_provider="openai",logging_obj=logging_obj)
async for transformed_chunk in streamwrapper:
yield transformed_chunk

View file

@ -442,7 +442,7 @@ def completion(
logging_obj=logging,
acompletion=acompletion
)
if "stream" in optional_params and optional_params["stream"] == True:
if optional_params.get("stream", False) and acompletion is False:
response = CustomStreamWrapper(response, model, custom_llm_provider=custom_llm_provider, logging_obj=logging)
return response
## LOGGING

View file

@ -45,7 +45,8 @@ def test_get_response_streaming():
user_message = "Hello, how are you?"
messages = [{"content": user_message, "role": "user"}]
try:
response = await acompletion(model="gpt-3.5-turbo", messages=messages, stream=True)
response = await acompletion(model="azure/chatgpt-v-2", messages=messages, stream=True)
# response = await acompletion(model="gpt-3.5-turbo", messages=messages, stream=True)
print(type(response))
import inspect
@ -69,7 +70,7 @@ def test_get_response_streaming():
asyncio.run(test_async_call())
# test_get_response_streaming()
test_get_response_streaming()
def test_get_response_non_openai_streaming():
import asyncio
@ -100,4 +101,4 @@ def test_get_response_non_openai_streaming():
return response
asyncio.run(test_async_call())
test_get_response_non_openai_streaming()
# test_get_response_non_openai_streaming()

View file

@ -3960,7 +3960,12 @@ class CustomStreamWrapper:
is_finished = False
finish_reason = ""
text = ""
if chunk.startswith("data:"):
if "data: [DONE]" in chunk:
text = ""
is_finished = True
finish_reason = "stop"
return {"text": text, "is_finished": is_finished, "finish_reason": finish_reason}
elif chunk.startswith("data:"):
data_json = json.loads(chunk[5:]) # chunk.startswith("data:"):
try:
text = data_json["choices"][0]["delta"].get("content", "")
@ -4285,7 +4290,7 @@ class CustomStreamWrapper:
async def __anext__(self):
try:
if self.custom_llm_provider == "openai":
if self.custom_llm_provider == "openai" or self.custom_llm_provider == "azure":
async for chunk in self.completion_stream.content:
if chunk == "None" or chunk is None:
raise Exception