fix triton

This commit is contained in:
Ishaan Jaff 2024-11-21 09:39:48 -08:00
parent ddfe687b13
commit 0420b07c13

View file

@ -8,7 +8,11 @@ import httpx # type: ignore
import requests # type: ignore
import litellm
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, HTTPHandler
from litellm.llms.custom_httpx.http_handler import (
AsyncHTTPHandler,
HTTPHandler,
get_async_httpx_client,
)
from litellm.utils import (
Choices,
CustomStreamWrapper,
@ -50,8 +54,8 @@ class TritonChatCompletion(BaseLLM):
logging_obj: Any,
api_key: Optional[str] = None,
) -> EmbeddingResponse:
async_handler = AsyncHTTPHandler(
timeout=httpx.Timeout(timeout=600.0, connect=5.0)
async_handler = get_async_httpx_client(
llm_provider=litellm.LlmProviders.TRITON, params={"timeout": 600.0}
)
response = await async_handler.post(url=api_base, data=json.dumps(data))
@ -261,7 +265,9 @@ class TritonChatCompletion(BaseLLM):
model_response,
type_of_model,
) -> ModelResponse:
handler = AsyncHTTPHandler()
handler = get_async_httpx_client(
llm_provider=litellm.LlmProviders.TRITON, params={"timeout": 600.0}
)
if stream:
return self._ahandle_stream( # type: ignore
handler, api_base, data_for_triton, model, logging_obj