mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-26 03:04:13 +00:00
fix(bedrock.py): convert httpx.timeout to boto3 valid timeout
Closes https://github.com/BerriAI/litellm/issues/3398
This commit is contained in:
parent
b2a0502383
commit
a732d8772a
7 changed files with 93 additions and 26 deletions
|
@ -151,7 +151,7 @@ class AzureChatCompletion(BaseLLM):
|
|||
api_type: str,
|
||||
azure_ad_token: str,
|
||||
print_verbose: Callable,
|
||||
timeout,
|
||||
timeout: Union[float, httpx.Timeout],
|
||||
logging_obj,
|
||||
optional_params,
|
||||
litellm_params,
|
||||
|
|
|
@ -533,7 +533,7 @@ def init_bedrock_client(
|
|||
aws_session_name: Optional[str] = None,
|
||||
aws_profile_name: Optional[str] = None,
|
||||
aws_role_name: Optional[str] = None,
|
||||
timeout: Optional[int] = None,
|
||||
timeout: Optional[Union[float, httpx.Timeout]] = None,
|
||||
):
|
||||
# check for custom AWS_REGION_NAME and use it if not passed to init_bedrock_client
|
||||
litellm_aws_region_name = get_secret("AWS_REGION_NAME", None)
|
||||
|
@ -592,7 +592,12 @@ def init_bedrock_client(
|
|||
|
||||
import boto3
|
||||
|
||||
config = boto3.session.Config(connect_timeout=timeout, read_timeout=timeout)
|
||||
if isinstance(timeout, float):
|
||||
config = boto3.session.Config(connect_timeout=timeout, read_timeout=timeout)
|
||||
elif isinstance(timeout, httpx.Timeout):
|
||||
config = boto3.session.Config(
|
||||
connect_timeout=timeout.connect, read_timeout=timeout.read
|
||||
)
|
||||
|
||||
### CHECK STS ###
|
||||
if aws_role_name is not None and aws_session_name is not None:
|
||||
|
|
|
@ -246,7 +246,7 @@ class OpenAIChatCompletion(BaseLLM):
|
|||
def completion(
|
||||
self,
|
||||
model_response: ModelResponse,
|
||||
timeout: float,
|
||||
timeout: Union[float, httpx.Timeout],
|
||||
model: Optional[str] = None,
|
||||
messages: Optional[list] = None,
|
||||
print_verbose: Optional[Callable] = None,
|
||||
|
@ -271,9 +271,12 @@ class OpenAIChatCompletion(BaseLLM):
|
|||
if model is None or messages is None:
|
||||
raise OpenAIError(status_code=422, message=f"Missing model or messages")
|
||||
|
||||
if not isinstance(timeout, float):
|
||||
if not isinstance(timeout, float) and not isinstance(
|
||||
timeout, httpx.Timeout
|
||||
):
|
||||
raise OpenAIError(
|
||||
status_code=422, message=f"Timeout needs to be a float"
|
||||
status_code=422,
|
||||
message=f"Timeout needs to be a float or httpx.Timeout",
|
||||
)
|
||||
|
||||
if custom_llm_provider != "openai":
|
||||
|
@ -425,7 +428,7 @@ class OpenAIChatCompletion(BaseLLM):
|
|||
self,
|
||||
data: dict,
|
||||
model_response: ModelResponse,
|
||||
timeout: float,
|
||||
timeout: Union[float, httpx.Timeout],
|
||||
api_key: Optional[str] = None,
|
||||
api_base: Optional[str] = None,
|
||||
organization: Optional[str] = None,
|
||||
|
@ -480,7 +483,7 @@ class OpenAIChatCompletion(BaseLLM):
|
|||
def streaming(
|
||||
self,
|
||||
logging_obj,
|
||||
timeout: float,
|
||||
timeout: Union[float, httpx.Timeout],
|
||||
data: dict,
|
||||
model: str,
|
||||
api_key: Optional[str] = None,
|
||||
|
@ -524,7 +527,7 @@ class OpenAIChatCompletion(BaseLLM):
|
|||
async def async_streaming(
|
||||
self,
|
||||
logging_obj,
|
||||
timeout: float,
|
||||
timeout: Union[float, httpx.Timeout],
|
||||
data: dict,
|
||||
model: str,
|
||||
api_key: Optional[str] = None,
|
||||
|
|
|
@ -39,6 +39,7 @@ from litellm.utils import (
|
|||
Usage,
|
||||
get_optional_params_embeddings,
|
||||
get_optional_params_image_gen,
|
||||
supports_httpx_timeout,
|
||||
)
|
||||
from .llms import (
|
||||
anthropic_text,
|
||||
|
@ -450,7 +451,7 @@ def completion(
|
|||
model: str,
|
||||
# Optional OpenAI params: see https://platform.openai.com/docs/api-reference/chat/create
|
||||
messages: List = [],
|
||||
timeout: Optional[Union[float, int]] = None,
|
||||
timeout: Optional[Union[float, str, httpx.Timeout]] = None,
|
||||
temperature: Optional[float] = None,
|
||||
top_p: Optional[float] = None,
|
||||
n: Optional[int] = None,
|
||||
|
@ -648,11 +649,21 @@ def completion(
|
|||
non_default_params = {
|
||||
k: v for k, v in kwargs.items() if k not in default_params
|
||||
} # model-specific params - pass them straight to the model/provider
|
||||
if timeout is None:
|
||||
timeout = (
|
||||
kwargs.get("request_timeout", None) or 600
|
||||
) # set timeout for 10 minutes by default
|
||||
timeout = float(timeout)
|
||||
|
||||
### TIMEOUT LOGIC ###
|
||||
timeout = timeout or kwargs.get("request_timeout", 600) or 600
|
||||
# set timeout for 10 minutes by default
|
||||
|
||||
if (
|
||||
timeout is not None
|
||||
and isinstance(timeout, httpx.Timeout)
|
||||
and supports_httpx_timeout(custom_llm_provider) == False
|
||||
):
|
||||
read_timeout = timeout.read or 600
|
||||
timeout = read_timeout # default 10 min timeout
|
||||
elif timeout is not None and not isinstance(timeout, httpx.Timeout):
|
||||
timeout = float(timeout) # type: ignore
|
||||
|
||||
try:
|
||||
if base_url is not None:
|
||||
api_base = base_url
|
||||
|
@ -873,7 +884,7 @@ def completion(
|
|||
logger_fn=logger_fn,
|
||||
logging_obj=logging,
|
||||
acompletion=acompletion,
|
||||
timeout=timeout,
|
||||
timeout=timeout, # type: ignore
|
||||
client=client, # pass AsyncAzureOpenAI, AzureOpenAI client
|
||||
)
|
||||
|
||||
|
@ -1014,7 +1025,7 @@ def completion(
|
|||
optional_params=optional_params,
|
||||
litellm_params=litellm_params,
|
||||
logger_fn=logger_fn,
|
||||
timeout=timeout,
|
||||
timeout=timeout, # type: ignore
|
||||
custom_prompt_dict=custom_prompt_dict,
|
||||
client=client, # pass AsyncOpenAI, OpenAI client
|
||||
organization=organization,
|
||||
|
@ -1099,7 +1110,7 @@ def completion(
|
|||
optional_params=optional_params,
|
||||
litellm_params=litellm_params,
|
||||
logger_fn=logger_fn,
|
||||
timeout=timeout,
|
||||
timeout=timeout, # type: ignore
|
||||
)
|
||||
|
||||
if (
|
||||
|
@ -1473,7 +1484,7 @@ def completion(
|
|||
acompletion=acompletion,
|
||||
logging_obj=logging,
|
||||
custom_prompt_dict=custom_prompt_dict,
|
||||
timeout=timeout,
|
||||
timeout=timeout, # type: ignore
|
||||
)
|
||||
if (
|
||||
"stream" in optional_params
|
||||
|
@ -1566,7 +1577,7 @@ def completion(
|
|||
logger_fn=logger_fn,
|
||||
logging_obj=logging,
|
||||
acompletion=acompletion,
|
||||
timeout=timeout,
|
||||
timeout=timeout, # type: ignore
|
||||
)
|
||||
## LOGGING
|
||||
logging.post_call(
|
||||
|
@ -1893,7 +1904,7 @@ def completion(
|
|||
logger_fn=logger_fn,
|
||||
encoding=encoding,
|
||||
logging_obj=logging,
|
||||
timeout=timeout,
|
||||
timeout=timeout, # type: ignore
|
||||
)
|
||||
if (
|
||||
"stream" in optional_params
|
||||
|
|
|
@ -375,7 +375,9 @@ class Router:
|
|||
except Exception as e:
|
||||
raise e
|
||||
|
||||
def _completion(self, model: str, messages: List[Dict[str, str]], **kwargs):
|
||||
def _completion(
|
||||
self, model: str, messages: List[Dict[str, str]], **kwargs
|
||||
) -> Union[ModelResponse, CustomStreamWrapper]:
|
||||
model_name = None
|
||||
try:
|
||||
# pick the one that is available (lowest TPM/RPM)
|
||||
|
@ -438,7 +440,9 @@ class Router:
|
|||
)
|
||||
raise e
|
||||
|
||||
async def acompletion(self, model: str, messages: List[Dict[str, str]], **kwargs):
|
||||
async def acompletion(
|
||||
self, model: str, messages: List[Dict[str, str]], **kwargs
|
||||
) -> Union[ModelResponse, CustomStreamWrapper]:
|
||||
try:
|
||||
kwargs["model"] = model
|
||||
kwargs["messages"] = messages
|
||||
|
@ -454,7 +458,9 @@ class Router:
|
|||
except Exception as e:
|
||||
raise e
|
||||
|
||||
async def _acompletion(self, model: str, messages: List[Dict[str, str]], **kwargs):
|
||||
async def _acompletion(
|
||||
self, model: str, messages: List[Dict[str, str]], **kwargs
|
||||
) -> Union[ModelResponse, CustomStreamWrapper]:
|
||||
"""
|
||||
- Get an available deployment
|
||||
- call it with a semaphore over the call
|
||||
|
|
|
@ -10,7 +10,37 @@ sys.path.insert(
|
|||
import time
|
||||
import litellm
|
||||
import openai
|
||||
import pytest, uuid
|
||||
import pytest, uuid, httpx
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"model, provider",
|
||||
[
|
||||
("gpt-3.5-turbo", "openai"),
|
||||
("anthropic.claude-instant-v1", "bedrock"),
|
||||
("azure/chatgpt-v-2", "azure"),
|
||||
],
|
||||
)
|
||||
@pytest.mark.parametrize("sync_mode", [True, False])
|
||||
@pytest.mark.asyncio
|
||||
async def test_httpx_timeout(model, provider, sync_mode):
|
||||
"""
|
||||
Test if setting httpx.timeout works for completion calls
|
||||
"""
|
||||
timeout_val = httpx.Timeout(10.0, connect=60.0)
|
||||
|
||||
messages = [{"role": "user", "content": "Hey, how's it going?"}]
|
||||
|
||||
if sync_mode:
|
||||
response = litellm.completion(
|
||||
model=model, messages=messages, timeout=timeout_val
|
||||
)
|
||||
else:
|
||||
response = await litellm.acompletion(
|
||||
model=model, messages=messages, timeout=timeout_val
|
||||
)
|
||||
|
||||
print(f"response: {response}")
|
||||
|
||||
|
||||
def test_timeout():
|
||||
|
|
|
@ -4442,7 +4442,19 @@ def completion_cost(
|
|||
raise e
|
||||
|
||||
|
||||
def supports_function_calling(model: str):
|
||||
def supports_httpx_timeout(custom_llm_provider: str) -> bool:
|
||||
"""
|
||||
Helper function to know if a provider implementation supports httpx timeout
|
||||
"""
|
||||
supported_providers = ["openai", "azure", "bedrock"]
|
||||
|
||||
if custom_llm_provider in supported_providers:
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
|
||||
def supports_function_calling(model: str) -> bool:
|
||||
"""
|
||||
Check if the given model supports function calling and return a boolean value.
|
||||
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue