This commit is contained in:
Adrian Lyjak 2025-04-24 00:55:44 -07:00 committed by GitHub
commit 20aba45757
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
2 changed files with 61 additions and 16 deletions

View file

@ -99,13 +99,22 @@ class AsyncHTTPHandler:
): ):
self.timeout = timeout self.timeout = timeout
self.event_hooks = event_hooks self.event_hooks = event_hooks
self.client = self.create_client( self.concurrent_limit = concurrent_limit
timeout=timeout,
concurrent_limit=concurrent_limit,
event_hooks=event_hooks,
ssl_verify=ssl_verify,
)
self.client_alias = client_alias self.client_alias = client_alias
self.ssl_verify = ssl_verify
self._client: Optional[httpx.AsyncClient] = None
@property
def client(self) -> httpx.AsyncClient:
# Optimization--lazy load the client. This is actually not super fast to create (10s of ms)
if self._client is None:
self._client = self.create_client(
timeout=self.timeout,
concurrent_limit=self.concurrent_limit,
event_hooks=self.event_hooks,
ssl_verify=self.ssl_verify,
)
return self._client
def create_client( def create_client(
self, self,
@ -485,27 +494,36 @@ class HTTPHandler:
# /path/to/client.pem # /path/to/client.pem
cert = os.getenv("SSL_CERTIFICATE", litellm.ssl_certificate) cert = os.getenv("SSL_CERTIFICATE", litellm.ssl_certificate)
if client is None: self.timeout = timeout
self.concurrent_limit = concurrent_limit
self.ssl_verify = ssl_verify
self.cert = cert
self._client = client
@property
def client(self) -> httpx.Client:
# Optimization--lazy load the client
if self._client is None:
transport = self._create_sync_transport() transport = self._create_sync_transport()
# Create a client with a connection pool # Create a client with a connection pool
self.client = httpx.Client( self._client = httpx.Client(
transport=transport, transport=transport,
timeout=timeout, timeout=self.timeout,
limits=httpx.Limits( limits=httpx.Limits(
max_connections=concurrent_limit, max_connections=self.concurrent_limit,
max_keepalive_connections=concurrent_limit, max_keepalive_connections=self.concurrent_limit,
), ),
verify=ssl_verify, verify=self.ssl_verify,
cert=cert, cert=self.cert,
headers=headers, headers=headers,
) )
else: return self._client
self.client = client
def close(self): def close(self):
# Close the client when you're done with it # Close the client when you're done with it
self.client.close() if self._client is not None:
self._client.close()
def get( def get(
self, self,

View file

@ -1,6 +1,7 @@
import io import io
import os import os
import pathlib import pathlib
import respx
import ssl import ssl
import sys import sys
from unittest.mock import MagicMock from unittest.mock import MagicMock
@ -27,3 +28,29 @@ async def test_ssl_security_level(monkeypatch):
# Verify that the SSL context exists and has the correct cipher string # Verify that the SSL context exists and has the correct cipher string
assert isinstance(ssl_context, ssl.SSLContext) assert isinstance(ssl_context, ssl.SSLContext)
@pytest.mark.asyncio
async def test_basic_async_request(respx_mock: respx.MockRouter):
respx_mock.get("https://api.example.com").respond(
json={"message": "Hello, world!"}
)
# Create async client with SSL verification disabled to isolate SSL context testing
client = AsyncHTTPHandler(ssl_verify=False)
response = await client.get("https://api.example.com")
assert response.status_code == 200
assert response.json() == {"message": "Hello, world!"}
def test_basic_sync_request(respx_mock: respx.MockRouter):
respx_mock.get("https://api.example.com").respond(
json={"message": "Hello, world!"}
)
# Create async client with SSL verification disabled to isolate SSL context testing
client = HTTPHandler(ssl_verify=False)
response = client.get("https://api.example.com")
assert response.status_code == 200
assert response.json() == {"message": "Hello, world!"}