From 7ac3a9cb838293890134bf7cd09304e4a2352c86 Mon Sep 17 00:00:00 2001 From: Jaswanth Karani Date: Fri, 28 Feb 2025 20:15:11 +0530 Subject: [PATCH 1/2] added streaming support for aiohttp_openai --- .../aiohttp_openai/chat/transformation.py | 15 +- litellm/llms/aiohttp_openai/common_utils.py | 169 ++++++++++++++++++ litellm/llms/custom_httpx/aiohttp_handler.py | 156 +++++++++++++--- tests/llm_translation/test_aiohttp_openai.py | 38 ++++ 4 files changed, 355 insertions(+), 23 deletions(-) create mode 100644 litellm/llms/aiohttp_openai/common_utils.py diff --git a/litellm/llms/aiohttp_openai/chat/transformation.py b/litellm/llms/aiohttp_openai/chat/transformation.py index 625704dbea..6b1289e1da 100644 --- a/litellm/llms/aiohttp_openai/chat/transformation.py +++ b/litellm/llms/aiohttp_openai/chat/transformation.py @@ -7,13 +7,14 @@ https://github.com/BerriAI/litellm/issues/6592 New config to ensure we introduce this without causing breaking changes for users """ -from typing import TYPE_CHECKING, Any, List, Optional +from typing import TYPE_CHECKING, Any, AsyncIterator, Union, Iterator, List, Optional from aiohttp import ClientResponse from litellm.llms.openai_like.chat.transformation import OpenAILikeChatConfig from litellm.types.llms.openai import AllMessageValues from litellm.types.utils import Choices, ModelResponse +from ..common_utils import ModelResponseIterator as AiohttpOpenAIResponseIterator if TYPE_CHECKING: from litellm.litellm_core_utils.litellm_logging import Logging as _LiteLLMLoggingObj @@ -77,3 +78,15 @@ class AiohttpOpenAIChatConfig(OpenAILikeChatConfig): model_response.object = _json_response.get("object") model_response.system_fingerprint = _json_response.get("system_fingerprint") return model_response + + def get_model_response_iterator( + self, + streaming_response: Union[Iterator[str], AsyncIterator[str], ModelResponse], + sync_stream: bool, + json_mode: Optional[bool] = False, + ): + return AiohttpOpenAIResponseIterator( + streaming_response=streaming_response, + sync_stream=sync_stream, + json_mode=json_mode, + ) \ No newline at end of file diff --git a/litellm/llms/aiohttp_openai/common_utils.py b/litellm/llms/aiohttp_openai/common_utils.py new file mode 100644 index 0000000000..76d18c367f --- /dev/null +++ b/litellm/llms/aiohttp_openai/common_utils.py @@ -0,0 +1,169 @@ +import json +from typing import List, Optional, Union + +from litellm.llms.base_llm.chat.transformation import BaseLLMException +from litellm.llms.base_llm.base_model_iterator import BaseModelResponseIterator +from litellm.types.llms.openai import AllMessageValues +from litellm.types.utils import ( + ChatCompletionToolCallChunk, + ChatCompletionUsageBlock, + GenericStreamingChunk, + ModelResponseStream +) + + +class AioHttpOpenAIError(BaseLLMException): + def __init__(self, status_code, message): + super().__init__(status_code=status_code, message=message) + + +def validate_environment( + headers: dict, + model: str, + messages: List[AllMessageValues], + optional_params: dict, + api_key: Optional[str] = None, +) -> dict: + """ + Return headers to use for aiopenhttp_openai chat completion request + """ + headers.update( + { + "Request-Source": "unspecified:litellm", + "accept": "application/json", + "content-type": "application/json", + } + ) + if api_key: + headers["Authorization"] = f"bearer {api_key}" + return headers + +class ModelResponseIterator: + def __init__( + self, streaming_response, sync_stream: bool, json_mode: Optional[bool] = False + ): + self.streaming_response = streaming_response + self.response_iterator = self.streaming_response + self.json_mode = json_mode + + def chunk_parser(self, chunk: dict) -> Union[GenericStreamingChunk, ModelResponseStream]: + try: + # Initialize default values + text = "" + tool_use: Optional[ChatCompletionToolCallChunk] = None + is_finished = False + finish_reason = "" + usage: Optional[ChatCompletionUsageBlock] = None + provider_specific_fields = None + + # Extract the index from the chunk + index = int(chunk.get("choices", [{}])[0].get("index", 0)) + + # Extract the text or delta content from the first choice + delta = chunk.get("choices", [{}])[0].get("delta", {}) + if "content" in delta: + text = delta["content"] + + # Check for finish_reason + finish_reason = chunk.get("choices", [{}])[0].get("finish_reason", "") + + # Determine if the stream has finished + is_finished = finish_reason in ("length", "stop") + + # Create and return the parsed chunk + returned_chunk = GenericStreamingChunk( + text=text, + tool_use=tool_use, + is_finished=is_finished, + finish_reason=finish_reason, + usage=usage, + index=index, + provider_specific_fields=provider_specific_fields, + ) + + return returned_chunk + + except json.JSONDecodeError: + raise ValueError(f"Failed to decode JSON from chunk: {chunk}") + + + # Sync iterator + def __iter__(self): + return self + + def _handle_string_chunk( + self, str_line: str + ) -> Union[GenericStreamingChunk, ModelResponseStream]: + # chunk is a str at this point + if "[DONE]" in str_line: + return GenericStreamingChunk( + text="", + is_finished=True, + finish_reason="stop", + usage=None, + index=0, + tool_use=None, + ) + elif str_line.startswith("data:"): + data_json = json.loads(str_line[5:]) + return self.chunk_parser(chunk=data_json) + else: + return GenericStreamingChunk( + text="", + is_finished=False, + finish_reason="", + usage=None, + index=0, + tool_use=None, + ) + + def __next__(self): + try: + chunk = self.response_iterator.__next__() + except StopIteration: + raise StopIteration + except ValueError as e: + raise RuntimeError(f"Error receiving chunk from stream: {e}") + + try: + str_line = chunk + if isinstance(chunk, bytes): # Handle binary data + str_line = chunk.decode("utf-8") # Convert bytes to string + index = str_line.find("data:") + if index != -1: + str_line = str_line[index:] + # chunk is a str at this point + return self._handle_string_chunk(str_line=str_line) + except StopIteration: + raise StopIteration + except ValueError as e: + raise RuntimeError(f"Error parsing chunk: {e},\nReceived chunk: {chunk}") + + # Async iterator + def __aiter__(self): + self.async_response_iterator = self.streaming_response.__aiter__() + return self + + async def __anext__(self): + try: + chunk = await self.async_response_iterator.__anext__() + except StopAsyncIteration: + raise StopAsyncIteration + except ValueError as e: + raise RuntimeError(f"Error receiving chunk from stream: {e}") + + try: + str_line = chunk + if isinstance(chunk, bytes): # Handle binary data + str_line = chunk.decode("utf-8") # Convert bytes to string + index = str_line.find("data:") + if index != -1: + str_line = str_line[index:] + + # chunk is a str at this point + return self._handle_string_chunk(str_line=str_line) + except StopAsyncIteration: + raise StopAsyncIteration + except ValueError as e: + raise RuntimeError(f"Error parsing chunk: {e},\nReceived chunk: {chunk}") + \ No newline at end of file diff --git a/litellm/llms/custom_httpx/aiohttp_handler.py b/litellm/llms/custom_httpx/aiohttp_handler.py index 4a9e07016f..e57ab0e737 100644 --- a/litellm/llms/custom_httpx/aiohttp_handler.py +++ b/litellm/llms/custom_httpx/aiohttp_handler.py @@ -174,6 +174,7 @@ class BaseLLMAIOHTTPHandler: api_key: Optional[str] = None, client: Optional[ClientSession] = None, ): + data.pop("max_retries", None) #added this as this was extra param which is not needed for openai _response = await self._make_common_async_call( async_client_session=client, provider_config=provider_config, @@ -257,27 +258,50 @@ class BaseLLMAIOHTTPHandler: ) if acompletion is True: - return self.async_completion( - custom_llm_provider=custom_llm_provider, - provider_config=provider_config, - api_base=api_base, - headers=headers, - data=data, - timeout=timeout, - model=model, - model_response=model_response, - logging_obj=logging_obj, - api_key=api_key, - messages=messages, - optional_params=optional_params, - litellm_params=litellm_params, - encoding=encoding, - client=( - client - if client is not None and isinstance(client, ClientSession) - else None - ), - ) + if stream is True: + if fake_stream is not True: + data["stream"] = stream + return self.acompletion_stream_function( + model=model, + messages=messages, + api_base=api_base, + headers=headers, + custom_llm_provider=custom_llm_provider, + provider_config=provider_config, + timeout=timeout, + logging_obj=logging_obj, + data=data, + fake_stream=fake_stream, + client=( + client + if client is not None and isinstance(client, ClientSession) + else None + ), + litellm_params=litellm_params, + ) + + else: + return self.async_completion( + custom_llm_provider=custom_llm_provider, + provider_config=provider_config, + api_base=api_base, + headers=headers, + data=data, + timeout=timeout, + model=model, + model_response=model_response, + logging_obj=logging_obj, + api_key=api_key, + messages=messages, + optional_params=optional_params, + litellm_params=litellm_params, + encoding=encoding, + client=( + client + if client is not None and isinstance(client, ClientSession) + else None + ), + ) if stream is True: if fake_stream is not True: @@ -332,7 +356,95 @@ class BaseLLMAIOHTTPHandler: litellm_params=litellm_params, encoding=encoding, ) + + async def acompletion_stream_function( + self, + model: str, + messages: list, + api_base: str, + custom_llm_provider: str, + headers: dict, + provider_config: BaseConfig, + timeout: Union[float, httpx.Timeout], + logging_obj: LiteLLMLoggingObj, + data: dict, + litellm_params: dict, + fake_stream: bool = False, + client: Optional[ClientSession] = None, + ): + completion_stream, _response_headers = await self.make_async_call( + custom_llm_provider=custom_llm_provider, + provider_config=provider_config, + api_base=api_base, + headers=headers, + data=data, + messages=messages, + logging_obj=logging_obj, + timeout=timeout, + fake_stream=fake_stream, + client=client, + litellm_params=litellm_params, + ) + + streamwrapper = CustomStreamWrapper( + completion_stream=completion_stream, + model=model, + custom_llm_provider=custom_llm_provider, + logging_obj=logging_obj, + ) + return streamwrapper + async def make_async_call( + self, + custom_llm_provider: str, + provider_config: BaseConfig, + api_base: str, + headers: dict, + data: dict, + messages: list, + logging_obj: LiteLLMLoggingObj, + timeout: Union[float, httpx.Timeout], + litellm_params: dict, + fake_stream: bool = False, + client: Optional[Union[AsyncHTTPHandler, ClientSession]] = None, + ) -> Tuple[Any, httpx.Headers]: + if client is None or not isinstance(client, ClientSession): + async_client_session = self._get_async_client_session() + + stream = True + if fake_stream is True: + stream = False + data.pop("max_retries", None) + response = await self._make_common_async_call( + async_client_session=async_client_session, + provider_config=provider_config, + api_base=api_base, + headers=headers, + data=data, + timeout=timeout, + litellm_params=litellm_params, + stream=stream, + ) + + if fake_stream is True: + json_response = await response.json() + completion_stream = provider_config.get_model_response_iterator( + streaming_response=json_response, sync_stream=False + ) + else: + completion_stream = provider_config.get_model_response_iterator( + streaming_response=response.content, sync_stream=False + ) + # LOGGING + logging_obj.post_call( + input=messages, + api_key="", + original_response="first stream response received:: ", + additional_args={"complete_input_dict": data}, + ) + + return completion_stream, response.headers + def make_sync_call( self, provider_config: BaseConfig, @@ -372,7 +484,7 @@ class BaseLLMAIOHTTPHandler: ) else: completion_stream = provider_config.get_model_response_iterator( - streaming_response=response.iter_lines(), sync_stream=True + streaming_response=response.content, sync_stream=True ) # LOGGING diff --git a/tests/llm_translation/test_aiohttp_openai.py b/tests/llm_translation/test_aiohttp_openai.py index 5b92c924ec..bd58f16d8c 100644 --- a/tests/llm_translation/test_aiohttp_openai.py +++ b/tests/llm_translation/test_aiohttp_openai.py @@ -10,6 +10,8 @@ sys.path.insert( import litellm +from local_testing.test_streaming import streaming_format_tests + @pytest.mark.asyncio() async def test_aiohttp_openai(): @@ -31,3 +33,39 @@ async def test_aiohttp_openai_gpt_4o(): messages=[{"role": "user", "content": "Hello, world!"}], ) print(response) + + +@pytest.mark.asyncio() +async def test_completion_model_stream(): + litellm.set_verbose = True + api_key = os.getenv("OPENAI_API_KEY") + assert api_key is not None, "API key is not set in environment variables" + + try: + messages = [ + {"role": "system", "content": "You are a helpful assistant."}, + { + "role": "user", + "content": "how does a court case get to the Supreme Court?", + }, + ] + response = await litellm.acompletion( + api_key=api_key, model="aiohttp_openai/gpt-4o", messages=messages, stream=True, max_tokens=50 + ) + + complete_response = "" + idx = 0 # Initialize index manually + async for chunk in response: # Use async for to handle async iterator + chunk, finished = streaming_format_tests(idx, chunk) # Await if streaming_format_tests is async + print(f"outside chunk: {chunk}") + if finished: + break + complete_response += chunk + idx += 1 # Increment index manually + + if complete_response.strip() == "": + raise Exception("Empty response received") + print(f"complete response: {complete_response}") + + except Exception as e: + pytest.fail(f"Error occurred: {e}") \ No newline at end of file From 8c3338f36806aa3db23c5c8ff639b66440c8b7f6 Mon Sep 17 00:00:00 2001 From: Jaswanth Karani Date: Fri, 28 Feb 2025 20:31:06 +0530 Subject: [PATCH 2/2] made minor optimisations --- litellm/llms/custom_httpx/aiohttp_handler.py | 117 +++++++------------ 1 file changed, 41 insertions(+), 76 deletions(-) diff --git a/litellm/llms/custom_httpx/aiohttp_handler.py b/litellm/llms/custom_httpx/aiohttp_handler.py index e57ab0e737..13e47784a5 100644 --- a/litellm/llms/custom_httpx/aiohttp_handler.py +++ b/litellm/llms/custom_httpx/aiohttp_handler.py @@ -41,12 +41,10 @@ class BaseLLMAIOHTTPHandler: ) -> ClientSession: if dynamic_client_session: return dynamic_client_session - elif self.client_session: - return self.client_session - else: - # init client session, and then return new session - self.client_session = aiohttp.ClientSession() + if self.client_session: return self.client_session + self.client_session = aiohttp.ClientSession() + return self.client_session async def _make_common_async_call( self, @@ -70,7 +68,7 @@ class BaseLLMAIOHTTPHandler: dynamic_client_session=async_client_session ) - for i in range(max(max_retry_on_unprocessable_entity_error, 1)): + for _ in range(max(max_retry_on_unprocessable_entity_error, 1)): try: response = await async_client_session.post( url=api_base, @@ -141,8 +139,7 @@ class BaseLLMAIOHTTPHandler: ) ) continue - else: - raise self._handle_error(e=e, provider_config=provider_config) + raise self._handle_error(e=e, provider_config=provider_config) except Exception as e: raise self._handle_error(e=e, provider_config=provider_config) break @@ -257,9 +254,9 @@ class BaseLLMAIOHTTPHandler: }, ) - if acompletion is True: - if stream is True: - if fake_stream is not True: + if acompletion: + if stream: + if not fake_stream: data["stream"] = stream return self.acompletion_stream_function( model=model, @@ -272,39 +269,29 @@ class BaseLLMAIOHTTPHandler: logging_obj=logging_obj, data=data, fake_stream=fake_stream, - client=( - client - if client is not None and isinstance(client, ClientSession) - else None - ), + client=client if isinstance(client, ClientSession) else None, litellm_params=litellm_params, ) - - else: - return self.async_completion( - custom_llm_provider=custom_llm_provider, - provider_config=provider_config, - api_base=api_base, - headers=headers, - data=data, - timeout=timeout, - model=model, - model_response=model_response, - logging_obj=logging_obj, - api_key=api_key, - messages=messages, - optional_params=optional_params, - litellm_params=litellm_params, - encoding=encoding, - client=( - client - if client is not None and isinstance(client, ClientSession) - else None - ), - ) + return self.async_completion( + custom_llm_provider=custom_llm_provider, + provider_config=provider_config, + api_base=api_base, + headers=headers, + data=data, + timeout=timeout, + model=model, + model_response=model_response, + logging_obj=logging_obj, + api_key=api_key, + messages=messages, + optional_params=optional_params, + litellm_params=litellm_params, + encoding=encoding, + client=client if isinstance(client, ClientSession) else None, + ) - if stream is True: - if fake_stream is not True: + if stream: + if not fake_stream: data["stream"] = stream completion_stream, headers = self.make_sync_call( provider_config=provider_config, @@ -316,11 +303,7 @@ class BaseLLMAIOHTTPHandler: logging_obj=logging_obj, timeout=timeout, fake_stream=fake_stream, - client=( - client - if client is not None and isinstance(client, HTTPHandler) - else None - ), + client=client if isinstance(client, HTTPHandler) else None, litellm_params=litellm_params, ) return CustomStreamWrapper( @@ -330,11 +313,7 @@ class BaseLLMAIOHTTPHandler: logging_obj=logging_obj, ) - if client is None or not isinstance(client, HTTPHandler): - sync_httpx_client = _get_httpx_client() - else: - sync_httpx_client = client - + sync_httpx_client = client if isinstance(client, HTTPHandler) else _get_httpx_client() response = self._make_common_sync_call( sync_httpx_client=sync_httpx_client, provider_config=provider_config, @@ -356,7 +335,7 @@ class BaseLLMAIOHTTPHandler: litellm_params=litellm_params, encoding=encoding, ) - + async def acompletion_stream_function( self, model: str, @@ -385,14 +364,13 @@ class BaseLLMAIOHTTPHandler: client=client, litellm_params=litellm_params, ) - - streamwrapper = CustomStreamWrapper( + + return CustomStreamWrapper( completion_stream=completion_stream, model=model, custom_llm_provider=custom_llm_provider, logging_obj=logging_obj, ) - return streamwrapper async def make_async_call( self, @@ -408,12 +386,8 @@ class BaseLLMAIOHTTPHandler: fake_stream: bool = False, client: Optional[Union[AsyncHTTPHandler, ClientSession]] = None, ) -> Tuple[Any, httpx.Headers]: - if client is None or not isinstance(client, ClientSession): - async_client_session = self._get_async_client_session() - - stream = True - if fake_stream is True: - stream = False + async_client_session = self._get_async_client_session() if client is None or not isinstance(client, ClientSession) else client + stream = not fake_stream data.pop("max_retries", None) response = await self._make_common_async_call( async_client_session=async_client_session, @@ -426,7 +400,7 @@ class BaseLLMAIOHTTPHandler: stream=stream, ) - if fake_stream is True: + if fake_stream: json_response = await response.json() completion_stream = provider_config.get_model_response_iterator( streaming_response=json_response, sync_stream=False @@ -459,13 +433,8 @@ class BaseLLMAIOHTTPHandler: fake_stream: bool = False, client: Optional[HTTPHandler] = None, ) -> Tuple[Any, dict]: - if client is None or not isinstance(client, HTTPHandler): - sync_httpx_client = _get_httpx_client() - else: - sync_httpx_client = client - stream = True - if fake_stream is True: - stream = False + sync_httpx_client = client if isinstance(client, HTTPHandler) else _get_httpx_client() + stream = not fake_stream response = self._make_common_sync_call( sync_httpx_client=sync_httpx_client, @@ -478,7 +447,7 @@ class BaseLLMAIOHTTPHandler: stream=stream, ) - if fake_stream is True: + if fake_stream: completion_stream = provider_config.get_model_response_iterator( streaming_response=response.json(), sync_stream=True ) @@ -640,13 +609,9 @@ class BaseLLMAIOHTTPHandler: litellm_params=litellm_params, image=image, provider_config=provider_config, - ) # type: ignore - - if client is None or not isinstance(client, HTTPHandler): - sync_httpx_client = _get_httpx_client() - else: - sync_httpx_client = client + ) + sync_httpx_client = client if isinstance(client, HTTPHandler) else _get_httpx_client() response = self._make_common_sync_call( sync_httpx_client=sync_httpx_client, provider_config=provider_config,