From 8c3338f36806aa3db23c5c8ff639b66440c8b7f6 Mon Sep 17 00:00:00 2001 From: Jaswanth Karani Date: Fri, 28 Feb 2025 20:31:06 +0530 Subject: [PATCH] made minor optimisations --- litellm/llms/custom_httpx/aiohttp_handler.py | 117 +++++++------------ 1 file changed, 41 insertions(+), 76 deletions(-) diff --git a/litellm/llms/custom_httpx/aiohttp_handler.py b/litellm/llms/custom_httpx/aiohttp_handler.py index e57ab0e737..13e47784a5 100644 --- a/litellm/llms/custom_httpx/aiohttp_handler.py +++ b/litellm/llms/custom_httpx/aiohttp_handler.py @@ -41,12 +41,10 @@ class BaseLLMAIOHTTPHandler: ) -> ClientSession: if dynamic_client_session: return dynamic_client_session - elif self.client_session: - return self.client_session - else: - # init client session, and then return new session - self.client_session = aiohttp.ClientSession() + if self.client_session: return self.client_session + self.client_session = aiohttp.ClientSession() + return self.client_session async def _make_common_async_call( self, @@ -70,7 +68,7 @@ class BaseLLMAIOHTTPHandler: dynamic_client_session=async_client_session ) - for i in range(max(max_retry_on_unprocessable_entity_error, 1)): + for _ in range(max(max_retry_on_unprocessable_entity_error, 1)): try: response = await async_client_session.post( url=api_base, @@ -141,8 +139,7 @@ class BaseLLMAIOHTTPHandler: ) ) continue - else: - raise self._handle_error(e=e, provider_config=provider_config) + raise self._handle_error(e=e, provider_config=provider_config) except Exception as e: raise self._handle_error(e=e, provider_config=provider_config) break @@ -257,9 +254,9 @@ class BaseLLMAIOHTTPHandler: }, ) - if acompletion is True: - if stream is True: - if fake_stream is not True: + if acompletion: + if stream: + if not fake_stream: data["stream"] = stream return self.acompletion_stream_function( model=model, @@ -272,39 +269,29 @@ class BaseLLMAIOHTTPHandler: logging_obj=logging_obj, data=data, fake_stream=fake_stream, - client=( - client - if client is not None and isinstance(client, ClientSession) - else None - ), + client=client if isinstance(client, ClientSession) else None, litellm_params=litellm_params, ) - - else: - return self.async_completion( - custom_llm_provider=custom_llm_provider, - provider_config=provider_config, - api_base=api_base, - headers=headers, - data=data, - timeout=timeout, - model=model, - model_response=model_response, - logging_obj=logging_obj, - api_key=api_key, - messages=messages, - optional_params=optional_params, - litellm_params=litellm_params, - encoding=encoding, - client=( - client - if client is not None and isinstance(client, ClientSession) - else None - ), - ) + return self.async_completion( + custom_llm_provider=custom_llm_provider, + provider_config=provider_config, + api_base=api_base, + headers=headers, + data=data, + timeout=timeout, + model=model, + model_response=model_response, + logging_obj=logging_obj, + api_key=api_key, + messages=messages, + optional_params=optional_params, + litellm_params=litellm_params, + encoding=encoding, + client=client if isinstance(client, ClientSession) else None, + ) - if stream is True: - if fake_stream is not True: + if stream: + if not fake_stream: data["stream"] = stream completion_stream, headers = self.make_sync_call( provider_config=provider_config, @@ -316,11 +303,7 @@ class BaseLLMAIOHTTPHandler: logging_obj=logging_obj, timeout=timeout, fake_stream=fake_stream, - client=( - client - if client is not None and isinstance(client, HTTPHandler) - else None - ), + client=client if isinstance(client, HTTPHandler) else None, litellm_params=litellm_params, ) return CustomStreamWrapper( @@ -330,11 +313,7 @@ class BaseLLMAIOHTTPHandler: logging_obj=logging_obj, ) - if client is None or not isinstance(client, HTTPHandler): - sync_httpx_client = _get_httpx_client() - else: - sync_httpx_client = client - + sync_httpx_client = client if isinstance(client, HTTPHandler) else _get_httpx_client() response = self._make_common_sync_call( sync_httpx_client=sync_httpx_client, provider_config=provider_config, @@ -356,7 +335,7 @@ class BaseLLMAIOHTTPHandler: litellm_params=litellm_params, encoding=encoding, ) - + async def acompletion_stream_function( self, model: str, @@ -385,14 +364,13 @@ class BaseLLMAIOHTTPHandler: client=client, litellm_params=litellm_params, ) - - streamwrapper = CustomStreamWrapper( + + return CustomStreamWrapper( completion_stream=completion_stream, model=model, custom_llm_provider=custom_llm_provider, logging_obj=logging_obj, ) - return streamwrapper async def make_async_call( self, @@ -408,12 +386,8 @@ class BaseLLMAIOHTTPHandler: fake_stream: bool = False, client: Optional[Union[AsyncHTTPHandler, ClientSession]] = None, ) -> Tuple[Any, httpx.Headers]: - if client is None or not isinstance(client, ClientSession): - async_client_session = self._get_async_client_session() - - stream = True - if fake_stream is True: - stream = False + async_client_session = self._get_async_client_session() if client is None or not isinstance(client, ClientSession) else client + stream = not fake_stream data.pop("max_retries", None) response = await self._make_common_async_call( async_client_session=async_client_session, @@ -426,7 +400,7 @@ class BaseLLMAIOHTTPHandler: stream=stream, ) - if fake_stream is True: + if fake_stream: json_response = await response.json() completion_stream = provider_config.get_model_response_iterator( streaming_response=json_response, sync_stream=False @@ -459,13 +433,8 @@ class BaseLLMAIOHTTPHandler: fake_stream: bool = False, client: Optional[HTTPHandler] = None, ) -> Tuple[Any, dict]: - if client is None or not isinstance(client, HTTPHandler): - sync_httpx_client = _get_httpx_client() - else: - sync_httpx_client = client - stream = True - if fake_stream is True: - stream = False + sync_httpx_client = client if isinstance(client, HTTPHandler) else _get_httpx_client() + stream = not fake_stream response = self._make_common_sync_call( sync_httpx_client=sync_httpx_client, @@ -478,7 +447,7 @@ class BaseLLMAIOHTTPHandler: stream=stream, ) - if fake_stream is True: + if fake_stream: completion_stream = provider_config.get_model_response_iterator( streaming_response=response.json(), sync_stream=True ) @@ -640,13 +609,9 @@ class BaseLLMAIOHTTPHandler: litellm_params=litellm_params, image=image, provider_config=provider_config, - ) # type: ignore - - if client is None or not isinstance(client, HTTPHandler): - sync_httpx_client = _get_httpx_client() - else: - sync_httpx_client = client + ) + sync_httpx_client = client if isinstance(client, HTTPHandler) else _get_httpx_client() response = self._make_common_sync_call( sync_httpx_client=sync_httpx_client, provider_config=provider_config,