feat(bedrock_httpx.py): add support for bedrock converse api

closes https://github.com/BerriAI/litellm/issues/4000
This commit is contained in:
Krrish Dholakia 2024-06-05 21:20:36 -07:00
parent 0d3e52373c
commit a76a9b7d11
6 changed files with 846 additions and 31 deletions

View file

@ -765,7 +765,7 @@ from .llms.sagemaker import SagemakerConfig
from .llms.ollama import OllamaConfig from .llms.ollama import OllamaConfig
from .llms.ollama_chat import OllamaChatConfig from .llms.ollama_chat import OllamaChatConfig
from .llms.maritalk import MaritTalkConfig from .llms.maritalk import MaritTalkConfig
from .llms.bedrock_httpx import AmazonCohereChatConfig from .llms.bedrock_httpx import AmazonCohereChatConfig, AmazonConverseConfig
from .llms.bedrock import ( from .llms.bedrock import (
AmazonTitanConfig, AmazonTitanConfig,
AmazonAI21Config, AmazonAI21Config,

View file

@ -38,6 +38,8 @@ from .prompt_templates.factory import (
extract_between_tags, extract_between_tags,
parse_xml_params, parse_xml_params,
contains_tag, contains_tag,
_bedrock_converse_messages_pt,
_bedrock_tools_pt,
) )
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, HTTPHandler from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, HTTPHandler
from .base import BaseLLM from .base import BaseLLM
@ -118,6 +120,8 @@ class AmazonCohereChatConfig:
"presence_penalty", "presence_penalty",
"seed", "seed",
"stop", "stop",
"tools",
"tool_choice",
] ]
def map_openai_params( def map_openai_params(
@ -1069,6 +1073,384 @@ class BedrockLLM(BaseLLM):
return super().embedding(*args, **kwargs) return super().embedding(*args, **kwargs)
class AmazonConverseConfig:
"""
Reference - https://docs.aws.amazon.com/bedrock/latest/APIReference/API_runtime_Converse.html
"""
maxTokens: Optional[int]
stopSequences: Optional[List[str]]
temperature: Optional[int]
topP: Optional[int]
def __init__(
self,
maxTokens: Optional[int] = None,
stopSequences: Optional[List[str]] = None,
temperature: Optional[int] = None,
top_p: Optional[int] = None,
) -> None:
locals_ = locals()
for key, value in locals_.items():
if key != "self" and value is not None:
setattr(self.__class__, key, value)
@classmethod
def get_config(cls):
return {
k: v
for k, v in cls.__dict__.items()
if not k.startswith("__")
and not isinstance(
v,
(
types.FunctionType,
types.BuiltinFunctionType,
classmethod,
staticmethod,
),
)
and v is not None
}
def get_supported_openai_params(self) -> List[str]:
return [
"max_tokens",
"stream",
"stream_options",
"stop",
"temperature",
"top_p",
"tools",
"tool_choice",
]
def get_supported_image_types(self) -> List[str]:
return ["png", "jpeg", "gif", "webp"]
def map_openai_params(
self, non_default_params: dict, optional_params: dict
) -> dict:
for param, value in non_default_params.items():
if param == "max_tokens":
optional_params["maxTokens"] = value
if param == "stream":
optional_params["stream"] = value
if param == "stop":
if isinstance(value, str):
value = [value]
optional_params["stop_sequences"] = value
if param == "temperature":
optional_params["temperature"] = value
if param == "top_p":
optional_params["topP"] = value
return optional_params
class BedrockConverseLLM(BaseLLM):
def __init__(self) -> None:
super().__init__()
def encode_model_id(self, model_id: str) -> str:
"""
Double encode the model ID to ensure it matches the expected double-encoded format.
Args:
model_id (str): The model ID to encode.
Returns:
str: The double-encoded model ID.
"""
return urllib.parse.quote(model_id, safe="")
def get_credentials(
self,
aws_access_key_id: Optional[str] = None,
aws_secret_access_key: Optional[str] = None,
aws_region_name: Optional[str] = None,
aws_session_name: Optional[str] = None,
aws_profile_name: Optional[str] = None,
aws_role_name: Optional[str] = None,
aws_web_identity_token: Optional[str] = None,
):
"""
Return a boto3.Credentials object
"""
import boto3
## CHECK IS 'os.environ/' passed in
params_to_check: List[Optional[str]] = [
aws_access_key_id,
aws_secret_access_key,
aws_region_name,
aws_session_name,
aws_profile_name,
aws_role_name,
aws_web_identity_token,
]
# Iterate over parameters and update if needed
for i, param in enumerate(params_to_check):
if param and param.startswith("os.environ/"):
_v = get_secret(param)
if _v is not None and isinstance(_v, str):
params_to_check[i] = _v
# Assign updated values back to parameters
(
aws_access_key_id,
aws_secret_access_key,
aws_region_name,
aws_session_name,
aws_profile_name,
aws_role_name,
aws_web_identity_token,
) = params_to_check
### CHECK STS ###
if (
aws_web_identity_token is not None
and aws_role_name is not None
and aws_session_name is not None
):
oidc_token = get_secret(aws_web_identity_token)
if oidc_token is None:
raise BedrockError(
message="OIDC token could not be retrieved from secret manager.",
status_code=401,
)
sts_client = boto3.client("sts")
# https://docs.aws.amazon.com/STS/latest/APIReference/API_AssumeRoleWithWebIdentity.html
# https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/sts/client/assume_role_with_web_identity.html
sts_response = sts_client.assume_role_with_web_identity(
RoleArn=aws_role_name,
RoleSessionName=aws_session_name,
WebIdentityToken=oidc_token,
DurationSeconds=3600,
)
session = boto3.Session(
aws_access_key_id=sts_response["Credentials"]["AccessKeyId"],
aws_secret_access_key=sts_response["Credentials"]["SecretAccessKey"],
aws_session_token=sts_response["Credentials"]["SessionToken"],
region_name=aws_region_name,
)
return session.get_credentials()
elif aws_role_name is not None and aws_session_name is not None:
sts_client = boto3.client(
"sts",
aws_access_key_id=aws_access_key_id, # [OPTIONAL]
aws_secret_access_key=aws_secret_access_key, # [OPTIONAL]
)
sts_response = sts_client.assume_role(
RoleArn=aws_role_name, RoleSessionName=aws_session_name
)
# Extract the credentials from the response and convert to Session Credentials
sts_credentials = sts_response["Credentials"]
from botocore.credentials import Credentials
credentials = Credentials(
access_key=sts_credentials["AccessKeyId"],
secret_key=sts_credentials["SecretAccessKey"],
token=sts_credentials["SessionToken"],
)
return credentials
elif aws_profile_name is not None: ### CHECK SESSION ###
# uses auth values from AWS profile usually stored in ~/.aws/credentials
client = boto3.Session(profile_name=aws_profile_name)
return client.get_credentials()
else:
session = boto3.Session(
aws_access_key_id=aws_access_key_id,
aws_secret_access_key=aws_secret_access_key,
region_name=aws_region_name,
)
return session.get_credentials()
def completion(
self,
model: str,
messages: list,
custom_prompt_dict: dict,
model_response: ModelResponse,
print_verbose: Callable,
encoding,
logging_obj,
optional_params: dict,
acompletion: bool,
timeout: Optional[Union[float, httpx.Timeout]],
litellm_params=None,
logger_fn=None,
extra_headers: Optional[dict] = None,
client: Optional[Union[AsyncHTTPHandler, HTTPHandler]] = None,
):
try:
import boto3
from botocore.auth import SigV4Auth
from botocore.awsrequest import AWSRequest
from botocore.credentials import Credentials
except ImportError as e:
raise ImportError("Missing boto3 to call bedrock. Run 'pip install boto3'.")
## SETUP ##
stream = optional_params.pop("stream", None)
modelId = optional_params.pop("model_id", None)
if modelId is not None:
modelId = self.encode_model_id(model_id=modelId)
else:
modelId = model
provider = model.split(".")[0]
## CREDENTIALS ##
# pop aws_secret_access_key, aws_access_key_id, aws_region_name from kwargs, since completion calls fail with them
aws_secret_access_key = optional_params.pop("aws_secret_access_key", None)
aws_access_key_id = optional_params.pop("aws_access_key_id", None)
aws_region_name = optional_params.pop("aws_region_name", None)
aws_role_name = optional_params.pop("aws_role_name", None)
aws_session_name = optional_params.pop("aws_session_name", None)
aws_profile_name = optional_params.pop("aws_profile_name", None)
aws_bedrock_runtime_endpoint = optional_params.pop(
"aws_bedrock_runtime_endpoint", None
) # https://bedrock-runtime.{region_name}.amazonaws.com
aws_web_identity_token = optional_params.pop("aws_web_identity_token", None)
### SET REGION NAME ###
if aws_region_name is None:
# check env #
litellm_aws_region_name = get_secret("AWS_REGION_NAME", None)
if litellm_aws_region_name is not None and isinstance(
litellm_aws_region_name, str
):
aws_region_name = litellm_aws_region_name
standard_aws_region_name = get_secret("AWS_REGION", None)
if standard_aws_region_name is not None and isinstance(
standard_aws_region_name, str
):
aws_region_name = standard_aws_region_name
if aws_region_name is None:
aws_region_name = "us-west-2"
credentials: Credentials = self.get_credentials(
aws_access_key_id=aws_access_key_id,
aws_secret_access_key=aws_secret_access_key,
aws_region_name=aws_region_name,
aws_session_name=aws_session_name,
aws_profile_name=aws_profile_name,
aws_role_name=aws_role_name,
aws_web_identity_token=aws_web_identity_token,
)
### SET RUNTIME ENDPOINT ###
endpoint_url = ""
env_aws_bedrock_runtime_endpoint = get_secret("AWS_BEDROCK_RUNTIME_ENDPOINT")
if aws_bedrock_runtime_endpoint is not None and isinstance(
aws_bedrock_runtime_endpoint, str
):
endpoint_url = aws_bedrock_runtime_endpoint
elif env_aws_bedrock_runtime_endpoint and isinstance(
env_aws_bedrock_runtime_endpoint, str
):
endpoint_url = env_aws_bedrock_runtime_endpoint
else:
endpoint_url = f"https://bedrock-runtime.{aws_region_name}.amazonaws.com"
if (stream is not None and stream == True) and provider != "ai21":
endpoint_url = f"{endpoint_url}/model/{modelId}/converse-stream"
else:
endpoint_url = f"{endpoint_url}/model/{modelId}/converse"
sigv4 = SigV4Auth(credentials, "bedrock", aws_region_name)
# Separate system prompt from rest of message
system_prompt_indices = []
system_content_blocks: List[SystemContentBlock] = []
for idx, message in enumerate(messages):
if message["role"] == "system":
_system_content_block = SystemContentBlock(text=message["content"])
system_content_blocks.append(_system_content_block)
system_prompt_indices.append(idx)
if len(system_prompt_indices) > 0:
for idx in reversed(system_prompt_indices):
messages.pop(idx)
inference_params = copy.deepcopy(optional_params)
additional_request_keys = []
additional_request_params = {}
supported_converse_params = AmazonConverseConfig().get_config().keys()
## TRANSFORMATION ##
# send all model-specific params in 'additional_request_params'
for k, v in inference_params.items():
if k not in supported_converse_params:
additional_request_params[k] = v
additional_request_keys.append(k)
for key in additional_request_keys:
inference_params.pop(key, None)
bedrock_messages: List[MessageBlock] = _bedrock_converse_messages_pt(
messages=messages
)
bedrock_tools: List[ToolBlock] = _bedrock_tools_pt(
inference_params.get("tools", [])
)
bedrock_tool_config: Optional[ToolConfigBlock] = None
if len(bedrock_tools) > 0:
bedrock_tool_config = ToolConfigBlock(
tools=bedrock_tools,
toolChoice=inference_params.get("tool_choice", None),
)
data: RequestObject = {
"messages": bedrock_messages,
"additionalModelRequestFields": additional_request_params,
"system": system_content_blocks,
}
if bedrock_tool_config is not None:
data["toolConfig"] = bedrock_tool_config
## COMPLETION CALL
headers = {"Content-Type": "application/json"}
if extra_headers is not None:
headers = {"Content-Type": "application/json", **extra_headers}
request = AWSRequest(
method="POST", url=endpoint_url, data=data, headers=headers
)
sigv4.add_auth(request)
prepped = request.prepare()
## LOGGING
logging_obj.pre_call(
input=messages,
api_key="",
additional_args={
"complete_input_dict": data,
"api_base": prepped.url,
"headers": prepped.headers,
},
)
### ROUTING (ASYNC, STREAMING, SYNC)
try:
response = self.client.post(url=prepped.url, headers=prepped.headers, data=data) # type: ignore
response.raise_for_status()
except httpx.HTTPStatusError as err:
error_code = err.response.status_code
raise BedrockError(status_code=error_code, message=response.text)
except httpx.TimeoutException as e:
raise BedrockError(status_code=408, message="Timeout error occurred.")
def get_response_stream_shape(): def get_response_stream_shape():
from botocore.model import ServiceModel from botocore.model import ServiceModel
from botocore.loaders import Loader from botocore.loaders import Loader

View file

@ -3,14 +3,7 @@ import requests, traceback
import json, re, xml.etree.ElementTree as ET import json, re, xml.etree.ElementTree as ET
from jinja2 import Template, exceptions, meta, BaseLoader from jinja2 import Template, exceptions, meta, BaseLoader
from jinja2.sandbox import ImmutableSandboxedEnvironment from jinja2.sandbox import ImmutableSandboxedEnvironment
from typing import ( from typing import Any, List, Mapping, MutableMapping, Optional, Sequence, Tuple
Any,
List,
Mapping,
MutableMapping,
Optional,
Sequence,
)
import litellm import litellm
import litellm.types import litellm.types
from litellm.types.completion import ( from litellm.types.completion import (
@ -24,7 +17,7 @@ from litellm.types.completion import (
import litellm.types.llms import litellm.types.llms
from litellm.types.llms.anthropic import * from litellm.types.llms.anthropic import *
import uuid import uuid
from litellm.types.llms.bedrock import MessageBlock as BedrockMessageBlock
import litellm.types.llms.vertex_ai import litellm.types.llms.vertex_ai
@ -1460,9 +1453,7 @@ def _load_image_from_url(image_url):
try: try:
from PIL import Image from PIL import Image
except: except:
raise Exception( raise Exception("image conversion failed please run `pip install Pillow`")
"gemini image conversion failed please run `pip install Pillow`"
)
from io import BytesIO from io import BytesIO
try: try:
@ -1613,6 +1604,379 @@ def azure_text_pt(messages: list):
return prompt return prompt
###### AMAZON BEDROCK #######
from litellm.types.llms.bedrock import (
ToolResultContentBlock as BedrockToolResultContentBlock,
ToolResultBlock as BedrockToolResultBlock,
ToolConfigBlock as BedrockToolConfigBlock,
ToolUseBlock as BedrockToolUseBlock,
ImageSourceBlock as BedrockImageSourceBlock,
ImageBlock as BedrockImageBlock,
ContentBlock as BedrockContentBlock,
ToolInputSchemaBlock as BedrockToolInputSchemaBlock,
ToolSpecBlock as BedrockToolSpecBlock,
ToolBlock as BedrockToolBlock,
)
def get_image_details(image_url) -> Tuple[bytes, str]:
try:
import base64
# Send a GET request to the image URL
response = requests.get(image_url)
response.raise_for_status() # Raise an exception for HTTP errors
# Check the response's content type to ensure it is an image
content_type = response.headers.get("content-type")
if not content_type or "image" not in content_type:
raise ValueError(
f"URL does not point to a valid image (content-type: {content_type})"
)
# Convert the image content to base64 bytes
base64_bytes = base64.b64encode(response.content)
# Get mime-type
mime_type = content_type.split("/")[
1
] # Extract mime-type from content-type header
return base64_bytes, mime_type
except requests.RequestException as e:
raise Exception(f"Request failed: {e}")
except Exception as e:
raise e
def _process_bedrock_converse_image_block(image_url: str) -> BedrockImageBlock:
if "base64" in image_url:
# Case 1: Images with base64 encoding
import base64, re
# base 64 is passed as data:image/jpeg;base64,<base-64-encoded-image>
image_metadata, img_without_base_64 = image_url.split(",")
image_format = image_metadata.split("/")[1]
# read mime_type from img_without_base_64=data:image/jpeg;base64
# Extract MIME type using regular expression
mime_type_match = re.match(r"data:(.*?);base64", image_metadata)
if mime_type_match:
mime_type = mime_type_match.group(1)
else:
mime_type = "jpeg"
decoded_img = base64.b64decode(img_without_base_64)
_blob = BedrockImageSourceBlock(bytes=decoded_img)
supported_image_formats = (
litellm.AmazonConverseConfig().get_supported_image_types()
)
if image_format in supported_image_formats:
return BedrockImageBlock(source=_blob, format=image_format) # type: ignore
else:
# Handle the case when the image format is not supported
raise ValueError(
"Unsupported image format: {}. Supported formats: {}".format(
image_format, supported_image_formats
)
)
elif "https:/" in image_url:
# Case 2: Images with direct links
image_bytes, image_format = get_image_details(image_url)
_blob = BedrockImageSourceBlock(bytes=image_bytes)
supported_image_formats = (
litellm.AmazonConverseConfig().get_supported_image_types()
)
if image_format in supported_image_formats:
return BedrockImageBlock(source=_blob, format=image_format) # type: ignore
else:
# Handle the case when the image format is not supported
raise ValueError(
"Unsupported image format: {}. Supported formats: {}".format(
image_format, supported_image_formats
)
)
else:
raise ValueError(
"Unsupported image type. Expected either image url or base64 encoded string"
)
def _convert_to_bedrock_tool_call_invoke(
tool_calls: list,
) -> List[BedrockContentBlock]:
"""
OpenAI tool invokes:
{
"role": "assistant",
"content": null,
"tool_calls": [
{
"id": "call_abc123",
"type": "function",
"function": {
"name": "get_current_weather",
"arguments": "{\n\"location\": \"Boston, MA\"\n}"
}
}
]
},
"""
"""
Bedrock tool invokes:
[
{
"role": "assistant",
"toolUse": {
"input": {"location": "Boston, MA", ..},
"name": "get_current_weather",
"toolUseId": "call_abc123"
}
}
]
"""
"""
- json.loads argument
- extract name
- extract id
"""
try:
_parts_list: List[BedrockContentBlock] = []
for tool in tool_calls:
if "function" in tool:
id = tool["id"]
name = tool["function"].get("name", "")
arguments = tool["function"].get("arguments", "")
arguments_dict = json.loads(arguments)
bedrock_tool = BedrockToolUseBlock(
input=arguments_dict, name=name, toolUseId=id
)
bedrock_content_block = BedrockContentBlock(toolUse=bedrock_tool)
_parts_list.append(bedrock_content_block)
return _parts_list
except Exception as e:
raise Exception(
"Unable to convert openai tool calls={} to bedrock tool calls. Received error={}".format(
tool_calls, str(e)
)
)
def _convert_to_bedrock_tool_call_result(
message: dict,
) -> BedrockMessageBlock:
"""
OpenAI message with a tool result looks like:
{
"tool_call_id": "tool_1",
"role": "tool",
"name": "get_current_weather",
"content": "function result goes here",
},
OpenAI message with a function call result looks like:
{
"role": "function",
"name": "get_current_weather",
"content": "function result goes here",
}
"""
"""
Bedrock result looks like this:
{
"role": "user",
"content": [
{
"toolResult": {
"toolUseId": "tooluse_kZJMlvQmRJ6eAyJE5GIl7Q",
"content": [
{
"json": {
"song": "Elemental Hotel",
"artist": "8 Storey Hike"
}
}
]
}
}
]
}
"""
"""
-
"""
content = message.get("content", "")
name = message.get("name", "")
id = message.get("tool_call_id", str(uuid.uuid4()))
tool_result_content_block = BedrockToolResultContentBlock(text=content)
tool_result = BedrockToolResultBlock(
content=tool_result_content_block,
toolUseId=id,
)
content_block = BedrockContentBlock(toolResult=tool_result)
return BedrockMessageBlock(role="user", content=[content_block])
def _bedrock_converse_messages_pt(messages: List) -> List[BedrockMessageBlock]:
"""
Converts given messages from OpenAI format to Bedrock format
- Roles must alternate b/w 'user' and 'model' (same as anthropic -> merge consecutive roles)
- Please ensure that function response turn comes immediately after a function call turn
"""
contents: List[BedrockMessageBlock] = []
msg_i = 0
while msg_i < len(messages):
user_content: List[BedrockContentBlock] = []
init_msg_i = msg_i
## MERGE CONSECUTIVE USER CONTENT ##
while msg_i < len(messages) and messages[msg_i]["role"] == "user":
if isinstance(messages[msg_i]["content"], list):
_parts: List[BedrockContentBlock] = []
for element in messages[msg_i]["content"]:
if isinstance(element, dict):
if element["type"] == "text":
_part = BedrockContentBlock(text=element["text"])
_parts.append(_part)
elif element["type"] == "image_url":
image_url = element["image_url"]["url"]
_part = _process_bedrock_converse_image_block( # type: ignore
image_url=image_url
)
_parts.append(BedrockContentBlock(image=_part)) # type: ignore
user_content.extend(_parts)
else:
_part = BedrockContentBlock(text=messages[msg_i]["content"])
user_content.append(_part)
msg_i += 1
if user_content:
contents.append(BedrockMessageBlock(role="user", content=user_content))
assistant_content: List[BedrockContentBlock] = []
## MERGE CONSECUTIVE ASSISTANT CONTENT ##
while msg_i < len(messages) and messages[msg_i]["role"] == "assistant":
if isinstance(messages[msg_i]["content"], list):
assistants_parts: List[BedrockContentBlock] = []
for element in messages[msg_i]["content"]:
if isinstance(element, dict):
if element["type"] == "text":
assistants_part = BedrockContentBlock(text=element["text"])
assistants_parts.append(assistants_part)
elif element["type"] == "image_url":
image_url = element["image_url"]["url"]
assistants_part = _process_bedrock_converse_image_block( # type: ignore
image_url=image_url
)
assistants_parts.append(
BedrockContentBlock(image=assistants_part) # type: ignore
)
assistant_content.extend(assistants_parts)
elif messages[msg_i].get(
"tool_calls", []
): # support assistant tool invoke convertion
assistant_content.extend(
_convert_to_bedrock_tool_call_invoke(messages[msg_i]["tool_calls"])
)
else:
assistant_text = (
messages[msg_i].get("content") or ""
) # either string or none
if assistant_text:
assistant_content.append(BedrockContentBlock(text=assistant_text))
msg_i += 1
if assistant_content:
contents.append(
BedrockMessageBlock(role="assistant", content=assistant_content)
)
## APPEND TOOL CALL MESSAGES ##
if msg_i < len(messages) and messages[msg_i]["role"] == "tool":
tool_call_result = _convert_to_bedrock_tool_call_result(messages[msg_i])
contents.append(tool_call_result)
msg_i += 1
if msg_i == init_msg_i: # prevent infinite loops
raise Exception(
"Invalid Message passed in - {}. File an issue https://github.com/BerriAI/litellm/issues".format(
messages[msg_i]
)
)
return contents
def _bedrock_tools_pt(tools: List) -> List[BedrockToolBlock]:
"""
OpenAI tools looks like:
tools = [
{
"type": "function",
"function": {
"name": "get_current_weather",
"description": "Get the current weather in a given location",
"parameters": {
"type": "object",
"properties": {
"location": {
"type": "string",
"description": "The city and state, e.g. San Francisco, CA",
},
"unit": {"type": "string", "enum": ["celsius", "fahrenheit"]},
},
"required": ["location"],
},
}
}
]
"""
"""
Bedrock toolConfig looks like:
"tools": [
{
"toolSpec": {
"name": "top_song",
"description": "Get the most popular song played on a radio station.",
"inputSchema": {
"json": {
"type": "object",
"properties": {
"sign": {
"type": "string",
"description": "The call sign for the radio station for which you want the most popular song. Example calls signs are WZPZ, and WKRP."
}
},
"required": [
"sign"
]
}
}
}
}
]
"""
tool_block_list: List[BedrockToolBlock] = []
for tool in tools:
parameters = tool.get("function", {}).get("parameters", None)
name = tool.get("function", {}).get("name", "")
description = tool.get("function", {}).get("description", "")
tool_input_schema = BedrockToolInputSchemaBlock(json=parameters)
tool_spec = BedrockToolSpecBlock(
inputSchema=tool_input_schema, name=name, description=description
)
tool_block = BedrockToolBlock(toolSpec=tool_spec)
tool_block_list.append(tool_block)
return tool_block_list
# Function call template # Function call template
def function_call_prompt(messages: list, functions: list): def function_call_prompt(messages: list, functions: list):
function_prompt = """Produce JSON OUTPUT ONLY! Adhere to this format {"name": "function_name", "arguments":{"argument_name": "argument_value"}} The following functions are available to you:""" function_prompt = """Produce JSON OUTPUT ONLY! Adhere to this format {"name": "function_name", "arguments":{"argument_name": "argument_value"}} The following functions are available to you:"""

View file

@ -300,7 +300,14 @@ def test_completion_claude_3():
pytest.fail(f"Error occurred: {e}") pytest.fail(f"Error occurred: {e}")
def test_completion_claude_3_function_call(): @pytest.mark.parametrize(
"model",
[
# "anthropic/claude-3-opus-20240229",
"cohere.command-r-plus-v1:0"
],
)
def test_completion_claude_3_function_call(model):
litellm.set_verbose = True litellm.set_verbose = True
tools = [ tools = [
{ {
@ -331,7 +338,7 @@ def test_completion_claude_3_function_call():
try: try:
# test without max tokens # test without max tokens
response = completion( response = completion(
model="anthropic/claude-3-opus-20240229", model=model,
messages=messages, messages=messages,
tools=tools, tools=tools,
tool_choice={ tool_choice={
@ -364,7 +371,7 @@ def test_completion_claude_3_function_call():
) )
# In the second response, Claude should deduce answer from tool results # In the second response, Claude should deduce answer from tool results
second_response = completion( second_response = completion(
model="anthropic/claude-3-opus-20240229", model=model,
messages=messages, messages=messages,
tools=tools, tools=tools,
tool_choice="auto", tool_choice="auto",

View file

@ -1,4 +1,4 @@
from typing import TypedDict, Any, Union, Optional from typing import TypedDict, Any, Union, Optional, Literal, List
import json import json
from typing_extensions import ( from typing_extensions import (
Self, Self,
@ -11,6 +11,81 @@ from typing_extensions import (
) )
class SystemContentBlock(TypedDict):
text: str
class ImageSourceBlock(TypedDict):
bytes: Optional[bytes]
class ImageBlock(TypedDict):
format: Literal["png", "jpeg", "gif", "webp"]
source: ImageSourceBlock
class ToolResultContentBlock(TypedDict, total=False):
image: ImageBlock
json: dict
text: str
class ToolResultBlock(TypedDict, total=False):
content: Required[ToolResultContentBlock]
toolUseId: Required[str]
status: Literal["success", "error"]
class ToolUseBlock(TypedDict):
input: dict
name: str
toolUseId: str
class ContentBlock(TypedDict, total=False):
text: str
image: ImageBlock
toolResult: ToolResultBlock
toolUse: ToolUseBlock
class MessageBlock(TypedDict):
content: List[ContentBlock]
role: Literal["user", "assistant"]
class ToolInputSchemaBlock(TypedDict):
json: Optional[dict]
class ToolSpecBlock(TypedDict, total=False):
inputSchema: Required[ToolInputSchemaBlock]
name: Required[str]
description: str
class ToolBlock(TypedDict):
toolSpec: Optional[ToolSpecBlock]
class SpecificToolChoiceBlock(TypedDict):
name: str
class ToolConfigBlock(TypedDict, total=False):
tools: Required[List[ToolBlock]]
toolChoice: Union[str, SpecificToolChoiceBlock]
class RequestObject(TypedDict, total=False):
additionalModelRequestFields: dict
additionalModelResponseFieldPaths: List[str]
inferenceConfig: dict
messages: Required[List[MessageBlock]]
system: List[SystemContentBlock]
toolConfig: ToolConfigBlock
class GenericStreamingChunk(TypedDict): class GenericStreamingChunk(TypedDict):
text: Required[str] text: Required[str]
is_finished: Required[bool] is_finished: Required[bool]

View file

@ -6401,20 +6401,7 @@ def get_supported_openai_params(
- None if unmapped - None if unmapped
""" """
if custom_llm_provider == "bedrock": if custom_llm_provider == "bedrock":
if model.startswith("anthropic.claude-3"): return litellm.AmazonConverseConfig().get_supported_openai_params()
return litellm.AmazonAnthropicClaude3Config().get_supported_openai_params()
elif model.startswith("anthropic"):
return litellm.AmazonAnthropicConfig().get_supported_openai_params()
elif model.startswith("ai21"):
return ["max_tokens", "temperature", "top_p", "stream"]
elif model.startswith("amazon"):
return ["max_tokens", "temperature", "stop", "top_p", "stream"]
elif model.startswith("meta"):
return ["max_tokens", "temperature", "top_p", "stream"]
elif model.startswith("cohere"):
return ["stream", "temperature", "max_tokens"]
elif model.startswith("mistral"):
return ["max_tokens", "temperature", "stop", "top_p", "stream"]
elif custom_llm_provider == "ollama": elif custom_llm_provider == "ollama":
return litellm.OllamaConfig().get_supported_openai_params() return litellm.OllamaConfig().get_supported_openai_params()
elif custom_llm_provider == "ollama_chat": elif custom_llm_provider == "ollama_chat":