mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-26 11:14:04 +00:00
fix(bedrock.py): convert httpx.timeout to boto3 valid timeout
Closes https://github.com/BerriAI/litellm/issues/3398
This commit is contained in:
parent
b2a0502383
commit
a732d8772a
7 changed files with 93 additions and 26 deletions
|
@ -151,7 +151,7 @@ class AzureChatCompletion(BaseLLM):
|
||||||
api_type: str,
|
api_type: str,
|
||||||
azure_ad_token: str,
|
azure_ad_token: str,
|
||||||
print_verbose: Callable,
|
print_verbose: Callable,
|
||||||
timeout,
|
timeout: Union[float, httpx.Timeout],
|
||||||
logging_obj,
|
logging_obj,
|
||||||
optional_params,
|
optional_params,
|
||||||
litellm_params,
|
litellm_params,
|
||||||
|
|
|
@ -533,7 +533,7 @@ def init_bedrock_client(
|
||||||
aws_session_name: Optional[str] = None,
|
aws_session_name: Optional[str] = None,
|
||||||
aws_profile_name: Optional[str] = None,
|
aws_profile_name: Optional[str] = None,
|
||||||
aws_role_name: Optional[str] = None,
|
aws_role_name: Optional[str] = None,
|
||||||
timeout: Optional[int] = None,
|
timeout: Optional[Union[float, httpx.Timeout]] = None,
|
||||||
):
|
):
|
||||||
# check for custom AWS_REGION_NAME and use it if not passed to init_bedrock_client
|
# check for custom AWS_REGION_NAME and use it if not passed to init_bedrock_client
|
||||||
litellm_aws_region_name = get_secret("AWS_REGION_NAME", None)
|
litellm_aws_region_name = get_secret("AWS_REGION_NAME", None)
|
||||||
|
@ -592,7 +592,12 @@ def init_bedrock_client(
|
||||||
|
|
||||||
import boto3
|
import boto3
|
||||||
|
|
||||||
|
if isinstance(timeout, float):
|
||||||
config = boto3.session.Config(connect_timeout=timeout, read_timeout=timeout)
|
config = boto3.session.Config(connect_timeout=timeout, read_timeout=timeout)
|
||||||
|
elif isinstance(timeout, httpx.Timeout):
|
||||||
|
config = boto3.session.Config(
|
||||||
|
connect_timeout=timeout.connect, read_timeout=timeout.read
|
||||||
|
)
|
||||||
|
|
||||||
### CHECK STS ###
|
### CHECK STS ###
|
||||||
if aws_role_name is not None and aws_session_name is not None:
|
if aws_role_name is not None and aws_session_name is not None:
|
||||||
|
|
|
@ -246,7 +246,7 @@ class OpenAIChatCompletion(BaseLLM):
|
||||||
def completion(
|
def completion(
|
||||||
self,
|
self,
|
||||||
model_response: ModelResponse,
|
model_response: ModelResponse,
|
||||||
timeout: float,
|
timeout: Union[float, httpx.Timeout],
|
||||||
model: Optional[str] = None,
|
model: Optional[str] = None,
|
||||||
messages: Optional[list] = None,
|
messages: Optional[list] = None,
|
||||||
print_verbose: Optional[Callable] = None,
|
print_verbose: Optional[Callable] = None,
|
||||||
|
@ -271,9 +271,12 @@ class OpenAIChatCompletion(BaseLLM):
|
||||||
if model is None or messages is None:
|
if model is None or messages is None:
|
||||||
raise OpenAIError(status_code=422, message=f"Missing model or messages")
|
raise OpenAIError(status_code=422, message=f"Missing model or messages")
|
||||||
|
|
||||||
if not isinstance(timeout, float):
|
if not isinstance(timeout, float) and not isinstance(
|
||||||
|
timeout, httpx.Timeout
|
||||||
|
):
|
||||||
raise OpenAIError(
|
raise OpenAIError(
|
||||||
status_code=422, message=f"Timeout needs to be a float"
|
status_code=422,
|
||||||
|
message=f"Timeout needs to be a float or httpx.Timeout",
|
||||||
)
|
)
|
||||||
|
|
||||||
if custom_llm_provider != "openai":
|
if custom_llm_provider != "openai":
|
||||||
|
@ -425,7 +428,7 @@ class OpenAIChatCompletion(BaseLLM):
|
||||||
self,
|
self,
|
||||||
data: dict,
|
data: dict,
|
||||||
model_response: ModelResponse,
|
model_response: ModelResponse,
|
||||||
timeout: float,
|
timeout: Union[float, httpx.Timeout],
|
||||||
api_key: Optional[str] = None,
|
api_key: Optional[str] = None,
|
||||||
api_base: Optional[str] = None,
|
api_base: Optional[str] = None,
|
||||||
organization: Optional[str] = None,
|
organization: Optional[str] = None,
|
||||||
|
@ -480,7 +483,7 @@ class OpenAIChatCompletion(BaseLLM):
|
||||||
def streaming(
|
def streaming(
|
||||||
self,
|
self,
|
||||||
logging_obj,
|
logging_obj,
|
||||||
timeout: float,
|
timeout: Union[float, httpx.Timeout],
|
||||||
data: dict,
|
data: dict,
|
||||||
model: str,
|
model: str,
|
||||||
api_key: Optional[str] = None,
|
api_key: Optional[str] = None,
|
||||||
|
@ -524,7 +527,7 @@ class OpenAIChatCompletion(BaseLLM):
|
||||||
async def async_streaming(
|
async def async_streaming(
|
||||||
self,
|
self,
|
||||||
logging_obj,
|
logging_obj,
|
||||||
timeout: float,
|
timeout: Union[float, httpx.Timeout],
|
||||||
data: dict,
|
data: dict,
|
||||||
model: str,
|
model: str,
|
||||||
api_key: Optional[str] = None,
|
api_key: Optional[str] = None,
|
||||||
|
|
|
@ -39,6 +39,7 @@ from litellm.utils import (
|
||||||
Usage,
|
Usage,
|
||||||
get_optional_params_embeddings,
|
get_optional_params_embeddings,
|
||||||
get_optional_params_image_gen,
|
get_optional_params_image_gen,
|
||||||
|
supports_httpx_timeout,
|
||||||
)
|
)
|
||||||
from .llms import (
|
from .llms import (
|
||||||
anthropic_text,
|
anthropic_text,
|
||||||
|
@ -450,7 +451,7 @@ def completion(
|
||||||
model: str,
|
model: str,
|
||||||
# Optional OpenAI params: see https://platform.openai.com/docs/api-reference/chat/create
|
# Optional OpenAI params: see https://platform.openai.com/docs/api-reference/chat/create
|
||||||
messages: List = [],
|
messages: List = [],
|
||||||
timeout: Optional[Union[float, int]] = None,
|
timeout: Optional[Union[float, str, httpx.Timeout]] = None,
|
||||||
temperature: Optional[float] = None,
|
temperature: Optional[float] = None,
|
||||||
top_p: Optional[float] = None,
|
top_p: Optional[float] = None,
|
||||||
n: Optional[int] = None,
|
n: Optional[int] = None,
|
||||||
|
@ -648,11 +649,21 @@ def completion(
|
||||||
non_default_params = {
|
non_default_params = {
|
||||||
k: v for k, v in kwargs.items() if k not in default_params
|
k: v for k, v in kwargs.items() if k not in default_params
|
||||||
} # model-specific params - pass them straight to the model/provider
|
} # model-specific params - pass them straight to the model/provider
|
||||||
if timeout is None:
|
|
||||||
timeout = (
|
### TIMEOUT LOGIC ###
|
||||||
kwargs.get("request_timeout", None) or 600
|
timeout = timeout or kwargs.get("request_timeout", 600) or 600
|
||||||
) # set timeout for 10 minutes by default
|
# set timeout for 10 minutes by default
|
||||||
timeout = float(timeout)
|
|
||||||
|
if (
|
||||||
|
timeout is not None
|
||||||
|
and isinstance(timeout, httpx.Timeout)
|
||||||
|
and supports_httpx_timeout(custom_llm_provider) == False
|
||||||
|
):
|
||||||
|
read_timeout = timeout.read or 600
|
||||||
|
timeout = read_timeout # default 10 min timeout
|
||||||
|
elif timeout is not None and not isinstance(timeout, httpx.Timeout):
|
||||||
|
timeout = float(timeout) # type: ignore
|
||||||
|
|
||||||
try:
|
try:
|
||||||
if base_url is not None:
|
if base_url is not None:
|
||||||
api_base = base_url
|
api_base = base_url
|
||||||
|
@ -873,7 +884,7 @@ def completion(
|
||||||
logger_fn=logger_fn,
|
logger_fn=logger_fn,
|
||||||
logging_obj=logging,
|
logging_obj=logging,
|
||||||
acompletion=acompletion,
|
acompletion=acompletion,
|
||||||
timeout=timeout,
|
timeout=timeout, # type: ignore
|
||||||
client=client, # pass AsyncAzureOpenAI, AzureOpenAI client
|
client=client, # pass AsyncAzureOpenAI, AzureOpenAI client
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -1014,7 +1025,7 @@ def completion(
|
||||||
optional_params=optional_params,
|
optional_params=optional_params,
|
||||||
litellm_params=litellm_params,
|
litellm_params=litellm_params,
|
||||||
logger_fn=logger_fn,
|
logger_fn=logger_fn,
|
||||||
timeout=timeout,
|
timeout=timeout, # type: ignore
|
||||||
custom_prompt_dict=custom_prompt_dict,
|
custom_prompt_dict=custom_prompt_dict,
|
||||||
client=client, # pass AsyncOpenAI, OpenAI client
|
client=client, # pass AsyncOpenAI, OpenAI client
|
||||||
organization=organization,
|
organization=organization,
|
||||||
|
@ -1099,7 +1110,7 @@ def completion(
|
||||||
optional_params=optional_params,
|
optional_params=optional_params,
|
||||||
litellm_params=litellm_params,
|
litellm_params=litellm_params,
|
||||||
logger_fn=logger_fn,
|
logger_fn=logger_fn,
|
||||||
timeout=timeout,
|
timeout=timeout, # type: ignore
|
||||||
)
|
)
|
||||||
|
|
||||||
if (
|
if (
|
||||||
|
@ -1473,7 +1484,7 @@ def completion(
|
||||||
acompletion=acompletion,
|
acompletion=acompletion,
|
||||||
logging_obj=logging,
|
logging_obj=logging,
|
||||||
custom_prompt_dict=custom_prompt_dict,
|
custom_prompt_dict=custom_prompt_dict,
|
||||||
timeout=timeout,
|
timeout=timeout, # type: ignore
|
||||||
)
|
)
|
||||||
if (
|
if (
|
||||||
"stream" in optional_params
|
"stream" in optional_params
|
||||||
|
@ -1566,7 +1577,7 @@ def completion(
|
||||||
logger_fn=logger_fn,
|
logger_fn=logger_fn,
|
||||||
logging_obj=logging,
|
logging_obj=logging,
|
||||||
acompletion=acompletion,
|
acompletion=acompletion,
|
||||||
timeout=timeout,
|
timeout=timeout, # type: ignore
|
||||||
)
|
)
|
||||||
## LOGGING
|
## LOGGING
|
||||||
logging.post_call(
|
logging.post_call(
|
||||||
|
@ -1893,7 +1904,7 @@ def completion(
|
||||||
logger_fn=logger_fn,
|
logger_fn=logger_fn,
|
||||||
encoding=encoding,
|
encoding=encoding,
|
||||||
logging_obj=logging,
|
logging_obj=logging,
|
||||||
timeout=timeout,
|
timeout=timeout, # type: ignore
|
||||||
)
|
)
|
||||||
if (
|
if (
|
||||||
"stream" in optional_params
|
"stream" in optional_params
|
||||||
|
|
|
@ -375,7 +375,9 @@ class Router:
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise e
|
raise e
|
||||||
|
|
||||||
def _completion(self, model: str, messages: List[Dict[str, str]], **kwargs):
|
def _completion(
|
||||||
|
self, model: str, messages: List[Dict[str, str]], **kwargs
|
||||||
|
) -> Union[ModelResponse, CustomStreamWrapper]:
|
||||||
model_name = None
|
model_name = None
|
||||||
try:
|
try:
|
||||||
# pick the one that is available (lowest TPM/RPM)
|
# pick the one that is available (lowest TPM/RPM)
|
||||||
|
@ -438,7 +440,9 @@ class Router:
|
||||||
)
|
)
|
||||||
raise e
|
raise e
|
||||||
|
|
||||||
async def acompletion(self, model: str, messages: List[Dict[str, str]], **kwargs):
|
async def acompletion(
|
||||||
|
self, model: str, messages: List[Dict[str, str]], **kwargs
|
||||||
|
) -> Union[ModelResponse, CustomStreamWrapper]:
|
||||||
try:
|
try:
|
||||||
kwargs["model"] = model
|
kwargs["model"] = model
|
||||||
kwargs["messages"] = messages
|
kwargs["messages"] = messages
|
||||||
|
@ -454,7 +458,9 @@ class Router:
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise e
|
raise e
|
||||||
|
|
||||||
async def _acompletion(self, model: str, messages: List[Dict[str, str]], **kwargs):
|
async def _acompletion(
|
||||||
|
self, model: str, messages: List[Dict[str, str]], **kwargs
|
||||||
|
) -> Union[ModelResponse, CustomStreamWrapper]:
|
||||||
"""
|
"""
|
||||||
- Get an available deployment
|
- Get an available deployment
|
||||||
- call it with a semaphore over the call
|
- call it with a semaphore over the call
|
||||||
|
|
|
@ -10,7 +10,37 @@ sys.path.insert(
|
||||||
import time
|
import time
|
||||||
import litellm
|
import litellm
|
||||||
import openai
|
import openai
|
||||||
import pytest, uuid
|
import pytest, uuid, httpx
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"model, provider",
|
||||||
|
[
|
||||||
|
("gpt-3.5-turbo", "openai"),
|
||||||
|
("anthropic.claude-instant-v1", "bedrock"),
|
||||||
|
("azure/chatgpt-v-2", "azure"),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
@pytest.mark.parametrize("sync_mode", [True, False])
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_httpx_timeout(model, provider, sync_mode):
|
||||||
|
"""
|
||||||
|
Test if setting httpx.timeout works for completion calls
|
||||||
|
"""
|
||||||
|
timeout_val = httpx.Timeout(10.0, connect=60.0)
|
||||||
|
|
||||||
|
messages = [{"role": "user", "content": "Hey, how's it going?"}]
|
||||||
|
|
||||||
|
if sync_mode:
|
||||||
|
response = litellm.completion(
|
||||||
|
model=model, messages=messages, timeout=timeout_val
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
response = await litellm.acompletion(
|
||||||
|
model=model, messages=messages, timeout=timeout_val
|
||||||
|
)
|
||||||
|
|
||||||
|
print(f"response: {response}")
|
||||||
|
|
||||||
|
|
||||||
def test_timeout():
|
def test_timeout():
|
||||||
|
|
|
@ -4442,7 +4442,19 @@ def completion_cost(
|
||||||
raise e
|
raise e
|
||||||
|
|
||||||
|
|
||||||
def supports_function_calling(model: str):
|
def supports_httpx_timeout(custom_llm_provider: str) -> bool:
|
||||||
|
"""
|
||||||
|
Helper function to know if a provider implementation supports httpx timeout
|
||||||
|
"""
|
||||||
|
supported_providers = ["openai", "azure", "bedrock"]
|
||||||
|
|
||||||
|
if custom_llm_provider in supported_providers:
|
||||||
|
return True
|
||||||
|
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
def supports_function_calling(model: str) -> bool:
|
||||||
"""
|
"""
|
||||||
Check if the given model supports function calling and return a boolean value.
|
Check if the given model supports function calling and return a boolean value.
|
||||||
|
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue