fix triton

This commit is contained in:
Ishaan Jaff 2024-11-21 09:39:48 -08:00
parent ddfe687b13
commit 0420b07c13

View file

@ -8,7 +8,11 @@ import httpx # type: ignore
import requests # type: ignore import requests # type: ignore
import litellm import litellm
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, HTTPHandler from litellm.llms.custom_httpx.http_handler import (
AsyncHTTPHandler,
HTTPHandler,
get_async_httpx_client,
)
from litellm.utils import ( from litellm.utils import (
Choices, Choices,
CustomStreamWrapper, CustomStreamWrapper,
@ -50,8 +54,8 @@ class TritonChatCompletion(BaseLLM):
logging_obj: Any, logging_obj: Any,
api_key: Optional[str] = None, api_key: Optional[str] = None,
) -> EmbeddingResponse: ) -> EmbeddingResponse:
async_handler = AsyncHTTPHandler( async_handler = get_async_httpx_client(
timeout=httpx.Timeout(timeout=600.0, connect=5.0) llm_provider=litellm.LlmProviders.TRITON, params={"timeout": 600.0}
) )
response = await async_handler.post(url=api_base, data=json.dumps(data)) response = await async_handler.post(url=api_base, data=json.dumps(data))
@ -261,7 +265,9 @@ class TritonChatCompletion(BaseLLM):
model_response, model_response,
type_of_model, type_of_model,
) -> ModelResponse: ) -> ModelResponse:
handler = AsyncHTTPHandler() handler = get_async_httpx_client(
llm_provider=litellm.LlmProviders.TRITON, params={"timeout": 600.0}
)
if stream: if stream:
return self._ahandle_stream( # type: ignore return self._ahandle_stream( # type: ignore
handler, api_base, data_for_triton, model, logging_obj handler, api_base, data_for_triton, model, logging_obj