mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-24 10:14:26 +00:00
made minor optimisations
This commit is contained in:
parent
7ac3a9cb83
commit
8c3338f368
1 changed files with 41 additions and 76 deletions
|
@ -41,12 +41,10 @@ class BaseLLMAIOHTTPHandler:
|
|||
) -> ClientSession:
|
||||
if dynamic_client_session:
|
||||
return dynamic_client_session
|
||||
elif self.client_session:
|
||||
return self.client_session
|
||||
else:
|
||||
# init client session, and then return new session
|
||||
self.client_session = aiohttp.ClientSession()
|
||||
if self.client_session:
|
||||
return self.client_session
|
||||
self.client_session = aiohttp.ClientSession()
|
||||
return self.client_session
|
||||
|
||||
async def _make_common_async_call(
|
||||
self,
|
||||
|
@ -70,7 +68,7 @@ class BaseLLMAIOHTTPHandler:
|
|||
dynamic_client_session=async_client_session
|
||||
)
|
||||
|
||||
for i in range(max(max_retry_on_unprocessable_entity_error, 1)):
|
||||
for _ in range(max(max_retry_on_unprocessable_entity_error, 1)):
|
||||
try:
|
||||
response = await async_client_session.post(
|
||||
url=api_base,
|
||||
|
@ -141,8 +139,7 @@ class BaseLLMAIOHTTPHandler:
|
|||
)
|
||||
)
|
||||
continue
|
||||
else:
|
||||
raise self._handle_error(e=e, provider_config=provider_config)
|
||||
raise self._handle_error(e=e, provider_config=provider_config)
|
||||
except Exception as e:
|
||||
raise self._handle_error(e=e, provider_config=provider_config)
|
||||
break
|
||||
|
@ -257,9 +254,9 @@ class BaseLLMAIOHTTPHandler:
|
|||
},
|
||||
)
|
||||
|
||||
if acompletion is True:
|
||||
if stream is True:
|
||||
if fake_stream is not True:
|
||||
if acompletion:
|
||||
if stream:
|
||||
if not fake_stream:
|
||||
data["stream"] = stream
|
||||
return self.acompletion_stream_function(
|
||||
model=model,
|
||||
|
@ -272,39 +269,29 @@ class BaseLLMAIOHTTPHandler:
|
|||
logging_obj=logging_obj,
|
||||
data=data,
|
||||
fake_stream=fake_stream,
|
||||
client=(
|
||||
client
|
||||
if client is not None and isinstance(client, ClientSession)
|
||||
else None
|
||||
),
|
||||
client=client if isinstance(client, ClientSession) else None,
|
||||
litellm_params=litellm_params,
|
||||
)
|
||||
|
||||
else:
|
||||
return self.async_completion(
|
||||
custom_llm_provider=custom_llm_provider,
|
||||
provider_config=provider_config,
|
||||
api_base=api_base,
|
||||
headers=headers,
|
||||
data=data,
|
||||
timeout=timeout,
|
||||
model=model,
|
||||
model_response=model_response,
|
||||
logging_obj=logging_obj,
|
||||
api_key=api_key,
|
||||
messages=messages,
|
||||
optional_params=optional_params,
|
||||
litellm_params=litellm_params,
|
||||
encoding=encoding,
|
||||
client=(
|
||||
client
|
||||
if client is not None and isinstance(client, ClientSession)
|
||||
else None
|
||||
),
|
||||
)
|
||||
return self.async_completion(
|
||||
custom_llm_provider=custom_llm_provider,
|
||||
provider_config=provider_config,
|
||||
api_base=api_base,
|
||||
headers=headers,
|
||||
data=data,
|
||||
timeout=timeout,
|
||||
model=model,
|
||||
model_response=model_response,
|
||||
logging_obj=logging_obj,
|
||||
api_key=api_key,
|
||||
messages=messages,
|
||||
optional_params=optional_params,
|
||||
litellm_params=litellm_params,
|
||||
encoding=encoding,
|
||||
client=client if isinstance(client, ClientSession) else None,
|
||||
)
|
||||
|
||||
if stream is True:
|
||||
if fake_stream is not True:
|
||||
if stream:
|
||||
if not fake_stream:
|
||||
data["stream"] = stream
|
||||
completion_stream, headers = self.make_sync_call(
|
||||
provider_config=provider_config,
|
||||
|
@ -316,11 +303,7 @@ class BaseLLMAIOHTTPHandler:
|
|||
logging_obj=logging_obj,
|
||||
timeout=timeout,
|
||||
fake_stream=fake_stream,
|
||||
client=(
|
||||
client
|
||||
if client is not None and isinstance(client, HTTPHandler)
|
||||
else None
|
||||
),
|
||||
client=client if isinstance(client, HTTPHandler) else None,
|
||||
litellm_params=litellm_params,
|
||||
)
|
||||
return CustomStreamWrapper(
|
||||
|
@ -330,11 +313,7 @@ class BaseLLMAIOHTTPHandler:
|
|||
logging_obj=logging_obj,
|
||||
)
|
||||
|
||||
if client is None or not isinstance(client, HTTPHandler):
|
||||
sync_httpx_client = _get_httpx_client()
|
||||
else:
|
||||
sync_httpx_client = client
|
||||
|
||||
sync_httpx_client = client if isinstance(client, HTTPHandler) else _get_httpx_client()
|
||||
response = self._make_common_sync_call(
|
||||
sync_httpx_client=sync_httpx_client,
|
||||
provider_config=provider_config,
|
||||
|
@ -356,7 +335,7 @@ class BaseLLMAIOHTTPHandler:
|
|||
litellm_params=litellm_params,
|
||||
encoding=encoding,
|
||||
)
|
||||
|
||||
|
||||
async def acompletion_stream_function(
|
||||
self,
|
||||
model: str,
|
||||
|
@ -385,14 +364,13 @@ class BaseLLMAIOHTTPHandler:
|
|||
client=client,
|
||||
litellm_params=litellm_params,
|
||||
)
|
||||
|
||||
streamwrapper = CustomStreamWrapper(
|
||||
|
||||
return CustomStreamWrapper(
|
||||
completion_stream=completion_stream,
|
||||
model=model,
|
||||
custom_llm_provider=custom_llm_provider,
|
||||
logging_obj=logging_obj,
|
||||
)
|
||||
return streamwrapper
|
||||
|
||||
async def make_async_call(
|
||||
self,
|
||||
|
@ -408,12 +386,8 @@ class BaseLLMAIOHTTPHandler:
|
|||
fake_stream: bool = False,
|
||||
client: Optional[Union[AsyncHTTPHandler, ClientSession]] = None,
|
||||
) -> Tuple[Any, httpx.Headers]:
|
||||
if client is None or not isinstance(client, ClientSession):
|
||||
async_client_session = self._get_async_client_session()
|
||||
|
||||
stream = True
|
||||
if fake_stream is True:
|
||||
stream = False
|
||||
async_client_session = self._get_async_client_session() if client is None or not isinstance(client, ClientSession) else client
|
||||
stream = not fake_stream
|
||||
data.pop("max_retries", None)
|
||||
response = await self._make_common_async_call(
|
||||
async_client_session=async_client_session,
|
||||
|
@ -426,7 +400,7 @@ class BaseLLMAIOHTTPHandler:
|
|||
stream=stream,
|
||||
)
|
||||
|
||||
if fake_stream is True:
|
||||
if fake_stream:
|
||||
json_response = await response.json()
|
||||
completion_stream = provider_config.get_model_response_iterator(
|
||||
streaming_response=json_response, sync_stream=False
|
||||
|
@ -459,13 +433,8 @@ class BaseLLMAIOHTTPHandler:
|
|||
fake_stream: bool = False,
|
||||
client: Optional[HTTPHandler] = None,
|
||||
) -> Tuple[Any, dict]:
|
||||
if client is None or not isinstance(client, HTTPHandler):
|
||||
sync_httpx_client = _get_httpx_client()
|
||||
else:
|
||||
sync_httpx_client = client
|
||||
stream = True
|
||||
if fake_stream is True:
|
||||
stream = False
|
||||
sync_httpx_client = client if isinstance(client, HTTPHandler) else _get_httpx_client()
|
||||
stream = not fake_stream
|
||||
|
||||
response = self._make_common_sync_call(
|
||||
sync_httpx_client=sync_httpx_client,
|
||||
|
@ -478,7 +447,7 @@ class BaseLLMAIOHTTPHandler:
|
|||
stream=stream,
|
||||
)
|
||||
|
||||
if fake_stream is True:
|
||||
if fake_stream:
|
||||
completion_stream = provider_config.get_model_response_iterator(
|
||||
streaming_response=response.json(), sync_stream=True
|
||||
)
|
||||
|
@ -640,13 +609,9 @@ class BaseLLMAIOHTTPHandler:
|
|||
litellm_params=litellm_params,
|
||||
image=image,
|
||||
provider_config=provider_config,
|
||||
) # type: ignore
|
||||
|
||||
if client is None or not isinstance(client, HTTPHandler):
|
||||
sync_httpx_client = _get_httpx_client()
|
||||
else:
|
||||
sync_httpx_client = client
|
||||
)
|
||||
|
||||
sync_httpx_client = client if isinstance(client, HTTPHandler) else _get_httpx_client()
|
||||
response = self._make_common_sync_call(
|
||||
sync_httpx_client=sync_httpx_client,
|
||||
provider_config=provider_config,
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue