diff --git a/litellm/llms/triton.py b/litellm/llms/triton.py index be4179ccc..efd0d0a2d 100644 --- a/litellm/llms/triton.py +++ b/litellm/llms/triton.py @@ -8,7 +8,11 @@ import httpx # type: ignore import requests # type: ignore import litellm -from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, HTTPHandler +from litellm.llms.custom_httpx.http_handler import ( + AsyncHTTPHandler, + HTTPHandler, + get_async_httpx_client, +) from litellm.utils import ( Choices, CustomStreamWrapper, @@ -50,8 +54,8 @@ class TritonChatCompletion(BaseLLM): logging_obj: Any, api_key: Optional[str] = None, ) -> EmbeddingResponse: - async_handler = AsyncHTTPHandler( - timeout=httpx.Timeout(timeout=600.0, connect=5.0) + async_handler = get_async_httpx_client( + llm_provider=litellm.LlmProviders.TRITON, params={"timeout": 600.0} ) response = await async_handler.post(url=api_base, data=json.dumps(data)) @@ -261,7 +265,9 @@ class TritonChatCompletion(BaseLLM): model_response, type_of_model, ) -> ModelResponse: - handler = AsyncHTTPHandler() + handler = get_async_httpx_client( + llm_provider=litellm.LlmProviders.TRITON, params={"timeout": 600.0} + ) if stream: return self._ahandle_stream( # type: ignore handler, api_base, data_for_triton, model, logging_obj