refactor(openai.py): support aiohttp streaming

This commit is contained in:
Krrish Dholakia 2023-11-09 16:15:21 -08:00
parent bba62b56d3
commit c053782d96
5 changed files with 108 additions and 42 deletions

View file

@ -1,11 +1,10 @@
from typing import Optional, Union
import types, requests
from .base import BaseLLM
from litellm.utils import ModelResponse, Choices, Message
from litellm.utils import ModelResponse, Choices, Message, CustomStreamWrapper
from typing import Callable, Optional
import aiohttp
class OpenAIError(Exception):
def __init__(self, status_code, message):
self.status_code = status_code
@ -219,7 +218,12 @@ class OpenAIChatCompletion(BaseLLM):
)
try:
if "stream" in optional_params and optional_params["stream"] == True:
if acompletion is True:
if optional_params.get("stream", False):
return self.async_streaming(logging_obj=logging_obj, api_base=api_base, data=data, headers=headers, model_response=model_response, model=model)
else:
return self.acompletion(logging_obj=logging_obj, api_base=api_base, data=data, headers=headers, model_response=model_response, model=model)
elif "stream" in optional_params and optional_params["stream"] == True:
response = self._client_session.post(
url=api_base,
json=data,
@ -231,8 +235,7 @@ class OpenAIChatCompletion(BaseLLM):
## RESPONSE OBJECT
return response.iter_lines()
elif acompletion is True:
return self.acompletion(api_base=api_base, data=data, headers=headers, model_response=model_response)
else:
response = self._client_session.post(
url=api_base,
@ -273,7 +276,12 @@ class OpenAIChatCompletion(BaseLLM):
import traceback
raise OpenAIError(status_code=500, message=traceback.format_exc())
async def acompletion(self, api_base: str, data: dict, headers: dict, model_response: ModelResponse):
async def acompletion(self,
logging_obj,
api_base: str,
data: dict, headers: dict,
model_response: ModelResponse,
model: str):
async with aiohttp.ClientSession() as session:
async with session.post(api_base, json=data, headers=headers) as response:
response_json = await response.json()
@ -284,6 +292,25 @@ class OpenAIChatCompletion(BaseLLM):
## RESPONSE OBJECT
return self.convert_to_model_response_object(response_object=response_json, model_response_object=model_response)
async def async_streaming(self,
logging_obj,
api_base: str,
data: dict, headers: dict,
model_response: ModelResponse,
model: str):
async with aiohttp.ClientSession() as session:
async with session.post(api_base, json=data, headers=headers) as response:
# Check if the request was successful (status code 200)
if response.status != 200:
raise OpenAIError(status_code=response.status, message=await response.text())
# Handle the streamed response
# async for line in response.content:
# print(line)
streamwrapper = CustomStreamWrapper(completion_stream=response, model=model, custom_llm_provider="openai",logging_obj=logging_obj)
async for transformed_chunk in streamwrapper:
yield transformed_chunk
def embedding(self,
model: str,
input: list,

View file

@ -137,9 +137,13 @@ async def acompletion(model: str, messages: List = [], *args, **kwargs):
_, custom_llm_provider, _, _ = get_llm_provider(model=model, api_base=kwargs.get("api_base", None))
if custom_llm_provider == "openai" or custom_llm_provider == "azure": # currently implemented aiohttp calls for just azure and openai, soon all.
# Await normally
response = await completion(*args, **kwargs)
if (custom_llm_provider == "openai" or custom_llm_provider == "azure"): # currently implemented aiohttp calls for just azure and openai, soon all.
if kwargs.get("stream", False):
response = completion(*args, **kwargs)
else:
# Await normally
response = await completion(*args, **kwargs)
else:
# Call the synchronous function using run_in_executor
response = await loop.run_in_executor(None, func_with_context)
@ -147,6 +151,7 @@ async def acompletion(model: str, messages: List = [], *args, **kwargs):
# do not change this
# for stream = True, always return an async generator
# See OpenAI acreate https://github.com/openai/openai-python/blob/5d50e9e3b39540af782ca24e65c290343d86e1a9/openai/api_resources/abstract/engine_api_resource.py#L193
# return response
return(
line
async for line in response
@ -515,7 +520,7 @@ def completion(
)
raise e
if "stream" in optional_params and optional_params["stream"] == True:
if optional_params.get("stream", False) and acompletion is False:
response = CustomStreamWrapper(response, model, custom_llm_provider=custom_llm_provider, logging_obj=logging)
return response
## LOGGING

View file

@ -37,7 +37,7 @@ def test_async_response():
response = asyncio.run(test_get_response())
# print(response)
test_async_response()
# test_async_response()
def test_get_response_streaming():
import asyncio
@ -45,8 +45,6 @@ def test_get_response_streaming():
user_message = "Hello, how are you?"
messages = [{"content": user_message, "role": "user"}]
try:
import litellm
litellm.set_verbose = True
response = await acompletion(model="gpt-3.5-turbo", messages=messages, stream=True)
print(type(response))
@ -56,11 +54,11 @@ def test_get_response_streaming():
print(is_async_generator)
output = ""
i = 0
async for chunk in response:
token = chunk["choices"][0]["delta"].get("content", "")
output += token
print(output)
print(f"output: {output}")
assert output is not None, "output cannot be None."
assert isinstance(output, str), "output needs to be of type str"
assert len(output) > 0, "Length of output needs to be greater than 0."
@ -71,3 +69,35 @@ def test_get_response_streaming():
asyncio.run(test_async_call())
# test_get_response_streaming()
def test_get_response_non_openai_streaming():
import asyncio
async def test_async_call():
user_message = "Hello, how are you?"
messages = [{"content": user_message, "role": "user"}]
try:
response = await acompletion(model="command-nightly", messages=messages, stream=True)
print(type(response))
import inspect
is_async_generator = inspect.isasyncgen(response)
print(is_async_generator)
output = ""
i = 0
async for chunk in response:
token = chunk["choices"][0]["delta"].get("content", "")
output += token
print(f"output: {output}")
assert output is not None, "output cannot be None."
assert isinstance(output, str), "output needs to be of type str"
assert len(output) > 0, "Length of output needs to be greater than 0."
except Exception as e:
pytest.fail(f"An exception occurred: {e}")
return response
asyncio.run(test_async_call())
test_get_response_non_openai_streaming()

View file

@ -60,5 +60,4 @@ def test_stream_chunk_builder():
print(role, content, finish_reason)
except Exception as e:
raise Exception("stream_chunk_builder failed to rebuild response", e)
test_stream_chunk_builder()

View file

@ -3998,14 +3998,15 @@ class CustomStreamWrapper:
text = ""
is_finished = False
finish_reason = None
if str_line == "data: [DONE]":
if "data: [DONE]" in str_line:
# anyscale returns a [DONE] special char for streaming, this cannot be json loaded. This is the end of stream
text = ""
is_finished = True
finish_reason = "stop"
return {"text": text, "is_finished": is_finished, "finish_reason": finish_reason}
elif str_line.startswith("data:"):
data_json = json.loads(str_line[5:])
elif str_line.startswith("data:") and len(str_line[5:]) > 0:
str_line = str_line[5:]
data_json = json.loads(str_line)
print_verbose(f"delta content: {data_json['choices'][0]['delta']}")
text = data_json["choices"][0]["delta"].get("content", "")
if data_json["choices"][0].get("finish_reason", None):
@ -4104,72 +4105,61 @@ class CustomStreamWrapper:
raise Exception(chunk["error"])
return {"text": text, "is_finished": is_finished, "finish_reason": finish_reason}
return ""
## needs to handle the empty string case (even starting chunk can be an empty string)
def __next__(self):
def chunk_creator(self, chunk):
model_response = ModelResponse(stream=True, model=self.model)
try:
while True: # loop until a non-empty string is found
# return this for all models
completion_obj = {"content": ""}
if self.custom_llm_provider and self.custom_llm_provider == "anthropic":
chunk = next(self.completion_stream)
response_obj = self.handle_anthropic_chunk(chunk)
completion_obj["content"] = response_obj["text"]
if response_obj["is_finished"]:
model_response.choices[0].finish_reason = response_obj["finish_reason"]
elif self.model == "replicate" or self.custom_llm_provider == "replicate":
chunk = next(self.completion_stream)
response_obj = self.handle_replicate_chunk(chunk)
completion_obj["content"] = response_obj["text"]
if response_obj["is_finished"]:
model_response.choices[0].finish_reason = response_obj["finish_reason"]
elif (
self.custom_llm_provider and self.custom_llm_provider == "together_ai"):
chunk = next(self.completion_stream)
response_obj = self.handle_together_ai_chunk(chunk)
completion_obj["content"] = response_obj["text"]
if response_obj["is_finished"]:
model_response.choices[0].finish_reason = response_obj["finish_reason"]
elif self.custom_llm_provider and self.custom_llm_provider == "huggingface":
chunk = next(self.completion_stream)
response_obj = self.handle_huggingface_chunk(chunk)
completion_obj["content"] = response_obj["text"]
if response_obj["is_finished"]:
model_response.choices[0].finish_reason = response_obj["finish_reason"]
elif self.custom_llm_provider and self.custom_llm_provider == "baseten": # baseten doesn't provide streaming
chunk = next(self.completion_stream)
completion_obj["content"] = self.handle_baseten_chunk(chunk)
elif self.custom_llm_provider and self.custom_llm_provider == "ai21": #ai21 doesn't provide streaming
chunk = next(self.completion_stream)
response_obj = self.handle_ai21_chunk(chunk)
completion_obj["content"] = response_obj["text"]
if response_obj["is_finished"]:
model_response.choices[0].finish_reason = response_obj["finish_reason"]
elif self.custom_llm_provider and self.custom_llm_provider == "azure":
chunk = next(self.completion_stream)
response_obj = self.handle_azure_chunk(chunk)
completion_obj["content"] = response_obj["text"]
if response_obj["is_finished"]:
model_response.choices[0].finish_reason = response_obj["finish_reason"]
elif self.custom_llm_provider and self.custom_llm_provider == "maritalk":
chunk = next(self.completion_stream)
response_obj = self.handle_maritalk_chunk(chunk)
completion_obj["content"] = response_obj["text"]
if response_obj["is_finished"]:
model_response.choices[0].finish_reason = response_obj["finish_reason"]
elif self.custom_llm_provider and self.custom_llm_provider == "vllm":
chunk = next(self.completion_stream)
completion_obj["content"] = chunk[0].outputs[0].text
elif self.custom_llm_provider and self.custom_llm_provider == "aleph_alpha": #aleph alpha doesn't provide streaming
chunk = next(self.completion_stream)
response_obj = self.handle_aleph_alpha_chunk(chunk)
completion_obj["content"] = response_obj["text"]
if response_obj["is_finished"]:
model_response.choices[0].finish_reason = response_obj["finish_reason"]
elif self.model in litellm.nlp_cloud_models or self.custom_llm_provider == "nlp_cloud":
try:
chunk = next(self.completion_stream)
response_obj = self.handle_nlp_cloud_chunk(chunk)
completion_obj["content"] = response_obj["text"]
if response_obj["is_finished"]:
@ -4184,7 +4174,7 @@ class CustomStreamWrapper:
self.sent_last_chunk = True
elif self.custom_llm_provider and self.custom_llm_provider == "vertex_ai":
try:
chunk = next(self.completion_stream)
completion_obj["content"] = str(chunk)
except StopIteration as e:
if self.sent_last_chunk:
@ -4193,13 +4183,11 @@ class CustomStreamWrapper:
model_response.choices[0].finish_reason = "stop"
self.sent_last_chunk = True
elif self.custom_llm_provider == "cohere":
chunk = next(self.completion_stream)
response_obj = self.handle_cohere_chunk(chunk)
completion_obj["content"] = response_obj["text"]
if response_obj["is_finished"]:
model_response.choices[0].finish_reason = response_obj["finish_reason"]
elif self.custom_llm_provider == "bedrock":
chunk = next(self.completion_stream)
response_obj = self.handle_bedrock_stream(chunk)
completion_obj["content"] = response_obj["text"]
if response_obj["is_finished"]:
@ -4242,19 +4230,16 @@ class CustomStreamWrapper:
self.completion_stream = self.completion_stream[chunk_size:]
time.sleep(0.05)
elif self.custom_llm_provider == "ollama":
chunk = next(self.completion_stream)
if "error" in chunk:
exception_type(model=self.model, custom_llm_provider=self.custom_llm_provider, original_exception=chunk["error"])
completion_obj = chunk
elif self.custom_llm_provider == "openai":
chunk = next(self.completion_stream)
response_obj = self.handle_openai_chat_completion_chunk(chunk)
completion_obj["content"] = response_obj["text"]
print_verbose(f"completion obj content: {completion_obj['content']}")
if response_obj["is_finished"]:
model_response.choices[0].finish_reason = response_obj["finish_reason"]
elif self.custom_llm_provider == "text-completion-openai":
chunk = next(self.completion_stream)
response_obj = self.handle_openai_text_completion_chunk(chunk)
completion_obj["content"] = response_obj["text"]
print_verbose(f"completion obj content: {completion_obj['content']}")
@ -4267,7 +4252,7 @@ class CustomStreamWrapper:
if len(completion_obj["content"]) > 0: # cannot set content of an OpenAI Object to be an empty string
hold, model_response_str = self.check_special_tokens(completion_obj["content"])
if hold is False:
completion_obj["content"] = model_response_str
completion_obj["content"] = model_response_str
if self.sent_first_chunk == False:
completion_obj["role"] = "assistant"
self.sent_first_chunk = True
@ -4275,11 +4260,15 @@ class CustomStreamWrapper:
# LOGGING
threading.Thread(target=self.logging_obj.success_handler, args=(model_response,)).start()
return model_response
else:
return
elif model_response.choices[0].finish_reason:
model_response.choices[0].finish_reason = map_finish_reason(model_response.choices[0].finish_reason) # ensure consistent output to openai
# LOGGING
threading.Thread(target=self.logging_obj.success_handler, args=(model_response,)).start()
return model_response
else:
return
except StopIteration:
raise StopIteration
except Exception as e:
@ -4288,11 +4277,27 @@ class CustomStreamWrapper:
# LOG FAILURE - handle streaming failure logging in the _next_ object, remove `handle_failure` once it's deprecated
threading.Thread(target=self.logging_obj.failure_handler, args=(e, traceback_exception)).start()
return exception_type(model=self.model, custom_llm_provider=self.custom_llm_provider, original_exception=e)
## needs to handle the empty string case (even starting chunk can be an empty string)
def __next__(self):
chunk = next(self.completion_stream)
return self.chunk_creator(chunk=chunk)
async def __anext__(self):
try:
return next(self)
except StopIteration:
if self.custom_llm_provider == "openai":
async for chunk in self.completion_stream.content:
if chunk == "None" or chunk is None:
raise Exception
processed_chunk = self.chunk_creator(chunk=chunk)
if processed_chunk is None:
continue
return processed_chunk
raise StopAsyncIteration
else: # temporary patch for non-aiohttp async calls
return next(self)
except Exception as e:
# Handle any exceptions that might occur during streaming
raise StopAsyncIteration
class TextCompletionStreamWrapper: