diff --git a/litellm/llms/azure.py b/litellm/llms/azure.py index 0fe5c4e7e5..e7af9d43b6 100644 --- a/litellm/llms/azure.py +++ b/litellm/llms/azure.py @@ -151,7 +151,7 @@ class AzureChatCompletion(BaseLLM): api_type: str, azure_ad_token: str, print_verbose: Callable, - timeout, + timeout: Union[float, httpx.Timeout], logging_obj, optional_params, litellm_params, diff --git a/litellm/llms/bedrock.py b/litellm/llms/bedrock.py index 235c13c59c..7ce544c96f 100644 --- a/litellm/llms/bedrock.py +++ b/litellm/llms/bedrock.py @@ -533,7 +533,7 @@ def init_bedrock_client( aws_session_name: Optional[str] = None, aws_profile_name: Optional[str] = None, aws_role_name: Optional[str] = None, - timeout: Optional[int] = None, + timeout: Optional[Union[float, httpx.Timeout]] = None, ): # check for custom AWS_REGION_NAME and use it if not passed to init_bedrock_client litellm_aws_region_name = get_secret("AWS_REGION_NAME", None) @@ -592,7 +592,12 @@ def init_bedrock_client( import boto3 - config = boto3.session.Config(connect_timeout=timeout, read_timeout=timeout) + if isinstance(timeout, float): + config = boto3.session.Config(connect_timeout=timeout, read_timeout=timeout) + elif isinstance(timeout, httpx.Timeout): + config = boto3.session.Config( + connect_timeout=timeout.connect, read_timeout=timeout.read + ) ### CHECK STS ### if aws_role_name is not None and aws_session_name is not None: diff --git a/litellm/llms/openai.py b/litellm/llms/openai.py index f68ab235e6..5a76605b3a 100644 --- a/litellm/llms/openai.py +++ b/litellm/llms/openai.py @@ -246,7 +246,7 @@ class OpenAIChatCompletion(BaseLLM): def completion( self, model_response: ModelResponse, - timeout: float, + timeout: Union[float, httpx.Timeout], model: Optional[str] = None, messages: Optional[list] = None, print_verbose: Optional[Callable] = None, @@ -271,9 +271,12 @@ class OpenAIChatCompletion(BaseLLM): if model is None or messages is None: raise OpenAIError(status_code=422, message=f"Missing model or messages") - if not isinstance(timeout, float): + if not isinstance(timeout, float) and not isinstance( + timeout, httpx.Timeout + ): raise OpenAIError( - status_code=422, message=f"Timeout needs to be a float" + status_code=422, + message=f"Timeout needs to be a float or httpx.Timeout", ) if custom_llm_provider != "openai": @@ -425,7 +428,7 @@ class OpenAIChatCompletion(BaseLLM): self, data: dict, model_response: ModelResponse, - timeout: float, + timeout: Union[float, httpx.Timeout], api_key: Optional[str] = None, api_base: Optional[str] = None, organization: Optional[str] = None, @@ -480,7 +483,7 @@ class OpenAIChatCompletion(BaseLLM): def streaming( self, logging_obj, - timeout: float, + timeout: Union[float, httpx.Timeout], data: dict, model: str, api_key: Optional[str] = None, @@ -524,7 +527,7 @@ class OpenAIChatCompletion(BaseLLM): async def async_streaming( self, logging_obj, - timeout: float, + timeout: Union[float, httpx.Timeout], data: dict, model: str, api_key: Optional[str] = None, diff --git a/litellm/main.py b/litellm/main.py index 9765669fe1..bbcdef0de0 100644 --- a/litellm/main.py +++ b/litellm/main.py @@ -39,6 +39,7 @@ from litellm.utils import ( Usage, get_optional_params_embeddings, get_optional_params_image_gen, + supports_httpx_timeout, ) from .llms import ( anthropic_text, @@ -450,7 +451,7 @@ def completion( model: str, # Optional OpenAI params: see https://platform.openai.com/docs/api-reference/chat/create messages: List = [], - timeout: Optional[Union[float, int]] = None, + timeout: Optional[Union[float, str, httpx.Timeout]] = None, temperature: Optional[float] = None, top_p: Optional[float] = None, n: Optional[int] = None, @@ -648,11 +649,21 @@ def completion( non_default_params = { k: v for k, v in kwargs.items() if k not in default_params } # model-specific params - pass them straight to the model/provider - if timeout is None: - timeout = ( - kwargs.get("request_timeout", None) or 600 - ) # set timeout for 10 minutes by default - timeout = float(timeout) + + ### TIMEOUT LOGIC ### + timeout = timeout or kwargs.get("request_timeout", 600) or 600 + # set timeout for 10 minutes by default + + if ( + timeout is not None + and isinstance(timeout, httpx.Timeout) + and supports_httpx_timeout(custom_llm_provider) == False + ): + read_timeout = timeout.read or 600 + timeout = read_timeout # default 10 min timeout + elif timeout is not None and not isinstance(timeout, httpx.Timeout): + timeout = float(timeout) # type: ignore + try: if base_url is not None: api_base = base_url @@ -873,7 +884,7 @@ def completion( logger_fn=logger_fn, logging_obj=logging, acompletion=acompletion, - timeout=timeout, + timeout=timeout, # type: ignore client=client, # pass AsyncAzureOpenAI, AzureOpenAI client ) @@ -1014,7 +1025,7 @@ def completion( optional_params=optional_params, litellm_params=litellm_params, logger_fn=logger_fn, - timeout=timeout, + timeout=timeout, # type: ignore custom_prompt_dict=custom_prompt_dict, client=client, # pass AsyncOpenAI, OpenAI client organization=organization, @@ -1099,7 +1110,7 @@ def completion( optional_params=optional_params, litellm_params=litellm_params, logger_fn=logger_fn, - timeout=timeout, + timeout=timeout, # type: ignore ) if ( @@ -1473,7 +1484,7 @@ def completion( acompletion=acompletion, logging_obj=logging, custom_prompt_dict=custom_prompt_dict, - timeout=timeout, + timeout=timeout, # type: ignore ) if ( "stream" in optional_params @@ -1566,7 +1577,7 @@ def completion( logger_fn=logger_fn, logging_obj=logging, acompletion=acompletion, - timeout=timeout, + timeout=timeout, # type: ignore ) ## LOGGING logging.post_call( @@ -1893,7 +1904,7 @@ def completion( logger_fn=logger_fn, encoding=encoding, logging_obj=logging, - timeout=timeout, + timeout=timeout, # type: ignore ) if ( "stream" in optional_params diff --git a/litellm/router.py b/litellm/router.py index 9638db548e..d64deecec1 100644 --- a/litellm/router.py +++ b/litellm/router.py @@ -375,7 +375,9 @@ class Router: except Exception as e: raise e - def _completion(self, model: str, messages: List[Dict[str, str]], **kwargs): + def _completion( + self, model: str, messages: List[Dict[str, str]], **kwargs + ) -> Union[ModelResponse, CustomStreamWrapper]: model_name = None try: # pick the one that is available (lowest TPM/RPM) @@ -438,7 +440,9 @@ class Router: ) raise e - async def acompletion(self, model: str, messages: List[Dict[str, str]], **kwargs): + async def acompletion( + self, model: str, messages: List[Dict[str, str]], **kwargs + ) -> Union[ModelResponse, CustomStreamWrapper]: try: kwargs["model"] = model kwargs["messages"] = messages @@ -454,7 +458,9 @@ class Router: except Exception as e: raise e - async def _acompletion(self, model: str, messages: List[Dict[str, str]], **kwargs): + async def _acompletion( + self, model: str, messages: List[Dict[str, str]], **kwargs + ) -> Union[ModelResponse, CustomStreamWrapper]: """ - Get an available deployment - call it with a semaphore over the call diff --git a/litellm/tests/test_timeout.py b/litellm/tests/test_timeout.py index d38da52e51..f24b26a0cf 100644 --- a/litellm/tests/test_timeout.py +++ b/litellm/tests/test_timeout.py @@ -10,7 +10,37 @@ sys.path.insert( import time import litellm import openai -import pytest, uuid +import pytest, uuid, httpx + + +@pytest.mark.parametrize( + "model, provider", + [ + ("gpt-3.5-turbo", "openai"), + ("anthropic.claude-instant-v1", "bedrock"), + ("azure/chatgpt-v-2", "azure"), + ], +) +@pytest.mark.parametrize("sync_mode", [True, False]) +@pytest.mark.asyncio +async def test_httpx_timeout(model, provider, sync_mode): + """ + Test if setting httpx.timeout works for completion calls + """ + timeout_val = httpx.Timeout(10.0, connect=60.0) + + messages = [{"role": "user", "content": "Hey, how's it going?"}] + + if sync_mode: + response = litellm.completion( + model=model, messages=messages, timeout=timeout_val + ) + else: + response = await litellm.acompletion( + model=model, messages=messages, timeout=timeout_val + ) + + print(f"response: {response}") def test_timeout(): diff --git a/litellm/utils.py b/litellm/utils.py index 80d26f58b9..89d814e324 100644 --- a/litellm/utils.py +++ b/litellm/utils.py @@ -4442,7 +4442,19 @@ def completion_cost( raise e -def supports_function_calling(model: str): +def supports_httpx_timeout(custom_llm_provider: str) -> bool: + """ + Helper function to know if a provider implementation supports httpx timeout + """ + supported_providers = ["openai", "azure", "bedrock"] + + if custom_llm_provider in supported_providers: + return True + + return False + + +def supports_function_calling(model: str) -> bool: """ Check if the given model supports function calling and return a boolean value.