forked from phoenix/litellm-mirror
feat(bedrock_httpx.py): add support for bedrock converse api
closes https://github.com/BerriAI/litellm/issues/4000
This commit is contained in:
parent
0d3e52373c
commit
a76a9b7d11
6 changed files with 846 additions and 31 deletions
|
@ -765,7 +765,7 @@ from .llms.sagemaker import SagemakerConfig
|
||||||
from .llms.ollama import OllamaConfig
|
from .llms.ollama import OllamaConfig
|
||||||
from .llms.ollama_chat import OllamaChatConfig
|
from .llms.ollama_chat import OllamaChatConfig
|
||||||
from .llms.maritalk import MaritTalkConfig
|
from .llms.maritalk import MaritTalkConfig
|
||||||
from .llms.bedrock_httpx import AmazonCohereChatConfig
|
from .llms.bedrock_httpx import AmazonCohereChatConfig, AmazonConverseConfig
|
||||||
from .llms.bedrock import (
|
from .llms.bedrock import (
|
||||||
AmazonTitanConfig,
|
AmazonTitanConfig,
|
||||||
AmazonAI21Config,
|
AmazonAI21Config,
|
||||||
|
|
|
@ -38,6 +38,8 @@ from .prompt_templates.factory import (
|
||||||
extract_between_tags,
|
extract_between_tags,
|
||||||
parse_xml_params,
|
parse_xml_params,
|
||||||
contains_tag,
|
contains_tag,
|
||||||
|
_bedrock_converse_messages_pt,
|
||||||
|
_bedrock_tools_pt,
|
||||||
)
|
)
|
||||||
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, HTTPHandler
|
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, HTTPHandler
|
||||||
from .base import BaseLLM
|
from .base import BaseLLM
|
||||||
|
@ -118,6 +120,8 @@ class AmazonCohereChatConfig:
|
||||||
"presence_penalty",
|
"presence_penalty",
|
||||||
"seed",
|
"seed",
|
||||||
"stop",
|
"stop",
|
||||||
|
"tools",
|
||||||
|
"tool_choice",
|
||||||
]
|
]
|
||||||
|
|
||||||
def map_openai_params(
|
def map_openai_params(
|
||||||
|
@ -1069,6 +1073,384 @@ class BedrockLLM(BaseLLM):
|
||||||
return super().embedding(*args, **kwargs)
|
return super().embedding(*args, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
class AmazonConverseConfig:
|
||||||
|
"""
|
||||||
|
Reference - https://docs.aws.amazon.com/bedrock/latest/APIReference/API_runtime_Converse.html
|
||||||
|
"""
|
||||||
|
|
||||||
|
maxTokens: Optional[int]
|
||||||
|
stopSequences: Optional[List[str]]
|
||||||
|
temperature: Optional[int]
|
||||||
|
topP: Optional[int]
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
maxTokens: Optional[int] = None,
|
||||||
|
stopSequences: Optional[List[str]] = None,
|
||||||
|
temperature: Optional[int] = None,
|
||||||
|
top_p: Optional[int] = None,
|
||||||
|
) -> None:
|
||||||
|
locals_ = locals()
|
||||||
|
for key, value in locals_.items():
|
||||||
|
if key != "self" and value is not None:
|
||||||
|
setattr(self.__class__, key, value)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_config(cls):
|
||||||
|
return {
|
||||||
|
k: v
|
||||||
|
for k, v in cls.__dict__.items()
|
||||||
|
if not k.startswith("__")
|
||||||
|
and not isinstance(
|
||||||
|
v,
|
||||||
|
(
|
||||||
|
types.FunctionType,
|
||||||
|
types.BuiltinFunctionType,
|
||||||
|
classmethod,
|
||||||
|
staticmethod,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
and v is not None
|
||||||
|
}
|
||||||
|
|
||||||
|
def get_supported_openai_params(self) -> List[str]:
|
||||||
|
return [
|
||||||
|
"max_tokens",
|
||||||
|
"stream",
|
||||||
|
"stream_options",
|
||||||
|
"stop",
|
||||||
|
"temperature",
|
||||||
|
"top_p",
|
||||||
|
"tools",
|
||||||
|
"tool_choice",
|
||||||
|
]
|
||||||
|
|
||||||
|
def get_supported_image_types(self) -> List[str]:
|
||||||
|
return ["png", "jpeg", "gif", "webp"]
|
||||||
|
|
||||||
|
def map_openai_params(
|
||||||
|
self, non_default_params: dict, optional_params: dict
|
||||||
|
) -> dict:
|
||||||
|
for param, value in non_default_params.items():
|
||||||
|
if param == "max_tokens":
|
||||||
|
optional_params["maxTokens"] = value
|
||||||
|
if param == "stream":
|
||||||
|
optional_params["stream"] = value
|
||||||
|
if param == "stop":
|
||||||
|
if isinstance(value, str):
|
||||||
|
value = [value]
|
||||||
|
optional_params["stop_sequences"] = value
|
||||||
|
if param == "temperature":
|
||||||
|
optional_params["temperature"] = value
|
||||||
|
if param == "top_p":
|
||||||
|
optional_params["topP"] = value
|
||||||
|
return optional_params
|
||||||
|
|
||||||
|
|
||||||
|
class BedrockConverseLLM(BaseLLM):
|
||||||
|
def __init__(self) -> None:
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
def encode_model_id(self, model_id: str) -> str:
|
||||||
|
"""
|
||||||
|
Double encode the model ID to ensure it matches the expected double-encoded format.
|
||||||
|
Args:
|
||||||
|
model_id (str): The model ID to encode.
|
||||||
|
Returns:
|
||||||
|
str: The double-encoded model ID.
|
||||||
|
"""
|
||||||
|
return urllib.parse.quote(model_id, safe="")
|
||||||
|
|
||||||
|
def get_credentials(
|
||||||
|
self,
|
||||||
|
aws_access_key_id: Optional[str] = None,
|
||||||
|
aws_secret_access_key: Optional[str] = None,
|
||||||
|
aws_region_name: Optional[str] = None,
|
||||||
|
aws_session_name: Optional[str] = None,
|
||||||
|
aws_profile_name: Optional[str] = None,
|
||||||
|
aws_role_name: Optional[str] = None,
|
||||||
|
aws_web_identity_token: Optional[str] = None,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Return a boto3.Credentials object
|
||||||
|
"""
|
||||||
|
import boto3
|
||||||
|
|
||||||
|
## CHECK IS 'os.environ/' passed in
|
||||||
|
params_to_check: List[Optional[str]] = [
|
||||||
|
aws_access_key_id,
|
||||||
|
aws_secret_access_key,
|
||||||
|
aws_region_name,
|
||||||
|
aws_session_name,
|
||||||
|
aws_profile_name,
|
||||||
|
aws_role_name,
|
||||||
|
aws_web_identity_token,
|
||||||
|
]
|
||||||
|
|
||||||
|
# Iterate over parameters and update if needed
|
||||||
|
for i, param in enumerate(params_to_check):
|
||||||
|
if param and param.startswith("os.environ/"):
|
||||||
|
_v = get_secret(param)
|
||||||
|
if _v is not None and isinstance(_v, str):
|
||||||
|
params_to_check[i] = _v
|
||||||
|
# Assign updated values back to parameters
|
||||||
|
(
|
||||||
|
aws_access_key_id,
|
||||||
|
aws_secret_access_key,
|
||||||
|
aws_region_name,
|
||||||
|
aws_session_name,
|
||||||
|
aws_profile_name,
|
||||||
|
aws_role_name,
|
||||||
|
aws_web_identity_token,
|
||||||
|
) = params_to_check
|
||||||
|
|
||||||
|
### CHECK STS ###
|
||||||
|
if (
|
||||||
|
aws_web_identity_token is not None
|
||||||
|
and aws_role_name is not None
|
||||||
|
and aws_session_name is not None
|
||||||
|
):
|
||||||
|
oidc_token = get_secret(aws_web_identity_token)
|
||||||
|
|
||||||
|
if oidc_token is None:
|
||||||
|
raise BedrockError(
|
||||||
|
message="OIDC token could not be retrieved from secret manager.",
|
||||||
|
status_code=401,
|
||||||
|
)
|
||||||
|
|
||||||
|
sts_client = boto3.client("sts")
|
||||||
|
|
||||||
|
# https://docs.aws.amazon.com/STS/latest/APIReference/API_AssumeRoleWithWebIdentity.html
|
||||||
|
# https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/sts/client/assume_role_with_web_identity.html
|
||||||
|
sts_response = sts_client.assume_role_with_web_identity(
|
||||||
|
RoleArn=aws_role_name,
|
||||||
|
RoleSessionName=aws_session_name,
|
||||||
|
WebIdentityToken=oidc_token,
|
||||||
|
DurationSeconds=3600,
|
||||||
|
)
|
||||||
|
|
||||||
|
session = boto3.Session(
|
||||||
|
aws_access_key_id=sts_response["Credentials"]["AccessKeyId"],
|
||||||
|
aws_secret_access_key=sts_response["Credentials"]["SecretAccessKey"],
|
||||||
|
aws_session_token=sts_response["Credentials"]["SessionToken"],
|
||||||
|
region_name=aws_region_name,
|
||||||
|
)
|
||||||
|
|
||||||
|
return session.get_credentials()
|
||||||
|
elif aws_role_name is not None and aws_session_name is not None:
|
||||||
|
sts_client = boto3.client(
|
||||||
|
"sts",
|
||||||
|
aws_access_key_id=aws_access_key_id, # [OPTIONAL]
|
||||||
|
aws_secret_access_key=aws_secret_access_key, # [OPTIONAL]
|
||||||
|
)
|
||||||
|
|
||||||
|
sts_response = sts_client.assume_role(
|
||||||
|
RoleArn=aws_role_name, RoleSessionName=aws_session_name
|
||||||
|
)
|
||||||
|
|
||||||
|
# Extract the credentials from the response and convert to Session Credentials
|
||||||
|
sts_credentials = sts_response["Credentials"]
|
||||||
|
from botocore.credentials import Credentials
|
||||||
|
|
||||||
|
credentials = Credentials(
|
||||||
|
access_key=sts_credentials["AccessKeyId"],
|
||||||
|
secret_key=sts_credentials["SecretAccessKey"],
|
||||||
|
token=sts_credentials["SessionToken"],
|
||||||
|
)
|
||||||
|
return credentials
|
||||||
|
elif aws_profile_name is not None: ### CHECK SESSION ###
|
||||||
|
# uses auth values from AWS profile usually stored in ~/.aws/credentials
|
||||||
|
client = boto3.Session(profile_name=aws_profile_name)
|
||||||
|
|
||||||
|
return client.get_credentials()
|
||||||
|
else:
|
||||||
|
session = boto3.Session(
|
||||||
|
aws_access_key_id=aws_access_key_id,
|
||||||
|
aws_secret_access_key=aws_secret_access_key,
|
||||||
|
region_name=aws_region_name,
|
||||||
|
)
|
||||||
|
|
||||||
|
return session.get_credentials()
|
||||||
|
|
||||||
|
def completion(
|
||||||
|
self,
|
||||||
|
model: str,
|
||||||
|
messages: list,
|
||||||
|
custom_prompt_dict: dict,
|
||||||
|
model_response: ModelResponse,
|
||||||
|
print_verbose: Callable,
|
||||||
|
encoding,
|
||||||
|
logging_obj,
|
||||||
|
optional_params: dict,
|
||||||
|
acompletion: bool,
|
||||||
|
timeout: Optional[Union[float, httpx.Timeout]],
|
||||||
|
litellm_params=None,
|
||||||
|
logger_fn=None,
|
||||||
|
extra_headers: Optional[dict] = None,
|
||||||
|
client: Optional[Union[AsyncHTTPHandler, HTTPHandler]] = None,
|
||||||
|
):
|
||||||
|
try:
|
||||||
|
import boto3
|
||||||
|
|
||||||
|
from botocore.auth import SigV4Auth
|
||||||
|
from botocore.awsrequest import AWSRequest
|
||||||
|
from botocore.credentials import Credentials
|
||||||
|
except ImportError as e:
|
||||||
|
raise ImportError("Missing boto3 to call bedrock. Run 'pip install boto3'.")
|
||||||
|
|
||||||
|
## SETUP ##
|
||||||
|
stream = optional_params.pop("stream", None)
|
||||||
|
modelId = optional_params.pop("model_id", None)
|
||||||
|
if modelId is not None:
|
||||||
|
modelId = self.encode_model_id(model_id=modelId)
|
||||||
|
else:
|
||||||
|
modelId = model
|
||||||
|
|
||||||
|
provider = model.split(".")[0]
|
||||||
|
|
||||||
|
## CREDENTIALS ##
|
||||||
|
# pop aws_secret_access_key, aws_access_key_id, aws_region_name from kwargs, since completion calls fail with them
|
||||||
|
aws_secret_access_key = optional_params.pop("aws_secret_access_key", None)
|
||||||
|
aws_access_key_id = optional_params.pop("aws_access_key_id", None)
|
||||||
|
aws_region_name = optional_params.pop("aws_region_name", None)
|
||||||
|
aws_role_name = optional_params.pop("aws_role_name", None)
|
||||||
|
aws_session_name = optional_params.pop("aws_session_name", None)
|
||||||
|
aws_profile_name = optional_params.pop("aws_profile_name", None)
|
||||||
|
aws_bedrock_runtime_endpoint = optional_params.pop(
|
||||||
|
"aws_bedrock_runtime_endpoint", None
|
||||||
|
) # https://bedrock-runtime.{region_name}.amazonaws.com
|
||||||
|
aws_web_identity_token = optional_params.pop("aws_web_identity_token", None)
|
||||||
|
|
||||||
|
### SET REGION NAME ###
|
||||||
|
if aws_region_name is None:
|
||||||
|
# check env #
|
||||||
|
litellm_aws_region_name = get_secret("AWS_REGION_NAME", None)
|
||||||
|
|
||||||
|
if litellm_aws_region_name is not None and isinstance(
|
||||||
|
litellm_aws_region_name, str
|
||||||
|
):
|
||||||
|
aws_region_name = litellm_aws_region_name
|
||||||
|
|
||||||
|
standard_aws_region_name = get_secret("AWS_REGION", None)
|
||||||
|
if standard_aws_region_name is not None and isinstance(
|
||||||
|
standard_aws_region_name, str
|
||||||
|
):
|
||||||
|
aws_region_name = standard_aws_region_name
|
||||||
|
|
||||||
|
if aws_region_name is None:
|
||||||
|
aws_region_name = "us-west-2"
|
||||||
|
|
||||||
|
credentials: Credentials = self.get_credentials(
|
||||||
|
aws_access_key_id=aws_access_key_id,
|
||||||
|
aws_secret_access_key=aws_secret_access_key,
|
||||||
|
aws_region_name=aws_region_name,
|
||||||
|
aws_session_name=aws_session_name,
|
||||||
|
aws_profile_name=aws_profile_name,
|
||||||
|
aws_role_name=aws_role_name,
|
||||||
|
aws_web_identity_token=aws_web_identity_token,
|
||||||
|
)
|
||||||
|
|
||||||
|
### SET RUNTIME ENDPOINT ###
|
||||||
|
endpoint_url = ""
|
||||||
|
env_aws_bedrock_runtime_endpoint = get_secret("AWS_BEDROCK_RUNTIME_ENDPOINT")
|
||||||
|
if aws_bedrock_runtime_endpoint is not None and isinstance(
|
||||||
|
aws_bedrock_runtime_endpoint, str
|
||||||
|
):
|
||||||
|
endpoint_url = aws_bedrock_runtime_endpoint
|
||||||
|
elif env_aws_bedrock_runtime_endpoint and isinstance(
|
||||||
|
env_aws_bedrock_runtime_endpoint, str
|
||||||
|
):
|
||||||
|
endpoint_url = env_aws_bedrock_runtime_endpoint
|
||||||
|
else:
|
||||||
|
endpoint_url = f"https://bedrock-runtime.{aws_region_name}.amazonaws.com"
|
||||||
|
|
||||||
|
if (stream is not None and stream == True) and provider != "ai21":
|
||||||
|
endpoint_url = f"{endpoint_url}/model/{modelId}/converse-stream"
|
||||||
|
else:
|
||||||
|
endpoint_url = f"{endpoint_url}/model/{modelId}/converse"
|
||||||
|
|
||||||
|
sigv4 = SigV4Auth(credentials, "bedrock", aws_region_name)
|
||||||
|
|
||||||
|
# Separate system prompt from rest of message
|
||||||
|
system_prompt_indices = []
|
||||||
|
system_content_blocks: List[SystemContentBlock] = []
|
||||||
|
for idx, message in enumerate(messages):
|
||||||
|
if message["role"] == "system":
|
||||||
|
_system_content_block = SystemContentBlock(text=message["content"])
|
||||||
|
system_content_blocks.append(_system_content_block)
|
||||||
|
system_prompt_indices.append(idx)
|
||||||
|
if len(system_prompt_indices) > 0:
|
||||||
|
for idx in reversed(system_prompt_indices):
|
||||||
|
messages.pop(idx)
|
||||||
|
|
||||||
|
inference_params = copy.deepcopy(optional_params)
|
||||||
|
additional_request_keys = []
|
||||||
|
additional_request_params = {}
|
||||||
|
supported_converse_params = AmazonConverseConfig().get_config().keys()
|
||||||
|
|
||||||
|
## TRANSFORMATION ##
|
||||||
|
# send all model-specific params in 'additional_request_params'
|
||||||
|
for k, v in inference_params.items():
|
||||||
|
if k not in supported_converse_params:
|
||||||
|
additional_request_params[k] = v
|
||||||
|
additional_request_keys.append(k)
|
||||||
|
for key in additional_request_keys:
|
||||||
|
inference_params.pop(key, None)
|
||||||
|
|
||||||
|
bedrock_messages: List[MessageBlock] = _bedrock_converse_messages_pt(
|
||||||
|
messages=messages
|
||||||
|
)
|
||||||
|
bedrock_tools: List[ToolBlock] = _bedrock_tools_pt(
|
||||||
|
inference_params.get("tools", [])
|
||||||
|
)
|
||||||
|
bedrock_tool_config: Optional[ToolConfigBlock] = None
|
||||||
|
if len(bedrock_tools) > 0:
|
||||||
|
bedrock_tool_config = ToolConfigBlock(
|
||||||
|
tools=bedrock_tools,
|
||||||
|
toolChoice=inference_params.get("tool_choice", None),
|
||||||
|
)
|
||||||
|
|
||||||
|
data: RequestObject = {
|
||||||
|
"messages": bedrock_messages,
|
||||||
|
"additionalModelRequestFields": additional_request_params,
|
||||||
|
"system": system_content_blocks,
|
||||||
|
}
|
||||||
|
if bedrock_tool_config is not None:
|
||||||
|
data["toolConfig"] = bedrock_tool_config
|
||||||
|
|
||||||
|
## COMPLETION CALL
|
||||||
|
|
||||||
|
headers = {"Content-Type": "application/json"}
|
||||||
|
if extra_headers is not None:
|
||||||
|
headers = {"Content-Type": "application/json", **extra_headers}
|
||||||
|
request = AWSRequest(
|
||||||
|
method="POST", url=endpoint_url, data=data, headers=headers
|
||||||
|
)
|
||||||
|
sigv4.add_auth(request)
|
||||||
|
prepped = request.prepare()
|
||||||
|
|
||||||
|
## LOGGING
|
||||||
|
logging_obj.pre_call(
|
||||||
|
input=messages,
|
||||||
|
api_key="",
|
||||||
|
additional_args={
|
||||||
|
"complete_input_dict": data,
|
||||||
|
"api_base": prepped.url,
|
||||||
|
"headers": prepped.headers,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
### ROUTING (ASYNC, STREAMING, SYNC)
|
||||||
|
try:
|
||||||
|
response = self.client.post(url=prepped.url, headers=prepped.headers, data=data) # type: ignore
|
||||||
|
response.raise_for_status()
|
||||||
|
except httpx.HTTPStatusError as err:
|
||||||
|
error_code = err.response.status_code
|
||||||
|
raise BedrockError(status_code=error_code, message=response.text)
|
||||||
|
except httpx.TimeoutException as e:
|
||||||
|
raise BedrockError(status_code=408, message="Timeout error occurred.")
|
||||||
|
|
||||||
|
|
||||||
def get_response_stream_shape():
|
def get_response_stream_shape():
|
||||||
from botocore.model import ServiceModel
|
from botocore.model import ServiceModel
|
||||||
from botocore.loaders import Loader
|
from botocore.loaders import Loader
|
||||||
|
|
|
@ -3,14 +3,7 @@ import requests, traceback
|
||||||
import json, re, xml.etree.ElementTree as ET
|
import json, re, xml.etree.ElementTree as ET
|
||||||
from jinja2 import Template, exceptions, meta, BaseLoader
|
from jinja2 import Template, exceptions, meta, BaseLoader
|
||||||
from jinja2.sandbox import ImmutableSandboxedEnvironment
|
from jinja2.sandbox import ImmutableSandboxedEnvironment
|
||||||
from typing import (
|
from typing import Any, List, Mapping, MutableMapping, Optional, Sequence, Tuple
|
||||||
Any,
|
|
||||||
List,
|
|
||||||
Mapping,
|
|
||||||
MutableMapping,
|
|
||||||
Optional,
|
|
||||||
Sequence,
|
|
||||||
)
|
|
||||||
import litellm
|
import litellm
|
||||||
import litellm.types
|
import litellm.types
|
||||||
from litellm.types.completion import (
|
from litellm.types.completion import (
|
||||||
|
@ -24,7 +17,7 @@ from litellm.types.completion import (
|
||||||
import litellm.types.llms
|
import litellm.types.llms
|
||||||
from litellm.types.llms.anthropic import *
|
from litellm.types.llms.anthropic import *
|
||||||
import uuid
|
import uuid
|
||||||
|
from litellm.types.llms.bedrock import MessageBlock as BedrockMessageBlock
|
||||||
import litellm.types.llms.vertex_ai
|
import litellm.types.llms.vertex_ai
|
||||||
|
|
||||||
|
|
||||||
|
@ -1460,9 +1453,7 @@ def _load_image_from_url(image_url):
|
||||||
try:
|
try:
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
except:
|
except:
|
||||||
raise Exception(
|
raise Exception("image conversion failed please run `pip install Pillow`")
|
||||||
"gemini image conversion failed please run `pip install Pillow`"
|
|
||||||
)
|
|
||||||
from io import BytesIO
|
from io import BytesIO
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
@ -1613,6 +1604,379 @@ def azure_text_pt(messages: list):
|
||||||
return prompt
|
return prompt
|
||||||
|
|
||||||
|
|
||||||
|
###### AMAZON BEDROCK #######
|
||||||
|
|
||||||
|
from litellm.types.llms.bedrock import (
|
||||||
|
ToolResultContentBlock as BedrockToolResultContentBlock,
|
||||||
|
ToolResultBlock as BedrockToolResultBlock,
|
||||||
|
ToolConfigBlock as BedrockToolConfigBlock,
|
||||||
|
ToolUseBlock as BedrockToolUseBlock,
|
||||||
|
ImageSourceBlock as BedrockImageSourceBlock,
|
||||||
|
ImageBlock as BedrockImageBlock,
|
||||||
|
ContentBlock as BedrockContentBlock,
|
||||||
|
ToolInputSchemaBlock as BedrockToolInputSchemaBlock,
|
||||||
|
ToolSpecBlock as BedrockToolSpecBlock,
|
||||||
|
ToolBlock as BedrockToolBlock,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def get_image_details(image_url) -> Tuple[bytes, str]:
|
||||||
|
try:
|
||||||
|
import base64
|
||||||
|
|
||||||
|
# Send a GET request to the image URL
|
||||||
|
response = requests.get(image_url)
|
||||||
|
response.raise_for_status() # Raise an exception for HTTP errors
|
||||||
|
|
||||||
|
# Check the response's content type to ensure it is an image
|
||||||
|
content_type = response.headers.get("content-type")
|
||||||
|
if not content_type or "image" not in content_type:
|
||||||
|
raise ValueError(
|
||||||
|
f"URL does not point to a valid image (content-type: {content_type})"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Convert the image content to base64 bytes
|
||||||
|
base64_bytes = base64.b64encode(response.content)
|
||||||
|
|
||||||
|
# Get mime-type
|
||||||
|
mime_type = content_type.split("/")[
|
||||||
|
1
|
||||||
|
] # Extract mime-type from content-type header
|
||||||
|
|
||||||
|
return base64_bytes, mime_type
|
||||||
|
|
||||||
|
except requests.RequestException as e:
|
||||||
|
raise Exception(f"Request failed: {e}")
|
||||||
|
except Exception as e:
|
||||||
|
raise e
|
||||||
|
|
||||||
|
|
||||||
|
def _process_bedrock_converse_image_block(image_url: str) -> BedrockImageBlock:
|
||||||
|
if "base64" in image_url:
|
||||||
|
# Case 1: Images with base64 encoding
|
||||||
|
import base64, re
|
||||||
|
|
||||||
|
# base 64 is passed as data:image/jpeg;base64,<base-64-encoded-image>
|
||||||
|
image_metadata, img_without_base_64 = image_url.split(",")
|
||||||
|
image_format = image_metadata.split("/")[1]
|
||||||
|
|
||||||
|
# read mime_type from img_without_base_64=data:image/jpeg;base64
|
||||||
|
# Extract MIME type using regular expression
|
||||||
|
mime_type_match = re.match(r"data:(.*?);base64", image_metadata)
|
||||||
|
|
||||||
|
if mime_type_match:
|
||||||
|
mime_type = mime_type_match.group(1)
|
||||||
|
else:
|
||||||
|
mime_type = "jpeg"
|
||||||
|
decoded_img = base64.b64decode(img_without_base_64)
|
||||||
|
_blob = BedrockImageSourceBlock(bytes=decoded_img)
|
||||||
|
supported_image_formats = (
|
||||||
|
litellm.AmazonConverseConfig().get_supported_image_types()
|
||||||
|
)
|
||||||
|
if image_format in supported_image_formats:
|
||||||
|
return BedrockImageBlock(source=_blob, format=image_format) # type: ignore
|
||||||
|
else:
|
||||||
|
# Handle the case when the image format is not supported
|
||||||
|
raise ValueError(
|
||||||
|
"Unsupported image format: {}. Supported formats: {}".format(
|
||||||
|
image_format, supported_image_formats
|
||||||
|
)
|
||||||
|
)
|
||||||
|
elif "https:/" in image_url:
|
||||||
|
# Case 2: Images with direct links
|
||||||
|
image_bytes, image_format = get_image_details(image_url)
|
||||||
|
_blob = BedrockImageSourceBlock(bytes=image_bytes)
|
||||||
|
supported_image_formats = (
|
||||||
|
litellm.AmazonConverseConfig().get_supported_image_types()
|
||||||
|
)
|
||||||
|
if image_format in supported_image_formats:
|
||||||
|
return BedrockImageBlock(source=_blob, format=image_format) # type: ignore
|
||||||
|
else:
|
||||||
|
# Handle the case when the image format is not supported
|
||||||
|
raise ValueError(
|
||||||
|
"Unsupported image format: {}. Supported formats: {}".format(
|
||||||
|
image_format, supported_image_formats
|
||||||
|
)
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
raise ValueError(
|
||||||
|
"Unsupported image type. Expected either image url or base64 encoded string"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _convert_to_bedrock_tool_call_invoke(
|
||||||
|
tool_calls: list,
|
||||||
|
) -> List[BedrockContentBlock]:
|
||||||
|
"""
|
||||||
|
OpenAI tool invokes:
|
||||||
|
{
|
||||||
|
"role": "assistant",
|
||||||
|
"content": null,
|
||||||
|
"tool_calls": [
|
||||||
|
{
|
||||||
|
"id": "call_abc123",
|
||||||
|
"type": "function",
|
||||||
|
"function": {
|
||||||
|
"name": "get_current_weather",
|
||||||
|
"arguments": "{\n\"location\": \"Boston, MA\"\n}"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"""
|
||||||
|
"""
|
||||||
|
Bedrock tool invokes:
|
||||||
|
[
|
||||||
|
{
|
||||||
|
"role": "assistant",
|
||||||
|
"toolUse": {
|
||||||
|
"input": {"location": "Boston, MA", ..},
|
||||||
|
"name": "get_current_weather",
|
||||||
|
"toolUseId": "call_abc123"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
]
|
||||||
|
"""
|
||||||
|
"""
|
||||||
|
- json.loads argument
|
||||||
|
- extract name
|
||||||
|
- extract id
|
||||||
|
"""
|
||||||
|
|
||||||
|
try:
|
||||||
|
_parts_list: List[BedrockContentBlock] = []
|
||||||
|
for tool in tool_calls:
|
||||||
|
if "function" in tool:
|
||||||
|
id = tool["id"]
|
||||||
|
name = tool["function"].get("name", "")
|
||||||
|
arguments = tool["function"].get("arguments", "")
|
||||||
|
arguments_dict = json.loads(arguments)
|
||||||
|
bedrock_tool = BedrockToolUseBlock(
|
||||||
|
input=arguments_dict, name=name, toolUseId=id
|
||||||
|
)
|
||||||
|
bedrock_content_block = BedrockContentBlock(toolUse=bedrock_tool)
|
||||||
|
_parts_list.append(bedrock_content_block)
|
||||||
|
return _parts_list
|
||||||
|
except Exception as e:
|
||||||
|
raise Exception(
|
||||||
|
"Unable to convert openai tool calls={} to bedrock tool calls. Received error={}".format(
|
||||||
|
tool_calls, str(e)
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _convert_to_bedrock_tool_call_result(
|
||||||
|
message: dict,
|
||||||
|
) -> BedrockMessageBlock:
|
||||||
|
"""
|
||||||
|
OpenAI message with a tool result looks like:
|
||||||
|
{
|
||||||
|
"tool_call_id": "tool_1",
|
||||||
|
"role": "tool",
|
||||||
|
"name": "get_current_weather",
|
||||||
|
"content": "function result goes here",
|
||||||
|
},
|
||||||
|
|
||||||
|
OpenAI message with a function call result looks like:
|
||||||
|
{
|
||||||
|
"role": "function",
|
||||||
|
"name": "get_current_weather",
|
||||||
|
"content": "function result goes here",
|
||||||
|
}
|
||||||
|
"""
|
||||||
|
"""
|
||||||
|
Bedrock result looks like this:
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": [
|
||||||
|
{
|
||||||
|
"toolResult": {
|
||||||
|
"toolUseId": "tooluse_kZJMlvQmRJ6eAyJE5GIl7Q",
|
||||||
|
"content": [
|
||||||
|
{
|
||||||
|
"json": {
|
||||||
|
"song": "Elemental Hotel",
|
||||||
|
"artist": "8 Storey Hike"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
"""
|
||||||
|
"""
|
||||||
|
-
|
||||||
|
"""
|
||||||
|
content = message.get("content", "")
|
||||||
|
name = message.get("name", "")
|
||||||
|
id = message.get("tool_call_id", str(uuid.uuid4()))
|
||||||
|
|
||||||
|
tool_result_content_block = BedrockToolResultContentBlock(text=content)
|
||||||
|
tool_result = BedrockToolResultBlock(
|
||||||
|
content=tool_result_content_block,
|
||||||
|
toolUseId=id,
|
||||||
|
)
|
||||||
|
content_block = BedrockContentBlock(toolResult=tool_result)
|
||||||
|
|
||||||
|
return BedrockMessageBlock(role="user", content=[content_block])
|
||||||
|
|
||||||
|
|
||||||
|
def _bedrock_converse_messages_pt(messages: List) -> List[BedrockMessageBlock]:
|
||||||
|
"""
|
||||||
|
Converts given messages from OpenAI format to Bedrock format
|
||||||
|
|
||||||
|
- Roles must alternate b/w 'user' and 'model' (same as anthropic -> merge consecutive roles)
|
||||||
|
- Please ensure that function response turn comes immediately after a function call turn
|
||||||
|
"""
|
||||||
|
|
||||||
|
contents: List[BedrockMessageBlock] = []
|
||||||
|
msg_i = 0
|
||||||
|
while msg_i < len(messages):
|
||||||
|
user_content: List[BedrockContentBlock] = []
|
||||||
|
init_msg_i = msg_i
|
||||||
|
## MERGE CONSECUTIVE USER CONTENT ##
|
||||||
|
while msg_i < len(messages) and messages[msg_i]["role"] == "user":
|
||||||
|
if isinstance(messages[msg_i]["content"], list):
|
||||||
|
_parts: List[BedrockContentBlock] = []
|
||||||
|
for element in messages[msg_i]["content"]:
|
||||||
|
if isinstance(element, dict):
|
||||||
|
if element["type"] == "text":
|
||||||
|
_part = BedrockContentBlock(text=element["text"])
|
||||||
|
_parts.append(_part)
|
||||||
|
elif element["type"] == "image_url":
|
||||||
|
image_url = element["image_url"]["url"]
|
||||||
|
_part = _process_bedrock_converse_image_block( # type: ignore
|
||||||
|
image_url=image_url
|
||||||
|
)
|
||||||
|
_parts.append(BedrockContentBlock(image=_part)) # type: ignore
|
||||||
|
user_content.extend(_parts)
|
||||||
|
else:
|
||||||
|
_part = BedrockContentBlock(text=messages[msg_i]["content"])
|
||||||
|
user_content.append(_part)
|
||||||
|
|
||||||
|
msg_i += 1
|
||||||
|
|
||||||
|
if user_content:
|
||||||
|
contents.append(BedrockMessageBlock(role="user", content=user_content))
|
||||||
|
assistant_content: List[BedrockContentBlock] = []
|
||||||
|
## MERGE CONSECUTIVE ASSISTANT CONTENT ##
|
||||||
|
while msg_i < len(messages) and messages[msg_i]["role"] == "assistant":
|
||||||
|
if isinstance(messages[msg_i]["content"], list):
|
||||||
|
assistants_parts: List[BedrockContentBlock] = []
|
||||||
|
for element in messages[msg_i]["content"]:
|
||||||
|
if isinstance(element, dict):
|
||||||
|
if element["type"] == "text":
|
||||||
|
assistants_part = BedrockContentBlock(text=element["text"])
|
||||||
|
assistants_parts.append(assistants_part)
|
||||||
|
elif element["type"] == "image_url":
|
||||||
|
image_url = element["image_url"]["url"]
|
||||||
|
assistants_part = _process_bedrock_converse_image_block( # type: ignore
|
||||||
|
image_url=image_url
|
||||||
|
)
|
||||||
|
assistants_parts.append(
|
||||||
|
BedrockContentBlock(image=assistants_part) # type: ignore
|
||||||
|
)
|
||||||
|
assistant_content.extend(assistants_parts)
|
||||||
|
elif messages[msg_i].get(
|
||||||
|
"tool_calls", []
|
||||||
|
): # support assistant tool invoke convertion
|
||||||
|
assistant_content.extend(
|
||||||
|
_convert_to_bedrock_tool_call_invoke(messages[msg_i]["tool_calls"])
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
assistant_text = (
|
||||||
|
messages[msg_i].get("content") or ""
|
||||||
|
) # either string or none
|
||||||
|
if assistant_text:
|
||||||
|
assistant_content.append(BedrockContentBlock(text=assistant_text))
|
||||||
|
|
||||||
|
msg_i += 1
|
||||||
|
|
||||||
|
if assistant_content:
|
||||||
|
contents.append(
|
||||||
|
BedrockMessageBlock(role="assistant", content=assistant_content)
|
||||||
|
)
|
||||||
|
|
||||||
|
## APPEND TOOL CALL MESSAGES ##
|
||||||
|
if msg_i < len(messages) and messages[msg_i]["role"] == "tool":
|
||||||
|
tool_call_result = _convert_to_bedrock_tool_call_result(messages[msg_i])
|
||||||
|
contents.append(tool_call_result)
|
||||||
|
msg_i += 1
|
||||||
|
if msg_i == init_msg_i: # prevent infinite loops
|
||||||
|
raise Exception(
|
||||||
|
"Invalid Message passed in - {}. File an issue https://github.com/BerriAI/litellm/issues".format(
|
||||||
|
messages[msg_i]
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
return contents
|
||||||
|
|
||||||
|
|
||||||
|
def _bedrock_tools_pt(tools: List) -> List[BedrockToolBlock]:
|
||||||
|
"""
|
||||||
|
OpenAI tools looks like:
|
||||||
|
tools = [
|
||||||
|
{
|
||||||
|
"type": "function",
|
||||||
|
"function": {
|
||||||
|
"name": "get_current_weather",
|
||||||
|
"description": "Get the current weather in a given location",
|
||||||
|
"parameters": {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"location": {
|
||||||
|
"type": "string",
|
||||||
|
"description": "The city and state, e.g. San Francisco, CA",
|
||||||
|
},
|
||||||
|
"unit": {"type": "string", "enum": ["celsius", "fahrenheit"]},
|
||||||
|
},
|
||||||
|
"required": ["location"],
|
||||||
|
},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
]
|
||||||
|
"""
|
||||||
|
"""
|
||||||
|
Bedrock toolConfig looks like:
|
||||||
|
"tools": [
|
||||||
|
{
|
||||||
|
"toolSpec": {
|
||||||
|
"name": "top_song",
|
||||||
|
"description": "Get the most popular song played on a radio station.",
|
||||||
|
"inputSchema": {
|
||||||
|
"json": {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"sign": {
|
||||||
|
"type": "string",
|
||||||
|
"description": "The call sign for the radio station for which you want the most popular song. Example calls signs are WZPZ, and WKRP."
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"required": [
|
||||||
|
"sign"
|
||||||
|
]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
]
|
||||||
|
"""
|
||||||
|
tool_block_list: List[BedrockToolBlock] = []
|
||||||
|
for tool in tools:
|
||||||
|
parameters = tool.get("function", {}).get("parameters", None)
|
||||||
|
name = tool.get("function", {}).get("name", "")
|
||||||
|
description = tool.get("function", {}).get("description", "")
|
||||||
|
tool_input_schema = BedrockToolInputSchemaBlock(json=parameters)
|
||||||
|
tool_spec = BedrockToolSpecBlock(
|
||||||
|
inputSchema=tool_input_schema, name=name, description=description
|
||||||
|
)
|
||||||
|
tool_block = BedrockToolBlock(toolSpec=tool_spec)
|
||||||
|
tool_block_list.append(tool_block)
|
||||||
|
|
||||||
|
return tool_block_list
|
||||||
|
|
||||||
|
|
||||||
# Function call template
|
# Function call template
|
||||||
def function_call_prompt(messages: list, functions: list):
|
def function_call_prompt(messages: list, functions: list):
|
||||||
function_prompt = """Produce JSON OUTPUT ONLY! Adhere to this format {"name": "function_name", "arguments":{"argument_name": "argument_value"}} The following functions are available to you:"""
|
function_prompt = """Produce JSON OUTPUT ONLY! Adhere to this format {"name": "function_name", "arguments":{"argument_name": "argument_value"}} The following functions are available to you:"""
|
||||||
|
|
|
@ -300,7 +300,14 @@ def test_completion_claude_3():
|
||||||
pytest.fail(f"Error occurred: {e}")
|
pytest.fail(f"Error occurred: {e}")
|
||||||
|
|
||||||
|
|
||||||
def test_completion_claude_3_function_call():
|
@pytest.mark.parametrize(
|
||||||
|
"model",
|
||||||
|
[
|
||||||
|
# "anthropic/claude-3-opus-20240229",
|
||||||
|
"cohere.command-r-plus-v1:0"
|
||||||
|
],
|
||||||
|
)
|
||||||
|
def test_completion_claude_3_function_call(model):
|
||||||
litellm.set_verbose = True
|
litellm.set_verbose = True
|
||||||
tools = [
|
tools = [
|
||||||
{
|
{
|
||||||
|
@ -331,7 +338,7 @@ def test_completion_claude_3_function_call():
|
||||||
try:
|
try:
|
||||||
# test without max tokens
|
# test without max tokens
|
||||||
response = completion(
|
response = completion(
|
||||||
model="anthropic/claude-3-opus-20240229",
|
model=model,
|
||||||
messages=messages,
|
messages=messages,
|
||||||
tools=tools,
|
tools=tools,
|
||||||
tool_choice={
|
tool_choice={
|
||||||
|
@ -364,7 +371,7 @@ def test_completion_claude_3_function_call():
|
||||||
)
|
)
|
||||||
# In the second response, Claude should deduce answer from tool results
|
# In the second response, Claude should deduce answer from tool results
|
||||||
second_response = completion(
|
second_response = completion(
|
||||||
model="anthropic/claude-3-opus-20240229",
|
model=model,
|
||||||
messages=messages,
|
messages=messages,
|
||||||
tools=tools,
|
tools=tools,
|
||||||
tool_choice="auto",
|
tool_choice="auto",
|
||||||
|
|
|
@ -1,4 +1,4 @@
|
||||||
from typing import TypedDict, Any, Union, Optional
|
from typing import TypedDict, Any, Union, Optional, Literal, List
|
||||||
import json
|
import json
|
||||||
from typing_extensions import (
|
from typing_extensions import (
|
||||||
Self,
|
Self,
|
||||||
|
@ -11,6 +11,81 @@ from typing_extensions import (
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class SystemContentBlock(TypedDict):
|
||||||
|
text: str
|
||||||
|
|
||||||
|
|
||||||
|
class ImageSourceBlock(TypedDict):
|
||||||
|
bytes: Optional[bytes]
|
||||||
|
|
||||||
|
|
||||||
|
class ImageBlock(TypedDict):
|
||||||
|
format: Literal["png", "jpeg", "gif", "webp"]
|
||||||
|
source: ImageSourceBlock
|
||||||
|
|
||||||
|
|
||||||
|
class ToolResultContentBlock(TypedDict, total=False):
|
||||||
|
image: ImageBlock
|
||||||
|
json: dict
|
||||||
|
text: str
|
||||||
|
|
||||||
|
|
||||||
|
class ToolResultBlock(TypedDict, total=False):
|
||||||
|
content: Required[ToolResultContentBlock]
|
||||||
|
toolUseId: Required[str]
|
||||||
|
status: Literal["success", "error"]
|
||||||
|
|
||||||
|
|
||||||
|
class ToolUseBlock(TypedDict):
|
||||||
|
input: dict
|
||||||
|
name: str
|
||||||
|
toolUseId: str
|
||||||
|
|
||||||
|
|
||||||
|
class ContentBlock(TypedDict, total=False):
|
||||||
|
text: str
|
||||||
|
image: ImageBlock
|
||||||
|
toolResult: ToolResultBlock
|
||||||
|
toolUse: ToolUseBlock
|
||||||
|
|
||||||
|
|
||||||
|
class MessageBlock(TypedDict):
|
||||||
|
content: List[ContentBlock]
|
||||||
|
role: Literal["user", "assistant"]
|
||||||
|
|
||||||
|
|
||||||
|
class ToolInputSchemaBlock(TypedDict):
|
||||||
|
json: Optional[dict]
|
||||||
|
|
||||||
|
|
||||||
|
class ToolSpecBlock(TypedDict, total=False):
|
||||||
|
inputSchema: Required[ToolInputSchemaBlock]
|
||||||
|
name: Required[str]
|
||||||
|
description: str
|
||||||
|
|
||||||
|
|
||||||
|
class ToolBlock(TypedDict):
|
||||||
|
toolSpec: Optional[ToolSpecBlock]
|
||||||
|
|
||||||
|
|
||||||
|
class SpecificToolChoiceBlock(TypedDict):
|
||||||
|
name: str
|
||||||
|
|
||||||
|
|
||||||
|
class ToolConfigBlock(TypedDict, total=False):
|
||||||
|
tools: Required[List[ToolBlock]]
|
||||||
|
toolChoice: Union[str, SpecificToolChoiceBlock]
|
||||||
|
|
||||||
|
|
||||||
|
class RequestObject(TypedDict, total=False):
|
||||||
|
additionalModelRequestFields: dict
|
||||||
|
additionalModelResponseFieldPaths: List[str]
|
||||||
|
inferenceConfig: dict
|
||||||
|
messages: Required[List[MessageBlock]]
|
||||||
|
system: List[SystemContentBlock]
|
||||||
|
toolConfig: ToolConfigBlock
|
||||||
|
|
||||||
|
|
||||||
class GenericStreamingChunk(TypedDict):
|
class GenericStreamingChunk(TypedDict):
|
||||||
text: Required[str]
|
text: Required[str]
|
||||||
is_finished: Required[bool]
|
is_finished: Required[bool]
|
||||||
|
|
|
@ -6401,20 +6401,7 @@ def get_supported_openai_params(
|
||||||
- None if unmapped
|
- None if unmapped
|
||||||
"""
|
"""
|
||||||
if custom_llm_provider == "bedrock":
|
if custom_llm_provider == "bedrock":
|
||||||
if model.startswith("anthropic.claude-3"):
|
return litellm.AmazonConverseConfig().get_supported_openai_params()
|
||||||
return litellm.AmazonAnthropicClaude3Config().get_supported_openai_params()
|
|
||||||
elif model.startswith("anthropic"):
|
|
||||||
return litellm.AmazonAnthropicConfig().get_supported_openai_params()
|
|
||||||
elif model.startswith("ai21"):
|
|
||||||
return ["max_tokens", "temperature", "top_p", "stream"]
|
|
||||||
elif model.startswith("amazon"):
|
|
||||||
return ["max_tokens", "temperature", "stop", "top_p", "stream"]
|
|
||||||
elif model.startswith("meta"):
|
|
||||||
return ["max_tokens", "temperature", "top_p", "stream"]
|
|
||||||
elif model.startswith("cohere"):
|
|
||||||
return ["stream", "temperature", "max_tokens"]
|
|
||||||
elif model.startswith("mistral"):
|
|
||||||
return ["max_tokens", "temperature", "stop", "top_p", "stream"]
|
|
||||||
elif custom_llm_provider == "ollama":
|
elif custom_llm_provider == "ollama":
|
||||||
return litellm.OllamaConfig().get_supported_openai_params()
|
return litellm.OllamaConfig().get_supported_openai_params()
|
||||||
elif custom_llm_provider == "ollama_chat":
|
elif custom_llm_provider == "ollama_chat":
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue