forked from phoenix/litellm-mirror
Litellm stable dev (#5711)
* feat(aws_base_llm.py): prevents recreating boto3 credentials during high traffic Leads to 100ms perf boost in local testing * fix(base_aws_llm.py): fix credential caching check to see if token is set * refactor(bedrock/chat): separate converse api and invoke api + isolate converse api transformation logic Make it easier to see how requests are transformed for /converse * fix: fix imports * fix(bedrock/embed): fix reordering of headers * fix(base_aws_llm.py): fix get credential logic * fix(converse_handler.py): fix ai21 streaming response
This commit is contained in:
parent
2efdd2a6a4
commit
da77706c26
14 changed files with 1073 additions and 1039 deletions
|
@ -914,7 +914,7 @@ from .llms.sagemaker.sagemaker import SagemakerConfig
|
||||||
from .llms.ollama import OllamaConfig
|
from .llms.ollama import OllamaConfig
|
||||||
from .llms.ollama_chat import OllamaChatConfig
|
from .llms.ollama_chat import OllamaChatConfig
|
||||||
from .llms.maritalk import MaritTalkConfig
|
from .llms.maritalk import MaritTalkConfig
|
||||||
from .llms.bedrock.chat import (
|
from .llms.bedrock.chat.invoke_handler import (
|
||||||
AmazonCohereChatConfig,
|
AmazonCohereChatConfig,
|
||||||
AmazonConverseConfig,
|
AmazonConverseConfig,
|
||||||
BEDROCK_CONVERSE_MODELS,
|
BEDROCK_CONVERSE_MODELS,
|
||||||
|
|
|
@ -1,5 +1,7 @@
|
||||||
|
import hashlib
|
||||||
import json
|
import json
|
||||||
from typing import List, Optional
|
import os
|
||||||
|
from typing import Dict, List, Optional, Tuple
|
||||||
|
|
||||||
import httpx
|
import httpx
|
||||||
|
|
||||||
|
@ -28,6 +30,14 @@ class BaseAWSLLM(BaseLLM):
|
||||||
self.iam_cache = DualCache()
|
self.iam_cache = DualCache()
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
|
def get_cache_key(self, credential_args: Dict[str, Optional[str]]) -> str:
|
||||||
|
"""
|
||||||
|
Generate a unique cache key based on the credential arguments.
|
||||||
|
"""
|
||||||
|
# Convert credential arguments to a JSON string and hash it to create a unique key
|
||||||
|
credential_str = json.dumps(credential_args, sort_keys=True)
|
||||||
|
return hashlib.sha256(credential_str.encode()).hexdigest()
|
||||||
|
|
||||||
def get_credentials(
|
def get_credentials(
|
||||||
self,
|
self,
|
||||||
aws_access_key_id: Optional[str] = None,
|
aws_access_key_id: Optional[str] = None,
|
||||||
|
@ -43,9 +53,22 @@ class BaseAWSLLM(BaseLLM):
|
||||||
"""
|
"""
|
||||||
Return a boto3.Credentials object
|
Return a boto3.Credentials object
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import boto3
|
import boto3
|
||||||
|
from botocore.credentials import Credentials
|
||||||
|
|
||||||
## CHECK IS 'os.environ/' passed in
|
## CHECK IS 'os.environ/' passed in
|
||||||
|
param_names = [
|
||||||
|
"aws_access_key_id",
|
||||||
|
"aws_secret_access_key",
|
||||||
|
"aws_session_token",
|
||||||
|
"aws_region_name",
|
||||||
|
"aws_session_name",
|
||||||
|
"aws_profile_name",
|
||||||
|
"aws_role_name",
|
||||||
|
"aws_web_identity_token",
|
||||||
|
"aws_sts_endpoint",
|
||||||
|
]
|
||||||
params_to_check: List[Optional[str]] = [
|
params_to_check: List[Optional[str]] = [
|
||||||
aws_access_key_id,
|
aws_access_key_id,
|
||||||
aws_secret_access_key,
|
aws_secret_access_key,
|
||||||
|
@ -64,6 +87,11 @@ class BaseAWSLLM(BaseLLM):
|
||||||
_v = get_secret(param)
|
_v = get_secret(param)
|
||||||
if _v is not None and isinstance(_v, str):
|
if _v is not None and isinstance(_v, str):
|
||||||
params_to_check[i] = _v
|
params_to_check[i] = _v
|
||||||
|
elif param is None: # check if uppercase value in env
|
||||||
|
key = param_names[i]
|
||||||
|
if key.upper() in os.environ:
|
||||||
|
params_to_check[i] = os.getenv(key)
|
||||||
|
|
||||||
# Assign updated values back to parameters
|
# Assign updated values back to parameters
|
||||||
(
|
(
|
||||||
aws_access_key_id,
|
aws_access_key_id,
|
||||||
|
@ -77,6 +105,10 @@ class BaseAWSLLM(BaseLLM):
|
||||||
aws_sts_endpoint,
|
aws_sts_endpoint,
|
||||||
) = params_to_check
|
) = params_to_check
|
||||||
|
|
||||||
|
# create cache key for non-expiring auth flows
|
||||||
|
args = {k: v for k, v in locals().items() if k.startswith("aws_")}
|
||||||
|
cache_key = self.get_cache_key(args)
|
||||||
|
|
||||||
verbose_logger.debug(
|
verbose_logger.debug(
|
||||||
"in get credentials\n"
|
"in get credentials\n"
|
||||||
"aws_access_key_id=%s\n"
|
"aws_access_key_id=%s\n"
|
||||||
|
@ -186,7 +218,6 @@ class BaseAWSLLM(BaseLLM):
|
||||||
|
|
||||||
# Extract the credentials from the response and convert to Session Credentials
|
# Extract the credentials from the response and convert to Session Credentials
|
||||||
sts_credentials = sts_response["Credentials"]
|
sts_credentials = sts_response["Credentials"]
|
||||||
from botocore.credentials import Credentials
|
|
||||||
|
|
||||||
credentials = Credentials(
|
credentials = Credentials(
|
||||||
access_key=sts_credentials["AccessKeyId"],
|
access_key=sts_credentials["AccessKeyId"],
|
||||||
|
@ -211,12 +242,72 @@ class BaseAWSLLM(BaseLLM):
|
||||||
secret_key=aws_secret_access_key,
|
secret_key=aws_secret_access_key,
|
||||||
token=aws_session_token,
|
token=aws_session_token,
|
||||||
)
|
)
|
||||||
|
|
||||||
return credentials
|
return credentials
|
||||||
else:
|
elif (
|
||||||
|
aws_access_key_id is not None
|
||||||
|
and aws_secret_access_key is not None
|
||||||
|
and aws_region_name is not None
|
||||||
|
):
|
||||||
|
# Check if credentials are already in cache. These credentials have no expiry time.
|
||||||
|
cached_credentials: Optional[Credentials] = self.iam_cache.get_cache(
|
||||||
|
cache_key
|
||||||
|
)
|
||||||
|
if cached_credentials:
|
||||||
|
return cached_credentials
|
||||||
|
|
||||||
session = boto3.Session(
|
session = boto3.Session(
|
||||||
aws_access_key_id=aws_access_key_id,
|
aws_access_key_id=aws_access_key_id,
|
||||||
aws_secret_access_key=aws_secret_access_key,
|
aws_secret_access_key=aws_secret_access_key,
|
||||||
region_name=aws_region_name,
|
region_name=aws_region_name,
|
||||||
)
|
)
|
||||||
|
|
||||||
return session.get_credentials()
|
credentials = session.get_credentials()
|
||||||
|
|
||||||
|
if (
|
||||||
|
credentials.token is None
|
||||||
|
): # don't cache if session token exists. The expiry time for that is not known.
|
||||||
|
self.iam_cache.set_cache(cache_key, credentials, ttl=3600 - 60)
|
||||||
|
|
||||||
|
return credentials
|
||||||
|
else:
|
||||||
|
# check env var. Do not cache the response from this.
|
||||||
|
session = boto3.Session()
|
||||||
|
|
||||||
|
credentials = session.get_credentials()
|
||||||
|
|
||||||
|
return credentials
|
||||||
|
|
||||||
|
def get_runtime_endpoint(
|
||||||
|
self,
|
||||||
|
api_base: Optional[str],
|
||||||
|
aws_bedrock_runtime_endpoint: Optional[str],
|
||||||
|
aws_region_name: str,
|
||||||
|
) -> Tuple[str, str]:
|
||||||
|
env_aws_bedrock_runtime_endpoint = get_secret("AWS_BEDROCK_RUNTIME_ENDPOINT")
|
||||||
|
if api_base is not None:
|
||||||
|
endpoint_url = api_base
|
||||||
|
elif aws_bedrock_runtime_endpoint is not None and isinstance(
|
||||||
|
aws_bedrock_runtime_endpoint, str
|
||||||
|
):
|
||||||
|
endpoint_url = aws_bedrock_runtime_endpoint
|
||||||
|
elif env_aws_bedrock_runtime_endpoint and isinstance(
|
||||||
|
env_aws_bedrock_runtime_endpoint, str
|
||||||
|
):
|
||||||
|
endpoint_url = env_aws_bedrock_runtime_endpoint
|
||||||
|
else:
|
||||||
|
endpoint_url = f"https://bedrock-runtime.{aws_region_name}.amazonaws.com"
|
||||||
|
|
||||||
|
# Determine proxy_endpoint_url
|
||||||
|
if env_aws_bedrock_runtime_endpoint and isinstance(
|
||||||
|
env_aws_bedrock_runtime_endpoint, str
|
||||||
|
):
|
||||||
|
proxy_endpoint_url = env_aws_bedrock_runtime_endpoint
|
||||||
|
elif aws_bedrock_runtime_endpoint is not None and isinstance(
|
||||||
|
aws_bedrock_runtime_endpoint, str
|
||||||
|
):
|
||||||
|
proxy_endpoint_url = aws_bedrock_runtime_endpoint
|
||||||
|
else:
|
||||||
|
proxy_endpoint_url = endpoint_url
|
||||||
|
|
||||||
|
return endpoint_url, proxy_endpoint_url
|
||||||
|
|
2
litellm/llms/bedrock/chat/__init__.py
Normal file
2
litellm/llms/bedrock/chat/__init__.py
Normal file
|
@ -0,0 +1,2 @@
|
||||||
|
from .converse_handler import BedrockConverseLLM
|
||||||
|
from .invoke_handler import BedrockLLM
|
408
litellm/llms/bedrock/chat/converse_handler.py
Normal file
408
litellm/llms/bedrock/chat/converse_handler.py
Normal file
|
@ -0,0 +1,408 @@
|
||||||
|
import json
|
||||||
|
import urllib
|
||||||
|
from typing import Any, Callable, Optional, Union
|
||||||
|
|
||||||
|
import httpx
|
||||||
|
|
||||||
|
import litellm
|
||||||
|
from litellm.llms.custom_httpx.http_handler import (
|
||||||
|
AsyncHTTPHandler,
|
||||||
|
HTTPHandler,
|
||||||
|
_get_httpx_client,
|
||||||
|
get_async_httpx_client,
|
||||||
|
)
|
||||||
|
from litellm.types.utils import ModelResponse
|
||||||
|
from litellm.utils import CustomStreamWrapper, get_secret
|
||||||
|
|
||||||
|
from ...base_aws_llm import BaseAWSLLM
|
||||||
|
from ..common_utils import BedrockError
|
||||||
|
from .invoke_handler import AWSEventStreamDecoder, MockResponseIterator, make_call
|
||||||
|
|
||||||
|
|
||||||
|
def make_sync_call(
|
||||||
|
client: Optional[HTTPHandler],
|
||||||
|
api_base: str,
|
||||||
|
headers: dict,
|
||||||
|
data: str,
|
||||||
|
model: str,
|
||||||
|
messages: list,
|
||||||
|
logging_obj,
|
||||||
|
):
|
||||||
|
if client is None:
|
||||||
|
client = _get_httpx_client() # Create a new client if none provided
|
||||||
|
|
||||||
|
response = client.post(
|
||||||
|
api_base,
|
||||||
|
headers=headers,
|
||||||
|
data=data,
|
||||||
|
stream=True if "ai21" not in api_base else False,
|
||||||
|
)
|
||||||
|
|
||||||
|
if response.status_code != 200:
|
||||||
|
raise BedrockError(status_code=response.status_code, message=response.read())
|
||||||
|
|
||||||
|
if "ai21" in api_base:
|
||||||
|
model_response: (
|
||||||
|
ModelResponse
|
||||||
|
) = litellm.AmazonConverseConfig()._transform_response(
|
||||||
|
model=model,
|
||||||
|
response=response,
|
||||||
|
model_response=litellm.ModelResponse(),
|
||||||
|
stream=True,
|
||||||
|
logging_obj=logging_obj,
|
||||||
|
optional_params={},
|
||||||
|
api_key="",
|
||||||
|
data=data,
|
||||||
|
messages=messages,
|
||||||
|
print_verbose=litellm.print_verbose,
|
||||||
|
encoding=litellm.encoding,
|
||||||
|
) # type: ignore
|
||||||
|
completion_stream: Any = MockResponseIterator(model_response=model_response)
|
||||||
|
else:
|
||||||
|
decoder = AWSEventStreamDecoder(model=model)
|
||||||
|
completion_stream = decoder.iter_bytes(response.iter_bytes(chunk_size=1024))
|
||||||
|
|
||||||
|
# LOGGING
|
||||||
|
logging_obj.post_call(
|
||||||
|
input=messages,
|
||||||
|
api_key="",
|
||||||
|
original_response="first stream response received",
|
||||||
|
additional_args={"complete_input_dict": data},
|
||||||
|
)
|
||||||
|
|
||||||
|
return completion_stream
|
||||||
|
|
||||||
|
|
||||||
|
class BedrockConverseLLM(BaseAWSLLM):
|
||||||
|
def __init__(self) -> None:
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
def encode_model_id(self, model_id: str) -> str:
|
||||||
|
"""
|
||||||
|
Double encode the model ID to ensure it matches the expected double-encoded format.
|
||||||
|
Args:
|
||||||
|
model_id (str): The model ID to encode.
|
||||||
|
Returns:
|
||||||
|
str: The double-encoded model ID.
|
||||||
|
"""
|
||||||
|
return urllib.parse.quote(model_id, safe="") # type: ignore
|
||||||
|
|
||||||
|
async def async_streaming(
|
||||||
|
self,
|
||||||
|
model: str,
|
||||||
|
messages: list,
|
||||||
|
api_base: str,
|
||||||
|
model_response: ModelResponse,
|
||||||
|
print_verbose: Callable,
|
||||||
|
data: str,
|
||||||
|
timeout: Optional[Union[float, httpx.Timeout]],
|
||||||
|
encoding,
|
||||||
|
logging_obj,
|
||||||
|
stream,
|
||||||
|
optional_params: dict,
|
||||||
|
litellm_params=None,
|
||||||
|
logger_fn=None,
|
||||||
|
headers={},
|
||||||
|
client: Optional[AsyncHTTPHandler] = None,
|
||||||
|
) -> CustomStreamWrapper:
|
||||||
|
|
||||||
|
completion_stream = await make_call(
|
||||||
|
client=client,
|
||||||
|
api_base=api_base,
|
||||||
|
headers=headers,
|
||||||
|
data=data,
|
||||||
|
model=model,
|
||||||
|
messages=messages,
|
||||||
|
logging_obj=logging_obj,
|
||||||
|
)
|
||||||
|
streaming_response = CustomStreamWrapper(
|
||||||
|
completion_stream=completion_stream,
|
||||||
|
model=model,
|
||||||
|
custom_llm_provider="bedrock",
|
||||||
|
logging_obj=logging_obj,
|
||||||
|
)
|
||||||
|
return streaming_response
|
||||||
|
|
||||||
|
async def async_completion(
|
||||||
|
self,
|
||||||
|
model: str,
|
||||||
|
messages: list,
|
||||||
|
api_base: str,
|
||||||
|
model_response: ModelResponse,
|
||||||
|
print_verbose: Callable,
|
||||||
|
data: str,
|
||||||
|
timeout: Optional[Union[float, httpx.Timeout]],
|
||||||
|
encoding,
|
||||||
|
logging_obj,
|
||||||
|
stream,
|
||||||
|
optional_params: dict,
|
||||||
|
litellm_params=None,
|
||||||
|
logger_fn=None,
|
||||||
|
headers={},
|
||||||
|
client: Optional[AsyncHTTPHandler] = None,
|
||||||
|
) -> Union[ModelResponse, CustomStreamWrapper]:
|
||||||
|
if client is None or not isinstance(client, AsyncHTTPHandler):
|
||||||
|
_params = {}
|
||||||
|
if timeout is not None:
|
||||||
|
if isinstance(timeout, float) or isinstance(timeout, int):
|
||||||
|
timeout = httpx.Timeout(timeout)
|
||||||
|
_params["timeout"] = timeout
|
||||||
|
client = get_async_httpx_client(
|
||||||
|
params=_params, llm_provider=litellm.LlmProviders.BEDROCK
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
client = client # type: ignore
|
||||||
|
|
||||||
|
try:
|
||||||
|
response = await client.post(url=api_base, headers=headers, data=data) # type: ignore
|
||||||
|
response.raise_for_status()
|
||||||
|
except httpx.HTTPStatusError as err:
|
||||||
|
error_code = err.response.status_code
|
||||||
|
raise BedrockError(status_code=error_code, message=err.response.text)
|
||||||
|
except httpx.TimeoutException as e:
|
||||||
|
raise BedrockError(status_code=408, message="Timeout error occurred.")
|
||||||
|
|
||||||
|
return litellm.AmazonConverseConfig()._transform_response(
|
||||||
|
model=model,
|
||||||
|
response=response,
|
||||||
|
model_response=model_response,
|
||||||
|
stream=stream if isinstance(stream, bool) else False,
|
||||||
|
logging_obj=logging_obj,
|
||||||
|
api_key="",
|
||||||
|
data=data,
|
||||||
|
messages=messages,
|
||||||
|
print_verbose=print_verbose,
|
||||||
|
optional_params=optional_params,
|
||||||
|
encoding=encoding,
|
||||||
|
)
|
||||||
|
|
||||||
|
def completion(
|
||||||
|
self,
|
||||||
|
model: str,
|
||||||
|
messages: list,
|
||||||
|
api_base: Optional[str],
|
||||||
|
custom_prompt_dict: dict,
|
||||||
|
model_response: ModelResponse,
|
||||||
|
print_verbose: Callable,
|
||||||
|
encoding,
|
||||||
|
logging_obj,
|
||||||
|
optional_params: dict,
|
||||||
|
acompletion: bool,
|
||||||
|
timeout: Optional[Union[float, httpx.Timeout]],
|
||||||
|
litellm_params: dict,
|
||||||
|
logger_fn=None,
|
||||||
|
extra_headers: Optional[dict] = None,
|
||||||
|
client: Optional[Union[AsyncHTTPHandler, HTTPHandler]] = None,
|
||||||
|
):
|
||||||
|
try:
|
||||||
|
import boto3
|
||||||
|
from botocore.auth import SigV4Auth
|
||||||
|
from botocore.awsrequest import AWSRequest
|
||||||
|
from botocore.credentials import Credentials
|
||||||
|
except ImportError:
|
||||||
|
raise ImportError("Missing boto3 to call bedrock. Run 'pip install boto3'.")
|
||||||
|
|
||||||
|
## SETUP ##
|
||||||
|
stream = optional_params.pop("stream", None)
|
||||||
|
modelId = optional_params.pop("model_id", None)
|
||||||
|
if modelId is not None:
|
||||||
|
modelId = self.encode_model_id(model_id=modelId)
|
||||||
|
else:
|
||||||
|
modelId = model
|
||||||
|
|
||||||
|
provider = model.split(".")[0]
|
||||||
|
|
||||||
|
## CREDENTIALS ##
|
||||||
|
# pop aws_secret_access_key, aws_access_key_id, aws_region_name from kwargs, since completion calls fail with them
|
||||||
|
aws_secret_access_key = optional_params.pop("aws_secret_access_key", None)
|
||||||
|
aws_access_key_id = optional_params.pop("aws_access_key_id", None)
|
||||||
|
aws_session_token = optional_params.pop("aws_session_token", None)
|
||||||
|
aws_region_name = optional_params.pop("aws_region_name", None)
|
||||||
|
aws_role_name = optional_params.pop("aws_role_name", None)
|
||||||
|
aws_session_name = optional_params.pop("aws_session_name", None)
|
||||||
|
aws_profile_name = optional_params.pop("aws_profile_name", None)
|
||||||
|
aws_bedrock_runtime_endpoint = optional_params.pop(
|
||||||
|
"aws_bedrock_runtime_endpoint", None
|
||||||
|
) # https://bedrock-runtime.{region_name}.amazonaws.com
|
||||||
|
aws_web_identity_token = optional_params.pop("aws_web_identity_token", None)
|
||||||
|
aws_sts_endpoint = optional_params.pop("aws_sts_endpoint", None)
|
||||||
|
|
||||||
|
### SET REGION NAME ###
|
||||||
|
if aws_region_name is None:
|
||||||
|
# check env #
|
||||||
|
litellm_aws_region_name = get_secret("AWS_REGION_NAME", None)
|
||||||
|
|
||||||
|
if litellm_aws_region_name is not None and isinstance(
|
||||||
|
litellm_aws_region_name, str
|
||||||
|
):
|
||||||
|
aws_region_name = litellm_aws_region_name
|
||||||
|
|
||||||
|
standard_aws_region_name = get_secret("AWS_REGION", None)
|
||||||
|
if standard_aws_region_name is not None and isinstance(
|
||||||
|
standard_aws_region_name, str
|
||||||
|
):
|
||||||
|
aws_region_name = standard_aws_region_name
|
||||||
|
|
||||||
|
if aws_region_name is None:
|
||||||
|
aws_region_name = "us-west-2"
|
||||||
|
|
||||||
|
credentials: Credentials = self.get_credentials(
|
||||||
|
aws_access_key_id=aws_access_key_id,
|
||||||
|
aws_secret_access_key=aws_secret_access_key,
|
||||||
|
aws_session_token=aws_session_token,
|
||||||
|
aws_region_name=aws_region_name,
|
||||||
|
aws_session_name=aws_session_name,
|
||||||
|
aws_profile_name=aws_profile_name,
|
||||||
|
aws_role_name=aws_role_name,
|
||||||
|
aws_web_identity_token=aws_web_identity_token,
|
||||||
|
aws_sts_endpoint=aws_sts_endpoint,
|
||||||
|
)
|
||||||
|
|
||||||
|
### SET RUNTIME ENDPOINT ###
|
||||||
|
endpoint_url, proxy_endpoint_url = self.get_runtime_endpoint(
|
||||||
|
api_base=api_base,
|
||||||
|
aws_bedrock_runtime_endpoint=aws_bedrock_runtime_endpoint,
|
||||||
|
aws_region_name=aws_region_name,
|
||||||
|
)
|
||||||
|
if (stream is not None and stream is True) and provider != "ai21":
|
||||||
|
endpoint_url = f"{endpoint_url}/model/{modelId}/converse-stream"
|
||||||
|
proxy_endpoint_url = f"{proxy_endpoint_url}/model/{modelId}/converse-stream"
|
||||||
|
else:
|
||||||
|
endpoint_url = f"{endpoint_url}/model/{modelId}/converse"
|
||||||
|
proxy_endpoint_url = f"{proxy_endpoint_url}/model/{modelId}/converse"
|
||||||
|
|
||||||
|
sigv4 = SigV4Auth(credentials, "bedrock", aws_region_name)
|
||||||
|
|
||||||
|
## TRANSFORMATION ##
|
||||||
|
|
||||||
|
_data = litellm.AmazonConverseConfig()._transform_request(
|
||||||
|
model=model,
|
||||||
|
messages=messages,
|
||||||
|
optional_params=optional_params,
|
||||||
|
litellm_params=litellm_params,
|
||||||
|
)
|
||||||
|
data = json.dumps(_data)
|
||||||
|
## COMPLETION CALL
|
||||||
|
|
||||||
|
headers = {"Content-Type": "application/json"}
|
||||||
|
if extra_headers is not None:
|
||||||
|
headers = {"Content-Type": "application/json", **extra_headers}
|
||||||
|
request = AWSRequest(
|
||||||
|
method="POST", url=endpoint_url, data=data, headers=headers
|
||||||
|
)
|
||||||
|
sigv4.add_auth(request)
|
||||||
|
if (
|
||||||
|
extra_headers is not None and "Authorization" in extra_headers
|
||||||
|
): # prevent sigv4 from overwriting the auth header
|
||||||
|
request.headers["Authorization"] = extra_headers["Authorization"]
|
||||||
|
prepped = request.prepare()
|
||||||
|
|
||||||
|
## LOGGING
|
||||||
|
logging_obj.pre_call(
|
||||||
|
input=messages,
|
||||||
|
api_key="",
|
||||||
|
additional_args={
|
||||||
|
"complete_input_dict": data,
|
||||||
|
"api_base": proxy_endpoint_url,
|
||||||
|
"headers": prepped.headers,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
### ROUTING (ASYNC, STREAMING, SYNC)
|
||||||
|
if acompletion:
|
||||||
|
if isinstance(client, HTTPHandler):
|
||||||
|
client = None
|
||||||
|
if stream is True:
|
||||||
|
return self.async_streaming(
|
||||||
|
model=model,
|
||||||
|
messages=messages,
|
||||||
|
data=data,
|
||||||
|
api_base=proxy_endpoint_url,
|
||||||
|
model_response=model_response,
|
||||||
|
print_verbose=print_verbose,
|
||||||
|
encoding=encoding,
|
||||||
|
logging_obj=logging_obj,
|
||||||
|
optional_params=optional_params,
|
||||||
|
stream=True,
|
||||||
|
litellm_params=litellm_params,
|
||||||
|
logger_fn=logger_fn,
|
||||||
|
headers=prepped.headers,
|
||||||
|
timeout=timeout,
|
||||||
|
client=client,
|
||||||
|
) # type: ignore
|
||||||
|
### ASYNC COMPLETION
|
||||||
|
return self.async_completion(
|
||||||
|
model=model,
|
||||||
|
messages=messages,
|
||||||
|
data=data,
|
||||||
|
api_base=proxy_endpoint_url,
|
||||||
|
model_response=model_response,
|
||||||
|
print_verbose=print_verbose,
|
||||||
|
encoding=encoding,
|
||||||
|
logging_obj=logging_obj,
|
||||||
|
optional_params=optional_params,
|
||||||
|
stream=stream, # type: ignore
|
||||||
|
litellm_params=litellm_params,
|
||||||
|
logger_fn=logger_fn,
|
||||||
|
headers=prepped.headers,
|
||||||
|
timeout=timeout,
|
||||||
|
client=client,
|
||||||
|
) # type: ignore
|
||||||
|
|
||||||
|
if client is None or isinstance(client, AsyncHTTPHandler):
|
||||||
|
_params = {}
|
||||||
|
if timeout is not None:
|
||||||
|
if isinstance(timeout, float) or isinstance(timeout, int):
|
||||||
|
timeout = httpx.Timeout(timeout)
|
||||||
|
_params["timeout"] = timeout
|
||||||
|
client = _get_httpx_client(_params) # type: ignore
|
||||||
|
else:
|
||||||
|
client = client
|
||||||
|
|
||||||
|
if stream is not None and stream is True:
|
||||||
|
completion_stream = make_sync_call(
|
||||||
|
client=(
|
||||||
|
client
|
||||||
|
if client is not None and isinstance(client, HTTPHandler)
|
||||||
|
else None
|
||||||
|
),
|
||||||
|
api_base=proxy_endpoint_url,
|
||||||
|
headers=prepped.headers, # type: ignore
|
||||||
|
data=data,
|
||||||
|
model=model,
|
||||||
|
messages=messages,
|
||||||
|
logging_obj=logging_obj,
|
||||||
|
)
|
||||||
|
streaming_response = CustomStreamWrapper(
|
||||||
|
completion_stream=completion_stream,
|
||||||
|
model=model,
|
||||||
|
custom_llm_provider="bedrock",
|
||||||
|
logging_obj=logging_obj,
|
||||||
|
)
|
||||||
|
|
||||||
|
return streaming_response
|
||||||
|
|
||||||
|
### COMPLETION
|
||||||
|
|
||||||
|
try:
|
||||||
|
response = client.post(url=proxy_endpoint_url, headers=prepped.headers, data=data) # type: ignore
|
||||||
|
response.raise_for_status()
|
||||||
|
except httpx.HTTPStatusError as err:
|
||||||
|
error_code = err.response.status_code
|
||||||
|
raise BedrockError(status_code=error_code, message=err.response.text)
|
||||||
|
except httpx.TimeoutException:
|
||||||
|
raise BedrockError(status_code=408, message="Timeout error occurred.")
|
||||||
|
|
||||||
|
return litellm.AmazonConverseConfig()._transform_response(
|
||||||
|
model=model,
|
||||||
|
response=response,
|
||||||
|
model_response=model_response,
|
||||||
|
stream=stream if isinstance(stream, bool) else False,
|
||||||
|
logging_obj=logging_obj,
|
||||||
|
api_key="",
|
||||||
|
data=data,
|
||||||
|
messages=messages,
|
||||||
|
print_verbose=print_verbose,
|
||||||
|
optional_params=optional_params,
|
||||||
|
encoding=encoding,
|
||||||
|
)
|
431
litellm/llms/bedrock/chat/converse_transformation.py
Normal file
431
litellm/llms/bedrock/chat/converse_transformation.py
Normal file
|
@ -0,0 +1,431 @@
|
||||||
|
"""
|
||||||
|
Translating between OpenAI's `/chat/completion` format and Amazon's `/converse` format
|
||||||
|
"""
|
||||||
|
|
||||||
|
import copy
|
||||||
|
import time
|
||||||
|
import types
|
||||||
|
from typing import List, Optional, Union
|
||||||
|
|
||||||
|
import httpx
|
||||||
|
|
||||||
|
import litellm
|
||||||
|
from litellm.litellm_core_utils.core_helpers import map_finish_reason
|
||||||
|
from litellm.litellm_core_utils.litellm_logging import Logging
|
||||||
|
from litellm.types.llms.bedrock import *
|
||||||
|
from litellm.types.llms.openai import (
|
||||||
|
AllMessageValues,
|
||||||
|
ChatCompletionResponseMessage,
|
||||||
|
ChatCompletionToolCallChunk,
|
||||||
|
ChatCompletionToolCallFunctionChunk,
|
||||||
|
ChatCompletionToolParam,
|
||||||
|
ChatCompletionToolParamFunctionChunk,
|
||||||
|
)
|
||||||
|
from litellm.types.utils import ModelResponse, Usage
|
||||||
|
from litellm.utils import CustomStreamWrapper
|
||||||
|
|
||||||
|
from ...prompt_templates.factory import _bedrock_converse_messages_pt, _bedrock_tools_pt
|
||||||
|
from ..common_utils import BedrockError, get_bedrock_tool_name
|
||||||
|
|
||||||
|
|
||||||
|
class AmazonConverseConfig:
|
||||||
|
"""
|
||||||
|
Reference - https://docs.aws.amazon.com/bedrock/latest/APIReference/API_runtime_Converse.html
|
||||||
|
#2 - https://docs.aws.amazon.com/bedrock/latest/userguide/conversation-inference.html#conversation-inference-supported-models-features
|
||||||
|
"""
|
||||||
|
|
||||||
|
maxTokens: Optional[int]
|
||||||
|
stopSequences: Optional[List[str]]
|
||||||
|
temperature: Optional[int]
|
||||||
|
topP: Optional[int]
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
maxTokens: Optional[int] = None,
|
||||||
|
stopSequences: Optional[List[str]] = None,
|
||||||
|
temperature: Optional[int] = None,
|
||||||
|
topP: Optional[int] = None,
|
||||||
|
) -> None:
|
||||||
|
locals_ = locals()
|
||||||
|
for key, value in locals_.items():
|
||||||
|
if key != "self" and value is not None:
|
||||||
|
setattr(self.__class__, key, value)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_config(cls):
|
||||||
|
return {
|
||||||
|
k: v
|
||||||
|
for k, v in cls.__dict__.items()
|
||||||
|
if not k.startswith("__")
|
||||||
|
and not isinstance(
|
||||||
|
v,
|
||||||
|
(
|
||||||
|
types.FunctionType,
|
||||||
|
types.BuiltinFunctionType,
|
||||||
|
classmethod,
|
||||||
|
staticmethod,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
and v is not None
|
||||||
|
}
|
||||||
|
|
||||||
|
def get_supported_openai_params(self, model: str) -> List[str]:
|
||||||
|
supported_params = [
|
||||||
|
"max_tokens",
|
||||||
|
"stream",
|
||||||
|
"stream_options",
|
||||||
|
"stop",
|
||||||
|
"temperature",
|
||||||
|
"top_p",
|
||||||
|
"extra_headers",
|
||||||
|
"response_format",
|
||||||
|
]
|
||||||
|
|
||||||
|
if (
|
||||||
|
model.startswith("anthropic")
|
||||||
|
or model.startswith("mistral")
|
||||||
|
or model.startswith("cohere")
|
||||||
|
or model.startswith("meta.llama3-1")
|
||||||
|
):
|
||||||
|
supported_params.append("tools")
|
||||||
|
|
||||||
|
if model.startswith("anthropic") or model.startswith("mistral"):
|
||||||
|
# only anthropic and mistral support tool choice config. otherwise (E.g. cohere) will fail the call - https://docs.aws.amazon.com/bedrock/latest/APIReference/API_runtime_ToolChoice.html
|
||||||
|
supported_params.append("tool_choice")
|
||||||
|
|
||||||
|
return supported_params
|
||||||
|
|
||||||
|
def map_tool_choice_values(
|
||||||
|
self, model: str, tool_choice: Union[str, dict], drop_params: bool
|
||||||
|
) -> Optional[ToolChoiceValuesBlock]:
|
||||||
|
if tool_choice == "none":
|
||||||
|
if litellm.drop_params is True or drop_params is True:
|
||||||
|
return None
|
||||||
|
else:
|
||||||
|
raise litellm.utils.UnsupportedParamsError(
|
||||||
|
message="Bedrock doesn't support tool_choice={}. To drop it from the call, set `litellm.drop_params = True.".format(
|
||||||
|
tool_choice
|
||||||
|
),
|
||||||
|
status_code=400,
|
||||||
|
)
|
||||||
|
elif tool_choice == "required":
|
||||||
|
return ToolChoiceValuesBlock(any={})
|
||||||
|
elif tool_choice == "auto":
|
||||||
|
return ToolChoiceValuesBlock(auto={})
|
||||||
|
elif isinstance(tool_choice, dict):
|
||||||
|
# only supported for anthropic + mistral models - https://docs.aws.amazon.com/bedrock/latest/APIReference/API_runtime_ToolChoice.html
|
||||||
|
specific_tool = SpecificToolChoiceBlock(
|
||||||
|
name=tool_choice.get("function", {}).get("name", "")
|
||||||
|
)
|
||||||
|
return ToolChoiceValuesBlock(tool=specific_tool)
|
||||||
|
else:
|
||||||
|
raise litellm.utils.UnsupportedParamsError(
|
||||||
|
message="Bedrock doesn't support tool_choice={}. Supported tool_choice values=['auto', 'required', json object]. To drop it from the call, set `litellm.drop_params = True.".format(
|
||||||
|
tool_choice
|
||||||
|
),
|
||||||
|
status_code=400,
|
||||||
|
)
|
||||||
|
|
||||||
|
def get_supported_image_types(self) -> List[str]:
|
||||||
|
return ["png", "jpeg", "gif", "webp"]
|
||||||
|
|
||||||
|
def map_openai_params(
|
||||||
|
self,
|
||||||
|
model: str,
|
||||||
|
non_default_params: dict,
|
||||||
|
optional_params: dict,
|
||||||
|
drop_params: bool,
|
||||||
|
) -> dict:
|
||||||
|
for param, value in non_default_params.items():
|
||||||
|
if param == "response_format":
|
||||||
|
json_schema: Optional[dict] = None
|
||||||
|
schema_name: str = ""
|
||||||
|
if "response_schema" in value:
|
||||||
|
json_schema = value["response_schema"]
|
||||||
|
schema_name = "json_tool_call"
|
||||||
|
elif "json_schema" in value:
|
||||||
|
json_schema = value["json_schema"]["schema"]
|
||||||
|
schema_name = value["json_schema"]["name"]
|
||||||
|
"""
|
||||||
|
Follow similar approach to anthropic - translate to a single tool call.
|
||||||
|
|
||||||
|
When using tools in this way: - https://docs.anthropic.com/en/docs/build-with-claude/tool-use#json-mode
|
||||||
|
- You usually want to provide a single tool
|
||||||
|
- You should set tool_choice (see Forcing tool use) to instruct the model to explicitly use that tool
|
||||||
|
- Remember that the model will pass the input to the tool, so the name of the tool and description should be from the model’s perspective.
|
||||||
|
"""
|
||||||
|
if json_schema is not None:
|
||||||
|
_tool_choice = self.map_tool_choice_values(
|
||||||
|
model=model, tool_choice="required", drop_params=drop_params # type: ignore
|
||||||
|
)
|
||||||
|
|
||||||
|
_tool = ChatCompletionToolParam(
|
||||||
|
type="function",
|
||||||
|
function=ChatCompletionToolParamFunctionChunk(
|
||||||
|
name=schema_name, parameters=json_schema
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
optional_params["tools"] = [_tool]
|
||||||
|
optional_params["tool_choice"] = _tool_choice
|
||||||
|
optional_params["json_mode"] = True
|
||||||
|
else:
|
||||||
|
if litellm.drop_params is True or drop_params is True:
|
||||||
|
pass
|
||||||
|
else:
|
||||||
|
raise litellm.utils.UnsupportedParamsError(
|
||||||
|
message="Bedrock doesn't support response_format={}. To drop it from the call, set `litellm.drop_params = True.".format(
|
||||||
|
value
|
||||||
|
),
|
||||||
|
status_code=400,
|
||||||
|
)
|
||||||
|
if param == "max_tokens":
|
||||||
|
optional_params["maxTokens"] = value
|
||||||
|
if param == "stream":
|
||||||
|
optional_params["stream"] = value
|
||||||
|
if param == "stop":
|
||||||
|
if isinstance(value, str):
|
||||||
|
if len(value) == 0: # converse raises error for empty strings
|
||||||
|
continue
|
||||||
|
value = [value]
|
||||||
|
optional_params["stopSequences"] = value
|
||||||
|
if param == "temperature":
|
||||||
|
optional_params["temperature"] = value
|
||||||
|
if param == "top_p":
|
||||||
|
optional_params["topP"] = value
|
||||||
|
if param == "tools":
|
||||||
|
optional_params["tools"] = value
|
||||||
|
if param == "tool_choice":
|
||||||
|
_tool_choice_value = self.map_tool_choice_values(
|
||||||
|
model=model, tool_choice=value, drop_params=drop_params # type: ignore
|
||||||
|
)
|
||||||
|
if _tool_choice_value is not None:
|
||||||
|
optional_params["tool_choice"] = _tool_choice_value
|
||||||
|
return optional_params
|
||||||
|
|
||||||
|
def _transform_request(
|
||||||
|
self,
|
||||||
|
model: str,
|
||||||
|
messages: List[AllMessageValues],
|
||||||
|
optional_params: dict,
|
||||||
|
litellm_params: dict,
|
||||||
|
) -> RequestObject:
|
||||||
|
system_prompt_indices = []
|
||||||
|
system_content_blocks: List[SystemContentBlock] = []
|
||||||
|
for idx, message in enumerate(messages):
|
||||||
|
if message["role"] == "system":
|
||||||
|
_system_content_block: Optional[SystemContentBlock] = None
|
||||||
|
if isinstance(message["content"], str) and len(message["content"]) > 0:
|
||||||
|
_system_content_block = SystemContentBlock(text=message["content"])
|
||||||
|
elif isinstance(message["content"], list):
|
||||||
|
for m in message["content"]:
|
||||||
|
if m.get("type", "") == "text" and len(m["text"]) > 0:
|
||||||
|
_system_content_block = SystemContentBlock(text=m["text"])
|
||||||
|
if _system_content_block is not None:
|
||||||
|
system_content_blocks.append(_system_content_block)
|
||||||
|
system_prompt_indices.append(idx)
|
||||||
|
if len(system_prompt_indices) > 0:
|
||||||
|
for idx in reversed(system_prompt_indices):
|
||||||
|
messages.pop(idx)
|
||||||
|
|
||||||
|
inference_params = copy.deepcopy(optional_params)
|
||||||
|
additional_request_keys = []
|
||||||
|
additional_request_params = {}
|
||||||
|
supported_converse_params = AmazonConverseConfig.__annotations__.keys()
|
||||||
|
supported_tool_call_params = ["tools", "tool_choice"]
|
||||||
|
supported_guardrail_params = ["guardrailConfig"]
|
||||||
|
json_mode: Optional[bool] = inference_params.pop(
|
||||||
|
"json_mode", None
|
||||||
|
) # used for handling json_schema
|
||||||
|
## TRANSFORMATION ##
|
||||||
|
|
||||||
|
bedrock_messages: List[MessageBlock] = _bedrock_converse_messages_pt(
|
||||||
|
messages=messages,
|
||||||
|
model=model,
|
||||||
|
llm_provider="bedrock_converse",
|
||||||
|
user_continue_message=litellm_params.pop("user_continue_message", None),
|
||||||
|
)
|
||||||
|
|
||||||
|
# send all model-specific params in 'additional_request_params'
|
||||||
|
for k, v in inference_params.items():
|
||||||
|
if (
|
||||||
|
k not in supported_converse_params
|
||||||
|
and k not in supported_tool_call_params
|
||||||
|
and k not in supported_guardrail_params
|
||||||
|
):
|
||||||
|
additional_request_params[k] = v
|
||||||
|
additional_request_keys.append(k)
|
||||||
|
for key in additional_request_keys:
|
||||||
|
inference_params.pop(key, None)
|
||||||
|
|
||||||
|
bedrock_tools: List[ToolBlock] = _bedrock_tools_pt(
|
||||||
|
inference_params.pop("tools", [])
|
||||||
|
)
|
||||||
|
bedrock_tool_config: Optional[ToolConfigBlock] = None
|
||||||
|
if len(bedrock_tools) > 0:
|
||||||
|
tool_choice_values: ToolChoiceValuesBlock = inference_params.pop(
|
||||||
|
"tool_choice", None
|
||||||
|
)
|
||||||
|
bedrock_tool_config = ToolConfigBlock(
|
||||||
|
tools=bedrock_tools,
|
||||||
|
)
|
||||||
|
if tool_choice_values is not None:
|
||||||
|
bedrock_tool_config["toolChoice"] = tool_choice_values
|
||||||
|
|
||||||
|
_data: RequestObject = {
|
||||||
|
"messages": bedrock_messages,
|
||||||
|
"additionalModelRequestFields": additional_request_params,
|
||||||
|
"system": system_content_blocks,
|
||||||
|
"inferenceConfig": InferenceConfig(**inference_params),
|
||||||
|
}
|
||||||
|
|
||||||
|
# Guardrail Config
|
||||||
|
guardrail_config: Optional[GuardrailConfigBlock] = None
|
||||||
|
request_guardrails_config = inference_params.pop("guardrailConfig", None)
|
||||||
|
if request_guardrails_config is not None:
|
||||||
|
guardrail_config = GuardrailConfigBlock(**request_guardrails_config)
|
||||||
|
_data["guardrailConfig"] = guardrail_config
|
||||||
|
|
||||||
|
# Tool Config
|
||||||
|
if bedrock_tool_config is not None:
|
||||||
|
_data["toolConfig"] = bedrock_tool_config
|
||||||
|
|
||||||
|
return _data
|
||||||
|
|
||||||
|
def _transform_response(
|
||||||
|
self,
|
||||||
|
model: str,
|
||||||
|
response: httpx.Response,
|
||||||
|
model_response: ModelResponse,
|
||||||
|
stream: bool,
|
||||||
|
logging_obj: Optional[Logging],
|
||||||
|
optional_params: dict,
|
||||||
|
api_key: str,
|
||||||
|
data: Union[dict, str],
|
||||||
|
messages: List,
|
||||||
|
print_verbose,
|
||||||
|
encoding,
|
||||||
|
) -> Union[ModelResponse, CustomStreamWrapper]:
|
||||||
|
|
||||||
|
## LOGGING
|
||||||
|
if logging_obj is not None:
|
||||||
|
logging_obj.post_call(
|
||||||
|
input=messages,
|
||||||
|
api_key=api_key,
|
||||||
|
original_response=response.text,
|
||||||
|
additional_args={"complete_input_dict": data},
|
||||||
|
)
|
||||||
|
print_verbose(f"raw model_response: {response.text}")
|
||||||
|
json_mode: Optional[bool] = optional_params.pop("json_mode", None)
|
||||||
|
## RESPONSE OBJECT
|
||||||
|
try:
|
||||||
|
completion_response = ConverseResponseBlock(**response.json()) # type: ignore
|
||||||
|
except Exception as e:
|
||||||
|
raise BedrockError(
|
||||||
|
message="Received={}, Error converting to valid response block={}. File an issue if litellm error - https://github.com/BerriAI/litellm/issues".format(
|
||||||
|
response.text, str(e)
|
||||||
|
),
|
||||||
|
status_code=422,
|
||||||
|
)
|
||||||
|
|
||||||
|
"""
|
||||||
|
Bedrock Response Object has optional message block
|
||||||
|
|
||||||
|
completion_response["output"].get("message", None)
|
||||||
|
|
||||||
|
A message block looks like this (Example 1):
|
||||||
|
"output": {
|
||||||
|
"message": {
|
||||||
|
"role": "assistant",
|
||||||
|
"content": [
|
||||||
|
{
|
||||||
|
"text": "Is there anything else you'd like to talk about? Perhaps I can help with some economic questions or provide some information about economic concepts?"
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
},
|
||||||
|
(Example 2):
|
||||||
|
"output": {
|
||||||
|
"message": {
|
||||||
|
"role": "assistant",
|
||||||
|
"content": [
|
||||||
|
{
|
||||||
|
"toolUse": {
|
||||||
|
"toolUseId": "tooluse_hbTgdi0CSLq_hM4P8csZJA",
|
||||||
|
"name": "top_song",
|
||||||
|
"input": {
|
||||||
|
"sign": "WZPZ"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
"""
|
||||||
|
message: Optional[MessageBlock] = completion_response["output"]["message"]
|
||||||
|
chat_completion_message: ChatCompletionResponseMessage = {"role": "assistant"}
|
||||||
|
content_str = ""
|
||||||
|
tools: List[ChatCompletionToolCallChunk] = []
|
||||||
|
if message is not None:
|
||||||
|
for idx, content in enumerate(message["content"]):
|
||||||
|
"""
|
||||||
|
- Content is either a tool response or text
|
||||||
|
"""
|
||||||
|
if "text" in content:
|
||||||
|
content_str += content["text"]
|
||||||
|
if "toolUse" in content:
|
||||||
|
|
||||||
|
## check tool name was formatted by litellm
|
||||||
|
_response_tool_name = content["toolUse"]["name"]
|
||||||
|
response_tool_name = get_bedrock_tool_name(
|
||||||
|
response_tool_name=_response_tool_name
|
||||||
|
)
|
||||||
|
_function_chunk = ChatCompletionToolCallFunctionChunk(
|
||||||
|
name=response_tool_name,
|
||||||
|
arguments=json.dumps(content["toolUse"]["input"]),
|
||||||
|
)
|
||||||
|
|
||||||
|
_tool_response_chunk = ChatCompletionToolCallChunk(
|
||||||
|
id=content["toolUse"]["toolUseId"],
|
||||||
|
type="function",
|
||||||
|
function=_function_chunk,
|
||||||
|
index=idx,
|
||||||
|
)
|
||||||
|
tools.append(_tool_response_chunk)
|
||||||
|
chat_completion_message["content"] = content_str
|
||||||
|
|
||||||
|
if json_mode is True and tools is not None and len(tools) == 1:
|
||||||
|
# to support 'json_schema' logic on bedrock models
|
||||||
|
json_mode_content_str: Optional[str] = tools[0]["function"].get("arguments")
|
||||||
|
if json_mode_content_str is not None:
|
||||||
|
chat_completion_message["content"] = json_mode_content_str
|
||||||
|
else:
|
||||||
|
chat_completion_message["tool_calls"] = tools
|
||||||
|
|
||||||
|
## CALCULATING USAGE - bedrock returns usage in the headers
|
||||||
|
input_tokens = completion_response["usage"]["inputTokens"]
|
||||||
|
output_tokens = completion_response["usage"]["outputTokens"]
|
||||||
|
total_tokens = completion_response["usage"]["totalTokens"]
|
||||||
|
|
||||||
|
model_response.choices = [
|
||||||
|
litellm.Choices(
|
||||||
|
finish_reason=map_finish_reason(completion_response["stopReason"]),
|
||||||
|
index=0,
|
||||||
|
message=litellm.Message(**chat_completion_message),
|
||||||
|
)
|
||||||
|
]
|
||||||
|
model_response.created = int(time.time())
|
||||||
|
model_response.model = model
|
||||||
|
usage = Usage(
|
||||||
|
prompt_tokens=input_tokens,
|
||||||
|
completion_tokens=output_tokens,
|
||||||
|
total_tokens=total_tokens,
|
||||||
|
)
|
||||||
|
setattr(model_response, "usage", usage)
|
||||||
|
|
||||||
|
# Add "trace" from Bedrock guardrails - if user has opted in to returning it
|
||||||
|
if "trace" in completion_response:
|
||||||
|
setattr(model_response, "trace", completion_response["trace"])
|
||||||
|
|
||||||
|
return model_response
|
|
@ -52,8 +52,8 @@ from litellm.types.llms.openai import (
|
||||||
from litellm.types.utils import GenericStreamingChunk as GChunk
|
from litellm.types.utils import GenericStreamingChunk as GChunk
|
||||||
from litellm.utils import CustomStreamWrapper, ModelResponse, Usage, get_secret
|
from litellm.utils import CustomStreamWrapper, ModelResponse, Usage, get_secret
|
||||||
|
|
||||||
from ..base_aws_llm import BaseAWSLLM
|
from ...base_aws_llm import BaseAWSLLM
|
||||||
from ..prompt_templates.factory import (
|
from ...prompt_templates.factory import (
|
||||||
_bedrock_converse_messages_pt,
|
_bedrock_converse_messages_pt,
|
||||||
_bedrock_tools_pt,
|
_bedrock_tools_pt,
|
||||||
cohere_message_pt,
|
cohere_message_pt,
|
||||||
|
@ -64,7 +64,8 @@ from ..prompt_templates.factory import (
|
||||||
parse_xml_params,
|
parse_xml_params,
|
||||||
prompt_factory,
|
prompt_factory,
|
||||||
)
|
)
|
||||||
from .common_utils import BedrockError, ModelResponseIterator, get_runtime_endpoint
|
from ..common_utils import BedrockError, ModelResponseIterator, get_bedrock_tool_name
|
||||||
|
from .converse_transformation import AmazonConverseConfig
|
||||||
|
|
||||||
BEDROCK_CONVERSE_MODELS = [
|
BEDROCK_CONVERSE_MODELS = [
|
||||||
"anthropic.claude-3-5-sonnet-20240620-v1:0",
|
"anthropic.claude-3-5-sonnet-20240620-v1:0",
|
||||||
|
@ -225,10 +226,9 @@ async def make_call(
|
||||||
raise BedrockError(status_code=response.status_code, message=response.text)
|
raise BedrockError(status_code=response.status_code, message=response.text)
|
||||||
|
|
||||||
if "ai21" in api_base:
|
if "ai21" in api_base:
|
||||||
aws_bedrock_process_response = BedrockConverseLLM()
|
|
||||||
model_response: (
|
model_response: (
|
||||||
ModelResponse
|
ModelResponse
|
||||||
) = aws_bedrock_process_response.process_response(
|
) = litellm.AmazonConverseConfig()._transform_response(
|
||||||
model=model,
|
model=model,
|
||||||
response=response,
|
response=response,
|
||||||
model_response=litellm.ModelResponse(),
|
model_response=litellm.ModelResponse(),
|
||||||
|
@ -266,59 +266,6 @@ async def make_call(
|
||||||
raise BedrockError(status_code=500, message=str(e))
|
raise BedrockError(status_code=500, message=str(e))
|
||||||
|
|
||||||
|
|
||||||
def make_sync_call(
|
|
||||||
client: Optional[HTTPHandler],
|
|
||||||
api_base: str,
|
|
||||||
headers: dict,
|
|
||||||
data: str,
|
|
||||||
model: str,
|
|
||||||
messages: list,
|
|
||||||
logging_obj,
|
|
||||||
):
|
|
||||||
if client is None:
|
|
||||||
client = _get_httpx_client() # Create a new client if none provided
|
|
||||||
|
|
||||||
response = client.post(
|
|
||||||
api_base,
|
|
||||||
headers=headers,
|
|
||||||
data=data,
|
|
||||||
stream=True if "ai21" not in api_base else False,
|
|
||||||
)
|
|
||||||
|
|
||||||
if response.status_code != 200:
|
|
||||||
raise BedrockError(status_code=response.status_code, message=response.read())
|
|
||||||
|
|
||||||
if "ai21" in api_base:
|
|
||||||
aws_bedrock_process_response = BedrockConverseLLM()
|
|
||||||
model_response: ModelResponse = aws_bedrock_process_response.process_response(
|
|
||||||
model=model,
|
|
||||||
response=response,
|
|
||||||
model_response=litellm.ModelResponse(),
|
|
||||||
stream=True,
|
|
||||||
logging_obj=logging_obj,
|
|
||||||
optional_params={},
|
|
||||||
api_key="",
|
|
||||||
data=data,
|
|
||||||
messages=messages,
|
|
||||||
print_verbose=litellm.print_verbose,
|
|
||||||
encoding=litellm.encoding,
|
|
||||||
) # type: ignore
|
|
||||||
completion_stream: Any = MockResponseIterator(model_response=model_response)
|
|
||||||
else:
|
|
||||||
decoder = AWSEventStreamDecoder(model=model)
|
|
||||||
completion_stream = decoder.iter_bytes(response.iter_bytes(chunk_size=1024))
|
|
||||||
|
|
||||||
# LOGGING
|
|
||||||
logging_obj.post_call(
|
|
||||||
input=messages,
|
|
||||||
api_key="",
|
|
||||||
original_response="first stream response received",
|
|
||||||
additional_args={"complete_input_dict": data},
|
|
||||||
)
|
|
||||||
|
|
||||||
return completion_stream
|
|
||||||
|
|
||||||
|
|
||||||
class BedrockLLM(BaseAWSLLM):
|
class BedrockLLM(BaseAWSLLM):
|
||||||
"""
|
"""
|
||||||
Example call
|
Example call
|
||||||
|
@ -417,6 +364,7 @@ class BedrockLLM(BaseAWSLLM):
|
||||||
except:
|
except:
|
||||||
raise BedrockError(message=response.text, status_code=422)
|
raise BedrockError(message=response.text, status_code=422)
|
||||||
|
|
||||||
|
outputText: Optional[str] = None
|
||||||
try:
|
try:
|
||||||
if provider == "cohere":
|
if provider == "cohere":
|
||||||
if "text" in completion_response:
|
if "text" in completion_response:
|
||||||
|
@ -566,23 +514,27 @@ class BedrockLLM(BaseAWSLLM):
|
||||||
|
|
||||||
try:
|
try:
|
||||||
if (
|
if (
|
||||||
len(outputText) > 0
|
outputText is not None
|
||||||
|
and len(outputText) > 0
|
||||||
and hasattr(model_response.choices[0], "message")
|
and hasattr(model_response.choices[0], "message")
|
||||||
and getattr(model_response.choices[0].message, "tool_calls", None)
|
and getattr(model_response.choices[0].message, "tool_calls", None) # type: ignore
|
||||||
is None
|
is None
|
||||||
):
|
):
|
||||||
model_response.choices[0].message.content = outputText
|
model_response.choices[0].message.content = outputText # type: ignore
|
||||||
elif (
|
elif (
|
||||||
hasattr(model_response.choices[0], "message")
|
hasattr(model_response.choices[0], "message")
|
||||||
and getattr(model_response.choices[0].message, "tool_calls", None)
|
and getattr(model_response.choices[0].message, "tool_calls", None) # type: ignore
|
||||||
is not None
|
is not None
|
||||||
):
|
):
|
||||||
pass
|
pass
|
||||||
else:
|
else:
|
||||||
raise Exception()
|
raise Exception()
|
||||||
except:
|
except Exception as e:
|
||||||
raise BedrockError(
|
raise BedrockError(
|
||||||
message=json.dumps(outputText), status_code=response.status_code
|
message="Error parsing received text={}.\nError-{}".format(
|
||||||
|
outputText, str(e)
|
||||||
|
),
|
||||||
|
status_code=response.status_code,
|
||||||
)
|
)
|
||||||
|
|
||||||
if stream and provider == "ai21":
|
if stream and provider == "ai21":
|
||||||
|
@ -594,8 +546,8 @@ class BedrockLLM(BaseAWSLLM):
|
||||||
streaming_choice = litellm.utils.StreamingChoices()
|
streaming_choice = litellm.utils.StreamingChoices()
|
||||||
streaming_choice.index = model_response.choices[0].index
|
streaming_choice.index = model_response.choices[0].index
|
||||||
delta_obj = litellm.utils.Delta(
|
delta_obj = litellm.utils.Delta(
|
||||||
content=getattr(model_response.choices[0].message, "content", None),
|
content=getattr(model_response.choices[0].message, "content", None), # type: ignore
|
||||||
role=model_response.choices[0].message.role,
|
role=model_response.choices[0].message.role, # type: ignore
|
||||||
)
|
)
|
||||||
streaming_choice.delta = delta_obj
|
streaming_choice.delta = delta_obj
|
||||||
streaming_model_response.choices = [streaming_choice]
|
streaming_model_response.choices = [streaming_choice]
|
||||||
|
@ -731,7 +683,7 @@ class BedrockLLM(BaseAWSLLM):
|
||||||
)
|
)
|
||||||
|
|
||||||
### SET RUNTIME ENDPOINT ###
|
### SET RUNTIME ENDPOINT ###
|
||||||
endpoint_url, proxy_endpoint_url = get_runtime_endpoint(
|
endpoint_url, proxy_endpoint_url = self.get_runtime_endpoint(
|
||||||
api_base=api_base,
|
api_base=api_base,
|
||||||
aws_bedrock_runtime_endpoint=aws_bedrock_runtime_endpoint,
|
aws_bedrock_runtime_endpoint=aws_bedrock_runtime_endpoint,
|
||||||
aws_region_name=aws_region_name,
|
aws_region_name=aws_region_name,
|
||||||
|
@ -1002,7 +954,7 @@ class BedrockLLM(BaseAWSLLM):
|
||||||
response.raise_for_status()
|
response.raise_for_status()
|
||||||
except httpx.HTTPStatusError as err:
|
except httpx.HTTPStatusError as err:
|
||||||
error_code = err.response.status_code
|
error_code = err.response.status_code
|
||||||
raise BedrockError(status_code=error_code, message=response.text)
|
raise BedrockError(status_code=error_code, message=err.response.text)
|
||||||
except httpx.TimeoutException as e:
|
except httpx.TimeoutException as e:
|
||||||
raise BedrockError(status_code=408, message="Timeout error occurred.")
|
raise BedrockError(status_code=408, message="Timeout error occurred.")
|
||||||
|
|
||||||
|
@ -1113,725 +1065,6 @@ class BedrockLLM(BaseAWSLLM):
|
||||||
return super().embedding(*args, **kwargs)
|
return super().embedding(*args, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
class AmazonConverseConfig:
|
|
||||||
"""
|
|
||||||
Reference - https://docs.aws.amazon.com/bedrock/latest/APIReference/API_runtime_Converse.html
|
|
||||||
#2 - https://docs.aws.amazon.com/bedrock/latest/userguide/conversation-inference.html#conversation-inference-supported-models-features
|
|
||||||
"""
|
|
||||||
|
|
||||||
maxTokens: Optional[int]
|
|
||||||
stopSequences: Optional[List[str]]
|
|
||||||
temperature: Optional[int]
|
|
||||||
topP: Optional[int]
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
maxTokens: Optional[int] = None,
|
|
||||||
stopSequences: Optional[List[str]] = None,
|
|
||||||
temperature: Optional[int] = None,
|
|
||||||
topP: Optional[int] = None,
|
|
||||||
) -> None:
|
|
||||||
locals_ = locals()
|
|
||||||
for key, value in locals_.items():
|
|
||||||
if key != "self" and value is not None:
|
|
||||||
setattr(self.__class__, key, value)
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def get_config(cls):
|
|
||||||
return {
|
|
||||||
k: v
|
|
||||||
for k, v in cls.__dict__.items()
|
|
||||||
if not k.startswith("__")
|
|
||||||
and not isinstance(
|
|
||||||
v,
|
|
||||||
(
|
|
||||||
types.FunctionType,
|
|
||||||
types.BuiltinFunctionType,
|
|
||||||
classmethod,
|
|
||||||
staticmethod,
|
|
||||||
),
|
|
||||||
)
|
|
||||||
and v is not None
|
|
||||||
}
|
|
||||||
|
|
||||||
def get_supported_openai_params(self, model: str) -> List[str]:
|
|
||||||
supported_params = [
|
|
||||||
"max_tokens",
|
|
||||||
"max_completion_tokens",
|
|
||||||
"stream",
|
|
||||||
"stream_options",
|
|
||||||
"stop",
|
|
||||||
"temperature",
|
|
||||||
"top_p",
|
|
||||||
"extra_headers",
|
|
||||||
"response_format",
|
|
||||||
]
|
|
||||||
|
|
||||||
if (
|
|
||||||
model.startswith("anthropic")
|
|
||||||
or model.startswith("mistral")
|
|
||||||
or model.startswith("cohere")
|
|
||||||
or model.startswith("meta.llama3-1")
|
|
||||||
):
|
|
||||||
supported_params.append("tools")
|
|
||||||
|
|
||||||
if model.startswith("anthropic") or model.startswith("mistral"):
|
|
||||||
# only anthropic and mistral support tool choice config. otherwise (E.g. cohere) will fail the call - https://docs.aws.amazon.com/bedrock/latest/APIReference/API_runtime_ToolChoice.html
|
|
||||||
supported_params.append("tool_choice")
|
|
||||||
|
|
||||||
return supported_params
|
|
||||||
|
|
||||||
def map_tool_choice_values(
|
|
||||||
self, model: str, tool_choice: Union[str, dict], drop_params: bool
|
|
||||||
) -> Optional[ToolChoiceValuesBlock]:
|
|
||||||
if tool_choice == "none":
|
|
||||||
if litellm.drop_params is True or drop_params is True:
|
|
||||||
return None
|
|
||||||
else:
|
|
||||||
raise litellm.utils.UnsupportedParamsError(
|
|
||||||
message="Bedrock doesn't support tool_choice={}. To drop it from the call, set `litellm.drop_params = True.".format(
|
|
||||||
tool_choice
|
|
||||||
),
|
|
||||||
status_code=400,
|
|
||||||
)
|
|
||||||
elif tool_choice == "required":
|
|
||||||
return ToolChoiceValuesBlock(any={})
|
|
||||||
elif tool_choice == "auto":
|
|
||||||
return ToolChoiceValuesBlock(auto={})
|
|
||||||
elif isinstance(tool_choice, dict):
|
|
||||||
# only supported for anthropic + mistral models - https://docs.aws.amazon.com/bedrock/latest/APIReference/API_runtime_ToolChoice.html
|
|
||||||
specific_tool = SpecificToolChoiceBlock(
|
|
||||||
name=tool_choice.get("function", {}).get("name", "")
|
|
||||||
)
|
|
||||||
return ToolChoiceValuesBlock(tool=specific_tool)
|
|
||||||
else:
|
|
||||||
raise litellm.utils.UnsupportedParamsError(
|
|
||||||
message="Bedrock doesn't support tool_choice={}. Supported tool_choice values=['auto', 'required', json object]. To drop it from the call, set `litellm.drop_params = True.".format(
|
|
||||||
tool_choice
|
|
||||||
),
|
|
||||||
status_code=400,
|
|
||||||
)
|
|
||||||
|
|
||||||
def get_supported_image_types(self) -> List[str]:
|
|
||||||
return ["png", "jpeg", "gif", "webp"]
|
|
||||||
|
|
||||||
def map_openai_params(
|
|
||||||
self,
|
|
||||||
model: str,
|
|
||||||
non_default_params: dict,
|
|
||||||
optional_params: dict,
|
|
||||||
drop_params: bool,
|
|
||||||
) -> dict:
|
|
||||||
for param, value in non_default_params.items():
|
|
||||||
if param == "response_format":
|
|
||||||
json_schema: Optional[dict] = None
|
|
||||||
schema_name: str = ""
|
|
||||||
if "response_schema" in value:
|
|
||||||
json_schema = value["response_schema"]
|
|
||||||
schema_name = "json_tool_call"
|
|
||||||
elif "json_schema" in value:
|
|
||||||
json_schema = value["json_schema"]["schema"]
|
|
||||||
schema_name = value["json_schema"]["name"]
|
|
||||||
"""
|
|
||||||
Follow similar approach to anthropic - translate to a single tool call.
|
|
||||||
|
|
||||||
When using tools in this way: - https://docs.anthropic.com/en/docs/build-with-claude/tool-use#json-mode
|
|
||||||
- You usually want to provide a single tool
|
|
||||||
- You should set tool_choice (see Forcing tool use) to instruct the model to explicitly use that tool
|
|
||||||
- Remember that the model will pass the input to the tool, so the name of the tool and description should be from the model’s perspective.
|
|
||||||
"""
|
|
||||||
if json_schema is not None:
|
|
||||||
_tool_choice = self.map_tool_choice_values(
|
|
||||||
model=model, tool_choice="required", drop_params=drop_params # type: ignore
|
|
||||||
)
|
|
||||||
|
|
||||||
_tool = ChatCompletionToolParam(
|
|
||||||
type="function",
|
|
||||||
function=ChatCompletionToolParamFunctionChunk(
|
|
||||||
name=schema_name, parameters=json_schema
|
|
||||||
),
|
|
||||||
)
|
|
||||||
|
|
||||||
optional_params["tools"] = [_tool]
|
|
||||||
optional_params["tool_choice"] = _tool_choice
|
|
||||||
optional_params["json_mode"] = True
|
|
||||||
else:
|
|
||||||
if litellm.drop_params is True or drop_params is True:
|
|
||||||
pass
|
|
||||||
else:
|
|
||||||
raise litellm.utils.UnsupportedParamsError(
|
|
||||||
message="Bedrock doesn't support response_format={}. To drop it from the call, set `litellm.drop_params = True.".format(
|
|
||||||
value
|
|
||||||
),
|
|
||||||
status_code=400,
|
|
||||||
)
|
|
||||||
if param == "max_tokens" or param == "max_completion_tokens":
|
|
||||||
optional_params["maxTokens"] = value
|
|
||||||
if param == "stream":
|
|
||||||
optional_params["stream"] = value
|
|
||||||
if param == "stop":
|
|
||||||
if isinstance(value, str):
|
|
||||||
if len(value) == 0: # converse raises error for empty strings
|
|
||||||
continue
|
|
||||||
value = [value]
|
|
||||||
optional_params["stopSequences"] = value
|
|
||||||
if param == "temperature":
|
|
||||||
optional_params["temperature"] = value
|
|
||||||
if param == "top_p":
|
|
||||||
optional_params["topP"] = value
|
|
||||||
if param == "tools":
|
|
||||||
optional_params["tools"] = value
|
|
||||||
if param == "tool_choice":
|
|
||||||
_tool_choice_value = self.map_tool_choice_values(
|
|
||||||
model=model, tool_choice=value, drop_params=drop_params # type: ignore
|
|
||||||
)
|
|
||||||
if _tool_choice_value is not None:
|
|
||||||
optional_params["tool_choice"] = _tool_choice_value
|
|
||||||
return optional_params
|
|
||||||
|
|
||||||
|
|
||||||
class BedrockConverseLLM(BaseAWSLLM):
|
|
||||||
def __init__(self) -> None:
|
|
||||||
super().__init__()
|
|
||||||
|
|
||||||
def process_response(
|
|
||||||
self,
|
|
||||||
model: str,
|
|
||||||
response: Union[requests.Response, httpx.Response],
|
|
||||||
model_response: ModelResponse,
|
|
||||||
stream: bool,
|
|
||||||
logging_obj: Optional[Logging],
|
|
||||||
optional_params: dict,
|
|
||||||
api_key: str,
|
|
||||||
data: Union[dict, str],
|
|
||||||
messages: List,
|
|
||||||
print_verbose,
|
|
||||||
encoding,
|
|
||||||
) -> Union[ModelResponse, CustomStreamWrapper]:
|
|
||||||
|
|
||||||
## LOGGING
|
|
||||||
if logging_obj is not None:
|
|
||||||
logging_obj.post_call(
|
|
||||||
input=messages,
|
|
||||||
api_key=api_key,
|
|
||||||
original_response=response.text,
|
|
||||||
additional_args={"complete_input_dict": data},
|
|
||||||
)
|
|
||||||
print_verbose(f"raw model_response: {response.text}")
|
|
||||||
json_mode: Optional[bool] = optional_params.pop("json_mode", None)
|
|
||||||
## RESPONSE OBJECT
|
|
||||||
try:
|
|
||||||
completion_response = ConverseResponseBlock(**response.json()) # type: ignore
|
|
||||||
except Exception as e:
|
|
||||||
raise BedrockError(
|
|
||||||
message="Received={}, Error converting to valid response block={}. File an issue if litellm error - https://github.com/BerriAI/litellm/issues".format(
|
|
||||||
response.text, str(e)
|
|
||||||
),
|
|
||||||
status_code=422,
|
|
||||||
)
|
|
||||||
|
|
||||||
"""
|
|
||||||
Bedrock Response Object has optional message block
|
|
||||||
|
|
||||||
completion_response["output"].get("message", None)
|
|
||||||
|
|
||||||
A message block looks like this (Example 1):
|
|
||||||
"output": {
|
|
||||||
"message": {
|
|
||||||
"role": "assistant",
|
|
||||||
"content": [
|
|
||||||
{
|
|
||||||
"text": "Is there anything else you'd like to talk about? Perhaps I can help with some economic questions or provide some information about economic concepts?"
|
|
||||||
}
|
|
||||||
]
|
|
||||||
}
|
|
||||||
},
|
|
||||||
(Example 2):
|
|
||||||
"output": {
|
|
||||||
"message": {
|
|
||||||
"role": "assistant",
|
|
||||||
"content": [
|
|
||||||
{
|
|
||||||
"toolUse": {
|
|
||||||
"toolUseId": "tooluse_hbTgdi0CSLq_hM4P8csZJA",
|
|
||||||
"name": "top_song",
|
|
||||||
"input": {
|
|
||||||
"sign": "WZPZ"
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
]
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
"""
|
|
||||||
message: Optional[MessageBlock] = completion_response["output"]["message"]
|
|
||||||
chat_completion_message: ChatCompletionResponseMessage = {"role": "assistant"}
|
|
||||||
content_str = ""
|
|
||||||
tools: List[ChatCompletionToolCallChunk] = []
|
|
||||||
if message is not None:
|
|
||||||
for idx, content in enumerate(message["content"]):
|
|
||||||
"""
|
|
||||||
- Content is either a tool response or text
|
|
||||||
"""
|
|
||||||
if "text" in content:
|
|
||||||
content_str += content["text"]
|
|
||||||
if "toolUse" in content:
|
|
||||||
|
|
||||||
## check tool name was formatted by litellm
|
|
||||||
_response_tool_name = content["toolUse"]["name"]
|
|
||||||
response_tool_name = get_bedrock_tool_name(
|
|
||||||
response_tool_name=_response_tool_name
|
|
||||||
)
|
|
||||||
_function_chunk = ChatCompletionToolCallFunctionChunk(
|
|
||||||
name=response_tool_name,
|
|
||||||
arguments=json.dumps(content["toolUse"]["input"]),
|
|
||||||
)
|
|
||||||
|
|
||||||
_tool_response_chunk = ChatCompletionToolCallChunk(
|
|
||||||
id=content["toolUse"]["toolUseId"],
|
|
||||||
type="function",
|
|
||||||
function=_function_chunk,
|
|
||||||
index=idx,
|
|
||||||
)
|
|
||||||
tools.append(_tool_response_chunk)
|
|
||||||
chat_completion_message["content"] = content_str
|
|
||||||
|
|
||||||
if json_mode is True and tools is not None and len(tools) == 1:
|
|
||||||
# to support 'json_schema' logic on bedrock models
|
|
||||||
json_mode_content_str: Optional[str] = tools[0]["function"].get("arguments")
|
|
||||||
if json_mode_content_str is not None:
|
|
||||||
chat_completion_message["content"] = json_mode_content_str
|
|
||||||
else:
|
|
||||||
chat_completion_message["tool_calls"] = tools
|
|
||||||
|
|
||||||
## CALCULATING USAGE - bedrock returns usage in the headers
|
|
||||||
input_tokens = completion_response["usage"]["inputTokens"]
|
|
||||||
output_tokens = completion_response["usage"]["outputTokens"]
|
|
||||||
total_tokens = completion_response["usage"]["totalTokens"]
|
|
||||||
|
|
||||||
model_response.choices = [
|
|
||||||
litellm.Choices(
|
|
||||||
finish_reason=map_finish_reason(completion_response["stopReason"]),
|
|
||||||
index=0,
|
|
||||||
message=litellm.Message(**chat_completion_message),
|
|
||||||
)
|
|
||||||
]
|
|
||||||
model_response.created = int(time.time())
|
|
||||||
model_response.model = model
|
|
||||||
usage = Usage(
|
|
||||||
prompt_tokens=input_tokens,
|
|
||||||
completion_tokens=output_tokens,
|
|
||||||
total_tokens=total_tokens,
|
|
||||||
)
|
|
||||||
setattr(model_response, "usage", usage)
|
|
||||||
|
|
||||||
# Add "trace" from Bedrock guardrails - if user has opted in to returning it
|
|
||||||
if "trace" in completion_response:
|
|
||||||
setattr(model_response, "trace", completion_response["trace"])
|
|
||||||
|
|
||||||
return model_response
|
|
||||||
|
|
||||||
def encode_model_id(self, model_id: str) -> str:
|
|
||||||
"""
|
|
||||||
Double encode the model ID to ensure it matches the expected double-encoded format.
|
|
||||||
Args:
|
|
||||||
model_id (str): The model ID to encode.
|
|
||||||
Returns:
|
|
||||||
str: The double-encoded model ID.
|
|
||||||
"""
|
|
||||||
return urllib.parse.quote(model_id, safe="")
|
|
||||||
|
|
||||||
async def async_streaming(
|
|
||||||
self,
|
|
||||||
model: str,
|
|
||||||
messages: list,
|
|
||||||
api_base: str,
|
|
||||||
model_response: ModelResponse,
|
|
||||||
print_verbose: Callable,
|
|
||||||
data: str,
|
|
||||||
timeout: Optional[Union[float, httpx.Timeout]],
|
|
||||||
encoding,
|
|
||||||
logging_obj,
|
|
||||||
stream,
|
|
||||||
optional_params: dict,
|
|
||||||
litellm_params=None,
|
|
||||||
logger_fn=None,
|
|
||||||
headers={},
|
|
||||||
client: Optional[AsyncHTTPHandler] = None,
|
|
||||||
) -> CustomStreamWrapper:
|
|
||||||
streaming_response = CustomStreamWrapper(
|
|
||||||
completion_stream=None,
|
|
||||||
make_call=partial(
|
|
||||||
make_call,
|
|
||||||
client=client,
|
|
||||||
api_base=api_base,
|
|
||||||
headers=headers,
|
|
||||||
data=data,
|
|
||||||
model=model,
|
|
||||||
messages=messages,
|
|
||||||
logging_obj=logging_obj,
|
|
||||||
),
|
|
||||||
model=model,
|
|
||||||
custom_llm_provider="bedrock",
|
|
||||||
logging_obj=logging_obj,
|
|
||||||
)
|
|
||||||
return streaming_response
|
|
||||||
|
|
||||||
async def async_completion(
|
|
||||||
self,
|
|
||||||
model: str,
|
|
||||||
messages: list,
|
|
||||||
api_base: str,
|
|
||||||
model_response: ModelResponse,
|
|
||||||
print_verbose: Callable,
|
|
||||||
data: str,
|
|
||||||
timeout: Optional[Union[float, httpx.Timeout]],
|
|
||||||
encoding,
|
|
||||||
logging_obj,
|
|
||||||
stream,
|
|
||||||
optional_params: dict,
|
|
||||||
litellm_params=None,
|
|
||||||
logger_fn=None,
|
|
||||||
headers={},
|
|
||||||
client: Optional[AsyncHTTPHandler] = None,
|
|
||||||
) -> Union[ModelResponse, CustomStreamWrapper]:
|
|
||||||
if client is None or not isinstance(client, AsyncHTTPHandler):
|
|
||||||
_params = {}
|
|
||||||
if timeout is not None:
|
|
||||||
if isinstance(timeout, float) or isinstance(timeout, int):
|
|
||||||
timeout = httpx.Timeout(timeout)
|
|
||||||
_params["timeout"] = timeout
|
|
||||||
client = get_async_httpx_client(
|
|
||||||
params=_params, llm_provider=litellm.LlmProviders.BEDROCK
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
client = client # type: ignore
|
|
||||||
|
|
||||||
try:
|
|
||||||
response = await client.post(url=api_base, headers=headers, data=data) # type: ignore
|
|
||||||
response.raise_for_status()
|
|
||||||
except httpx.HTTPStatusError as err:
|
|
||||||
error_code = err.response.status_code
|
|
||||||
raise BedrockError(status_code=error_code, message=err.response.text)
|
|
||||||
except httpx.TimeoutException as e:
|
|
||||||
raise BedrockError(status_code=408, message="Timeout error occurred.")
|
|
||||||
|
|
||||||
return self.process_response(
|
|
||||||
model=model,
|
|
||||||
response=response,
|
|
||||||
model_response=model_response,
|
|
||||||
stream=stream if isinstance(stream, bool) else False,
|
|
||||||
logging_obj=logging_obj,
|
|
||||||
api_key="",
|
|
||||||
data=data,
|
|
||||||
messages=messages,
|
|
||||||
print_verbose=print_verbose,
|
|
||||||
optional_params=optional_params,
|
|
||||||
encoding=encoding,
|
|
||||||
)
|
|
||||||
|
|
||||||
def completion(
|
|
||||||
self,
|
|
||||||
model: str,
|
|
||||||
messages: list,
|
|
||||||
api_base: Optional[str],
|
|
||||||
custom_prompt_dict: dict,
|
|
||||||
model_response: ModelResponse,
|
|
||||||
print_verbose: Callable,
|
|
||||||
encoding,
|
|
||||||
logging_obj,
|
|
||||||
optional_params: dict,
|
|
||||||
acompletion: bool,
|
|
||||||
timeout: Optional[Union[float, httpx.Timeout]],
|
|
||||||
litellm_params: dict,
|
|
||||||
logger_fn=None,
|
|
||||||
extra_headers: Optional[dict] = None,
|
|
||||||
client: Optional[Union[AsyncHTTPHandler, HTTPHandler]] = None,
|
|
||||||
):
|
|
||||||
try:
|
|
||||||
import boto3
|
|
||||||
from botocore.auth import SigV4Auth
|
|
||||||
from botocore.awsrequest import AWSRequest
|
|
||||||
from botocore.credentials import Credentials
|
|
||||||
except ImportError:
|
|
||||||
raise ImportError("Missing boto3 to call bedrock. Run 'pip install boto3'.")
|
|
||||||
|
|
||||||
## SETUP ##
|
|
||||||
stream = optional_params.pop("stream", None)
|
|
||||||
modelId = optional_params.pop("model_id", None)
|
|
||||||
if modelId is not None:
|
|
||||||
modelId = self.encode_model_id(model_id=modelId)
|
|
||||||
else:
|
|
||||||
modelId = model
|
|
||||||
|
|
||||||
provider = model.split(".")[0]
|
|
||||||
|
|
||||||
## CREDENTIALS ##
|
|
||||||
# pop aws_secret_access_key, aws_access_key_id, aws_region_name from kwargs, since completion calls fail with them
|
|
||||||
aws_secret_access_key = optional_params.pop("aws_secret_access_key", None)
|
|
||||||
aws_access_key_id = optional_params.pop("aws_access_key_id", None)
|
|
||||||
aws_session_token = optional_params.pop("aws_session_token", None)
|
|
||||||
aws_region_name = optional_params.pop("aws_region_name", None)
|
|
||||||
aws_role_name = optional_params.pop("aws_role_name", None)
|
|
||||||
aws_session_name = optional_params.pop("aws_session_name", None)
|
|
||||||
aws_profile_name = optional_params.pop("aws_profile_name", None)
|
|
||||||
aws_bedrock_runtime_endpoint = optional_params.pop(
|
|
||||||
"aws_bedrock_runtime_endpoint", None
|
|
||||||
) # https://bedrock-runtime.{region_name}.amazonaws.com
|
|
||||||
aws_web_identity_token = optional_params.pop("aws_web_identity_token", None)
|
|
||||||
aws_sts_endpoint = optional_params.pop("aws_sts_endpoint", None)
|
|
||||||
|
|
||||||
### SET REGION NAME ###
|
|
||||||
if aws_region_name is None:
|
|
||||||
# check env #
|
|
||||||
litellm_aws_region_name = get_secret("AWS_REGION_NAME", None)
|
|
||||||
|
|
||||||
if litellm_aws_region_name is not None and isinstance(
|
|
||||||
litellm_aws_region_name, str
|
|
||||||
):
|
|
||||||
aws_region_name = litellm_aws_region_name
|
|
||||||
|
|
||||||
standard_aws_region_name = get_secret("AWS_REGION", None)
|
|
||||||
if standard_aws_region_name is not None and isinstance(
|
|
||||||
standard_aws_region_name, str
|
|
||||||
):
|
|
||||||
aws_region_name = standard_aws_region_name
|
|
||||||
|
|
||||||
if aws_region_name is None:
|
|
||||||
aws_region_name = "us-west-2"
|
|
||||||
|
|
||||||
credentials: Credentials = self.get_credentials(
|
|
||||||
aws_access_key_id=aws_access_key_id,
|
|
||||||
aws_secret_access_key=aws_secret_access_key,
|
|
||||||
aws_session_token=aws_session_token,
|
|
||||||
aws_region_name=aws_region_name,
|
|
||||||
aws_session_name=aws_session_name,
|
|
||||||
aws_profile_name=aws_profile_name,
|
|
||||||
aws_role_name=aws_role_name,
|
|
||||||
aws_web_identity_token=aws_web_identity_token,
|
|
||||||
aws_sts_endpoint=aws_sts_endpoint,
|
|
||||||
)
|
|
||||||
|
|
||||||
### SET RUNTIME ENDPOINT ###
|
|
||||||
endpoint_url, proxy_endpoint_url = get_runtime_endpoint(
|
|
||||||
api_base=api_base,
|
|
||||||
aws_bedrock_runtime_endpoint=aws_bedrock_runtime_endpoint,
|
|
||||||
aws_region_name=aws_region_name,
|
|
||||||
)
|
|
||||||
if (stream is not None and stream is True) and provider != "ai21":
|
|
||||||
endpoint_url = f"{endpoint_url}/model/{modelId}/converse-stream"
|
|
||||||
proxy_endpoint_url = f"{proxy_endpoint_url}/model/{modelId}/converse-stream"
|
|
||||||
else:
|
|
||||||
endpoint_url = f"{endpoint_url}/model/{modelId}/converse"
|
|
||||||
proxy_endpoint_url = f"{proxy_endpoint_url}/model/{modelId}/converse"
|
|
||||||
|
|
||||||
sigv4 = SigV4Auth(credentials, "bedrock", aws_region_name)
|
|
||||||
|
|
||||||
# Separate system prompt from rest of message
|
|
||||||
system_prompt_indices = []
|
|
||||||
system_content_blocks: List[SystemContentBlock] = []
|
|
||||||
for idx, message in enumerate(messages):
|
|
||||||
if message["role"] == "system":
|
|
||||||
_system_content_block: Optional[SystemContentBlock] = None
|
|
||||||
if isinstance(message["content"], str) and len(message["content"]) > 0:
|
|
||||||
_system_content_block = SystemContentBlock(text=message["content"])
|
|
||||||
elif isinstance(message["content"], list):
|
|
||||||
for m in message["content"]:
|
|
||||||
if m.get("type", "") == "text" and len(m["text"]) > 0:
|
|
||||||
_system_content_block = SystemContentBlock(text=m["text"])
|
|
||||||
if _system_content_block is not None:
|
|
||||||
system_content_blocks.append(_system_content_block)
|
|
||||||
system_prompt_indices.append(idx)
|
|
||||||
if len(system_prompt_indices) > 0:
|
|
||||||
for idx in reversed(system_prompt_indices):
|
|
||||||
messages.pop(idx)
|
|
||||||
|
|
||||||
inference_params = copy.deepcopy(optional_params)
|
|
||||||
additional_request_keys = []
|
|
||||||
additional_request_params = {}
|
|
||||||
supported_converse_params = AmazonConverseConfig.__annotations__.keys()
|
|
||||||
supported_tool_call_params = ["tools", "tool_choice"]
|
|
||||||
supported_guardrail_params = ["guardrailConfig"]
|
|
||||||
json_mode: Optional[bool] = inference_params.pop(
|
|
||||||
"json_mode", None
|
|
||||||
) # used for handling json_schema
|
|
||||||
## TRANSFORMATION ##
|
|
||||||
|
|
||||||
bedrock_messages: List[MessageBlock] = _bedrock_converse_messages_pt(
|
|
||||||
messages=messages,
|
|
||||||
model=model,
|
|
||||||
llm_provider="bedrock_converse",
|
|
||||||
user_continue_message=litellm_params.pop("user_continue_message", None),
|
|
||||||
)
|
|
||||||
|
|
||||||
# send all model-specific params in 'additional_request_params'
|
|
||||||
for k, v in inference_params.items():
|
|
||||||
if (
|
|
||||||
k not in supported_converse_params
|
|
||||||
and k not in supported_tool_call_params
|
|
||||||
and k not in supported_guardrail_params
|
|
||||||
):
|
|
||||||
additional_request_params[k] = v
|
|
||||||
additional_request_keys.append(k)
|
|
||||||
for key in additional_request_keys:
|
|
||||||
inference_params.pop(key, None)
|
|
||||||
|
|
||||||
bedrock_tools: List[ToolBlock] = _bedrock_tools_pt(
|
|
||||||
inference_params.pop("tools", [])
|
|
||||||
)
|
|
||||||
bedrock_tool_config: Optional[ToolConfigBlock] = None
|
|
||||||
if len(bedrock_tools) > 0:
|
|
||||||
tool_choice_values: ToolChoiceValuesBlock = inference_params.pop(
|
|
||||||
"tool_choice", None
|
|
||||||
)
|
|
||||||
bedrock_tool_config = ToolConfigBlock(
|
|
||||||
tools=bedrock_tools,
|
|
||||||
)
|
|
||||||
if tool_choice_values is not None:
|
|
||||||
bedrock_tool_config["toolChoice"] = tool_choice_values
|
|
||||||
|
|
||||||
_data: RequestObject = {
|
|
||||||
"messages": bedrock_messages,
|
|
||||||
"additionalModelRequestFields": additional_request_params,
|
|
||||||
"system": system_content_blocks,
|
|
||||||
"inferenceConfig": InferenceConfig(**inference_params),
|
|
||||||
}
|
|
||||||
|
|
||||||
# Guardrail Config
|
|
||||||
guardrail_config: Optional[GuardrailConfigBlock] = None
|
|
||||||
request_guardrails_config = inference_params.pop("guardrailConfig", None)
|
|
||||||
if request_guardrails_config is not None:
|
|
||||||
guardrail_config = GuardrailConfigBlock(**request_guardrails_config)
|
|
||||||
_data["guardrailConfig"] = guardrail_config
|
|
||||||
|
|
||||||
# Tool Config
|
|
||||||
if bedrock_tool_config is not None:
|
|
||||||
_data["toolConfig"] = bedrock_tool_config
|
|
||||||
|
|
||||||
data = json.dumps(_data)
|
|
||||||
## COMPLETION CALL
|
|
||||||
|
|
||||||
headers = {"Content-Type": "application/json"}
|
|
||||||
if extra_headers is not None:
|
|
||||||
headers = {"Content-Type": "application/json", **extra_headers}
|
|
||||||
request = AWSRequest(
|
|
||||||
method="POST", url=endpoint_url, data=data, headers=headers
|
|
||||||
)
|
|
||||||
sigv4.add_auth(request)
|
|
||||||
if (
|
|
||||||
extra_headers is not None and "Authorization" in extra_headers
|
|
||||||
): # prevent sigv4 from overwriting the auth header
|
|
||||||
request.headers["Authorization"] = extra_headers["Authorization"]
|
|
||||||
prepped = request.prepare()
|
|
||||||
|
|
||||||
## LOGGING
|
|
||||||
logging_obj.pre_call(
|
|
||||||
input=messages,
|
|
||||||
api_key="",
|
|
||||||
additional_args={
|
|
||||||
"complete_input_dict": data,
|
|
||||||
"api_base": proxy_endpoint_url,
|
|
||||||
"headers": prepped.headers,
|
|
||||||
},
|
|
||||||
)
|
|
||||||
|
|
||||||
### ROUTING (ASYNC, STREAMING, SYNC)
|
|
||||||
if acompletion:
|
|
||||||
if isinstance(client, HTTPHandler):
|
|
||||||
client = None
|
|
||||||
if stream is True:
|
|
||||||
return self.async_streaming(
|
|
||||||
model=model,
|
|
||||||
messages=messages,
|
|
||||||
data=data,
|
|
||||||
api_base=proxy_endpoint_url,
|
|
||||||
model_response=model_response,
|
|
||||||
print_verbose=print_verbose,
|
|
||||||
encoding=encoding,
|
|
||||||
logging_obj=logging_obj,
|
|
||||||
optional_params=optional_params,
|
|
||||||
stream=True,
|
|
||||||
litellm_params=litellm_params,
|
|
||||||
logger_fn=logger_fn,
|
|
||||||
headers=prepped.headers,
|
|
||||||
timeout=timeout,
|
|
||||||
client=client,
|
|
||||||
) # type: ignore
|
|
||||||
### ASYNC COMPLETION
|
|
||||||
return self.async_completion(
|
|
||||||
model=model,
|
|
||||||
messages=messages,
|
|
||||||
data=data,
|
|
||||||
api_base=proxy_endpoint_url,
|
|
||||||
model_response=model_response,
|
|
||||||
print_verbose=print_verbose,
|
|
||||||
encoding=encoding,
|
|
||||||
logging_obj=logging_obj,
|
|
||||||
optional_params=optional_params,
|
|
||||||
stream=stream, # type: ignore
|
|
||||||
litellm_params=litellm_params,
|
|
||||||
logger_fn=logger_fn,
|
|
||||||
headers=prepped.headers,
|
|
||||||
timeout=timeout,
|
|
||||||
client=client,
|
|
||||||
) # type: ignore
|
|
||||||
|
|
||||||
if stream is not None and stream is True:
|
|
||||||
|
|
||||||
streaming_response = CustomStreamWrapper(
|
|
||||||
completion_stream=None,
|
|
||||||
make_call=partial(
|
|
||||||
make_sync_call,
|
|
||||||
client=None,
|
|
||||||
api_base=proxy_endpoint_url,
|
|
||||||
headers=prepped.headers, # type: ignore
|
|
||||||
data=data,
|
|
||||||
model=model,
|
|
||||||
messages=messages,
|
|
||||||
logging_obj=logging_obj,
|
|
||||||
),
|
|
||||||
model=model,
|
|
||||||
custom_llm_provider="bedrock",
|
|
||||||
logging_obj=logging_obj,
|
|
||||||
)
|
|
||||||
|
|
||||||
return streaming_response
|
|
||||||
### COMPLETION
|
|
||||||
|
|
||||||
if client is None or isinstance(client, AsyncHTTPHandler):
|
|
||||||
_params = {}
|
|
||||||
if timeout is not None:
|
|
||||||
if isinstance(timeout, float) or isinstance(timeout, int):
|
|
||||||
timeout = httpx.Timeout(timeout)
|
|
||||||
_params["timeout"] = timeout
|
|
||||||
client = _get_httpx_client(_params) # type: ignore
|
|
||||||
else:
|
|
||||||
client = client
|
|
||||||
try:
|
|
||||||
response = client.post(url=proxy_endpoint_url, headers=prepped.headers, data=data) # type: ignore
|
|
||||||
response.raise_for_status()
|
|
||||||
except httpx.HTTPStatusError as err:
|
|
||||||
error_code = err.response.status_code
|
|
||||||
raise BedrockError(status_code=error_code, message=response.text)
|
|
||||||
except httpx.TimeoutException:
|
|
||||||
raise BedrockError(status_code=408, message="Timeout error occurred.")
|
|
||||||
|
|
||||||
return self.process_response(
|
|
||||||
model=model,
|
|
||||||
response=response,
|
|
||||||
model_response=model_response,
|
|
||||||
stream=stream if isinstance(stream, bool) else False,
|
|
||||||
logging_obj=logging_obj,
|
|
||||||
optional_params=optional_params,
|
|
||||||
api_key="",
|
|
||||||
data=data,
|
|
||||||
messages=messages,
|
|
||||||
print_verbose=print_verbose,
|
|
||||||
encoding=encoding,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def get_response_stream_shape():
|
def get_response_stream_shape():
|
||||||
global _response_stream_shape_cache
|
global _response_stream_shape_cache
|
||||||
if _response_stream_shape_cache is None:
|
if _response_stream_shape_cache is None:
|
||||||
|
@ -1847,24 +1080,6 @@ def get_response_stream_shape():
|
||||||
return _response_stream_shape_cache
|
return _response_stream_shape_cache
|
||||||
|
|
||||||
|
|
||||||
def get_bedrock_tool_name(response_tool_name: str) -> str:
|
|
||||||
"""
|
|
||||||
If litellm formatted the input tool name, we need to convert it back to the original name.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
response_tool_name (str): The name of the tool as received from the response.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
str: The original name of the tool.
|
|
||||||
"""
|
|
||||||
|
|
||||||
if response_tool_name in litellm.bedrock_tool_name_mappings.cache_dict:
|
|
||||||
response_tool_name = litellm.bedrock_tool_name_mappings.cache_dict[
|
|
||||||
response_tool_name
|
|
||||||
]
|
|
||||||
return response_tool_name
|
|
||||||
|
|
||||||
|
|
||||||
class AWSEventStreamDecoder:
|
class AWSEventStreamDecoder:
|
||||||
def __init__(self, model: str) -> None:
|
def __init__(self, model: str) -> None:
|
||||||
from botocore.parsers import EventStreamJSONParser
|
from botocore.parsers import EventStreamJSONParser
|
|
@ -583,7 +583,7 @@ def init_bedrock_client(
|
||||||
# Iterate over parameters and update if needed
|
# Iterate over parameters and update if needed
|
||||||
for i, param in enumerate(params_to_check):
|
for i, param in enumerate(params_to_check):
|
||||||
if param and param.startswith("os.environ/"):
|
if param and param.startswith("os.environ/"):
|
||||||
params_to_check[i] = get_secret(param)
|
params_to_check[i] = get_secret(param) # type: ignore
|
||||||
# Assign updated values back to parameters
|
# Assign updated values back to parameters
|
||||||
(
|
(
|
||||||
aws_access_key_id,
|
aws_access_key_id,
|
||||||
|
@ -626,13 +626,13 @@ def init_bedrock_client(
|
||||||
import boto3
|
import boto3
|
||||||
|
|
||||||
if isinstance(timeout, float):
|
if isinstance(timeout, float):
|
||||||
config = boto3.session.Config(connect_timeout=timeout, read_timeout=timeout)
|
config = boto3.session.Config(connect_timeout=timeout, read_timeout=timeout) # type: ignore
|
||||||
elif isinstance(timeout, httpx.Timeout):
|
elif isinstance(timeout, httpx.Timeout):
|
||||||
config = boto3.session.Config(
|
config = boto3.session.Config( # type: ignore
|
||||||
connect_timeout=timeout.connect, read_timeout=timeout.read
|
connect_timeout=timeout.connect, read_timeout=timeout.read
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
config = boto3.session.Config()
|
config = boto3.session.Config() # type: ignore
|
||||||
|
|
||||||
### CHECK STS ###
|
### CHECK STS ###
|
||||||
if (
|
if (
|
||||||
|
@ -733,40 +733,6 @@ def init_bedrock_client(
|
||||||
return client
|
return client
|
||||||
|
|
||||||
|
|
||||||
def get_runtime_endpoint(
|
|
||||||
api_base: Optional[str],
|
|
||||||
aws_bedrock_runtime_endpoint: Optional[str],
|
|
||||||
aws_region_name: str,
|
|
||||||
) -> Tuple[str, str]:
|
|
||||||
env_aws_bedrock_runtime_endpoint = get_secret("AWS_BEDROCK_RUNTIME_ENDPOINT")
|
|
||||||
if api_base is not None:
|
|
||||||
endpoint_url = api_base
|
|
||||||
elif aws_bedrock_runtime_endpoint is not None and isinstance(
|
|
||||||
aws_bedrock_runtime_endpoint, str
|
|
||||||
):
|
|
||||||
endpoint_url = aws_bedrock_runtime_endpoint
|
|
||||||
elif env_aws_bedrock_runtime_endpoint and isinstance(
|
|
||||||
env_aws_bedrock_runtime_endpoint, str
|
|
||||||
):
|
|
||||||
endpoint_url = env_aws_bedrock_runtime_endpoint
|
|
||||||
else:
|
|
||||||
endpoint_url = f"https://bedrock-runtime.{aws_region_name}.amazonaws.com"
|
|
||||||
|
|
||||||
# Determine proxy_endpoint_url
|
|
||||||
if env_aws_bedrock_runtime_endpoint and isinstance(
|
|
||||||
env_aws_bedrock_runtime_endpoint, str
|
|
||||||
):
|
|
||||||
proxy_endpoint_url = env_aws_bedrock_runtime_endpoint
|
|
||||||
elif aws_bedrock_runtime_endpoint is not None and isinstance(
|
|
||||||
aws_bedrock_runtime_endpoint, str
|
|
||||||
):
|
|
||||||
proxy_endpoint_url = aws_bedrock_runtime_endpoint
|
|
||||||
else:
|
|
||||||
proxy_endpoint_url = endpoint_url
|
|
||||||
|
|
||||||
return endpoint_url, proxy_endpoint_url
|
|
||||||
|
|
||||||
|
|
||||||
class ModelResponseIterator:
|
class ModelResponseIterator:
|
||||||
def __init__(self, model_response):
|
def __init__(self, model_response):
|
||||||
self.model_response = model_response
|
self.model_response = model_response
|
||||||
|
@ -791,3 +757,21 @@ class ModelResponseIterator:
|
||||||
raise StopAsyncIteration
|
raise StopAsyncIteration
|
||||||
self.is_done = True
|
self.is_done = True
|
||||||
return self.model_response
|
return self.model_response
|
||||||
|
|
||||||
|
|
||||||
|
def get_bedrock_tool_name(response_tool_name: str) -> str:
|
||||||
|
"""
|
||||||
|
If litellm formatted the input tool name, we need to convert it back to the original name.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
response_tool_name (str): The name of the tool as received from the response.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
str: The original name of the tool.
|
||||||
|
"""
|
||||||
|
|
||||||
|
if response_tool_name in litellm.bedrock_tool_name_mappings.cache_dict:
|
||||||
|
response_tool_name = litellm.bedrock_tool_name_mappings.cache_dict[
|
||||||
|
response_tool_name
|
||||||
|
]
|
||||||
|
return response_tool_name
|
||||||
|
|
|
@ -23,7 +23,7 @@ from litellm.types.llms.bedrock import AmazonEmbeddingRequest, CohereEmbeddingRe
|
||||||
from litellm.types.utils import Embedding, EmbeddingResponse, Usage
|
from litellm.types.utils import Embedding, EmbeddingResponse, Usage
|
||||||
|
|
||||||
from ...base_aws_llm import BaseAWSLLM
|
from ...base_aws_llm import BaseAWSLLM
|
||||||
from ..common_utils import BedrockError, get_runtime_endpoint
|
from ..common_utils import BedrockError
|
||||||
from .amazon_titan_g1_transformation import AmazonTitanG1Config
|
from .amazon_titan_g1_transformation import AmazonTitanG1Config
|
||||||
from .amazon_titan_multimodal_transformation import (
|
from .amazon_titan_multimodal_transformation import (
|
||||||
AmazonTitanMultimodalEmbeddingG1Config,
|
AmazonTitanMultimodalEmbeddingG1Config,
|
||||||
|
@ -141,7 +141,7 @@ class BedrockEmbedding(BaseAWSLLM):
|
||||||
response.raise_for_status()
|
response.raise_for_status()
|
||||||
except httpx.HTTPStatusError as err:
|
except httpx.HTTPStatusError as err:
|
||||||
error_code = err.response.status_code
|
error_code = err.response.status_code
|
||||||
raise BedrockError(status_code=error_code, message=response.text)
|
raise BedrockError(status_code=error_code, message=err.response.text)
|
||||||
except httpx.TimeoutException:
|
except httpx.TimeoutException:
|
||||||
raise BedrockError(status_code=408, message="Timeout error occurred.")
|
raise BedrockError(status_code=408, message="Timeout error occurred.")
|
||||||
|
|
||||||
|
@ -197,7 +197,7 @@ class BedrockEmbedding(BaseAWSLLM):
|
||||||
client=client,
|
client=client,
|
||||||
timeout=timeout,
|
timeout=timeout,
|
||||||
api_base=prepped.url,
|
api_base=prepped.url,
|
||||||
headers=prepped.headers,
|
headers=prepped.headers, # type: ignore
|
||||||
data=data,
|
data=data,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -288,7 +288,7 @@ class BedrockEmbedding(BaseAWSLLM):
|
||||||
client=client,
|
client=client,
|
||||||
timeout=timeout,
|
timeout=timeout,
|
||||||
api_base=prepped.url,
|
api_base=prepped.url,
|
||||||
headers=prepped.headers,
|
headers=prepped.headers, # type: ignore
|
||||||
data=data,
|
data=data,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -342,8 +342,8 @@ class BedrockEmbedding(BaseAWSLLM):
|
||||||
timeout: Optional[Union[float, httpx.Timeout]],
|
timeout: Optional[Union[float, httpx.Timeout]],
|
||||||
aembedding: Optional[bool],
|
aembedding: Optional[bool],
|
||||||
extra_headers: Optional[dict],
|
extra_headers: Optional[dict],
|
||||||
optional_params=None,
|
optional_params: dict,
|
||||||
litellm_params=None,
|
litellm_params: dict,
|
||||||
) -> EmbeddingResponse:
|
) -> EmbeddingResponse:
|
||||||
try:
|
try:
|
||||||
import boto3
|
import boto3
|
||||||
|
@ -392,10 +392,21 @@ class BedrockEmbedding(BaseAWSLLM):
|
||||||
transformed_request = AmazonTitanV2Config()._transform_request(
|
transformed_request = AmazonTitanV2Config()._transform_request(
|
||||||
input=i, inference_params=inference_params
|
input=i, inference_params=inference_params
|
||||||
)
|
)
|
||||||
|
else:
|
||||||
|
raise Exception(
|
||||||
|
"Unmapped model. Received={}. Expected={}".format(
|
||||||
|
model,
|
||||||
|
[
|
||||||
|
"amazon.titan-embed-image-v1",
|
||||||
|
"amazon.titan-embed-text-v1",
|
||||||
|
"amazon.titan-embed-text-v2:0",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
)
|
||||||
batch_data.append(transformed_request)
|
batch_data.append(transformed_request)
|
||||||
|
|
||||||
### SET RUNTIME ENDPOINT ###
|
### SET RUNTIME ENDPOINT ###
|
||||||
endpoint_url, proxy_endpoint_url = get_runtime_endpoint(
|
endpoint_url, proxy_endpoint_url = self.get_runtime_endpoint(
|
||||||
api_base=api_base,
|
api_base=api_base,
|
||||||
aws_bedrock_runtime_endpoint=optional_params.pop(
|
aws_bedrock_runtime_endpoint=optional_params.pop(
|
||||||
"aws_bedrock_runtime_endpoint", None
|
"aws_bedrock_runtime_endpoint", None
|
||||||
|
@ -443,6 +454,7 @@ class BedrockEmbedding(BaseAWSLLM):
|
||||||
headers = {"Content-Type": "application/json"}
|
headers = {"Content-Type": "application/json"}
|
||||||
if extra_headers is not None:
|
if extra_headers is not None:
|
||||||
headers = {"Content-Type": "application/json", **extra_headers}
|
headers = {"Content-Type": "application/json", **extra_headers}
|
||||||
|
|
||||||
request = AWSRequest(
|
request = AWSRequest(
|
||||||
method="POST", url=endpoint_url, data=json.dumps(data), headers=headers
|
method="POST", url=endpoint_url, data=json.dumps(data), headers=headers
|
||||||
)
|
)
|
||||||
|
@ -467,170 +479,5 @@ class BedrockEmbedding(BaseAWSLLM):
|
||||||
aembedding=aembedding,
|
aembedding=aembedding,
|
||||||
timeout=timeout,
|
timeout=timeout,
|
||||||
client=client,
|
client=client,
|
||||||
headers=prepped.headers,
|
headers=prepped.headers, # type: ignore
|
||||||
)
|
)
|
||||||
|
|
||||||
# def _embedding_func_single(
|
|
||||||
# model: str,
|
|
||||||
# input: str,
|
|
||||||
# client: Any,
|
|
||||||
# optional_params=None,
|
|
||||||
# encoding=None,
|
|
||||||
# logging_obj=None,
|
|
||||||
# ):
|
|
||||||
# if isinstance(input, str) is False:
|
|
||||||
# raise BedrockError(
|
|
||||||
# message="Bedrock Embedding API input must be type str | List[str]",
|
|
||||||
# status_code=400,
|
|
||||||
# )
|
|
||||||
# # logic for parsing in - calling - parsing out model embedding calls
|
|
||||||
# ## FORMAT EMBEDDING INPUT ##
|
|
||||||
# provider = model.split(".")[0]
|
|
||||||
# inference_params = copy.deepcopy(optional_params)
|
|
||||||
# inference_params.pop(
|
|
||||||
# "user", None
|
|
||||||
# ) # make sure user is not passed in for bedrock call
|
|
||||||
# modelId = (
|
|
||||||
# optional_params.pop("model_id", None) or model
|
|
||||||
# ) # default to model if not passed
|
|
||||||
# if provider == "amazon":
|
|
||||||
# input = input.replace(os.linesep, " ")
|
|
||||||
# data = {"inputText": input, **inference_params}
|
|
||||||
# # data = json.dumps(data)
|
|
||||||
# elif provider == "cohere":
|
|
||||||
# inference_params["input_type"] = inference_params.get(
|
|
||||||
# "input_type", "search_document"
|
|
||||||
# ) # aws bedrock example default - https://us-east-1.console.aws.amazon.com/bedrock/home?region=us-east-1#/providers?model=cohere.embed-english-v3
|
|
||||||
# data = {"texts": [input], **inference_params} # type: ignore
|
|
||||||
# body = json.dumps(data).encode("utf-8") # type: ignore
|
|
||||||
# ## LOGGING
|
|
||||||
# request_str = f"""
|
|
||||||
# response = client.invoke_model(
|
|
||||||
# body={body},
|
|
||||||
# modelId={modelId},
|
|
||||||
# accept="*/*",
|
|
||||||
# contentType="application/json",
|
|
||||||
# )""" # type: ignore
|
|
||||||
# logging_obj.pre_call(
|
|
||||||
# input=input,
|
|
||||||
# api_key="", # boto3 is used for init.
|
|
||||||
# additional_args={
|
|
||||||
# "complete_input_dict": {"model": modelId, "texts": input},
|
|
||||||
# "request_str": request_str,
|
|
||||||
# },
|
|
||||||
# )
|
|
||||||
# try:
|
|
||||||
# response = client.invoke_model(
|
|
||||||
# body=body,
|
|
||||||
# modelId=modelId,
|
|
||||||
# accept="*/*",
|
|
||||||
# contentType="application/json",
|
|
||||||
# )
|
|
||||||
# response_body = json.loads(response.get("body").read())
|
|
||||||
# ## LOGGING
|
|
||||||
# logging_obj.post_call(
|
|
||||||
# input=input,
|
|
||||||
# api_key="",
|
|
||||||
# additional_args={"complete_input_dict": data},
|
|
||||||
# original_response=json.dumps(response_body),
|
|
||||||
# )
|
|
||||||
# if provider == "cohere":
|
|
||||||
# response = response_body.get("embeddings")
|
|
||||||
# # flatten list
|
|
||||||
# response = [item for sublist in response for item in sublist]
|
|
||||||
# return response
|
|
||||||
# elif provider == "amazon":
|
|
||||||
# return response_body.get("embedding")
|
|
||||||
# except Exception as e:
|
|
||||||
# raise BedrockError(
|
|
||||||
# message=f"Embedding Error with model {model}: {e}", status_code=500
|
|
||||||
# )
|
|
||||||
|
|
||||||
# def embedding(
|
|
||||||
# model: str,
|
|
||||||
# input: Union[list, str],
|
|
||||||
# model_response: litellm.EmbeddingResponse,
|
|
||||||
# api_key: Optional[str] = None,
|
|
||||||
# logging_obj=None,
|
|
||||||
# optional_params=None,
|
|
||||||
# encoding=None,
|
|
||||||
# ):
|
|
||||||
# ### BOTO3 INIT ###
|
|
||||||
# # pop aws_secret_access_key, aws_access_key_id, aws_region_name from kwargs, since completion calls fail with them
|
|
||||||
# aws_secret_access_key = optional_params.pop("aws_secret_access_key", None)
|
|
||||||
# aws_access_key_id = optional_params.pop("aws_access_key_id", None)
|
|
||||||
# aws_region_name = optional_params.pop("aws_region_name", None)
|
|
||||||
# aws_role_name = optional_params.pop("aws_role_name", None)
|
|
||||||
# aws_session_name = optional_params.pop("aws_session_name", None)
|
|
||||||
# aws_bedrock_runtime_endpoint = optional_params.pop(
|
|
||||||
# "aws_bedrock_runtime_endpoint", None
|
|
||||||
# )
|
|
||||||
# aws_web_identity_token = optional_params.pop("aws_web_identity_token", None)
|
|
||||||
|
|
||||||
# # use passed in BedrockRuntime.Client if provided, otherwise create a new one
|
|
||||||
# client = init_bedrock_client(
|
|
||||||
# aws_access_key_id=aws_access_key_id,
|
|
||||||
# aws_secret_access_key=aws_secret_access_key,
|
|
||||||
# aws_region_name=aws_region_name,
|
|
||||||
# aws_bedrock_runtime_endpoint=aws_bedrock_runtime_endpoint,
|
|
||||||
# aws_web_identity_token=aws_web_identity_token,
|
|
||||||
# aws_role_name=aws_role_name,
|
|
||||||
# aws_session_name=aws_session_name,
|
|
||||||
# )
|
|
||||||
# if isinstance(input, str):
|
|
||||||
# ## Embedding Call
|
|
||||||
# embeddings = [
|
|
||||||
# _embedding_func_single(
|
|
||||||
# model,
|
|
||||||
# input,
|
|
||||||
# optional_params=optional_params,
|
|
||||||
# client=client,
|
|
||||||
# logging_obj=logging_obj,
|
|
||||||
# )
|
|
||||||
# ]
|
|
||||||
# elif isinstance(input, list):
|
|
||||||
# ## Embedding Call - assuming this is a List[str]
|
|
||||||
# embeddings = [
|
|
||||||
# _embedding_func_single(
|
|
||||||
# model,
|
|
||||||
# i,
|
|
||||||
# optional_params=optional_params,
|
|
||||||
# client=client,
|
|
||||||
# logging_obj=logging_obj,
|
|
||||||
# )
|
|
||||||
# for i in input
|
|
||||||
# ] # [TODO]: make these parallel calls
|
|
||||||
# else:
|
|
||||||
# # enters this branch if input = int, ex. input=2
|
|
||||||
# raise BedrockError(
|
|
||||||
# message="Bedrock Embedding API input must be type str | List[str]",
|
|
||||||
# status_code=400,
|
|
||||||
# )
|
|
||||||
|
|
||||||
# ## Populate OpenAI compliant dictionary
|
|
||||||
# embedding_response = []
|
|
||||||
# for idx, embedding in enumerate(embeddings):
|
|
||||||
# embedding_response.append(
|
|
||||||
# {
|
|
||||||
# "object": "embedding",
|
|
||||||
# "index": idx,
|
|
||||||
# "embedding": embedding,
|
|
||||||
# }
|
|
||||||
# )
|
|
||||||
# model_response.object = "list"
|
|
||||||
# model_response.data = embedding_response
|
|
||||||
# model_response.model = model
|
|
||||||
# input_tokens = 0
|
|
||||||
|
|
||||||
# input_str = "".join(input)
|
|
||||||
|
|
||||||
# input_tokens += len(encoding.encode(input_str))
|
|
||||||
|
|
||||||
# usage = Usage(
|
|
||||||
# prompt_tokens=input_tokens,
|
|
||||||
# completion_tokens=0,
|
|
||||||
# total_tokens=input_tokens + 0,
|
|
||||||
# )
|
|
||||||
# model_response.usage = usage
|
|
||||||
|
|
||||||
# return model_response
|
|
||||||
|
|
|
@ -2385,6 +2385,7 @@ def completion(
|
||||||
)
|
)
|
||||||
|
|
||||||
if model in litellm.BEDROCK_CONVERSE_MODELS:
|
if model in litellm.BEDROCK_CONVERSE_MODELS:
|
||||||
|
|
||||||
response = bedrock_converse_chat_completion.completion(
|
response = bedrock_converse_chat_completion.completion(
|
||||||
model=model,
|
model=model,
|
||||||
messages=messages,
|
messages=messages,
|
||||||
|
@ -3570,7 +3571,7 @@ def embedding(
|
||||||
client=client,
|
client=client,
|
||||||
timeout=timeout,
|
timeout=timeout,
|
||||||
aembedding=aembedding,
|
aembedding=aembedding,
|
||||||
litellm_params=litellm_params,
|
litellm_params={},
|
||||||
api_base=api_base,
|
api_base=api_base,
|
||||||
print_verbose=print_verbose,
|
print_verbose=print_verbose,
|
||||||
extra_headers=extra_headers,
|
extra_headers=extra_headers,
|
||||||
|
|
|
@ -1,7 +1,20 @@
|
||||||
model_list:
|
model_list:
|
||||||
- model_name: gpt-3.5-turbo
|
- model_name: fake-claude-endpoint
|
||||||
|
litellm_params:
|
||||||
|
model: anthropic.claude-3-sonnet-20240229-v1:0
|
||||||
|
api_base: https://exampleopenaiendpoint-production.up.railway.app
|
||||||
|
# aws_session_token: "IQoJb3JpZ2luX2VjELj//////////wEaCXVzLXdlc3QtMiJHMEUCIQDatCRVkIZERLcrR6P7Qd1vNfZ8r8xB/LUeaVaTW/lBTwIgAgmHSBe41d65GVRKSkpgVonjsCmOmAS7s/yklM9NsZcq3AEI4P//////////ARABGgw4ODg2MDIyMjM0MjgiDJrio0/CHYEfyt5EqyqwAfyWO4t3bFVWAOIwTyZ1N6lszeJKfMNus2hzVc+r73hia2Anv88uwPxNg2uqnXQNJumEo0DcBt30ZwOw03Isboy0d5l05h8gjb4nl9feyeKmKAnRdcqElrEWtCC1Qcefv78jQv53AbUipH1ssa5NPvptqZZpZYDPMlBEnV3YdvJJiuE23u2yOkCt+EoUJLaOYjOryoRyrSfbWB+JaUsB68R3rNTHzReeN3Nob/9Ic4HrMMmzmLcGOpgBZxclO4w8Z7i6TcVqbCwDOskxuR6bZaiFxKFG+9tDrWS7jaQKpq/YP9HUT0YwYpZplaBEEZR5sbIndg5yb4dRZrSHplblqKz8XLaUf5tuuyRJmwr96PTpw/dyEVk9gicFX6JfLBEv0v5rN2Z0JMFLdfIP4kC1U2PjcPOWoglWO3fLmJ4Lol2a3c5XDSMwMxjcJXq+c8Ue1v0="
|
||||||
|
aws_secret_access_key: os.environ/AWS_SECRET_ACCESS_KEY
|
||||||
|
aws_access_key_id: os.environ/AWS_ACCESS_KEY_ID
|
||||||
|
- model_name: gemini-vision
|
||||||
|
litellm_params:
|
||||||
|
model: vertex_ai/gemini-1.0-pro-vision-001
|
||||||
|
api_base: https://exampleopenaiendpoint-production.up.railway.app/v1/projects/adroit-crow-413218/locations/us-central1/publishers/google/models/gemini-1.0-pro-vision-001
|
||||||
|
vertex_project: "adroit-crow-413218"
|
||||||
|
vertex_location: "us-central1"
|
||||||
|
|
||||||
|
- model_name: fake-openai-endpoint
|
||||||
litellm_params:
|
litellm_params:
|
||||||
model: gpt-3.5-turbo
|
model: gpt-3.5-turbo
|
||||||
|
api_base: https://exampleopenaiendpoint-production.up.railway.app
|
||||||
router_settings:
|
|
||||||
model_group_alias: {"gpt-4": {"model": "gpt-3.5-turbo", "hidden": false}}
|
|
|
@ -80,7 +80,13 @@ async def gemini_proxy_route(
|
||||||
updated_url = base_url.copy_with(path=encoded_endpoint)
|
updated_url = base_url.copy_with(path=encoded_endpoint)
|
||||||
|
|
||||||
# Add or update query parameters
|
# Add or update query parameters
|
||||||
gemini_api_key = litellm.utils.get_secret(secret_name="GEMINI_API_KEY")
|
gemini_api_key: Optional[str] = litellm.utils.get_secret( # type: ignore
|
||||||
|
secret_name="GEMINI_API_KEY"
|
||||||
|
)
|
||||||
|
if gemini_api_key is None:
|
||||||
|
raise Exception(
|
||||||
|
"Required 'GEMINI_API_KEY' in environment to make pass-through calls to Google AI Studio."
|
||||||
|
)
|
||||||
# Merge query parameters, giving precedence to those in updated_url
|
# Merge query parameters, giving precedence to those in updated_url
|
||||||
merged_params = dict(request.query_params)
|
merged_params = dict(request.query_params)
|
||||||
merged_params.update({"key": gemini_api_key})
|
merged_params.update({"key": gemini_api_key})
|
||||||
|
@ -99,8 +105,8 @@ async def gemini_proxy_route(
|
||||||
request,
|
request,
|
||||||
fastapi_response,
|
fastapi_response,
|
||||||
user_api_key_dict,
|
user_api_key_dict,
|
||||||
query_params=merged_params,
|
query_params=merged_params, # type: ignore
|
||||||
stream=is_streaming_request,
|
stream=is_streaming_request, # type: ignore
|
||||||
)
|
)
|
||||||
|
|
||||||
return received_value
|
return received_value
|
||||||
|
@ -142,7 +148,7 @@ async def cohere_proxy_route(
|
||||||
request,
|
request,
|
||||||
fastapi_response,
|
fastapi_response,
|
||||||
user_api_key_dict,
|
user_api_key_dict,
|
||||||
stream=is_streaming_request,
|
stream=is_streaming_request, # type: ignore
|
||||||
)
|
)
|
||||||
|
|
||||||
return received_value
|
return received_value
|
||||||
|
@ -208,15 +214,15 @@ async def bedrock_proxy_route(
|
||||||
endpoint_func = create_pass_through_route(
|
endpoint_func = create_pass_through_route(
|
||||||
endpoint=endpoint,
|
endpoint=endpoint,
|
||||||
target=str(prepped.url),
|
target=str(prepped.url),
|
||||||
custom_headers=prepped.headers,
|
custom_headers=prepped.headers, # type: ignore
|
||||||
) # dynamically construct pass-through endpoint based on incoming path
|
) # dynamically construct pass-through endpoint based on incoming path
|
||||||
received_value = await endpoint_func(
|
received_value = await endpoint_func(
|
||||||
request,
|
request,
|
||||||
fastapi_response,
|
fastapi_response,
|
||||||
user_api_key_dict,
|
user_api_key_dict,
|
||||||
stream=is_streaming_request,
|
stream=is_streaming_request, # type: ignore
|
||||||
custom_body=data,
|
custom_body=data, # type: ignore
|
||||||
query_params={},
|
query_params={}, # type: ignore
|
||||||
)
|
)
|
||||||
|
|
||||||
return received_value
|
return received_value
|
||||||
|
|
|
@ -27,7 +27,7 @@ from litellm import (
|
||||||
completion_cost,
|
completion_cost,
|
||||||
embedding,
|
embedding,
|
||||||
)
|
)
|
||||||
from litellm.llms.bedrock.chat import BedrockLLM, ToolBlock
|
from litellm.llms.bedrock.chat import BedrockLLM
|
||||||
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, HTTPHandler
|
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, HTTPHandler
|
||||||
from litellm.llms.prompt_templates.factory import _bedrock_tools_pt
|
from litellm.llms.prompt_templates.factory import _bedrock_tools_pt
|
||||||
|
|
||||||
|
@ -1287,3 +1287,41 @@ def test_bedrock_converse_translation_tool_message():
|
||||||
],
|
],
|
||||||
}
|
}
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def test_base_aws_llm_get_credentials():
|
||||||
|
import time
|
||||||
|
|
||||||
|
import boto3
|
||||||
|
|
||||||
|
from litellm.llms.base_aws_llm import BaseAWSLLM
|
||||||
|
|
||||||
|
start_time = time.time()
|
||||||
|
session = boto3.Session(
|
||||||
|
aws_access_key_id="test",
|
||||||
|
aws_secret_access_key="test2",
|
||||||
|
region_name="test3",
|
||||||
|
)
|
||||||
|
credentials = session.get_credentials().get_frozen_credentials()
|
||||||
|
end_time = time.time()
|
||||||
|
|
||||||
|
print(
|
||||||
|
"Total time for credentials - {}. Credentials - {}".format(
|
||||||
|
end_time - start_time, credentials
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
start_time = time.time()
|
||||||
|
credentials = BaseAWSLLM().get_credentials(
|
||||||
|
aws_access_key_id="test",
|
||||||
|
aws_secret_access_key="test2",
|
||||||
|
aws_region_name="test3",
|
||||||
|
)
|
||||||
|
|
||||||
|
end_time = time.time()
|
||||||
|
|
||||||
|
print(
|
||||||
|
"Total time for credentials - {}. Credentials - {}".format(
|
||||||
|
end_time - start_time, credentials.get_frozen_credentials()
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
|
@ -1454,7 +1454,7 @@ async def test_bedrock_httpx_streaming(sync_mode, model, region):
|
||||||
has_finish_reason = True
|
has_finish_reason = True
|
||||||
break
|
break
|
||||||
complete_response += chunk
|
complete_response += chunk
|
||||||
if has_finish_reason == False:
|
if has_finish_reason is False:
|
||||||
raise Exception("finish reason not set")
|
raise Exception("finish reason not set")
|
||||||
if complete_response.strip() == "":
|
if complete_response.strip() == "":
|
||||||
raise Exception("Empty response received")
|
raise Exception("Empty response received")
|
||||||
|
|
|
@ -8159,9 +8159,7 @@ def exception_type(
|
||||||
exception_mapping_worked = True
|
exception_mapping_worked = True
|
||||||
if hasattr(original_exception, "request"):
|
if hasattr(original_exception, "request"):
|
||||||
raise APIConnectionError(
|
raise APIConnectionError(
|
||||||
message="{}\n{}".format(
|
message="{} - {}".format(exception_provider, error_str),
|
||||||
str(original_exception), traceback.format_exc()
|
|
||||||
),
|
|
||||||
llm_provider=custom_llm_provider,
|
llm_provider=custom_llm_provider,
|
||||||
model=model,
|
model=model,
|
||||||
request=original_exception.request,
|
request=original_exception.request,
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue