From f43768d617ab8c09020162c04d82a1ae06f0107b Mon Sep 17 00:00:00 2001 From: Ishaan Jaff Date: Mon, 18 Nov 2024 12:22:51 -0800 Subject: [PATCH] (fix) httpx handler - bind to ipv4 for httpx handler (#6785) * bind to ipv4 on httpx handler * add force_ipv4 * use helper for _create_async_transport * fix circular import * document force_ipv4 * test_async_http_handler_force_ipv4 --- docs/my-website/docs/proxy/configs.md | 3 ++ litellm/__init__.py | 7 ++++ litellm/llms/custom_httpx/http_handler.py | 31 ++++++++++++++- tests/local_testing/test_utils.py | 47 +++++++++++++++++++++++ 4 files changed, 87 insertions(+), 1 deletion(-) diff --git a/docs/my-website/docs/proxy/configs.md b/docs/my-website/docs/proxy/configs.md index 888f424b4..1609e16ae 100644 --- a/docs/my-website/docs/proxy/configs.md +++ b/docs/my-website/docs/proxy/configs.md @@ -625,7 +625,9 @@ litellm_settings: redact_user_api_key_info: boolean # Redact information about the user api key (hashed token, user_id, team id, etc.), from logs. Currently supported for Langfuse, OpenTelemetry, Logfire, ArizeAI logging. langfuse_default_tags: ["cache_hit", "cache_key", "proxy_base_url", "user_api_key_alias", "user_api_key_user_id", "user_api_key_user_email", "user_api_key_team_alias", "semantic-similarity", "proxy_base_url"] # default tags for Langfuse Logging + # Networking settings request_timeout: 10 # (int) llm requesttimeout in seconds. Raise Timeout error if call takes longer than 10s. Sets litellm.request_timeout + force_ipv4: boolean # If true, litellm will force ipv4 for all LLM requests. Some users have seen httpx ConnectionError when using ipv6 + Anthropic API set_verbose: boolean # sets litellm.set_verbose=True to view verbose debug logs. DO NOT LEAVE THIS ON IN PRODUCTION json_logs: boolean # if true, logs will be in json format @@ -727,6 +729,7 @@ general_settings: | json_logs | boolean | If true, logs will be in json format. If you need to store the logs as JSON, just set the `litellm.json_logs = True`. We currently just log the raw POST request from litellm as a JSON [Further docs](./debugging) | | default_fallbacks | array of strings | List of fallback models to use if a specific model group is misconfigured / bad. [Further docs](./reliability#default-fallbacks) | | request_timeout | integer | The timeout for requests in seconds. If not set, the default value is `6000 seconds`. [For reference OpenAI Python SDK defaults to `600 seconds`.](https://github.com/openai/openai-python/blob/main/src/openai/_constants.py) | +| force_ipv4 | boolean | If true, litellm will force ipv4 for all LLM requests. Some users have seen httpx ConnectionError when using ipv6 + Anthropic API | | content_policy_fallbacks | array of objects | Fallbacks to use when a ContentPolicyViolationError is encountered. [Further docs](./reliability#content-policy-fallbacks) | | context_window_fallbacks | array of objects | Fallbacks to use when a ContextWindowExceededError is encountered. [Further docs](./reliability#context-window-fallbacks) | | cache | boolean | If true, enables caching. [Further docs](./caching) | diff --git a/litellm/__init__.py b/litellm/__init__.py index edfe1a336..04b594ca1 100644 --- a/litellm/__init__.py +++ b/litellm/__init__.py @@ -284,11 +284,18 @@ max_end_user_budget: Optional[float] = None priority_reservation: Optional[Dict[str, float]] = None #### RELIABILITY #### REPEATED_STREAMING_CHUNK_LIMIT = 100 # catch if model starts looping the same chunk while streaming. Uses high default to prevent false positives. + +#### Networking settings #### request_timeout: float = 6000 # time in seconds +force_ipv4: bool = ( + False # when True, litellm will force ipv4 for all LLM requests. Some users have seen httpx ConnectionError when using ipv6. +) module_level_aclient = AsyncHTTPHandler( timeout=request_timeout, client_alias="module level aclient" ) module_level_client = HTTPHandler(timeout=request_timeout) + +#### RETRIES #### num_retries: Optional[int] = None # per model endpoint max_fallbacks: Optional[int] = None default_fallbacks: Optional[List] = None diff --git a/litellm/llms/custom_httpx/http_handler.py b/litellm/llms/custom_httpx/http_handler.py index 9e5ed782e..020af7e90 100644 --- a/litellm/llms/custom_httpx/http_handler.py +++ b/litellm/llms/custom_httpx/http_handler.py @@ -4,7 +4,7 @@ import traceback from typing import TYPE_CHECKING, Any, Callable, List, Mapping, Optional, Union import httpx -from httpx import USE_CLIENT_DEFAULT +from httpx import USE_CLIENT_DEFAULT, AsyncHTTPTransport, HTTPTransport import litellm @@ -60,8 +60,10 @@ class AsyncHTTPHandler: if timeout is None: timeout = _DEFAULT_TIMEOUT # Create a client with a connection pool + transport = self._create_async_transport() return httpx.AsyncClient( + transport=transport, event_hooks=event_hooks, timeout=timeout, limits=httpx.Limits( @@ -297,6 +299,18 @@ class AsyncHTTPHandler: except Exception: pass + def _create_async_transport(self) -> Optional[AsyncHTTPTransport]: + """ + Create an async transport with IPv4 only if litellm.force_ipv4 is True. + Otherwise, return None. + + Some users have seen httpx ConnectionError when using ipv6 - forcing ipv4 resolves the issue for them + """ + if litellm.force_ipv4: + return AsyncHTTPTransport(local_address="0.0.0.0") + else: + return None + class HTTPHandler: def __init__( @@ -316,8 +330,11 @@ class HTTPHandler: cert = os.getenv("SSL_CERTIFICATE", litellm.ssl_certificate) if client is None: + transport = self._create_sync_transport() + # Create a client with a connection pool self.client = httpx.Client( + transport=transport, timeout=timeout, limits=httpx.Limits( max_connections=concurrent_limit, @@ -427,6 +444,18 @@ class HTTPHandler: except Exception: pass + def _create_sync_transport(self) -> Optional[HTTPTransport]: + """ + Create an HTTP transport with IPv4 only if litellm.force_ipv4 is True. + Otherwise, return None. + + Some users have seen httpx ConnectionError when using ipv6 - forcing ipv4 resolves the issue for them + """ + if litellm.force_ipv4: + return HTTPTransport(local_address="0.0.0.0") + else: + return None + def get_async_httpx_client( llm_provider: Union[LlmProviders, httpxSpecialProvider], diff --git a/tests/local_testing/test_utils.py b/tests/local_testing/test_utils.py index 31f17eed9..6e7b0ff05 100644 --- a/tests/local_testing/test_utils.py +++ b/tests/local_testing/test_utils.py @@ -855,6 +855,7 @@ def test_async_http_handler(mock_async_client): mock_async_client.assert_called_with( cert="/client.pem", + transport=None, event_hooks=event_hooks, headers=headers, limits=httpx.Limits( @@ -866,6 +867,52 @@ def test_async_http_handler(mock_async_client): ) +@mock.patch("httpx.AsyncClient") +@mock.patch.dict(os.environ, {}, clear=True) +def test_async_http_handler_force_ipv4(mock_async_client): + """ + Test AsyncHTTPHandler when litellm.force_ipv4 is True + + This is prod test - we need to ensure that httpx always uses ipv4 when litellm.force_ipv4 is True + """ + import httpx + from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler + + # Set force_ipv4 to True + litellm.force_ipv4 = True + + try: + timeout = 120 + event_hooks = {"request": [lambda r: r]} + concurrent_limit = 2 + + AsyncHTTPHandler(timeout, event_hooks, concurrent_limit) + + # Get the call arguments + call_args = mock_async_client.call_args[1] + + ############# IMPORTANT ASSERTION ################# + # Assert transport exists and is configured correctly for using ipv4 + assert isinstance(call_args["transport"], httpx.AsyncHTTPTransport) + print(call_args["transport"]) + assert call_args["transport"]._pool._local_address == "0.0.0.0" + #################################### + + # Assert other parameters match + assert call_args["event_hooks"] == event_hooks + assert call_args["headers"] == headers + assert isinstance(call_args["limits"], httpx.Limits) + assert call_args["limits"].max_connections == concurrent_limit + assert call_args["limits"].max_keepalive_connections == concurrent_limit + assert call_args["timeout"] == timeout + assert call_args["verify"] is True + assert call_args["cert"] is None + + finally: + # Reset force_ipv4 to default + litellm.force_ipv4 = False + + @pytest.mark.parametrize( "model, expected_bool", [("gpt-3.5-turbo", False), ("gpt-4o-audio-preview", True)] )