refactor(openai.py): support aiohttp streaming

This commit is contained in:
Krrish Dholakia 2023-11-09 16:15:21 -08:00
parent bba62b56d3
commit c053782d96
5 changed files with 108 additions and 42 deletions

View file

@ -1,11 +1,10 @@
from typing import Optional, Union from typing import Optional, Union
import types, requests import types, requests
from .base import BaseLLM from .base import BaseLLM
from litellm.utils import ModelResponse, Choices, Message from litellm.utils import ModelResponse, Choices, Message, CustomStreamWrapper
from typing import Callable, Optional from typing import Callable, Optional
import aiohttp import aiohttp
class OpenAIError(Exception): class OpenAIError(Exception):
def __init__(self, status_code, message): def __init__(self, status_code, message):
self.status_code = status_code self.status_code = status_code
@ -219,7 +218,12 @@ class OpenAIChatCompletion(BaseLLM):
) )
try: try:
if "stream" in optional_params and optional_params["stream"] == True: if acompletion is True:
if optional_params.get("stream", False):
return self.async_streaming(logging_obj=logging_obj, api_base=api_base, data=data, headers=headers, model_response=model_response, model=model)
else:
return self.acompletion(logging_obj=logging_obj, api_base=api_base, data=data, headers=headers, model_response=model_response, model=model)
elif "stream" in optional_params and optional_params["stream"] == True:
response = self._client_session.post( response = self._client_session.post(
url=api_base, url=api_base,
json=data, json=data,
@ -231,8 +235,7 @@ class OpenAIChatCompletion(BaseLLM):
## RESPONSE OBJECT ## RESPONSE OBJECT
return response.iter_lines() return response.iter_lines()
elif acompletion is True:
return self.acompletion(api_base=api_base, data=data, headers=headers, model_response=model_response)
else: else:
response = self._client_session.post( response = self._client_session.post(
url=api_base, url=api_base,
@ -273,7 +276,12 @@ class OpenAIChatCompletion(BaseLLM):
import traceback import traceback
raise OpenAIError(status_code=500, message=traceback.format_exc()) raise OpenAIError(status_code=500, message=traceback.format_exc())
async def acompletion(self, api_base: str, data: dict, headers: dict, model_response: ModelResponse): async def acompletion(self,
logging_obj,
api_base: str,
data: dict, headers: dict,
model_response: ModelResponse,
model: str):
async with aiohttp.ClientSession() as session: async with aiohttp.ClientSession() as session:
async with session.post(api_base, json=data, headers=headers) as response: async with session.post(api_base, json=data, headers=headers) as response:
response_json = await response.json() response_json = await response.json()
@ -284,6 +292,25 @@ class OpenAIChatCompletion(BaseLLM):
## RESPONSE OBJECT ## RESPONSE OBJECT
return self.convert_to_model_response_object(response_object=response_json, model_response_object=model_response) return self.convert_to_model_response_object(response_object=response_json, model_response_object=model_response)
async def async_streaming(self,
logging_obj,
api_base: str,
data: dict, headers: dict,
model_response: ModelResponse,
model: str):
async with aiohttp.ClientSession() as session:
async with session.post(api_base, json=data, headers=headers) as response:
# Check if the request was successful (status code 200)
if response.status != 200:
raise OpenAIError(status_code=response.status, message=await response.text())
# Handle the streamed response
# async for line in response.content:
# print(line)
streamwrapper = CustomStreamWrapper(completion_stream=response, model=model, custom_llm_provider="openai",logging_obj=logging_obj)
async for transformed_chunk in streamwrapper:
yield transformed_chunk
def embedding(self, def embedding(self,
model: str, model: str,
input: list, input: list,

View file

@ -137,9 +137,13 @@ async def acompletion(model: str, messages: List = [], *args, **kwargs):
_, custom_llm_provider, _, _ = get_llm_provider(model=model, api_base=kwargs.get("api_base", None)) _, custom_llm_provider, _, _ = get_llm_provider(model=model, api_base=kwargs.get("api_base", None))
if custom_llm_provider == "openai" or custom_llm_provider == "azure": # currently implemented aiohttp calls for just azure and openai, soon all.
# Await normally if (custom_llm_provider == "openai" or custom_llm_provider == "azure"): # currently implemented aiohttp calls for just azure and openai, soon all.
response = await completion(*args, **kwargs) if kwargs.get("stream", False):
response = completion(*args, **kwargs)
else:
# Await normally
response = await completion(*args, **kwargs)
else: else:
# Call the synchronous function using run_in_executor # Call the synchronous function using run_in_executor
response = await loop.run_in_executor(None, func_with_context) response = await loop.run_in_executor(None, func_with_context)
@ -147,6 +151,7 @@ async def acompletion(model: str, messages: List = [], *args, **kwargs):
# do not change this # do not change this
# for stream = True, always return an async generator # for stream = True, always return an async generator
# See OpenAI acreate https://github.com/openai/openai-python/blob/5d50e9e3b39540af782ca24e65c290343d86e1a9/openai/api_resources/abstract/engine_api_resource.py#L193 # See OpenAI acreate https://github.com/openai/openai-python/blob/5d50e9e3b39540af782ca24e65c290343d86e1a9/openai/api_resources/abstract/engine_api_resource.py#L193
# return response
return( return(
line line
async for line in response async for line in response
@ -515,7 +520,7 @@ def completion(
) )
raise e raise e
if "stream" in optional_params and optional_params["stream"] == True: if optional_params.get("stream", False) and acompletion is False:
response = CustomStreamWrapper(response, model, custom_llm_provider=custom_llm_provider, logging_obj=logging) response = CustomStreamWrapper(response, model, custom_llm_provider=custom_llm_provider, logging_obj=logging)
return response return response
## LOGGING ## LOGGING

View file

@ -37,7 +37,7 @@ def test_async_response():
response = asyncio.run(test_get_response()) response = asyncio.run(test_get_response())
# print(response) # print(response)
test_async_response() # test_async_response()
def test_get_response_streaming(): def test_get_response_streaming():
import asyncio import asyncio
@ -45,8 +45,6 @@ def test_get_response_streaming():
user_message = "Hello, how are you?" user_message = "Hello, how are you?"
messages = [{"content": user_message, "role": "user"}] messages = [{"content": user_message, "role": "user"}]
try: try:
import litellm
litellm.set_verbose = True
response = await acompletion(model="gpt-3.5-turbo", messages=messages, stream=True) response = await acompletion(model="gpt-3.5-turbo", messages=messages, stream=True)
print(type(response)) print(type(response))
@ -56,11 +54,11 @@ def test_get_response_streaming():
print(is_async_generator) print(is_async_generator)
output = "" output = ""
i = 0
async for chunk in response: async for chunk in response:
token = chunk["choices"][0]["delta"].get("content", "") token = chunk["choices"][0]["delta"].get("content", "")
output += token output += token
print(output) print(f"output: {output}")
assert output is not None, "output cannot be None." assert output is not None, "output cannot be None."
assert isinstance(output, str), "output needs to be of type str" assert isinstance(output, str), "output needs to be of type str"
assert len(output) > 0, "Length of output needs to be greater than 0." assert len(output) > 0, "Length of output needs to be greater than 0."
@ -71,3 +69,35 @@ def test_get_response_streaming():
asyncio.run(test_async_call()) asyncio.run(test_async_call())
# test_get_response_streaming()
def test_get_response_non_openai_streaming():
import asyncio
async def test_async_call():
user_message = "Hello, how are you?"
messages = [{"content": user_message, "role": "user"}]
try:
response = await acompletion(model="command-nightly", messages=messages, stream=True)
print(type(response))
import inspect
is_async_generator = inspect.isasyncgen(response)
print(is_async_generator)
output = ""
i = 0
async for chunk in response:
token = chunk["choices"][0]["delta"].get("content", "")
output += token
print(f"output: {output}")
assert output is not None, "output cannot be None."
assert isinstance(output, str), "output needs to be of type str"
assert len(output) > 0, "Length of output needs to be greater than 0."
except Exception as e:
pytest.fail(f"An exception occurred: {e}")
return response
asyncio.run(test_async_call())
test_get_response_non_openai_streaming()

View file

@ -60,5 +60,4 @@ def test_stream_chunk_builder():
print(role, content, finish_reason) print(role, content, finish_reason)
except Exception as e: except Exception as e:
raise Exception("stream_chunk_builder failed to rebuild response", e) raise Exception("stream_chunk_builder failed to rebuild response", e)
test_stream_chunk_builder()

View file

@ -3998,14 +3998,15 @@ class CustomStreamWrapper:
text = "" text = ""
is_finished = False is_finished = False
finish_reason = None finish_reason = None
if str_line == "data: [DONE]": if "data: [DONE]" in str_line:
# anyscale returns a [DONE] special char for streaming, this cannot be json loaded. This is the end of stream # anyscale returns a [DONE] special char for streaming, this cannot be json loaded. This is the end of stream
text = "" text = ""
is_finished = True is_finished = True
finish_reason = "stop" finish_reason = "stop"
return {"text": text, "is_finished": is_finished, "finish_reason": finish_reason} return {"text": text, "is_finished": is_finished, "finish_reason": finish_reason}
elif str_line.startswith("data:"): elif str_line.startswith("data:") and len(str_line[5:]) > 0:
data_json = json.loads(str_line[5:]) str_line = str_line[5:]
data_json = json.loads(str_line)
print_verbose(f"delta content: {data_json['choices'][0]['delta']}") print_verbose(f"delta content: {data_json['choices'][0]['delta']}")
text = data_json["choices"][0]["delta"].get("content", "") text = data_json["choices"][0]["delta"].get("content", "")
if data_json["choices"][0].get("finish_reason", None): if data_json["choices"][0].get("finish_reason", None):
@ -4105,71 +4106,60 @@ class CustomStreamWrapper:
return {"text": text, "is_finished": is_finished, "finish_reason": finish_reason} return {"text": text, "is_finished": is_finished, "finish_reason": finish_reason}
return "" return ""
## needs to handle the empty string case (even starting chunk can be an empty string) def chunk_creator(self, chunk):
def __next__(self):
model_response = ModelResponse(stream=True, model=self.model) model_response = ModelResponse(stream=True, model=self.model)
try: try:
while True: # loop until a non-empty string is found while True: # loop until a non-empty string is found
# return this for all models # return this for all models
completion_obj = {"content": ""} completion_obj = {"content": ""}
if self.custom_llm_provider and self.custom_llm_provider == "anthropic": if self.custom_llm_provider and self.custom_llm_provider == "anthropic":
chunk = next(self.completion_stream)
response_obj = self.handle_anthropic_chunk(chunk) response_obj = self.handle_anthropic_chunk(chunk)
completion_obj["content"] = response_obj["text"] completion_obj["content"] = response_obj["text"]
if response_obj["is_finished"]: if response_obj["is_finished"]:
model_response.choices[0].finish_reason = response_obj["finish_reason"] model_response.choices[0].finish_reason = response_obj["finish_reason"]
elif self.model == "replicate" or self.custom_llm_provider == "replicate": elif self.model == "replicate" or self.custom_llm_provider == "replicate":
chunk = next(self.completion_stream)
response_obj = self.handle_replicate_chunk(chunk) response_obj = self.handle_replicate_chunk(chunk)
completion_obj["content"] = response_obj["text"] completion_obj["content"] = response_obj["text"]
if response_obj["is_finished"]: if response_obj["is_finished"]:
model_response.choices[0].finish_reason = response_obj["finish_reason"] model_response.choices[0].finish_reason = response_obj["finish_reason"]
elif ( elif (
self.custom_llm_provider and self.custom_llm_provider == "together_ai"): self.custom_llm_provider and self.custom_llm_provider == "together_ai"):
chunk = next(self.completion_stream)
response_obj = self.handle_together_ai_chunk(chunk) response_obj = self.handle_together_ai_chunk(chunk)
completion_obj["content"] = response_obj["text"] completion_obj["content"] = response_obj["text"]
if response_obj["is_finished"]: if response_obj["is_finished"]:
model_response.choices[0].finish_reason = response_obj["finish_reason"] model_response.choices[0].finish_reason = response_obj["finish_reason"]
elif self.custom_llm_provider and self.custom_llm_provider == "huggingface": elif self.custom_llm_provider and self.custom_llm_provider == "huggingface":
chunk = next(self.completion_stream)
response_obj = self.handle_huggingface_chunk(chunk) response_obj = self.handle_huggingface_chunk(chunk)
completion_obj["content"] = response_obj["text"] completion_obj["content"] = response_obj["text"]
if response_obj["is_finished"]: if response_obj["is_finished"]:
model_response.choices[0].finish_reason = response_obj["finish_reason"] model_response.choices[0].finish_reason = response_obj["finish_reason"]
elif self.custom_llm_provider and self.custom_llm_provider == "baseten": # baseten doesn't provide streaming elif self.custom_llm_provider and self.custom_llm_provider == "baseten": # baseten doesn't provide streaming
chunk = next(self.completion_stream)
completion_obj["content"] = self.handle_baseten_chunk(chunk) completion_obj["content"] = self.handle_baseten_chunk(chunk)
elif self.custom_llm_provider and self.custom_llm_provider == "ai21": #ai21 doesn't provide streaming elif self.custom_llm_provider and self.custom_llm_provider == "ai21": #ai21 doesn't provide streaming
chunk = next(self.completion_stream)
response_obj = self.handle_ai21_chunk(chunk) response_obj = self.handle_ai21_chunk(chunk)
completion_obj["content"] = response_obj["text"] completion_obj["content"] = response_obj["text"]
if response_obj["is_finished"]: if response_obj["is_finished"]:
model_response.choices[0].finish_reason = response_obj["finish_reason"] model_response.choices[0].finish_reason = response_obj["finish_reason"]
elif self.custom_llm_provider and self.custom_llm_provider == "azure": elif self.custom_llm_provider and self.custom_llm_provider == "azure":
chunk = next(self.completion_stream)
response_obj = self.handle_azure_chunk(chunk) response_obj = self.handle_azure_chunk(chunk)
completion_obj["content"] = response_obj["text"] completion_obj["content"] = response_obj["text"]
if response_obj["is_finished"]: if response_obj["is_finished"]:
model_response.choices[0].finish_reason = response_obj["finish_reason"] model_response.choices[0].finish_reason = response_obj["finish_reason"]
elif self.custom_llm_provider and self.custom_llm_provider == "maritalk": elif self.custom_llm_provider and self.custom_llm_provider == "maritalk":
chunk = next(self.completion_stream)
response_obj = self.handle_maritalk_chunk(chunk) response_obj = self.handle_maritalk_chunk(chunk)
completion_obj["content"] = response_obj["text"] completion_obj["content"] = response_obj["text"]
if response_obj["is_finished"]: if response_obj["is_finished"]:
model_response.choices[0].finish_reason = response_obj["finish_reason"] model_response.choices[0].finish_reason = response_obj["finish_reason"]
elif self.custom_llm_provider and self.custom_llm_provider == "vllm": elif self.custom_llm_provider and self.custom_llm_provider == "vllm":
chunk = next(self.completion_stream)
completion_obj["content"] = chunk[0].outputs[0].text completion_obj["content"] = chunk[0].outputs[0].text
elif self.custom_llm_provider and self.custom_llm_provider == "aleph_alpha": #aleph alpha doesn't provide streaming elif self.custom_llm_provider and self.custom_llm_provider == "aleph_alpha": #aleph alpha doesn't provide streaming
chunk = next(self.completion_stream)
response_obj = self.handle_aleph_alpha_chunk(chunk) response_obj = self.handle_aleph_alpha_chunk(chunk)
completion_obj["content"] = response_obj["text"] completion_obj["content"] = response_obj["text"]
if response_obj["is_finished"]: if response_obj["is_finished"]:
model_response.choices[0].finish_reason = response_obj["finish_reason"] model_response.choices[0].finish_reason = response_obj["finish_reason"]
elif self.model in litellm.nlp_cloud_models or self.custom_llm_provider == "nlp_cloud": elif self.model in litellm.nlp_cloud_models or self.custom_llm_provider == "nlp_cloud":
try: try:
chunk = next(self.completion_stream)
response_obj = self.handle_nlp_cloud_chunk(chunk) response_obj = self.handle_nlp_cloud_chunk(chunk)
completion_obj["content"] = response_obj["text"] completion_obj["content"] = response_obj["text"]
if response_obj["is_finished"]: if response_obj["is_finished"]:
@ -4184,7 +4174,7 @@ class CustomStreamWrapper:
self.sent_last_chunk = True self.sent_last_chunk = True
elif self.custom_llm_provider and self.custom_llm_provider == "vertex_ai": elif self.custom_llm_provider and self.custom_llm_provider == "vertex_ai":
try: try:
chunk = next(self.completion_stream)
completion_obj["content"] = str(chunk) completion_obj["content"] = str(chunk)
except StopIteration as e: except StopIteration as e:
if self.sent_last_chunk: if self.sent_last_chunk:
@ -4193,13 +4183,11 @@ class CustomStreamWrapper:
model_response.choices[0].finish_reason = "stop" model_response.choices[0].finish_reason = "stop"
self.sent_last_chunk = True self.sent_last_chunk = True
elif self.custom_llm_provider == "cohere": elif self.custom_llm_provider == "cohere":
chunk = next(self.completion_stream)
response_obj = self.handle_cohere_chunk(chunk) response_obj = self.handle_cohere_chunk(chunk)
completion_obj["content"] = response_obj["text"] completion_obj["content"] = response_obj["text"]
if response_obj["is_finished"]: if response_obj["is_finished"]:
model_response.choices[0].finish_reason = response_obj["finish_reason"] model_response.choices[0].finish_reason = response_obj["finish_reason"]
elif self.custom_llm_provider == "bedrock": elif self.custom_llm_provider == "bedrock":
chunk = next(self.completion_stream)
response_obj = self.handle_bedrock_stream(chunk) response_obj = self.handle_bedrock_stream(chunk)
completion_obj["content"] = response_obj["text"] completion_obj["content"] = response_obj["text"]
if response_obj["is_finished"]: if response_obj["is_finished"]:
@ -4242,19 +4230,16 @@ class CustomStreamWrapper:
self.completion_stream = self.completion_stream[chunk_size:] self.completion_stream = self.completion_stream[chunk_size:]
time.sleep(0.05) time.sleep(0.05)
elif self.custom_llm_provider == "ollama": elif self.custom_llm_provider == "ollama":
chunk = next(self.completion_stream)
if "error" in chunk: if "error" in chunk:
exception_type(model=self.model, custom_llm_provider=self.custom_llm_provider, original_exception=chunk["error"]) exception_type(model=self.model, custom_llm_provider=self.custom_llm_provider, original_exception=chunk["error"])
completion_obj = chunk completion_obj = chunk
elif self.custom_llm_provider == "openai": elif self.custom_llm_provider == "openai":
chunk = next(self.completion_stream)
response_obj = self.handle_openai_chat_completion_chunk(chunk) response_obj = self.handle_openai_chat_completion_chunk(chunk)
completion_obj["content"] = response_obj["text"] completion_obj["content"] = response_obj["text"]
print_verbose(f"completion obj content: {completion_obj['content']}") print_verbose(f"completion obj content: {completion_obj['content']}")
if response_obj["is_finished"]: if response_obj["is_finished"]:
model_response.choices[0].finish_reason = response_obj["finish_reason"] model_response.choices[0].finish_reason = response_obj["finish_reason"]
elif self.custom_llm_provider == "text-completion-openai": elif self.custom_llm_provider == "text-completion-openai":
chunk = next(self.completion_stream)
response_obj = self.handle_openai_text_completion_chunk(chunk) response_obj = self.handle_openai_text_completion_chunk(chunk)
completion_obj["content"] = response_obj["text"] completion_obj["content"] = response_obj["text"]
print_verbose(f"completion obj content: {completion_obj['content']}") print_verbose(f"completion obj content: {completion_obj['content']}")
@ -4275,11 +4260,15 @@ class CustomStreamWrapper:
# LOGGING # LOGGING
threading.Thread(target=self.logging_obj.success_handler, args=(model_response,)).start() threading.Thread(target=self.logging_obj.success_handler, args=(model_response,)).start()
return model_response return model_response
else:
return
elif model_response.choices[0].finish_reason: elif model_response.choices[0].finish_reason:
model_response.choices[0].finish_reason = map_finish_reason(model_response.choices[0].finish_reason) # ensure consistent output to openai model_response.choices[0].finish_reason = map_finish_reason(model_response.choices[0].finish_reason) # ensure consistent output to openai
# LOGGING # LOGGING
threading.Thread(target=self.logging_obj.success_handler, args=(model_response,)).start() threading.Thread(target=self.logging_obj.success_handler, args=(model_response,)).start()
return model_response return model_response
else:
return
except StopIteration: except StopIteration:
raise StopIteration raise StopIteration
except Exception as e: except Exception as e:
@ -4289,10 +4278,26 @@ class CustomStreamWrapper:
threading.Thread(target=self.logging_obj.failure_handler, args=(e, traceback_exception)).start() threading.Thread(target=self.logging_obj.failure_handler, args=(e, traceback_exception)).start()
return exception_type(model=self.model, custom_llm_provider=self.custom_llm_provider, original_exception=e) return exception_type(model=self.model, custom_llm_provider=self.custom_llm_provider, original_exception=e)
## needs to handle the empty string case (even starting chunk can be an empty string)
def __next__(self):
chunk = next(self.completion_stream)
return self.chunk_creator(chunk=chunk)
async def __anext__(self): async def __anext__(self):
try: try:
return next(self) if self.custom_llm_provider == "openai":
except StopIteration: async for chunk in self.completion_stream.content:
if chunk == "None" or chunk is None:
raise Exception
processed_chunk = self.chunk_creator(chunk=chunk)
if processed_chunk is None:
continue
return processed_chunk
raise StopAsyncIteration
else: # temporary patch for non-aiohttp async calls
return next(self)
except Exception as e:
# Handle any exceptions that might occur during streaming
raise StopAsyncIteration raise StopAsyncIteration
class TextCompletionStreamWrapper: class TextCompletionStreamWrapper: