mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-25 18:54:30 +00:00
Merge be3ddac8b8
into b82af5b826
This commit is contained in:
commit
20aba45757
2 changed files with 61 additions and 16 deletions
|
@ -99,13 +99,22 @@ class AsyncHTTPHandler:
|
||||||
):
|
):
|
||||||
self.timeout = timeout
|
self.timeout = timeout
|
||||||
self.event_hooks = event_hooks
|
self.event_hooks = event_hooks
|
||||||
self.client = self.create_client(
|
self.concurrent_limit = concurrent_limit
|
||||||
timeout=timeout,
|
|
||||||
concurrent_limit=concurrent_limit,
|
|
||||||
event_hooks=event_hooks,
|
|
||||||
ssl_verify=ssl_verify,
|
|
||||||
)
|
|
||||||
self.client_alias = client_alias
|
self.client_alias = client_alias
|
||||||
|
self.ssl_verify = ssl_verify
|
||||||
|
self._client: Optional[httpx.AsyncClient] = None
|
||||||
|
|
||||||
|
@property
|
||||||
|
def client(self) -> httpx.AsyncClient:
|
||||||
|
# Optimization--lazy load the client. This is actually not super fast to create (10s of ms)
|
||||||
|
if self._client is None:
|
||||||
|
self._client = self.create_client(
|
||||||
|
timeout=self.timeout,
|
||||||
|
concurrent_limit=self.concurrent_limit,
|
||||||
|
event_hooks=self.event_hooks,
|
||||||
|
ssl_verify=self.ssl_verify,
|
||||||
|
)
|
||||||
|
return self._client
|
||||||
|
|
||||||
def create_client(
|
def create_client(
|
||||||
self,
|
self,
|
||||||
|
@ -485,27 +494,36 @@ class HTTPHandler:
|
||||||
# /path/to/client.pem
|
# /path/to/client.pem
|
||||||
cert = os.getenv("SSL_CERTIFICATE", litellm.ssl_certificate)
|
cert = os.getenv("SSL_CERTIFICATE", litellm.ssl_certificate)
|
||||||
|
|
||||||
if client is None:
|
self.timeout = timeout
|
||||||
|
self.concurrent_limit = concurrent_limit
|
||||||
|
self.ssl_verify = ssl_verify
|
||||||
|
self.cert = cert
|
||||||
|
self._client = client
|
||||||
|
|
||||||
|
@property
|
||||||
|
def client(self) -> httpx.Client:
|
||||||
|
# Optimization--lazy load the client
|
||||||
|
if self._client is None:
|
||||||
transport = self._create_sync_transport()
|
transport = self._create_sync_transport()
|
||||||
|
|
||||||
# Create a client with a connection pool
|
# Create a client with a connection pool
|
||||||
self.client = httpx.Client(
|
self._client = httpx.Client(
|
||||||
transport=transport,
|
transport=transport,
|
||||||
timeout=timeout,
|
timeout=self.timeout,
|
||||||
limits=httpx.Limits(
|
limits=httpx.Limits(
|
||||||
max_connections=concurrent_limit,
|
max_connections=self.concurrent_limit,
|
||||||
max_keepalive_connections=concurrent_limit,
|
max_keepalive_connections=self.concurrent_limit,
|
||||||
),
|
),
|
||||||
verify=ssl_verify,
|
verify=self.ssl_verify,
|
||||||
cert=cert,
|
cert=self.cert,
|
||||||
headers=headers,
|
headers=headers,
|
||||||
)
|
)
|
||||||
else:
|
return self._client
|
||||||
self.client = client
|
|
||||||
|
|
||||||
def close(self):
|
def close(self):
|
||||||
# Close the client when you're done with it
|
# Close the client when you're done with it
|
||||||
self.client.close()
|
if self._client is not None:
|
||||||
|
self._client.close()
|
||||||
|
|
||||||
def get(
|
def get(
|
||||||
self,
|
self,
|
||||||
|
|
|
@ -1,6 +1,7 @@
|
||||||
import io
|
import io
|
||||||
import os
|
import os
|
||||||
import pathlib
|
import pathlib
|
||||||
|
import respx
|
||||||
import ssl
|
import ssl
|
||||||
import sys
|
import sys
|
||||||
from unittest.mock import MagicMock
|
from unittest.mock import MagicMock
|
||||||
|
@ -27,3 +28,29 @@ async def test_ssl_security_level(monkeypatch):
|
||||||
|
|
||||||
# Verify that the SSL context exists and has the correct cipher string
|
# Verify that the SSL context exists and has the correct cipher string
|
||||||
assert isinstance(ssl_context, ssl.SSLContext)
|
assert isinstance(ssl_context, ssl.SSLContext)
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_basic_async_request(respx_mock: respx.MockRouter):
|
||||||
|
respx_mock.get("https://api.example.com").respond(
|
||||||
|
json={"message": "Hello, world!"}
|
||||||
|
)
|
||||||
|
|
||||||
|
# Create async client with SSL verification disabled to isolate SSL context testing
|
||||||
|
client = AsyncHTTPHandler(ssl_verify=False)
|
||||||
|
|
||||||
|
response = await client.get("https://api.example.com")
|
||||||
|
assert response.status_code == 200
|
||||||
|
assert response.json() == {"message": "Hello, world!"}
|
||||||
|
|
||||||
|
|
||||||
|
def test_basic_sync_request(respx_mock: respx.MockRouter):
|
||||||
|
respx_mock.get("https://api.example.com").respond(
|
||||||
|
json={"message": "Hello, world!"}
|
||||||
|
)
|
||||||
|
|
||||||
|
# Create async client with SSL verification disabled to isolate SSL context testing
|
||||||
|
client = HTTPHandler(ssl_verify=False)
|
||||||
|
|
||||||
|
response = client.get("https://api.example.com")
|
||||||
|
assert response.status_code == 200
|
||||||
|
assert response.json() == {"message": "Hello, world!"}
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue