fix(bedrock.py): convert httpx.timeout to boto3 valid timeout

Closes https://github.com/BerriAI/litellm/issues/3398
This commit is contained in:
Krrish Dholakia 2024-05-03 16:24:21 -07:00
parent b2a0502383
commit a732d8772a
7 changed files with 93 additions and 26 deletions

View file

@ -151,7 +151,7 @@ class AzureChatCompletion(BaseLLM):
api_type: str, api_type: str,
azure_ad_token: str, azure_ad_token: str,
print_verbose: Callable, print_verbose: Callable,
timeout, timeout: Union[float, httpx.Timeout],
logging_obj, logging_obj,
optional_params, optional_params,
litellm_params, litellm_params,

View file

@ -533,7 +533,7 @@ def init_bedrock_client(
aws_session_name: Optional[str] = None, aws_session_name: Optional[str] = None,
aws_profile_name: Optional[str] = None, aws_profile_name: Optional[str] = None,
aws_role_name: Optional[str] = None, aws_role_name: Optional[str] = None,
timeout: Optional[int] = None, timeout: Optional[Union[float, httpx.Timeout]] = None,
): ):
# check for custom AWS_REGION_NAME and use it if not passed to init_bedrock_client # check for custom AWS_REGION_NAME and use it if not passed to init_bedrock_client
litellm_aws_region_name = get_secret("AWS_REGION_NAME", None) litellm_aws_region_name = get_secret("AWS_REGION_NAME", None)
@ -592,7 +592,12 @@ def init_bedrock_client(
import boto3 import boto3
if isinstance(timeout, float):
config = boto3.session.Config(connect_timeout=timeout, read_timeout=timeout) config = boto3.session.Config(connect_timeout=timeout, read_timeout=timeout)
elif isinstance(timeout, httpx.Timeout):
config = boto3.session.Config(
connect_timeout=timeout.connect, read_timeout=timeout.read
)
### CHECK STS ### ### CHECK STS ###
if aws_role_name is not None and aws_session_name is not None: if aws_role_name is not None and aws_session_name is not None:

View file

@ -246,7 +246,7 @@ class OpenAIChatCompletion(BaseLLM):
def completion( def completion(
self, self,
model_response: ModelResponse, model_response: ModelResponse,
timeout: float, timeout: Union[float, httpx.Timeout],
model: Optional[str] = None, model: Optional[str] = None,
messages: Optional[list] = None, messages: Optional[list] = None,
print_verbose: Optional[Callable] = None, print_verbose: Optional[Callable] = None,
@ -271,9 +271,12 @@ class OpenAIChatCompletion(BaseLLM):
if model is None or messages is None: if model is None or messages is None:
raise OpenAIError(status_code=422, message=f"Missing model or messages") raise OpenAIError(status_code=422, message=f"Missing model or messages")
if not isinstance(timeout, float): if not isinstance(timeout, float) and not isinstance(
timeout, httpx.Timeout
):
raise OpenAIError( raise OpenAIError(
status_code=422, message=f"Timeout needs to be a float" status_code=422,
message=f"Timeout needs to be a float or httpx.Timeout",
) )
if custom_llm_provider != "openai": if custom_llm_provider != "openai":
@ -425,7 +428,7 @@ class OpenAIChatCompletion(BaseLLM):
self, self,
data: dict, data: dict,
model_response: ModelResponse, model_response: ModelResponse,
timeout: float, timeout: Union[float, httpx.Timeout],
api_key: Optional[str] = None, api_key: Optional[str] = None,
api_base: Optional[str] = None, api_base: Optional[str] = None,
organization: Optional[str] = None, organization: Optional[str] = None,
@ -480,7 +483,7 @@ class OpenAIChatCompletion(BaseLLM):
def streaming( def streaming(
self, self,
logging_obj, logging_obj,
timeout: float, timeout: Union[float, httpx.Timeout],
data: dict, data: dict,
model: str, model: str,
api_key: Optional[str] = None, api_key: Optional[str] = None,
@ -524,7 +527,7 @@ class OpenAIChatCompletion(BaseLLM):
async def async_streaming( async def async_streaming(
self, self,
logging_obj, logging_obj,
timeout: float, timeout: Union[float, httpx.Timeout],
data: dict, data: dict,
model: str, model: str,
api_key: Optional[str] = None, api_key: Optional[str] = None,

View file

@ -39,6 +39,7 @@ from litellm.utils import (
Usage, Usage,
get_optional_params_embeddings, get_optional_params_embeddings,
get_optional_params_image_gen, get_optional_params_image_gen,
supports_httpx_timeout,
) )
from .llms import ( from .llms import (
anthropic_text, anthropic_text,
@ -450,7 +451,7 @@ def completion(
model: str, model: str,
# Optional OpenAI params: see https://platform.openai.com/docs/api-reference/chat/create # Optional OpenAI params: see https://platform.openai.com/docs/api-reference/chat/create
messages: List = [], messages: List = [],
timeout: Optional[Union[float, int]] = None, timeout: Optional[Union[float, str, httpx.Timeout]] = None,
temperature: Optional[float] = None, temperature: Optional[float] = None,
top_p: Optional[float] = None, top_p: Optional[float] = None,
n: Optional[int] = None, n: Optional[int] = None,
@ -648,11 +649,21 @@ def completion(
non_default_params = { non_default_params = {
k: v for k, v in kwargs.items() if k not in default_params k: v for k, v in kwargs.items() if k not in default_params
} # model-specific params - pass them straight to the model/provider } # model-specific params - pass them straight to the model/provider
if timeout is None:
timeout = ( ### TIMEOUT LOGIC ###
kwargs.get("request_timeout", None) or 600 timeout = timeout or kwargs.get("request_timeout", 600) or 600
) # set timeout for 10 minutes by default # set timeout for 10 minutes by default
timeout = float(timeout)
if (
timeout is not None
and isinstance(timeout, httpx.Timeout)
and supports_httpx_timeout(custom_llm_provider) == False
):
read_timeout = timeout.read or 600
timeout = read_timeout # default 10 min timeout
elif timeout is not None and not isinstance(timeout, httpx.Timeout):
timeout = float(timeout) # type: ignore
try: try:
if base_url is not None: if base_url is not None:
api_base = base_url api_base = base_url
@ -873,7 +884,7 @@ def completion(
logger_fn=logger_fn, logger_fn=logger_fn,
logging_obj=logging, logging_obj=logging,
acompletion=acompletion, acompletion=acompletion,
timeout=timeout, timeout=timeout, # type: ignore
client=client, # pass AsyncAzureOpenAI, AzureOpenAI client client=client, # pass AsyncAzureOpenAI, AzureOpenAI client
) )
@ -1014,7 +1025,7 @@ def completion(
optional_params=optional_params, optional_params=optional_params,
litellm_params=litellm_params, litellm_params=litellm_params,
logger_fn=logger_fn, logger_fn=logger_fn,
timeout=timeout, timeout=timeout, # type: ignore
custom_prompt_dict=custom_prompt_dict, custom_prompt_dict=custom_prompt_dict,
client=client, # pass AsyncOpenAI, OpenAI client client=client, # pass AsyncOpenAI, OpenAI client
organization=organization, organization=organization,
@ -1099,7 +1110,7 @@ def completion(
optional_params=optional_params, optional_params=optional_params,
litellm_params=litellm_params, litellm_params=litellm_params,
logger_fn=logger_fn, logger_fn=logger_fn,
timeout=timeout, timeout=timeout, # type: ignore
) )
if ( if (
@ -1473,7 +1484,7 @@ def completion(
acompletion=acompletion, acompletion=acompletion,
logging_obj=logging, logging_obj=logging,
custom_prompt_dict=custom_prompt_dict, custom_prompt_dict=custom_prompt_dict,
timeout=timeout, timeout=timeout, # type: ignore
) )
if ( if (
"stream" in optional_params "stream" in optional_params
@ -1566,7 +1577,7 @@ def completion(
logger_fn=logger_fn, logger_fn=logger_fn,
logging_obj=logging, logging_obj=logging,
acompletion=acompletion, acompletion=acompletion,
timeout=timeout, timeout=timeout, # type: ignore
) )
## LOGGING ## LOGGING
logging.post_call( logging.post_call(
@ -1893,7 +1904,7 @@ def completion(
logger_fn=logger_fn, logger_fn=logger_fn,
encoding=encoding, encoding=encoding,
logging_obj=logging, logging_obj=logging,
timeout=timeout, timeout=timeout, # type: ignore
) )
if ( if (
"stream" in optional_params "stream" in optional_params

View file

@ -375,7 +375,9 @@ class Router:
except Exception as e: except Exception as e:
raise e raise e
def _completion(self, model: str, messages: List[Dict[str, str]], **kwargs): def _completion(
self, model: str, messages: List[Dict[str, str]], **kwargs
) -> Union[ModelResponse, CustomStreamWrapper]:
model_name = None model_name = None
try: try:
# pick the one that is available (lowest TPM/RPM) # pick the one that is available (lowest TPM/RPM)
@ -438,7 +440,9 @@ class Router:
) )
raise e raise e
async def acompletion(self, model: str, messages: List[Dict[str, str]], **kwargs): async def acompletion(
self, model: str, messages: List[Dict[str, str]], **kwargs
) -> Union[ModelResponse, CustomStreamWrapper]:
try: try:
kwargs["model"] = model kwargs["model"] = model
kwargs["messages"] = messages kwargs["messages"] = messages
@ -454,7 +458,9 @@ class Router:
except Exception as e: except Exception as e:
raise e raise e
async def _acompletion(self, model: str, messages: List[Dict[str, str]], **kwargs): async def _acompletion(
self, model: str, messages: List[Dict[str, str]], **kwargs
) -> Union[ModelResponse, CustomStreamWrapper]:
""" """
- Get an available deployment - Get an available deployment
- call it with a semaphore over the call - call it with a semaphore over the call

View file

@ -10,7 +10,37 @@ sys.path.insert(
import time import time
import litellm import litellm
import openai import openai
import pytest, uuid import pytest, uuid, httpx
@pytest.mark.parametrize(
"model, provider",
[
("gpt-3.5-turbo", "openai"),
("anthropic.claude-instant-v1", "bedrock"),
("azure/chatgpt-v-2", "azure"),
],
)
@pytest.mark.parametrize("sync_mode", [True, False])
@pytest.mark.asyncio
async def test_httpx_timeout(model, provider, sync_mode):
"""
Test if setting httpx.timeout works for completion calls
"""
timeout_val = httpx.Timeout(10.0, connect=60.0)
messages = [{"role": "user", "content": "Hey, how's it going?"}]
if sync_mode:
response = litellm.completion(
model=model, messages=messages, timeout=timeout_val
)
else:
response = await litellm.acompletion(
model=model, messages=messages, timeout=timeout_val
)
print(f"response: {response}")
def test_timeout(): def test_timeout():

View file

@ -4442,7 +4442,19 @@ def completion_cost(
raise e raise e
def supports_function_calling(model: str): def supports_httpx_timeout(custom_llm_provider: str) -> bool:
"""
Helper function to know if a provider implementation supports httpx timeout
"""
supported_providers = ["openai", "azure", "bedrock"]
if custom_llm_provider in supported_providers:
return True
return False
def supports_function_calling(model: str) -> bool:
""" """
Check if the given model supports function calling and return a boolean value. Check if the given model supports function calling and return a boolean value.