From c053782d961f9f088000e24f1b5c72b329bf44f6 Mon Sep 17 00:00:00 2001 From: Krrish Dholakia Date: Thu, 9 Nov 2023 16:15:21 -0800 Subject: [PATCH] refactor(openai.py): support aiohttp streaming --- litellm/llms/openai.py | 39 ++++++++++++--- litellm/main.py | 13 +++-- litellm/tests/test_async_fn.py | 40 +++++++++++++-- litellm/tests/test_stream_chunk_builder.py | 1 - litellm/utils.py | 57 ++++++++++++---------- 5 files changed, 108 insertions(+), 42 deletions(-) diff --git a/litellm/llms/openai.py b/litellm/llms/openai.py index 52801d2f0..06b063273 100644 --- a/litellm/llms/openai.py +++ b/litellm/llms/openai.py @@ -1,11 +1,10 @@ from typing import Optional, Union import types, requests from .base import BaseLLM -from litellm.utils import ModelResponse, Choices, Message +from litellm.utils import ModelResponse, Choices, Message, CustomStreamWrapper from typing import Callable, Optional import aiohttp - class OpenAIError(Exception): def __init__(self, status_code, message): self.status_code = status_code @@ -219,7 +218,12 @@ class OpenAIChatCompletion(BaseLLM): ) try: - if "stream" in optional_params and optional_params["stream"] == True: + if acompletion is True: + if optional_params.get("stream", False): + return self.async_streaming(logging_obj=logging_obj, api_base=api_base, data=data, headers=headers, model_response=model_response, model=model) + else: + return self.acompletion(logging_obj=logging_obj, api_base=api_base, data=data, headers=headers, model_response=model_response, model=model) + elif "stream" in optional_params and optional_params["stream"] == True: response = self._client_session.post( url=api_base, json=data, @@ -231,8 +235,7 @@ class OpenAIChatCompletion(BaseLLM): ## RESPONSE OBJECT return response.iter_lines() - elif acompletion is True: - return self.acompletion(api_base=api_base, data=data, headers=headers, model_response=model_response) + else: response = self._client_session.post( url=api_base, @@ -273,7 +276,12 @@ class OpenAIChatCompletion(BaseLLM): import traceback raise OpenAIError(status_code=500, message=traceback.format_exc()) - async def acompletion(self, api_base: str, data: dict, headers: dict, model_response: ModelResponse): + async def acompletion(self, + logging_obj, + api_base: str, + data: dict, headers: dict, + model_response: ModelResponse, + model: str): async with aiohttp.ClientSession() as session: async with session.post(api_base, json=data, headers=headers) as response: response_json = await response.json() @@ -284,6 +292,25 @@ class OpenAIChatCompletion(BaseLLM): ## RESPONSE OBJECT return self.convert_to_model_response_object(response_object=response_json, model_response_object=model_response) + async def async_streaming(self, + logging_obj, + api_base: str, + data: dict, headers: dict, + model_response: ModelResponse, + model: str): + async with aiohttp.ClientSession() as session: + async with session.post(api_base, json=data, headers=headers) as response: + # Check if the request was successful (status code 200) + if response.status != 200: + raise OpenAIError(status_code=response.status, message=await response.text()) + + # Handle the streamed response + # async for line in response.content: + # print(line) + streamwrapper = CustomStreamWrapper(completion_stream=response, model=model, custom_llm_provider="openai",logging_obj=logging_obj) + async for transformed_chunk in streamwrapper: + yield transformed_chunk + def embedding(self, model: str, input: list, diff --git a/litellm/main.py b/litellm/main.py index 4f2ab291e..d00ee8103 100644 --- a/litellm/main.py +++ b/litellm/main.py @@ -137,9 +137,13 @@ async def acompletion(model: str, messages: List = [], *args, **kwargs): _, custom_llm_provider, _, _ = get_llm_provider(model=model, api_base=kwargs.get("api_base", None)) - if custom_llm_provider == "openai" or custom_llm_provider == "azure": # currently implemented aiohttp calls for just azure and openai, soon all. - # Await normally - response = await completion(*args, **kwargs) + + if (custom_llm_provider == "openai" or custom_llm_provider == "azure"): # currently implemented aiohttp calls for just azure and openai, soon all. + if kwargs.get("stream", False): + response = completion(*args, **kwargs) + else: + # Await normally + response = await completion(*args, **kwargs) else: # Call the synchronous function using run_in_executor response = await loop.run_in_executor(None, func_with_context) @@ -147,6 +151,7 @@ async def acompletion(model: str, messages: List = [], *args, **kwargs): # do not change this # for stream = True, always return an async generator # See OpenAI acreate https://github.com/openai/openai-python/blob/5d50e9e3b39540af782ca24e65c290343d86e1a9/openai/api_resources/abstract/engine_api_resource.py#L193 + # return response return( line async for line in response @@ -515,7 +520,7 @@ def completion( ) raise e - if "stream" in optional_params and optional_params["stream"] == True: + if optional_params.get("stream", False) and acompletion is False: response = CustomStreamWrapper(response, model, custom_llm_provider=custom_llm_provider, logging_obj=logging) return response ## LOGGING diff --git a/litellm/tests/test_async_fn.py b/litellm/tests/test_async_fn.py index c86bf89a9..4d33dda8c 100644 --- a/litellm/tests/test_async_fn.py +++ b/litellm/tests/test_async_fn.py @@ -37,7 +37,7 @@ def test_async_response(): response = asyncio.run(test_get_response()) # print(response) -test_async_response() +# test_async_response() def test_get_response_streaming(): import asyncio @@ -45,8 +45,6 @@ def test_get_response_streaming(): user_message = "Hello, how are you?" messages = [{"content": user_message, "role": "user"}] try: - import litellm - litellm.set_verbose = True response = await acompletion(model="gpt-3.5-turbo", messages=messages, stream=True) print(type(response)) @@ -56,11 +54,11 @@ def test_get_response_streaming(): print(is_async_generator) output = "" + i = 0 async for chunk in response: token = chunk["choices"][0]["delta"].get("content", "") output += token - print(output) - + print(f"output: {output}") assert output is not None, "output cannot be None." assert isinstance(output, str), "output needs to be of type str" assert len(output) > 0, "Length of output needs to be greater than 0." @@ -71,3 +69,35 @@ def test_get_response_streaming(): asyncio.run(test_async_call()) +# test_get_response_streaming() + +def test_get_response_non_openai_streaming(): + import asyncio + async def test_async_call(): + user_message = "Hello, how are you?" + messages = [{"content": user_message, "role": "user"}] + try: + response = await acompletion(model="command-nightly", messages=messages, stream=True) + print(type(response)) + + import inspect + + is_async_generator = inspect.isasyncgen(response) + print(is_async_generator) + + output = "" + i = 0 + async for chunk in response: + token = chunk["choices"][0]["delta"].get("content", "") + output += token + print(f"output: {output}") + assert output is not None, "output cannot be None." + assert isinstance(output, str), "output needs to be of type str" + assert len(output) > 0, "Length of output needs to be greater than 0." + + except Exception as e: + pytest.fail(f"An exception occurred: {e}") + return response + asyncio.run(test_async_call()) + +test_get_response_non_openai_streaming() diff --git a/litellm/tests/test_stream_chunk_builder.py b/litellm/tests/test_stream_chunk_builder.py index a96900333..8a3586826 100644 --- a/litellm/tests/test_stream_chunk_builder.py +++ b/litellm/tests/test_stream_chunk_builder.py @@ -60,5 +60,4 @@ def test_stream_chunk_builder(): print(role, content, finish_reason) except Exception as e: raise Exception("stream_chunk_builder failed to rebuild response", e) -test_stream_chunk_builder() diff --git a/litellm/utils.py b/litellm/utils.py index 16ae291da..d0130475e 100644 --- a/litellm/utils.py +++ b/litellm/utils.py @@ -3998,14 +3998,15 @@ class CustomStreamWrapper: text = "" is_finished = False finish_reason = None - if str_line == "data: [DONE]": + if "data: [DONE]" in str_line: # anyscale returns a [DONE] special char for streaming, this cannot be json loaded. This is the end of stream text = "" is_finished = True finish_reason = "stop" return {"text": text, "is_finished": is_finished, "finish_reason": finish_reason} - elif str_line.startswith("data:"): - data_json = json.loads(str_line[5:]) + elif str_line.startswith("data:") and len(str_line[5:]) > 0: + str_line = str_line[5:] + data_json = json.loads(str_line) print_verbose(f"delta content: {data_json['choices'][0]['delta']}") text = data_json["choices"][0]["delta"].get("content", "") if data_json["choices"][0].get("finish_reason", None): @@ -4104,72 +4105,61 @@ class CustomStreamWrapper: raise Exception(chunk["error"]) return {"text": text, "is_finished": is_finished, "finish_reason": finish_reason} return "" - - ## needs to handle the empty string case (even starting chunk can be an empty string) - def __next__(self): + + def chunk_creator(self, chunk): model_response = ModelResponse(stream=True, model=self.model) try: while True: # loop until a non-empty string is found # return this for all models completion_obj = {"content": ""} if self.custom_llm_provider and self.custom_llm_provider == "anthropic": - chunk = next(self.completion_stream) response_obj = self.handle_anthropic_chunk(chunk) completion_obj["content"] = response_obj["text"] if response_obj["is_finished"]: model_response.choices[0].finish_reason = response_obj["finish_reason"] elif self.model == "replicate" or self.custom_llm_provider == "replicate": - chunk = next(self.completion_stream) response_obj = self.handle_replicate_chunk(chunk) completion_obj["content"] = response_obj["text"] if response_obj["is_finished"]: model_response.choices[0].finish_reason = response_obj["finish_reason"] elif ( self.custom_llm_provider and self.custom_llm_provider == "together_ai"): - chunk = next(self.completion_stream) response_obj = self.handle_together_ai_chunk(chunk) completion_obj["content"] = response_obj["text"] if response_obj["is_finished"]: model_response.choices[0].finish_reason = response_obj["finish_reason"] elif self.custom_llm_provider and self.custom_llm_provider == "huggingface": - chunk = next(self.completion_stream) response_obj = self.handle_huggingface_chunk(chunk) completion_obj["content"] = response_obj["text"] if response_obj["is_finished"]: model_response.choices[0].finish_reason = response_obj["finish_reason"] elif self.custom_llm_provider and self.custom_llm_provider == "baseten": # baseten doesn't provide streaming - chunk = next(self.completion_stream) completion_obj["content"] = self.handle_baseten_chunk(chunk) elif self.custom_llm_provider and self.custom_llm_provider == "ai21": #ai21 doesn't provide streaming - chunk = next(self.completion_stream) response_obj = self.handle_ai21_chunk(chunk) completion_obj["content"] = response_obj["text"] if response_obj["is_finished"]: model_response.choices[0].finish_reason = response_obj["finish_reason"] elif self.custom_llm_provider and self.custom_llm_provider == "azure": - chunk = next(self.completion_stream) response_obj = self.handle_azure_chunk(chunk) completion_obj["content"] = response_obj["text"] if response_obj["is_finished"]: model_response.choices[0].finish_reason = response_obj["finish_reason"] elif self.custom_llm_provider and self.custom_llm_provider == "maritalk": - chunk = next(self.completion_stream) response_obj = self.handle_maritalk_chunk(chunk) completion_obj["content"] = response_obj["text"] if response_obj["is_finished"]: model_response.choices[0].finish_reason = response_obj["finish_reason"] elif self.custom_llm_provider and self.custom_llm_provider == "vllm": - chunk = next(self.completion_stream) completion_obj["content"] = chunk[0].outputs[0].text elif self.custom_llm_provider and self.custom_llm_provider == "aleph_alpha": #aleph alpha doesn't provide streaming - chunk = next(self.completion_stream) response_obj = self.handle_aleph_alpha_chunk(chunk) completion_obj["content"] = response_obj["text"] if response_obj["is_finished"]: model_response.choices[0].finish_reason = response_obj["finish_reason"] elif self.model in litellm.nlp_cloud_models or self.custom_llm_provider == "nlp_cloud": try: - chunk = next(self.completion_stream) + response_obj = self.handle_nlp_cloud_chunk(chunk) completion_obj["content"] = response_obj["text"] if response_obj["is_finished"]: @@ -4184,7 +4174,7 @@ class CustomStreamWrapper: self.sent_last_chunk = True elif self.custom_llm_provider and self.custom_llm_provider == "vertex_ai": try: - chunk = next(self.completion_stream) + completion_obj["content"] = str(chunk) except StopIteration as e: if self.sent_last_chunk: @@ -4193,13 +4183,11 @@ class CustomStreamWrapper: model_response.choices[0].finish_reason = "stop" self.sent_last_chunk = True elif self.custom_llm_provider == "cohere": - chunk = next(self.completion_stream) response_obj = self.handle_cohere_chunk(chunk) completion_obj["content"] = response_obj["text"] if response_obj["is_finished"]: model_response.choices[0].finish_reason = response_obj["finish_reason"] elif self.custom_llm_provider == "bedrock": - chunk = next(self.completion_stream) response_obj = self.handle_bedrock_stream(chunk) completion_obj["content"] = response_obj["text"] if response_obj["is_finished"]: @@ -4242,19 +4230,16 @@ class CustomStreamWrapper: self.completion_stream = self.completion_stream[chunk_size:] time.sleep(0.05) elif self.custom_llm_provider == "ollama": - chunk = next(self.completion_stream) if "error" in chunk: exception_type(model=self.model, custom_llm_provider=self.custom_llm_provider, original_exception=chunk["error"]) completion_obj = chunk elif self.custom_llm_provider == "openai": - chunk = next(self.completion_stream) response_obj = self.handle_openai_chat_completion_chunk(chunk) completion_obj["content"] = response_obj["text"] print_verbose(f"completion obj content: {completion_obj['content']}") if response_obj["is_finished"]: model_response.choices[0].finish_reason = response_obj["finish_reason"] elif self.custom_llm_provider == "text-completion-openai": - chunk = next(self.completion_stream) response_obj = self.handle_openai_text_completion_chunk(chunk) completion_obj["content"] = response_obj["text"] print_verbose(f"completion obj content: {completion_obj['content']}") @@ -4267,7 +4252,7 @@ class CustomStreamWrapper: if len(completion_obj["content"]) > 0: # cannot set content of an OpenAI Object to be an empty string hold, model_response_str = self.check_special_tokens(completion_obj["content"]) if hold is False: - completion_obj["content"] = model_response_str + completion_obj["content"] = model_response_str if self.sent_first_chunk == False: completion_obj["role"] = "assistant" self.sent_first_chunk = True @@ -4275,11 +4260,15 @@ class CustomStreamWrapper: # LOGGING threading.Thread(target=self.logging_obj.success_handler, args=(model_response,)).start() return model_response + else: + return elif model_response.choices[0].finish_reason: model_response.choices[0].finish_reason = map_finish_reason(model_response.choices[0].finish_reason) # ensure consistent output to openai # LOGGING threading.Thread(target=self.logging_obj.success_handler, args=(model_response,)).start() return model_response + else: + return except StopIteration: raise StopIteration except Exception as e: @@ -4288,11 +4277,27 @@ class CustomStreamWrapper: # LOG FAILURE - handle streaming failure logging in the _next_ object, remove `handle_failure` once it's deprecated threading.Thread(target=self.logging_obj.failure_handler, args=(e, traceback_exception)).start() return exception_type(model=self.model, custom_llm_provider=self.custom_llm_provider, original_exception=e) + + ## needs to handle the empty string case (even starting chunk can be an empty string) + def __next__(self): + chunk = next(self.completion_stream) + return self.chunk_creator(chunk=chunk) async def __anext__(self): try: - return next(self) - except StopIteration: + if self.custom_llm_provider == "openai": + async for chunk in self.completion_stream.content: + if chunk == "None" or chunk is None: + raise Exception + processed_chunk = self.chunk_creator(chunk=chunk) + if processed_chunk is None: + continue + return processed_chunk + raise StopAsyncIteration + else: # temporary patch for non-aiohttp async calls + return next(self) + except Exception as e: + # Handle any exceptions that might occur during streaming raise StopAsyncIteration class TextCompletionStreamWrapper: