diff --git a/litellm/llms/aiohttp_openai/chat/transformation.py b/litellm/llms/aiohttp_openai/chat/transformation.py index c2d4e5adcd..439ee30a9b 100644 --- a/litellm/llms/aiohttp_openai/chat/transformation.py +++ b/litellm/llms/aiohttp_openai/chat/transformation.py @@ -7,13 +7,14 @@ https://github.com/BerriAI/litellm/issues/6592 New config to ensure we introduce this without causing breaking changes for users """ -from typing import TYPE_CHECKING, Any, List, Optional +from typing import TYPE_CHECKING, Any, AsyncIterator, Union, Iterator, List, Optional from aiohttp import ClientResponse from litellm.llms.openai_like.chat.transformation import OpenAILikeChatConfig from litellm.types.llms.openai import AllMessageValues from litellm.types.utils import Choices, ModelResponse +from ..common_utils import ModelResponseIterator as AiohttpOpenAIResponseIterator if TYPE_CHECKING: from litellm.litellm_core_utils.litellm_logging import Logging as _LiteLLMLoggingObj @@ -80,3 +81,15 @@ class AiohttpOpenAIChatConfig(OpenAILikeChatConfig): model_response.object = _json_response.get("object") model_response.system_fingerprint = _json_response.get("system_fingerprint") return model_response + + def get_model_response_iterator( + self, + streaming_response: Union[Iterator[str], AsyncIterator[str], ModelResponse], + sync_stream: bool, + json_mode: Optional[bool] = False, + ): + return AiohttpOpenAIResponseIterator( + streaming_response=streaming_response, + sync_stream=sync_stream, + json_mode=json_mode, + ) \ No newline at end of file diff --git a/litellm/llms/aiohttp_openai/common_utils.py b/litellm/llms/aiohttp_openai/common_utils.py new file mode 100644 index 0000000000..76d18c367f --- /dev/null +++ b/litellm/llms/aiohttp_openai/common_utils.py @@ -0,0 +1,169 @@ +import json +from typing import List, Optional, Union + +from litellm.llms.base_llm.chat.transformation import BaseLLMException +from litellm.llms.base_llm.base_model_iterator import BaseModelResponseIterator +from litellm.types.llms.openai import AllMessageValues +from litellm.types.utils import ( + ChatCompletionToolCallChunk, + ChatCompletionUsageBlock, + GenericStreamingChunk, + ModelResponseStream +) + + +class AioHttpOpenAIError(BaseLLMException): + def __init__(self, status_code, message): + super().__init__(status_code=status_code, message=message) + + +def validate_environment( + headers: dict, + model: str, + messages: List[AllMessageValues], + optional_params: dict, + api_key: Optional[str] = None, +) -> dict: + """ + Return headers to use for aiopenhttp_openai chat completion request + """ + headers.update( + { + "Request-Source": "unspecified:litellm", + "accept": "application/json", + "content-type": "application/json", + } + ) + if api_key: + headers["Authorization"] = f"bearer {api_key}" + return headers + +class ModelResponseIterator: + def __init__( + self, streaming_response, sync_stream: bool, json_mode: Optional[bool] = False + ): + self.streaming_response = streaming_response + self.response_iterator = self.streaming_response + self.json_mode = json_mode + + def chunk_parser(self, chunk: dict) -> Union[GenericStreamingChunk, ModelResponseStream]: + try: + # Initialize default values + text = "" + tool_use: Optional[ChatCompletionToolCallChunk] = None + is_finished = False + finish_reason = "" + usage: Optional[ChatCompletionUsageBlock] = None + provider_specific_fields = None + + # Extract the index from the chunk + index = int(chunk.get("choices", [{}])[0].get("index", 0)) + + # Extract the text or delta content from the first choice + delta = chunk.get("choices", [{}])[0].get("delta", {}) + if "content" in delta: + text = delta["content"] + + # Check for finish_reason + finish_reason = chunk.get("choices", [{}])[0].get("finish_reason", "") + + # Determine if the stream has finished + is_finished = finish_reason in ("length", "stop") + + # Create and return the parsed chunk + returned_chunk = GenericStreamingChunk( + text=text, + tool_use=tool_use, + is_finished=is_finished, + finish_reason=finish_reason, + usage=usage, + index=index, + provider_specific_fields=provider_specific_fields, + ) + + return returned_chunk + + except json.JSONDecodeError: + raise ValueError(f"Failed to decode JSON from chunk: {chunk}") + + + # Sync iterator + def __iter__(self): + return self + + def _handle_string_chunk( + self, str_line: str + ) -> Union[GenericStreamingChunk, ModelResponseStream]: + # chunk is a str at this point + if "[DONE]" in str_line: + return GenericStreamingChunk( + text="", + is_finished=True, + finish_reason="stop", + usage=None, + index=0, + tool_use=None, + ) + elif str_line.startswith("data:"): + data_json = json.loads(str_line[5:]) + return self.chunk_parser(chunk=data_json) + else: + return GenericStreamingChunk( + text="", + is_finished=False, + finish_reason="", + usage=None, + index=0, + tool_use=None, + ) + + def __next__(self): + try: + chunk = self.response_iterator.__next__() + except StopIteration: + raise StopIteration + except ValueError as e: + raise RuntimeError(f"Error receiving chunk from stream: {e}") + + try: + str_line = chunk + if isinstance(chunk, bytes): # Handle binary data + str_line = chunk.decode("utf-8") # Convert bytes to string + index = str_line.find("data:") + if index != -1: + str_line = str_line[index:] + # chunk is a str at this point + return self._handle_string_chunk(str_line=str_line) + except StopIteration: + raise StopIteration + except ValueError as e: + raise RuntimeError(f"Error parsing chunk: {e},\nReceived chunk: {chunk}") + + # Async iterator + def __aiter__(self): + self.async_response_iterator = self.streaming_response.__aiter__() + return self + + async def __anext__(self): + try: + chunk = await self.async_response_iterator.__anext__() + except StopAsyncIteration: + raise StopAsyncIteration + except ValueError as e: + raise RuntimeError(f"Error receiving chunk from stream: {e}") + + try: + str_line = chunk + if isinstance(chunk, bytes): # Handle binary data + str_line = chunk.decode("utf-8") # Convert bytes to string + index = str_line.find("data:") + if index != -1: + str_line = str_line[index:] + + # chunk is a str at this point + return self._handle_string_chunk(str_line=str_line) + except StopAsyncIteration: + raise StopAsyncIteration + except ValueError as e: + raise RuntimeError(f"Error parsing chunk: {e},\nReceived chunk: {chunk}") + \ No newline at end of file diff --git a/litellm/llms/custom_httpx/aiohttp_handler.py b/litellm/llms/custom_httpx/aiohttp_handler.py index 13141fc19a..6ac9b5985c 100644 --- a/litellm/llms/custom_httpx/aiohttp_handler.py +++ b/litellm/llms/custom_httpx/aiohttp_handler.py @@ -40,12 +40,10 @@ class BaseLLMAIOHTTPHandler: ) -> ClientSession: if dynamic_client_session: return dynamic_client_session - elif self.client_session: - return self.client_session - else: - # init client session, and then return new session - self.client_session = aiohttp.ClientSession() + if self.client_session: return self.client_session + self.client_session = aiohttp.ClientSession() + return self.client_session async def _make_common_async_call( self, @@ -69,7 +67,7 @@ class BaseLLMAIOHTTPHandler: dynamic_client_session=async_client_session ) - for i in range(max(max_retry_on_unprocessable_entity_error, 1)): + for _ in range(max(max_retry_on_unprocessable_entity_error, 1)): try: response = await async_client_session.post( url=api_base, @@ -139,8 +137,7 @@ class BaseLLMAIOHTTPHandler: ) ) continue - else: - raise self._handle_error(e=e, provider_config=provider_config) + raise self._handle_error(e=e, provider_config=provider_config) except Exception as e: raise self._handle_error(e=e, provider_config=provider_config) break @@ -172,6 +169,7 @@ class BaseLLMAIOHTTPHandler: api_key: Optional[str] = None, client: Optional[ClientSession] = None, ): + data.pop("max_retries", None) #added this as this was extra param which is not needed for openai _response = await self._make_common_async_call( async_client_session=client, provider_config=provider_config, @@ -261,7 +259,24 @@ class BaseLLMAIOHTTPHandler: }, ) - if acompletion is True: + if acompletion: + if stream: + if not fake_stream: + data["stream"] = stream + return self.acompletion_stream_function( + model=model, + messages=messages, + api_base=api_base, + headers=headers, + custom_llm_provider=custom_llm_provider, + provider_config=provider_config, + timeout=timeout, + logging_obj=logging_obj, + data=data, + fake_stream=fake_stream, + client=client if isinstance(client, ClientSession) else None, + litellm_params=litellm_params, + ) return self.async_completion( custom_llm_provider=custom_llm_provider, provider_config=provider_config, @@ -277,15 +292,11 @@ class BaseLLMAIOHTTPHandler: optional_params=optional_params, litellm_params=litellm_params, encoding=encoding, - client=( - client - if client is not None and isinstance(client, ClientSession) - else None - ), + client=client if isinstance(client, ClientSession) else None, ) - if stream is True: - if fake_stream is not True: + if stream: + if not fake_stream: data["stream"] = stream completion_stream, headers = self.make_sync_call( provider_config=provider_config, @@ -297,11 +308,7 @@ class BaseLLMAIOHTTPHandler: logging_obj=logging_obj, timeout=timeout, fake_stream=fake_stream, - client=( - client - if client is not None and isinstance(client, HTTPHandler) - else None - ), + client=client if isinstance(client, HTTPHandler) else None, litellm_params=litellm_params, ) return CustomStreamWrapper( @@ -311,11 +318,7 @@ class BaseLLMAIOHTTPHandler: logging_obj=logging_obj, ) - if client is None or not isinstance(client, HTTPHandler): - sync_httpx_client = _get_httpx_client() - else: - sync_httpx_client = client - + sync_httpx_client = client if isinstance(client, HTTPHandler) else _get_httpx_client() response = self._make_common_sync_call( sync_httpx_client=sync_httpx_client, provider_config=provider_config, @@ -338,6 +341,89 @@ class BaseLLMAIOHTTPHandler: encoding=encoding, ) + async def acompletion_stream_function( + self, + model: str, + messages: list, + api_base: str, + custom_llm_provider: str, + headers: dict, + provider_config: BaseConfig, + timeout: Union[float, httpx.Timeout], + logging_obj: LiteLLMLoggingObj, + data: dict, + litellm_params: dict, + fake_stream: bool = False, + client: Optional[ClientSession] = None, + ): + completion_stream, _response_headers = await self.make_async_call( + custom_llm_provider=custom_llm_provider, + provider_config=provider_config, + api_base=api_base, + headers=headers, + data=data, + messages=messages, + logging_obj=logging_obj, + timeout=timeout, + fake_stream=fake_stream, + client=client, + litellm_params=litellm_params, + ) + + return CustomStreamWrapper( + completion_stream=completion_stream, + model=model, + custom_llm_provider=custom_llm_provider, + logging_obj=logging_obj, + ) + + async def make_async_call( + self, + custom_llm_provider: str, + provider_config: BaseConfig, + api_base: str, + headers: dict, + data: dict, + messages: list, + logging_obj: LiteLLMLoggingObj, + timeout: Union[float, httpx.Timeout], + litellm_params: dict, + fake_stream: bool = False, + client: Optional[Union[AsyncHTTPHandler, ClientSession]] = None, + ) -> Tuple[Any, httpx.Headers]: + async_client_session = self._get_async_client_session() if client is None or not isinstance(client, ClientSession) else client + stream = not fake_stream + data.pop("max_retries", None) + response = await self._make_common_async_call( + async_client_session=async_client_session, + provider_config=provider_config, + api_base=api_base, + headers=headers, + data=data, + timeout=timeout, + litellm_params=litellm_params, + stream=stream, + ) + + if fake_stream: + json_response = await response.json() + completion_stream = provider_config.get_model_response_iterator( + streaming_response=json_response, sync_stream=False + ) + else: + completion_stream = provider_config.get_model_response_iterator( + streaming_response=response.content, sync_stream=False + ) + # LOGGING + logging_obj.post_call( + input=messages, + api_key="", + original_response="first stream response received:: ", + additional_args={"complete_input_dict": data}, + ) + + return completion_stream, response.headers + def make_sync_call( self, provider_config: BaseConfig, @@ -352,13 +438,8 @@ class BaseLLMAIOHTTPHandler: fake_stream: bool = False, client: Optional[HTTPHandler] = None, ) -> Tuple[Any, dict]: - if client is None or not isinstance(client, HTTPHandler): - sync_httpx_client = _get_httpx_client() - else: - sync_httpx_client = client - stream = True - if fake_stream is True: - stream = False + sync_httpx_client = client if isinstance(client, HTTPHandler) else _get_httpx_client() + stream = not fake_stream response = self._make_common_sync_call( sync_httpx_client=sync_httpx_client, @@ -371,13 +452,13 @@ class BaseLLMAIOHTTPHandler: stream=stream, ) - if fake_stream is True: + if fake_stream: completion_stream = provider_config.get_model_response_iterator( streaming_response=response.json(), sync_stream=True ) else: completion_stream = provider_config.get_model_response_iterator( - streaming_response=response.iter_lines(), sync_stream=True + streaming_response=response.content, sync_stream=True ) # LOGGING @@ -535,13 +616,9 @@ class BaseLLMAIOHTTPHandler: litellm_params=litellm_params, image=image, provider_config=provider_config, - ) # type: ignore - - if client is None or not isinstance(client, HTTPHandler): - sync_httpx_client = _get_httpx_client() - else: - sync_httpx_client = client + ) + sync_httpx_client = client if isinstance(client, HTTPHandler) else _get_httpx_client() response = self._make_common_sync_call( sync_httpx_client=sync_httpx_client, provider_config=provider_config, diff --git a/tests/llm_translation/test_aiohttp_openai.py b/tests/llm_translation/test_aiohttp_openai.py index 5b92c924ec..bd58f16d8c 100644 --- a/tests/llm_translation/test_aiohttp_openai.py +++ b/tests/llm_translation/test_aiohttp_openai.py @@ -10,6 +10,8 @@ sys.path.insert( import litellm +from local_testing.test_streaming import streaming_format_tests + @pytest.mark.asyncio() async def test_aiohttp_openai(): @@ -31,3 +33,39 @@ async def test_aiohttp_openai_gpt_4o(): messages=[{"role": "user", "content": "Hello, world!"}], ) print(response) + + +@pytest.mark.asyncio() +async def test_completion_model_stream(): + litellm.set_verbose = True + api_key = os.getenv("OPENAI_API_KEY") + assert api_key is not None, "API key is not set in environment variables" + + try: + messages = [ + {"role": "system", "content": "You are a helpful assistant."}, + { + "role": "user", + "content": "how does a court case get to the Supreme Court?", + }, + ] + response = await litellm.acompletion( + api_key=api_key, model="aiohttp_openai/gpt-4o", messages=messages, stream=True, max_tokens=50 + ) + + complete_response = "" + idx = 0 # Initialize index manually + async for chunk in response: # Use async for to handle async iterator + chunk, finished = streaming_format_tests(idx, chunk) # Await if streaming_format_tests is async + print(f"outside chunk: {chunk}") + if finished: + break + complete_response += chunk + idx += 1 # Increment index manually + + if complete_response.strip() == "": + raise Exception("Empty response received") + print(f"complete response: {complete_response}") + + except Exception as e: + pytest.fail(f"Error occurred: {e}") \ No newline at end of file