diff --git a/litellm/llms/custom_httpx/http_handler.py b/litellm/llms/custom_httpx/http_handler.py index ddffe9ad8..02d494f9f 100644 --- a/litellm/llms/custom_httpx/http_handler.py +++ b/litellm/llms/custom_httpx/http_handler.py @@ -80,12 +80,18 @@ class AsyncHTTPHandler: json: Optional[dict] = None, params: Optional[dict] = None, headers: Optional[dict] = None, + timeout: Optional[Union[float, httpx.Timeout]] = None, stream: bool = False, ): try: - req = self.client.build_request( - "POST", url, data=data, json=json, params=params, headers=headers # type: ignore - ) + if timeout is not None: + req = self.client.build_request( + "POST", url, data=data, json=json, params=params, headers=headers, timeout=timeout # type: ignore + ) + else: + req = self.client.build_request( + "POST", url, data=data, json=json, params=params, headers=headers # type: ignore + ) response = await self.client.send(req, stream=stream) response.raise_for_status() return response @@ -104,6 +110,14 @@ class AsyncHTTPHandler: ) finally: await new_client.aclose() + except httpx.ConnectTimeout: + if data is None: + data = {} + raise litellm.Timeout( + message=f"Connection timed out after {timeout} seconds.", + model=data.get("model"), + llm_provider="litellm-httpx-handler", + ) except httpx.HTTPStatusError as e: setattr(e, "status_code", e.response.status_code) if stream is True: @@ -192,13 +206,30 @@ class HTTPHandler: params: Optional[dict] = None, headers: Optional[dict] = None, stream: bool = False, + timeout: Optional[Union[float, httpx.Timeout]] = None, ): + try: - req = self.client.build_request( - "POST", url, data=data, json=json, params=params, headers=headers # type: ignore - ) - response = self.client.send(req, stream=stream) - return response + if timeout is not None: + req = self.client.build_request( + "POST", url, data=data, json=json, params=params, headers=headers, timeout=timeout # type: ignore + ) + else: + req = self.client.build_request( + "POST", url, data=data, json=json, params=params, headers=headers # type: ignore + ) + response = self.client.send(req, stream=stream) + return response + except httpx.ConnectTimeout: + if data is None: + data = {} + raise litellm.Timeout( + message=f"Connection timed out after {timeout} seconds.", + model=data.get("model"), + llm_provider="litellm-httpx-handler", + ) + except Exception as e: + raise e def __del__(self) -> None: try: