made minor optimisations

This commit is contained in:
Jaswanth Karani 2025-02-28 20:31:06 +05:30 committed by GitHub
parent 7ac3a9cb83
commit 8c3338f368
No known key found for this signature in database
GPG key ID: B5690EEEBB952194

View file

@ -41,12 +41,10 @@ class BaseLLMAIOHTTPHandler:
) -> ClientSession:
if dynamic_client_session:
return dynamic_client_session
elif self.client_session:
return self.client_session
else:
# init client session, and then return new session
self.client_session = aiohttp.ClientSession()
if self.client_session:
return self.client_session
self.client_session = aiohttp.ClientSession()
return self.client_session
async def _make_common_async_call(
self,
@ -70,7 +68,7 @@ class BaseLLMAIOHTTPHandler:
dynamic_client_session=async_client_session
)
for i in range(max(max_retry_on_unprocessable_entity_error, 1)):
for _ in range(max(max_retry_on_unprocessable_entity_error, 1)):
try:
response = await async_client_session.post(
url=api_base,
@ -141,8 +139,7 @@ class BaseLLMAIOHTTPHandler:
)
)
continue
else:
raise self._handle_error(e=e, provider_config=provider_config)
raise self._handle_error(e=e, provider_config=provider_config)
except Exception as e:
raise self._handle_error(e=e, provider_config=provider_config)
break
@ -257,9 +254,9 @@ class BaseLLMAIOHTTPHandler:
},
)
if acompletion is True:
if stream is True:
if fake_stream is not True:
if acompletion:
if stream:
if not fake_stream:
data["stream"] = stream
return self.acompletion_stream_function(
model=model,
@ -272,39 +269,29 @@ class BaseLLMAIOHTTPHandler:
logging_obj=logging_obj,
data=data,
fake_stream=fake_stream,
client=(
client
if client is not None and isinstance(client, ClientSession)
else None
),
client=client if isinstance(client, ClientSession) else None,
litellm_params=litellm_params,
)
else:
return self.async_completion(
custom_llm_provider=custom_llm_provider,
provider_config=provider_config,
api_base=api_base,
headers=headers,
data=data,
timeout=timeout,
model=model,
model_response=model_response,
logging_obj=logging_obj,
api_key=api_key,
messages=messages,
optional_params=optional_params,
litellm_params=litellm_params,
encoding=encoding,
client=(
client
if client is not None and isinstance(client, ClientSession)
else None
),
)
return self.async_completion(
custom_llm_provider=custom_llm_provider,
provider_config=provider_config,
api_base=api_base,
headers=headers,
data=data,
timeout=timeout,
model=model,
model_response=model_response,
logging_obj=logging_obj,
api_key=api_key,
messages=messages,
optional_params=optional_params,
litellm_params=litellm_params,
encoding=encoding,
client=client if isinstance(client, ClientSession) else None,
)
if stream is True:
if fake_stream is not True:
if stream:
if not fake_stream:
data["stream"] = stream
completion_stream, headers = self.make_sync_call(
provider_config=provider_config,
@ -316,11 +303,7 @@ class BaseLLMAIOHTTPHandler:
logging_obj=logging_obj,
timeout=timeout,
fake_stream=fake_stream,
client=(
client
if client is not None and isinstance(client, HTTPHandler)
else None
),
client=client if isinstance(client, HTTPHandler) else None,
litellm_params=litellm_params,
)
return CustomStreamWrapper(
@ -330,11 +313,7 @@ class BaseLLMAIOHTTPHandler:
logging_obj=logging_obj,
)
if client is None or not isinstance(client, HTTPHandler):
sync_httpx_client = _get_httpx_client()
else:
sync_httpx_client = client
sync_httpx_client = client if isinstance(client, HTTPHandler) else _get_httpx_client()
response = self._make_common_sync_call(
sync_httpx_client=sync_httpx_client,
provider_config=provider_config,
@ -356,7 +335,7 @@ class BaseLLMAIOHTTPHandler:
litellm_params=litellm_params,
encoding=encoding,
)
async def acompletion_stream_function(
self,
model: str,
@ -385,14 +364,13 @@ class BaseLLMAIOHTTPHandler:
client=client,
litellm_params=litellm_params,
)
streamwrapper = CustomStreamWrapper(
return CustomStreamWrapper(
completion_stream=completion_stream,
model=model,
custom_llm_provider=custom_llm_provider,
logging_obj=logging_obj,
)
return streamwrapper
async def make_async_call(
self,
@ -408,12 +386,8 @@ class BaseLLMAIOHTTPHandler:
fake_stream: bool = False,
client: Optional[Union[AsyncHTTPHandler, ClientSession]] = None,
) -> Tuple[Any, httpx.Headers]:
if client is None or not isinstance(client, ClientSession):
async_client_session = self._get_async_client_session()
stream = True
if fake_stream is True:
stream = False
async_client_session = self._get_async_client_session() if client is None or not isinstance(client, ClientSession) else client
stream = not fake_stream
data.pop("max_retries", None)
response = await self._make_common_async_call(
async_client_session=async_client_session,
@ -426,7 +400,7 @@ class BaseLLMAIOHTTPHandler:
stream=stream,
)
if fake_stream is True:
if fake_stream:
json_response = await response.json()
completion_stream = provider_config.get_model_response_iterator(
streaming_response=json_response, sync_stream=False
@ -459,13 +433,8 @@ class BaseLLMAIOHTTPHandler:
fake_stream: bool = False,
client: Optional[HTTPHandler] = None,
) -> Tuple[Any, dict]:
if client is None or not isinstance(client, HTTPHandler):
sync_httpx_client = _get_httpx_client()
else:
sync_httpx_client = client
stream = True
if fake_stream is True:
stream = False
sync_httpx_client = client if isinstance(client, HTTPHandler) else _get_httpx_client()
stream = not fake_stream
response = self._make_common_sync_call(
sync_httpx_client=sync_httpx_client,
@ -478,7 +447,7 @@ class BaseLLMAIOHTTPHandler:
stream=stream,
)
if fake_stream is True:
if fake_stream:
completion_stream = provider_config.get_model_response_iterator(
streaming_response=response.json(), sync_stream=True
)
@ -640,13 +609,9 @@ class BaseLLMAIOHTTPHandler:
litellm_params=litellm_params,
image=image,
provider_config=provider_config,
) # type: ignore
if client is None or not isinstance(client, HTTPHandler):
sync_httpx_client = _get_httpx_client()
else:
sync_httpx_client = client
)
sync_httpx_client = client if isinstance(client, HTTPHandler) else _get_httpx_client()
response = self._make_common_sync_call(
sync_httpx_client=sync_httpx_client,
provider_config=provider_config,