diff --git a/litellm/__init__.py b/litellm/__init__.py index b9d9891ca..75a6751b0 100644 --- a/litellm/__init__.py +++ b/litellm/__init__.py @@ -77,6 +77,7 @@ baseten_key: Optional[str] = None aleph_alpha_key: Optional[str] = None nlp_cloud_key: Optional[str] = None use_client: bool = False +ssl_verify: bool = True disable_streaming_logging: bool = False ### GUARDRAILS ### llamaguard_model_name: Optional[str] = None diff --git a/litellm/router.py b/litellm/router.py index 371d8e8eb..8cb0f3ed2 100644 --- a/litellm/router.py +++ b/litellm/router.py @@ -2052,9 +2052,11 @@ class Router: timeout=timeout, max_retries=max_retries, http_client=httpx.AsyncClient( - transport=AsyncCustomHTTPTransport(), - limits=httpx.Limits( - max_connections=1000, max_keepalive_connections=100 + transport=AsyncCustomHTTPTransport( + limits=httpx.Limits( + max_connections=1000, max_keepalive_connections=100 + ), + verify=litellm.ssl_verify, ), mounts=async_proxy_mounts, ), # type: ignore @@ -2074,9 +2076,11 @@ class Router: timeout=timeout, max_retries=max_retries, http_client=httpx.Client( - transport=CustomHTTPTransport(), - limits=httpx.Limits( - max_connections=1000, max_keepalive_connections=100 + transport=CustomHTTPTransport( + limits=httpx.Limits( + max_connections=1000, max_keepalive_connections=100 + ), + verify=litellm.ssl_verify, ), mounts=sync_proxy_mounts, ), # type: ignore @@ -2096,9 +2100,11 @@ class Router: timeout=stream_timeout, max_retries=max_retries, http_client=httpx.AsyncClient( - transport=AsyncCustomHTTPTransport(), - limits=httpx.Limits( - max_connections=1000, max_keepalive_connections=100 + transport=AsyncCustomHTTPTransport( + limits=httpx.Limits( + max_connections=1000, max_keepalive_connections=100 + ), + verify=litellm.ssl_verify, ), mounts=async_proxy_mounts, ), # type: ignore @@ -2118,9 +2124,11 @@ class Router: timeout=stream_timeout, max_retries=max_retries, http_client=httpx.Client( - transport=CustomHTTPTransport(), - limits=httpx.Limits( - max_connections=1000, max_keepalive_connections=100 + transport=CustomHTTPTransport( + limits=httpx.Limits( + max_connections=1000, max_keepalive_connections=100 + ), + verify=litellm.ssl_verify, ), mounts=sync_proxy_mounts, ), # type: ignore @@ -2158,9 +2166,11 @@ class Router: timeout=timeout, max_retries=max_retries, http_client=httpx.AsyncClient( - transport=AsyncCustomHTTPTransport(), - limits=httpx.Limits( - max_connections=1000, max_keepalive_connections=100 + transport=AsyncCustomHTTPTransport( + limits=httpx.Limits( + max_connections=1000, max_keepalive_connections=100 + ), + verify=litellm.ssl_verify, ), mounts=async_proxy_mounts, ), # type: ignore @@ -2178,9 +2188,11 @@ class Router: timeout=timeout, max_retries=max_retries, http_client=httpx.Client( - transport=CustomHTTPTransport(), - limits=httpx.Limits( - max_connections=1000, max_keepalive_connections=100 + transport=CustomHTTPTransport( + verify=litellm.ssl_verify, + limits=httpx.Limits( + max_connections=1000, max_keepalive_connections=100 + ), ), mounts=sync_proxy_mounts, ), # type: ignore @@ -2199,9 +2211,11 @@ class Router: timeout=stream_timeout, max_retries=max_retries, http_client=httpx.AsyncClient( - transport=AsyncCustomHTTPTransport(), - limits=httpx.Limits( - max_connections=1000, max_keepalive_connections=100 + transport=AsyncCustomHTTPTransport( + limits=httpx.Limits( + max_connections=1000, max_keepalive_connections=100 + ), + verify=litellm.ssl_verify, ), mounts=async_proxy_mounts, ), @@ -2219,9 +2233,11 @@ class Router: timeout=stream_timeout, max_retries=max_retries, http_client=httpx.Client( - transport=CustomHTTPTransport(), - limits=httpx.Limits( - max_connections=1000, max_keepalive_connections=100 + transport=CustomHTTPTransport( + limits=httpx.Limits( + max_connections=1000, max_keepalive_connections=100 + ), + verify=litellm.ssl_verify, ), mounts=sync_proxy_mounts, ), @@ -2249,9 +2265,11 @@ class Router: max_retries=max_retries, organization=organization, http_client=httpx.AsyncClient( - transport=AsyncCustomHTTPTransport(), - limits=httpx.Limits( - max_connections=1000, max_keepalive_connections=100 + transport=AsyncCustomHTTPTransport( + limits=httpx.Limits( + max_connections=1000, max_keepalive_connections=100 + ), + verify=litellm.ssl_verify, ), mounts=async_proxy_mounts, ), # type: ignore @@ -2271,9 +2289,11 @@ class Router: max_retries=max_retries, organization=organization, http_client=httpx.Client( - transport=CustomHTTPTransport(), - limits=httpx.Limits( - max_connections=1000, max_keepalive_connections=100 + transport=CustomHTTPTransport( + limits=httpx.Limits( + max_connections=1000, max_keepalive_connections=100 + ), + verify=litellm.ssl_verify, ), mounts=sync_proxy_mounts, ), # type: ignore @@ -2294,9 +2314,11 @@ class Router: max_retries=max_retries, organization=organization, http_client=httpx.AsyncClient( - transport=AsyncCustomHTTPTransport(), - limits=httpx.Limits( - max_connections=1000, max_keepalive_connections=100 + transport=AsyncCustomHTTPTransport( + limits=httpx.Limits( + max_connections=1000, max_keepalive_connections=100 + ), + verify=litellm.ssl_verify, ), mounts=async_proxy_mounts, ), # type: ignore @@ -2317,9 +2339,11 @@ class Router: max_retries=max_retries, organization=organization, http_client=httpx.Client( - transport=CustomHTTPTransport(), - limits=httpx.Limits( - max_connections=1000, max_keepalive_connections=100 + transport=CustomHTTPTransport( + limits=httpx.Limits( + max_connections=1000, max_keepalive_connections=100 + ), + verify=litellm.ssl_verify, ), mounts=sync_proxy_mounts, ), # type: ignore diff --git a/litellm/tests/test_router.py b/litellm/tests/test_router.py index 26843b50b..fd95083b7 100644 --- a/litellm/tests/test_router.py +++ b/litellm/tests/test_router.py @@ -22,12 +22,14 @@ load_dotenv() @pytest.mark.parametrize( "timeout", [10, 1.0, httpx.Timeout(timeout=300.0, connect=20.0)] ) -def test_router_timeout_init(timeout): +@pytest.mark.parametrize("ssl_verify", [True, False]) +def test_router_timeout_init(timeout, ssl_verify): """ Allow user to pass httpx.Timeout related issue - https://github.com/BerriAI/litellm/issues/3162 """ + litellm.ssl_verify = ssl_verify router = Router( model_list=[ @@ -40,14 +42,28 @@ def test_router_timeout_init(timeout): "api_version": os.getenv("AZURE_API_VERSION"), "timeout": timeout, }, + "model_info": {"id": 1234}, } ] ) - router.completion( - model="test-model", messages=[{"role": "user", "content": "Hey!"}] + model_client = router._get_client( + deployment={"model_info": {"id": 1234}}, client_type="sync_client", kwargs={} ) + assert getattr(model_client, "timeout") == timeout + + print(f"vars model_client: {vars(model_client)}") + http_client = getattr(model_client, "_client") + print(f"http client: {vars(http_client)}, ssl_Verify={ssl_verify}") + if ssl_verify == False: + assert http_client._transport._pool._ssl_context.verify_mode.name == "CERT_NONE" + else: + assert ( + http_client._transport._pool._ssl_context.verify_mode.name + == "CERT_REQUIRED" + ) + def test_exception_raising(): # this tests if the router raises an exception when invalid params are set