Merge branch 'main' into main

This commit is contained in:
Lucca Zenóbio 2024-05-06 09:40:23 -03:00 committed by GitHub
commit 146a49103f
98 changed files with 3926 additions and 997 deletions

View file

@ -4,7 +4,13 @@ from enum import Enum
import time, uuid
from typing import Callable, Optional, Any, Union, List
import litellm
from litellm.utils import ModelResponse, get_secret, Usage, ImageResponse
from litellm.utils import (
ModelResponse,
get_secret,
Usage,
ImageResponse,
map_finish_reason,
)
from .prompt_templates.factory import (
prompt_factory,
custom_prompt,
@ -545,7 +551,7 @@ def init_bedrock_client(
aws_profile_name: Optional[str] = None,
aws_role_name: Optional[str] = None,
extra_headers: Optional[dict] = None,
timeout: Optional[int] = None,
timeout: Optional[Union[float, httpx.Timeout]] = None,
):
# check for custom AWS_REGION_NAME and use it if not passed to init_bedrock_client
litellm_aws_region_name = get_secret("AWS_REGION_NAME", None)
@ -603,7 +609,14 @@ def init_bedrock_client(
import boto3
config = boto3.session.Config(connect_timeout=timeout, read_timeout=timeout)
if isinstance(timeout, float):
config = boto3.session.Config(connect_timeout=timeout, read_timeout=timeout)
elif isinstance(timeout, httpx.Timeout):
config = boto3.session.Config(
connect_timeout=timeout.connect, read_timeout=timeout.read
)
else:
config = boto3.session.Config()
### CHECK STS ###
if aws_role_name is not None and aws_session_name is not None:
@ -1058,7 +1071,9 @@ def completion(
logging_obj=logging_obj,
)
model_response["finish_reason"] = response_body["stop_reason"]
model_response["finish_reason"] = map_finish_reason(
response_body["stop_reason"]
)
_usage = litellm.Usage(
prompt_tokens=response_body["usage"]["input_tokens"],
completion_tokens=response_body["usage"]["output_tokens"],