From a76a9b7d11ab095ca4f1b52005e28f9723f155e0 Mon Sep 17 00:00:00 2001 From: Krrish Dholakia Date: Wed, 5 Jun 2024 21:20:36 -0700 Subject: [PATCH] feat(bedrock_httpx.py): add support for bedrock converse api closes https://github.com/BerriAI/litellm/issues/4000 --- litellm/__init__.py | 2 +- litellm/llms/bedrock_httpx.py | 382 ++++++++++++++++++++++ litellm/llms/prompt_templates/factory.py | 388 ++++++++++++++++++++++- litellm/tests/test_completion.py | 13 +- litellm/types/llms/bedrock.py | 77 ++++- litellm/utils.py | 15 +- 6 files changed, 846 insertions(+), 31 deletions(-) diff --git a/litellm/__init__.py b/litellm/__init__.py index f67a252eb..2fc47a992 100644 --- a/litellm/__init__.py +++ b/litellm/__init__.py @@ -765,7 +765,7 @@ from .llms.sagemaker import SagemakerConfig from .llms.ollama import OllamaConfig from .llms.ollama_chat import OllamaChatConfig from .llms.maritalk import MaritTalkConfig -from .llms.bedrock_httpx import AmazonCohereChatConfig +from .llms.bedrock_httpx import AmazonCohereChatConfig, AmazonConverseConfig from .llms.bedrock import ( AmazonTitanConfig, AmazonAI21Config, diff --git a/litellm/llms/bedrock_httpx.py b/litellm/llms/bedrock_httpx.py index dbd7e7c69..e21265006 100644 --- a/litellm/llms/bedrock_httpx.py +++ b/litellm/llms/bedrock_httpx.py @@ -38,6 +38,8 @@ from .prompt_templates.factory import ( extract_between_tags, parse_xml_params, contains_tag, + _bedrock_converse_messages_pt, + _bedrock_tools_pt, ) from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, HTTPHandler from .base import BaseLLM @@ -118,6 +120,8 @@ class AmazonCohereChatConfig: "presence_penalty", "seed", "stop", + "tools", + "tool_choice", ] def map_openai_params( @@ -1069,6 +1073,384 @@ class BedrockLLM(BaseLLM): return super().embedding(*args, **kwargs) +class AmazonConverseConfig: + """ + Reference - https://docs.aws.amazon.com/bedrock/latest/APIReference/API_runtime_Converse.html + """ + + maxTokens: Optional[int] + stopSequences: Optional[List[str]] + temperature: Optional[int] + topP: Optional[int] + + def __init__( + self, + maxTokens: Optional[int] = None, + stopSequences: Optional[List[str]] = None, + temperature: Optional[int] = None, + top_p: Optional[int] = None, + ) -> None: + locals_ = locals() + for key, value in locals_.items(): + if key != "self" and value is not None: + setattr(self.__class__, key, value) + + @classmethod + def get_config(cls): + return { + k: v + for k, v in cls.__dict__.items() + if not k.startswith("__") + and not isinstance( + v, + ( + types.FunctionType, + types.BuiltinFunctionType, + classmethod, + staticmethod, + ), + ) + and v is not None + } + + def get_supported_openai_params(self) -> List[str]: + return [ + "max_tokens", + "stream", + "stream_options", + "stop", + "temperature", + "top_p", + "tools", + "tool_choice", + ] + + def get_supported_image_types(self) -> List[str]: + return ["png", "jpeg", "gif", "webp"] + + def map_openai_params( + self, non_default_params: dict, optional_params: dict + ) -> dict: + for param, value in non_default_params.items(): + if param == "max_tokens": + optional_params["maxTokens"] = value + if param == "stream": + optional_params["stream"] = value + if param == "stop": + if isinstance(value, str): + value = [value] + optional_params["stop_sequences"] = value + if param == "temperature": + optional_params["temperature"] = value + if param == "top_p": + optional_params["topP"] = value + return optional_params + + +class BedrockConverseLLM(BaseLLM): + def __init__(self) -> None: + super().__init__() + + def encode_model_id(self, model_id: str) -> str: + """ + Double encode the model ID to ensure it matches the expected double-encoded format. + Args: + model_id (str): The model ID to encode. + Returns: + str: The double-encoded model ID. + """ + return urllib.parse.quote(model_id, safe="") + + def get_credentials( + self, + aws_access_key_id: Optional[str] = None, + aws_secret_access_key: Optional[str] = None, + aws_region_name: Optional[str] = None, + aws_session_name: Optional[str] = None, + aws_profile_name: Optional[str] = None, + aws_role_name: Optional[str] = None, + aws_web_identity_token: Optional[str] = None, + ): + """ + Return a boto3.Credentials object + """ + import boto3 + + ## CHECK IS 'os.environ/' passed in + params_to_check: List[Optional[str]] = [ + aws_access_key_id, + aws_secret_access_key, + aws_region_name, + aws_session_name, + aws_profile_name, + aws_role_name, + aws_web_identity_token, + ] + + # Iterate over parameters and update if needed + for i, param in enumerate(params_to_check): + if param and param.startswith("os.environ/"): + _v = get_secret(param) + if _v is not None and isinstance(_v, str): + params_to_check[i] = _v + # Assign updated values back to parameters + ( + aws_access_key_id, + aws_secret_access_key, + aws_region_name, + aws_session_name, + aws_profile_name, + aws_role_name, + aws_web_identity_token, + ) = params_to_check + + ### CHECK STS ### + if ( + aws_web_identity_token is not None + and aws_role_name is not None + and aws_session_name is not None + ): + oidc_token = get_secret(aws_web_identity_token) + + if oidc_token is None: + raise BedrockError( + message="OIDC token could not be retrieved from secret manager.", + status_code=401, + ) + + sts_client = boto3.client("sts") + + # https://docs.aws.amazon.com/STS/latest/APIReference/API_AssumeRoleWithWebIdentity.html + # https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/sts/client/assume_role_with_web_identity.html + sts_response = sts_client.assume_role_with_web_identity( + RoleArn=aws_role_name, + RoleSessionName=aws_session_name, + WebIdentityToken=oidc_token, + DurationSeconds=3600, + ) + + session = boto3.Session( + aws_access_key_id=sts_response["Credentials"]["AccessKeyId"], + aws_secret_access_key=sts_response["Credentials"]["SecretAccessKey"], + aws_session_token=sts_response["Credentials"]["SessionToken"], + region_name=aws_region_name, + ) + + return session.get_credentials() + elif aws_role_name is not None and aws_session_name is not None: + sts_client = boto3.client( + "sts", + aws_access_key_id=aws_access_key_id, # [OPTIONAL] + aws_secret_access_key=aws_secret_access_key, # [OPTIONAL] + ) + + sts_response = sts_client.assume_role( + RoleArn=aws_role_name, RoleSessionName=aws_session_name + ) + + # Extract the credentials from the response and convert to Session Credentials + sts_credentials = sts_response["Credentials"] + from botocore.credentials import Credentials + + credentials = Credentials( + access_key=sts_credentials["AccessKeyId"], + secret_key=sts_credentials["SecretAccessKey"], + token=sts_credentials["SessionToken"], + ) + return credentials + elif aws_profile_name is not None: ### CHECK SESSION ### + # uses auth values from AWS profile usually stored in ~/.aws/credentials + client = boto3.Session(profile_name=aws_profile_name) + + return client.get_credentials() + else: + session = boto3.Session( + aws_access_key_id=aws_access_key_id, + aws_secret_access_key=aws_secret_access_key, + region_name=aws_region_name, + ) + + return session.get_credentials() + + def completion( + self, + model: str, + messages: list, + custom_prompt_dict: dict, + model_response: ModelResponse, + print_verbose: Callable, + encoding, + logging_obj, + optional_params: dict, + acompletion: bool, + timeout: Optional[Union[float, httpx.Timeout]], + litellm_params=None, + logger_fn=None, + extra_headers: Optional[dict] = None, + client: Optional[Union[AsyncHTTPHandler, HTTPHandler]] = None, + ): + try: + import boto3 + + from botocore.auth import SigV4Auth + from botocore.awsrequest import AWSRequest + from botocore.credentials import Credentials + except ImportError as e: + raise ImportError("Missing boto3 to call bedrock. Run 'pip install boto3'.") + + ## SETUP ## + stream = optional_params.pop("stream", None) + modelId = optional_params.pop("model_id", None) + if modelId is not None: + modelId = self.encode_model_id(model_id=modelId) + else: + modelId = model + + provider = model.split(".")[0] + + ## CREDENTIALS ## + # pop aws_secret_access_key, aws_access_key_id, aws_region_name from kwargs, since completion calls fail with them + aws_secret_access_key = optional_params.pop("aws_secret_access_key", None) + aws_access_key_id = optional_params.pop("aws_access_key_id", None) + aws_region_name = optional_params.pop("aws_region_name", None) + aws_role_name = optional_params.pop("aws_role_name", None) + aws_session_name = optional_params.pop("aws_session_name", None) + aws_profile_name = optional_params.pop("aws_profile_name", None) + aws_bedrock_runtime_endpoint = optional_params.pop( + "aws_bedrock_runtime_endpoint", None + ) # https://bedrock-runtime.{region_name}.amazonaws.com + aws_web_identity_token = optional_params.pop("aws_web_identity_token", None) + + ### SET REGION NAME ### + if aws_region_name is None: + # check env # + litellm_aws_region_name = get_secret("AWS_REGION_NAME", None) + + if litellm_aws_region_name is not None and isinstance( + litellm_aws_region_name, str + ): + aws_region_name = litellm_aws_region_name + + standard_aws_region_name = get_secret("AWS_REGION", None) + if standard_aws_region_name is not None and isinstance( + standard_aws_region_name, str + ): + aws_region_name = standard_aws_region_name + + if aws_region_name is None: + aws_region_name = "us-west-2" + + credentials: Credentials = self.get_credentials( + aws_access_key_id=aws_access_key_id, + aws_secret_access_key=aws_secret_access_key, + aws_region_name=aws_region_name, + aws_session_name=aws_session_name, + aws_profile_name=aws_profile_name, + aws_role_name=aws_role_name, + aws_web_identity_token=aws_web_identity_token, + ) + + ### SET RUNTIME ENDPOINT ### + endpoint_url = "" + env_aws_bedrock_runtime_endpoint = get_secret("AWS_BEDROCK_RUNTIME_ENDPOINT") + if aws_bedrock_runtime_endpoint is not None and isinstance( + aws_bedrock_runtime_endpoint, str + ): + endpoint_url = aws_bedrock_runtime_endpoint + elif env_aws_bedrock_runtime_endpoint and isinstance( + env_aws_bedrock_runtime_endpoint, str + ): + endpoint_url = env_aws_bedrock_runtime_endpoint + else: + endpoint_url = f"https://bedrock-runtime.{aws_region_name}.amazonaws.com" + + if (stream is not None and stream == True) and provider != "ai21": + endpoint_url = f"{endpoint_url}/model/{modelId}/converse-stream" + else: + endpoint_url = f"{endpoint_url}/model/{modelId}/converse" + + sigv4 = SigV4Auth(credentials, "bedrock", aws_region_name) + + # Separate system prompt from rest of message + system_prompt_indices = [] + system_content_blocks: List[SystemContentBlock] = [] + for idx, message in enumerate(messages): + if message["role"] == "system": + _system_content_block = SystemContentBlock(text=message["content"]) + system_content_blocks.append(_system_content_block) + system_prompt_indices.append(idx) + if len(system_prompt_indices) > 0: + for idx in reversed(system_prompt_indices): + messages.pop(idx) + + inference_params = copy.deepcopy(optional_params) + additional_request_keys = [] + additional_request_params = {} + supported_converse_params = AmazonConverseConfig().get_config().keys() + + ## TRANSFORMATION ## + # send all model-specific params in 'additional_request_params' + for k, v in inference_params.items(): + if k not in supported_converse_params: + additional_request_params[k] = v + additional_request_keys.append(k) + for key in additional_request_keys: + inference_params.pop(key, None) + + bedrock_messages: List[MessageBlock] = _bedrock_converse_messages_pt( + messages=messages + ) + bedrock_tools: List[ToolBlock] = _bedrock_tools_pt( + inference_params.get("tools", []) + ) + bedrock_tool_config: Optional[ToolConfigBlock] = None + if len(bedrock_tools) > 0: + bedrock_tool_config = ToolConfigBlock( + tools=bedrock_tools, + toolChoice=inference_params.get("tool_choice", None), + ) + + data: RequestObject = { + "messages": bedrock_messages, + "additionalModelRequestFields": additional_request_params, + "system": system_content_blocks, + } + if bedrock_tool_config is not None: + data["toolConfig"] = bedrock_tool_config + + ## COMPLETION CALL + + headers = {"Content-Type": "application/json"} + if extra_headers is not None: + headers = {"Content-Type": "application/json", **extra_headers} + request = AWSRequest( + method="POST", url=endpoint_url, data=data, headers=headers + ) + sigv4.add_auth(request) + prepped = request.prepare() + + ## LOGGING + logging_obj.pre_call( + input=messages, + api_key="", + additional_args={ + "complete_input_dict": data, + "api_base": prepped.url, + "headers": prepped.headers, + }, + ) + + ### ROUTING (ASYNC, STREAMING, SYNC) + try: + response = self.client.post(url=prepped.url, headers=prepped.headers, data=data) # type: ignore + response.raise_for_status() + except httpx.HTTPStatusError as err: + error_code = err.response.status_code + raise BedrockError(status_code=error_code, message=response.text) + except httpx.TimeoutException as e: + raise BedrockError(status_code=408, message="Timeout error occurred.") + + def get_response_stream_shape(): from botocore.model import ServiceModel from botocore.loaders import Loader diff --git a/litellm/llms/prompt_templates/factory.py b/litellm/llms/prompt_templates/factory.py index 41ecb486c..d5ef69687 100644 --- a/litellm/llms/prompt_templates/factory.py +++ b/litellm/llms/prompt_templates/factory.py @@ -3,14 +3,7 @@ import requests, traceback import json, re, xml.etree.ElementTree as ET from jinja2 import Template, exceptions, meta, BaseLoader from jinja2.sandbox import ImmutableSandboxedEnvironment -from typing import ( - Any, - List, - Mapping, - MutableMapping, - Optional, - Sequence, -) +from typing import Any, List, Mapping, MutableMapping, Optional, Sequence, Tuple import litellm import litellm.types from litellm.types.completion import ( @@ -24,7 +17,7 @@ from litellm.types.completion import ( import litellm.types.llms from litellm.types.llms.anthropic import * import uuid - +from litellm.types.llms.bedrock import MessageBlock as BedrockMessageBlock import litellm.types.llms.vertex_ai @@ -1460,9 +1453,7 @@ def _load_image_from_url(image_url): try: from PIL import Image except: - raise Exception( - "gemini image conversion failed please run `pip install Pillow`" - ) + raise Exception("image conversion failed please run `pip install Pillow`") from io import BytesIO try: @@ -1613,6 +1604,379 @@ def azure_text_pt(messages: list): return prompt +###### AMAZON BEDROCK ####### + +from litellm.types.llms.bedrock import ( + ToolResultContentBlock as BedrockToolResultContentBlock, + ToolResultBlock as BedrockToolResultBlock, + ToolConfigBlock as BedrockToolConfigBlock, + ToolUseBlock as BedrockToolUseBlock, + ImageSourceBlock as BedrockImageSourceBlock, + ImageBlock as BedrockImageBlock, + ContentBlock as BedrockContentBlock, + ToolInputSchemaBlock as BedrockToolInputSchemaBlock, + ToolSpecBlock as BedrockToolSpecBlock, + ToolBlock as BedrockToolBlock, +) + + +def get_image_details(image_url) -> Tuple[bytes, str]: + try: + import base64 + + # Send a GET request to the image URL + response = requests.get(image_url) + response.raise_for_status() # Raise an exception for HTTP errors + + # Check the response's content type to ensure it is an image + content_type = response.headers.get("content-type") + if not content_type or "image" not in content_type: + raise ValueError( + f"URL does not point to a valid image (content-type: {content_type})" + ) + + # Convert the image content to base64 bytes + base64_bytes = base64.b64encode(response.content) + + # Get mime-type + mime_type = content_type.split("/")[ + 1 + ] # Extract mime-type from content-type header + + return base64_bytes, mime_type + + except requests.RequestException as e: + raise Exception(f"Request failed: {e}") + except Exception as e: + raise e + + +def _process_bedrock_converse_image_block(image_url: str) -> BedrockImageBlock: + if "base64" in image_url: + # Case 1: Images with base64 encoding + import base64, re + + # base 64 is passed as data:image/jpeg;base64, + image_metadata, img_without_base_64 = image_url.split(",") + image_format = image_metadata.split("/")[1] + + # read mime_type from img_without_base_64=data:image/jpeg;base64 + # Extract MIME type using regular expression + mime_type_match = re.match(r"data:(.*?);base64", image_metadata) + + if mime_type_match: + mime_type = mime_type_match.group(1) + else: + mime_type = "jpeg" + decoded_img = base64.b64decode(img_without_base_64) + _blob = BedrockImageSourceBlock(bytes=decoded_img) + supported_image_formats = ( + litellm.AmazonConverseConfig().get_supported_image_types() + ) + if image_format in supported_image_formats: + return BedrockImageBlock(source=_blob, format=image_format) # type: ignore + else: + # Handle the case when the image format is not supported + raise ValueError( + "Unsupported image format: {}. Supported formats: {}".format( + image_format, supported_image_formats + ) + ) + elif "https:/" in image_url: + # Case 2: Images with direct links + image_bytes, image_format = get_image_details(image_url) + _blob = BedrockImageSourceBlock(bytes=image_bytes) + supported_image_formats = ( + litellm.AmazonConverseConfig().get_supported_image_types() + ) + if image_format in supported_image_formats: + return BedrockImageBlock(source=_blob, format=image_format) # type: ignore + else: + # Handle the case when the image format is not supported + raise ValueError( + "Unsupported image format: {}. Supported formats: {}".format( + image_format, supported_image_formats + ) + ) + else: + raise ValueError( + "Unsupported image type. Expected either image url or base64 encoded string" + ) + + +def _convert_to_bedrock_tool_call_invoke( + tool_calls: list, +) -> List[BedrockContentBlock]: + """ + OpenAI tool invokes: + { + "role": "assistant", + "content": null, + "tool_calls": [ + { + "id": "call_abc123", + "type": "function", + "function": { + "name": "get_current_weather", + "arguments": "{\n\"location\": \"Boston, MA\"\n}" + } + } + ] + }, + """ + """ + Bedrock tool invokes: + [ + { + "role": "assistant", + "toolUse": { + "input": {"location": "Boston, MA", ..}, + "name": "get_current_weather", + "toolUseId": "call_abc123" + } + } + ] + """ + """ + - json.loads argument + - extract name + - extract id + """ + + try: + _parts_list: List[BedrockContentBlock] = [] + for tool in tool_calls: + if "function" in tool: + id = tool["id"] + name = tool["function"].get("name", "") + arguments = tool["function"].get("arguments", "") + arguments_dict = json.loads(arguments) + bedrock_tool = BedrockToolUseBlock( + input=arguments_dict, name=name, toolUseId=id + ) + bedrock_content_block = BedrockContentBlock(toolUse=bedrock_tool) + _parts_list.append(bedrock_content_block) + return _parts_list + except Exception as e: + raise Exception( + "Unable to convert openai tool calls={} to bedrock tool calls. Received error={}".format( + tool_calls, str(e) + ) + ) + + +def _convert_to_bedrock_tool_call_result( + message: dict, +) -> BedrockMessageBlock: + """ + OpenAI message with a tool result looks like: + { + "tool_call_id": "tool_1", + "role": "tool", + "name": "get_current_weather", + "content": "function result goes here", + }, + + OpenAI message with a function call result looks like: + { + "role": "function", + "name": "get_current_weather", + "content": "function result goes here", + } + """ + """ + Bedrock result looks like this: + { + "role": "user", + "content": [ + { + "toolResult": { + "toolUseId": "tooluse_kZJMlvQmRJ6eAyJE5GIl7Q", + "content": [ + { + "json": { + "song": "Elemental Hotel", + "artist": "8 Storey Hike" + } + } + ] + } + } + ] + } + """ + """ + - + """ + content = message.get("content", "") + name = message.get("name", "") + id = message.get("tool_call_id", str(uuid.uuid4())) + + tool_result_content_block = BedrockToolResultContentBlock(text=content) + tool_result = BedrockToolResultBlock( + content=tool_result_content_block, + toolUseId=id, + ) + content_block = BedrockContentBlock(toolResult=tool_result) + + return BedrockMessageBlock(role="user", content=[content_block]) + + +def _bedrock_converse_messages_pt(messages: List) -> List[BedrockMessageBlock]: + """ + Converts given messages from OpenAI format to Bedrock format + + - Roles must alternate b/w 'user' and 'model' (same as anthropic -> merge consecutive roles) + - Please ensure that function response turn comes immediately after a function call turn + """ + + contents: List[BedrockMessageBlock] = [] + msg_i = 0 + while msg_i < len(messages): + user_content: List[BedrockContentBlock] = [] + init_msg_i = msg_i + ## MERGE CONSECUTIVE USER CONTENT ## + while msg_i < len(messages) and messages[msg_i]["role"] == "user": + if isinstance(messages[msg_i]["content"], list): + _parts: List[BedrockContentBlock] = [] + for element in messages[msg_i]["content"]: + if isinstance(element, dict): + if element["type"] == "text": + _part = BedrockContentBlock(text=element["text"]) + _parts.append(_part) + elif element["type"] == "image_url": + image_url = element["image_url"]["url"] + _part = _process_bedrock_converse_image_block( # type: ignore + image_url=image_url + ) + _parts.append(BedrockContentBlock(image=_part)) # type: ignore + user_content.extend(_parts) + else: + _part = BedrockContentBlock(text=messages[msg_i]["content"]) + user_content.append(_part) + + msg_i += 1 + + if user_content: + contents.append(BedrockMessageBlock(role="user", content=user_content)) + assistant_content: List[BedrockContentBlock] = [] + ## MERGE CONSECUTIVE ASSISTANT CONTENT ## + while msg_i < len(messages) and messages[msg_i]["role"] == "assistant": + if isinstance(messages[msg_i]["content"], list): + assistants_parts: List[BedrockContentBlock] = [] + for element in messages[msg_i]["content"]: + if isinstance(element, dict): + if element["type"] == "text": + assistants_part = BedrockContentBlock(text=element["text"]) + assistants_parts.append(assistants_part) + elif element["type"] == "image_url": + image_url = element["image_url"]["url"] + assistants_part = _process_bedrock_converse_image_block( # type: ignore + image_url=image_url + ) + assistants_parts.append( + BedrockContentBlock(image=assistants_part) # type: ignore + ) + assistant_content.extend(assistants_parts) + elif messages[msg_i].get( + "tool_calls", [] + ): # support assistant tool invoke convertion + assistant_content.extend( + _convert_to_bedrock_tool_call_invoke(messages[msg_i]["tool_calls"]) + ) + else: + assistant_text = ( + messages[msg_i].get("content") or "" + ) # either string or none + if assistant_text: + assistant_content.append(BedrockContentBlock(text=assistant_text)) + + msg_i += 1 + + if assistant_content: + contents.append( + BedrockMessageBlock(role="assistant", content=assistant_content) + ) + + ## APPEND TOOL CALL MESSAGES ## + if msg_i < len(messages) and messages[msg_i]["role"] == "tool": + tool_call_result = _convert_to_bedrock_tool_call_result(messages[msg_i]) + contents.append(tool_call_result) + msg_i += 1 + if msg_i == init_msg_i: # prevent infinite loops + raise Exception( + "Invalid Message passed in - {}. File an issue https://github.com/BerriAI/litellm/issues".format( + messages[msg_i] + ) + ) + + return contents + + +def _bedrock_tools_pt(tools: List) -> List[BedrockToolBlock]: + """ + OpenAI tools looks like: + tools = [ + { + "type": "function", + "function": { + "name": "get_current_weather", + "description": "Get the current weather in a given location", + "parameters": { + "type": "object", + "properties": { + "location": { + "type": "string", + "description": "The city and state, e.g. San Francisco, CA", + }, + "unit": {"type": "string", "enum": ["celsius", "fahrenheit"]}, + }, + "required": ["location"], + }, + } + } + ] + """ + """ + Bedrock toolConfig looks like: + "tools": [ + { + "toolSpec": { + "name": "top_song", + "description": "Get the most popular song played on a radio station.", + "inputSchema": { + "json": { + "type": "object", + "properties": { + "sign": { + "type": "string", + "description": "The call sign for the radio station for which you want the most popular song. Example calls signs are WZPZ, and WKRP." + } + }, + "required": [ + "sign" + ] + } + } + } + } + ] + """ + tool_block_list: List[BedrockToolBlock] = [] + for tool in tools: + parameters = tool.get("function", {}).get("parameters", None) + name = tool.get("function", {}).get("name", "") + description = tool.get("function", {}).get("description", "") + tool_input_schema = BedrockToolInputSchemaBlock(json=parameters) + tool_spec = BedrockToolSpecBlock( + inputSchema=tool_input_schema, name=name, description=description + ) + tool_block = BedrockToolBlock(toolSpec=tool_spec) + tool_block_list.append(tool_block) + + return tool_block_list + + # Function call template def function_call_prompt(messages: list, functions: list): function_prompt = """Produce JSON OUTPUT ONLY! Adhere to this format {"name": "function_name", "arguments":{"argument_name": "argument_value"}} The following functions are available to you:""" diff --git a/litellm/tests/test_completion.py b/litellm/tests/test_completion.py index 47c55ca4f..1befa1392 100644 --- a/litellm/tests/test_completion.py +++ b/litellm/tests/test_completion.py @@ -300,7 +300,14 @@ def test_completion_claude_3(): pytest.fail(f"Error occurred: {e}") -def test_completion_claude_3_function_call(): +@pytest.mark.parametrize( + "model", + [ + # "anthropic/claude-3-opus-20240229", + "cohere.command-r-plus-v1:0" + ], +) +def test_completion_claude_3_function_call(model): litellm.set_verbose = True tools = [ { @@ -331,7 +338,7 @@ def test_completion_claude_3_function_call(): try: # test without max tokens response = completion( - model="anthropic/claude-3-opus-20240229", + model=model, messages=messages, tools=tools, tool_choice={ @@ -364,7 +371,7 @@ def test_completion_claude_3_function_call(): ) # In the second response, Claude should deduce answer from tool results second_response = completion( - model="anthropic/claude-3-opus-20240229", + model=model, messages=messages, tools=tools, tool_choice="auto", diff --git a/litellm/types/llms/bedrock.py b/litellm/types/llms/bedrock.py index 0c8259682..9333ea1b9 100644 --- a/litellm/types/llms/bedrock.py +++ b/litellm/types/llms/bedrock.py @@ -1,4 +1,4 @@ -from typing import TypedDict, Any, Union, Optional +from typing import TypedDict, Any, Union, Optional, Literal, List import json from typing_extensions import ( Self, @@ -11,6 +11,81 @@ from typing_extensions import ( ) +class SystemContentBlock(TypedDict): + text: str + + +class ImageSourceBlock(TypedDict): + bytes: Optional[bytes] + + +class ImageBlock(TypedDict): + format: Literal["png", "jpeg", "gif", "webp"] + source: ImageSourceBlock + + +class ToolResultContentBlock(TypedDict, total=False): + image: ImageBlock + json: dict + text: str + + +class ToolResultBlock(TypedDict, total=False): + content: Required[ToolResultContentBlock] + toolUseId: Required[str] + status: Literal["success", "error"] + + +class ToolUseBlock(TypedDict): + input: dict + name: str + toolUseId: str + + +class ContentBlock(TypedDict, total=False): + text: str + image: ImageBlock + toolResult: ToolResultBlock + toolUse: ToolUseBlock + + +class MessageBlock(TypedDict): + content: List[ContentBlock] + role: Literal["user", "assistant"] + + +class ToolInputSchemaBlock(TypedDict): + json: Optional[dict] + + +class ToolSpecBlock(TypedDict, total=False): + inputSchema: Required[ToolInputSchemaBlock] + name: Required[str] + description: str + + +class ToolBlock(TypedDict): + toolSpec: Optional[ToolSpecBlock] + + +class SpecificToolChoiceBlock(TypedDict): + name: str + + +class ToolConfigBlock(TypedDict, total=False): + tools: Required[List[ToolBlock]] + toolChoice: Union[str, SpecificToolChoiceBlock] + + +class RequestObject(TypedDict, total=False): + additionalModelRequestFields: dict + additionalModelResponseFieldPaths: List[str] + inferenceConfig: dict + messages: Required[List[MessageBlock]] + system: List[SystemContentBlock] + toolConfig: ToolConfigBlock + + class GenericStreamingChunk(TypedDict): text: Required[str] is_finished: Required[bool] diff --git a/litellm/utils.py b/litellm/utils.py index 178860094..65a34058b 100644 --- a/litellm/utils.py +++ b/litellm/utils.py @@ -6401,20 +6401,7 @@ def get_supported_openai_params( - None if unmapped """ if custom_llm_provider == "bedrock": - if model.startswith("anthropic.claude-3"): - return litellm.AmazonAnthropicClaude3Config().get_supported_openai_params() - elif model.startswith("anthropic"): - return litellm.AmazonAnthropicConfig().get_supported_openai_params() - elif model.startswith("ai21"): - return ["max_tokens", "temperature", "top_p", "stream"] - elif model.startswith("amazon"): - return ["max_tokens", "temperature", "stop", "top_p", "stream"] - elif model.startswith("meta"): - return ["max_tokens", "temperature", "top_p", "stream"] - elif model.startswith("cohere"): - return ["stream", "temperature", "max_tokens"] - elif model.startswith("mistral"): - return ["max_tokens", "temperature", "stop", "top_p", "stream"] + return litellm.AmazonConverseConfig().get_supported_openai_params() elif custom_llm_provider == "ollama": return litellm.OllamaConfig().get_supported_openai_params() elif custom_llm_provider == "ollama_chat":