fix(bedrock.py): convert httpx.timeout to boto3 valid timeout

Closes https://github.com/BerriAI/litellm/issues/3398
This commit is contained in:
Krrish Dholakia 2024-05-03 16:24:21 -07:00
parent b2a0502383
commit a732d8772a
7 changed files with 93 additions and 26 deletions

View file

@ -39,6 +39,7 @@ from litellm.utils import (
Usage,
get_optional_params_embeddings,
get_optional_params_image_gen,
supports_httpx_timeout,
)
from .llms import (
anthropic_text,
@ -450,7 +451,7 @@ def completion(
model: str,
# Optional OpenAI params: see https://platform.openai.com/docs/api-reference/chat/create
messages: List = [],
timeout: Optional[Union[float, int]] = None,
timeout: Optional[Union[float, str, httpx.Timeout]] = None,
temperature: Optional[float] = None,
top_p: Optional[float] = None,
n: Optional[int] = None,
@ -648,11 +649,21 @@ def completion(
non_default_params = {
k: v for k, v in kwargs.items() if k not in default_params
} # model-specific params - pass them straight to the model/provider
if timeout is None:
timeout = (
kwargs.get("request_timeout", None) or 600
) # set timeout for 10 minutes by default
timeout = float(timeout)
### TIMEOUT LOGIC ###
timeout = timeout or kwargs.get("request_timeout", 600) or 600
# set timeout for 10 minutes by default
if (
timeout is not None
and isinstance(timeout, httpx.Timeout)
and supports_httpx_timeout(custom_llm_provider) == False
):
read_timeout = timeout.read or 600
timeout = read_timeout # default 10 min timeout
elif timeout is not None and not isinstance(timeout, httpx.Timeout):
timeout = float(timeout) # type: ignore
try:
if base_url is not None:
api_base = base_url
@ -873,7 +884,7 @@ def completion(
logger_fn=logger_fn,
logging_obj=logging,
acompletion=acompletion,
timeout=timeout,
timeout=timeout, # type: ignore
client=client, # pass AsyncAzureOpenAI, AzureOpenAI client
)
@ -1014,7 +1025,7 @@ def completion(
optional_params=optional_params,
litellm_params=litellm_params,
logger_fn=logger_fn,
timeout=timeout,
timeout=timeout, # type: ignore
custom_prompt_dict=custom_prompt_dict,
client=client, # pass AsyncOpenAI, OpenAI client
organization=organization,
@ -1099,7 +1110,7 @@ def completion(
optional_params=optional_params,
litellm_params=litellm_params,
logger_fn=logger_fn,
timeout=timeout,
timeout=timeout, # type: ignore
)
if (
@ -1473,7 +1484,7 @@ def completion(
acompletion=acompletion,
logging_obj=logging,
custom_prompt_dict=custom_prompt_dict,
timeout=timeout,
timeout=timeout, # type: ignore
)
if (
"stream" in optional_params
@ -1566,7 +1577,7 @@ def completion(
logger_fn=logger_fn,
logging_obj=logging,
acompletion=acompletion,
timeout=timeout,
timeout=timeout, # type: ignore
)
## LOGGING
logging.post_call(
@ -1893,7 +1904,7 @@ def completion(
logger_fn=logger_fn,
encoding=encoding,
logging_obj=logging,
timeout=timeout,
timeout=timeout, # type: ignore
)
if (
"stream" in optional_params