mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-24 18:24:20 +00:00
fix for #7605, lazy load httpx client
There's a few of these clients instantiated at startup, and they start adding up. They were around 10ms to 30ms each to create, so this shaves off ~50ms of startup time
This commit is contained in:
parent
ac4f32fb1e
commit
be3ddac8b8
2 changed files with 61 additions and 16 deletions
|
@ -99,13 +99,22 @@ class AsyncHTTPHandler:
|
|||
):
|
||||
self.timeout = timeout
|
||||
self.event_hooks = event_hooks
|
||||
self.client = self.create_client(
|
||||
timeout=timeout,
|
||||
concurrent_limit=concurrent_limit,
|
||||
event_hooks=event_hooks,
|
||||
ssl_verify=ssl_verify,
|
||||
)
|
||||
self.concurrent_limit = concurrent_limit
|
||||
self.client_alias = client_alias
|
||||
self.ssl_verify = ssl_verify
|
||||
self._client: Optional[httpx.AsyncClient] = None
|
||||
|
||||
@property
|
||||
def client(self) -> httpx.AsyncClient:
|
||||
# Optimization--lazy load the client. This is actually not super fast to create (10s of ms)
|
||||
if self._client is None:
|
||||
self._client = self.create_client(
|
||||
timeout=self.timeout,
|
||||
concurrent_limit=self.concurrent_limit,
|
||||
event_hooks=self.event_hooks,
|
||||
ssl_verify=self.ssl_verify,
|
||||
)
|
||||
return self._client
|
||||
|
||||
def create_client(
|
||||
self,
|
||||
|
@ -485,27 +494,36 @@ class HTTPHandler:
|
|||
# /path/to/client.pem
|
||||
cert = os.getenv("SSL_CERTIFICATE", litellm.ssl_certificate)
|
||||
|
||||
if client is None:
|
||||
self.timeout = timeout
|
||||
self.concurrent_limit = concurrent_limit
|
||||
self.ssl_verify = ssl_verify
|
||||
self.cert = cert
|
||||
self._client = client
|
||||
|
||||
@property
|
||||
def client(self) -> httpx.Client:
|
||||
# Optimization--lazy load the client
|
||||
if self._client is None:
|
||||
transport = self._create_sync_transport()
|
||||
|
||||
# Create a client with a connection pool
|
||||
self.client = httpx.Client(
|
||||
self._client = httpx.Client(
|
||||
transport=transport,
|
||||
timeout=timeout,
|
||||
timeout=self.timeout,
|
||||
limits=httpx.Limits(
|
||||
max_connections=concurrent_limit,
|
||||
max_keepalive_connections=concurrent_limit,
|
||||
max_connections=self.concurrent_limit,
|
||||
max_keepalive_connections=self.concurrent_limit,
|
||||
),
|
||||
verify=ssl_verify,
|
||||
cert=cert,
|
||||
verify=self.ssl_verify,
|
||||
cert=self.cert,
|
||||
headers=headers,
|
||||
)
|
||||
else:
|
||||
self.client = client
|
||||
return self._client
|
||||
|
||||
def close(self):
|
||||
# Close the client when you're done with it
|
||||
self.client.close()
|
||||
if self._client is not None:
|
||||
self._client.close()
|
||||
|
||||
def get(
|
||||
self,
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
import io
|
||||
import os
|
||||
import pathlib
|
||||
import respx
|
||||
import ssl
|
||||
import sys
|
||||
from unittest.mock import MagicMock
|
||||
|
@ -27,3 +28,29 @@ async def test_ssl_security_level(monkeypatch):
|
|||
|
||||
# Verify that the SSL context exists and has the correct cipher string
|
||||
assert isinstance(ssl_context, ssl.SSLContext)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_basic_async_request(respx_mock: respx.MockRouter):
|
||||
respx_mock.get("https://api.example.com").respond(
|
||||
json={"message": "Hello, world!"}
|
||||
)
|
||||
|
||||
# Create async client with SSL verification disabled to isolate SSL context testing
|
||||
client = AsyncHTTPHandler(ssl_verify=False)
|
||||
|
||||
response = await client.get("https://api.example.com")
|
||||
assert response.status_code == 200
|
||||
assert response.json() == {"message": "Hello, world!"}
|
||||
|
||||
|
||||
def test_basic_sync_request(respx_mock: respx.MockRouter):
|
||||
respx_mock.get("https://api.example.com").respond(
|
||||
json={"message": "Hello, world!"}
|
||||
)
|
||||
|
||||
# Create async client with SSL verification disabled to isolate SSL context testing
|
||||
client = HTTPHandler(ssl_verify=False)
|
||||
|
||||
response = client.get("https://api.example.com")
|
||||
assert response.status_code == 200
|
||||
assert response.json() == {"message": "Hello, world!"}
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue