fix(bedrock_httpx.py): working claude 3 function calling

This commit is contained in:
Krrish Dholakia 2024-06-06 20:12:41 -07:00
parent a76a9b7d11
commit a995a0b172
11 changed files with 321 additions and 108 deletions

View file

@ -16,11 +16,11 @@ repos:
name: Check if files match
entry: python3 ci_cd/check_files_match.py
language: system
- repo: local
hooks:
- id: mypy
name: mypy
entry: python3 -m mypy --ignore-missing-imports
language: system
types: [python]
files: ^litellm/
# - repo: local
# hooks:
# - id: mypy
# name: mypy
# entry: python3 -m mypy --ignore-missing-imports
# language: system
# types: [python]
# files: ^litellm/

View file

@ -47,6 +47,11 @@ import httpx # type: ignore
from .bedrock import BedrockError, convert_messages_to_prompt, ModelResponseIterator
from litellm.types.llms.bedrock import *
import urllib.parse
from litellm.types.llms.openai import (
ChatCompletionResponseMessage,
ChatCompletionToolCallChunk,
ChatCompletionToolCallFunctionChunk,
)
class AmazonCohereChatConfig:
@ -1004,12 +1009,12 @@ class BedrockLLM(BaseLLM):
if isinstance(timeout, float) or isinstance(timeout, int):
timeout = httpx.Timeout(timeout)
_params["timeout"] = timeout
self.client = AsyncHTTPHandler(**_params) # type: ignore
client = AsyncHTTPHandler(**_params) # type: ignore
else:
self.client = client # type: ignore
client = client # type: ignore
try:
response = await self.client.post(api_base, headers=headers, data=data) # type: ignore
response = await client.post(api_base, headers=headers, data=data) # type: ignore
response.raise_for_status()
except httpx.HTTPStatusError as err:
error_code = err.response.status_code
@ -1125,11 +1130,55 @@ class AmazonConverseConfig:
"tool_choice",
]
def map_tool_choice_values(
self, model: str, tool_choice: Union[str, dict], drop_params: bool
) -> Optional[ToolChoiceValuesBlock]:
if not model.startswith("anthropic") and not model.startswith("mistral"):
# only anthropic and mistral support tool choice config. otherwise (E.g. cohere) will fail the call - https://docs.aws.amazon.com/bedrock/latest/APIReference/API_runtime_ToolChoice.html
if drop_params == True or litellm.drop_params == True:
return None
else:
raise litellm.utils.UnsupportedParamsError(
message="Only Anthropic and Mistral on Bedrock support 'tool_choice'. To drop it from the call, set `litellm.drop_params = True.`",
status_code=400,
)
if tool_choice == "none":
if litellm.drop_params is True or drop_params is True:
return None
else:
raise litellm.utils.UnsupportedParamsError(
message="Bedrock doesn't support tool_choice={}. To drop it from the call, set `litellm.drop_params = True.".format(
tool_choice
),
status_code=400,
)
elif tool_choice == "required":
return ToolChoiceValuesBlock(any={})
elif tool_choice == "auto":
return ToolChoiceValuesBlock(auto={})
elif isinstance(tool_choice, dict):
# only supported for anthropic + mistral models - https://docs.aws.amazon.com/bedrock/latest/APIReference/API_runtime_ToolChoice.html
specific_tool = SpecificToolChoiceBlock(
name=tool_choice.get("function", {}).get("name", "")
)
return ToolChoiceValuesBlock(tool=specific_tool)
else:
raise litellm.utils.UnsupportedParamsError(
message="Bedrock doesn't support tool_choice={}. Supported tool_choice values=['auto', 'required', json object]. To drop it from the call, set `litellm.drop_params = True.".format(
tool_choice
),
status_code=400,
)
def get_supported_image_types(self) -> List[str]:
return ["png", "jpeg", "gif", "webp"]
def map_openai_params(
self, non_default_params: dict, optional_params: dict
self,
model: str,
non_default_params: dict,
optional_params: dict,
drop_params: bool,
) -> dict:
for param, value in non_default_params.items():
if param == "max_tokens":
@ -1144,6 +1193,14 @@ class AmazonConverseConfig:
optional_params["temperature"] = value
if param == "top_p":
optional_params["topP"] = value
if param == "tools":
optional_params["tools"] = value
if param == "tool_choice":
_tool_choice_value = self.map_tool_choice_values(
model=model, tool_choice=value, drop_params=drop_params
)
if _tool_choice_value is not None:
optional_params["tool_choice"] = _tool_choice_value
return optional_params
@ -1151,6 +1208,124 @@ class BedrockConverseLLM(BaseLLM):
def __init__(self) -> None:
super().__init__()
def process_response(
self,
model: str,
response: Union[requests.Response, httpx.Response],
model_response: ModelResponse,
stream: bool,
logging_obj: Logging,
optional_params: dict,
api_key: str,
data: Union[dict, str],
messages: List,
print_verbose,
encoding,
) -> Union[ModelResponse, CustomStreamWrapper]:
## LOGGING
logging_obj.post_call(
input=messages,
api_key=api_key,
original_response=response.text,
additional_args={"complete_input_dict": data},
)
print_verbose(f"raw model_response: {response.text}")
## RESPONSE OBJECT
try:
completion_response = ConverseResponseBlock(**response.json()) # type: ignore
except Exception as e:
raise BedrockError(
message="Received={}, Error converting to valid response block={}. File an issue if litellm error - https://github.com/BerriAI/litellm/issues".format(
response.text, str(e)
),
status_code=422,
)
"""
Bedrock Response Object has optional message block
completion_response["output"].get("message", None)
A message block looks like this (Example 1):
"output": {
"message": {
"role": "assistant",
"content": [
{
"text": "Is there anything else you'd like to talk about? Perhaps I can help with some economic questions or provide some information about economic concepts?"
}
]
}
},
(Example 2):
"output": {
"message": {
"role": "assistant",
"content": [
{
"toolUse": {
"toolUseId": "tooluse_hbTgdi0CSLq_hM4P8csZJA",
"name": "top_song",
"input": {
"sign": "WZPZ"
}
}
}
]
}
}
"""
message: Optional[MessageBlock] = completion_response["output"]["message"]
chat_completion_message: ChatCompletionResponseMessage = {"role": "assistant"}
content_str = ""
tools: List[ChatCompletionToolCallChunk] = []
if message is not None:
for content in message["content"]:
"""
- Content is either a tool response or text
"""
if "text" in content:
content_str += content["text"]
if "toolUse" in content:
_function_chunk = ChatCompletionToolCallFunctionChunk(
name=content["toolUse"]["name"],
arguments=json.dumps(content["toolUse"]["input"]),
)
_tool_response_chunk = ChatCompletionToolCallChunk(
id=content["toolUse"]["toolUseId"],
type="function",
function=_function_chunk,
)
tools.append(_tool_response_chunk)
chat_completion_message["content"] = content_str
chat_completion_message["tool_calls"] = tools
## CALCULATING USAGE - bedrock returns usage in the headers
input_tokens = completion_response["usage"]["inputTokens"]
output_tokens = completion_response["usage"]["outputTokens"]
total_tokens = completion_response["usage"]["totalTokens"]
model_response.choices = [
litellm.Choices(
finish_reason=map_finish_reason(completion_response["stopReason"]),
index=0,
message=litellm.Message(**chat_completion_message),
)
]
model_response["created"] = int(time.time())
model_response["model"] = model
usage = Usage(
prompt_tokens=input_tokens,
completion_tokens=output_tokens,
total_tokens=total_tokens,
)
setattr(model_response, "usage", usage)
return model_response
def encode_model_id(self, model_id: str) -> str:
"""
Double encode the model ID to ensure it matches the expected double-encoded format.
@ -1387,11 +1562,14 @@ class BedrockConverseLLM(BaseLLM):
additional_request_keys = []
additional_request_params = {}
supported_converse_params = AmazonConverseConfig().get_config().keys()
supported_tool_call_params = ["tools", "tool_choice"]
## TRANSFORMATION ##
# send all model-specific params in 'additional_request_params'
for k, v in inference_params.items():
if k not in supported_converse_params:
if (
k not in supported_converse_params
and k not in supported_tool_call_params
):
additional_request_params[k] = v
additional_request_keys.append(k)
for key in additional_request_keys:
@ -1401,23 +1579,27 @@ class BedrockConverseLLM(BaseLLM):
messages=messages
)
bedrock_tools: List[ToolBlock] = _bedrock_tools_pt(
inference_params.get("tools", [])
inference_params.pop("tools", [])
)
bedrock_tool_config: Optional[ToolConfigBlock] = None
if len(bedrock_tools) > 0:
tool_choice_values: ToolChoiceValuesBlock = inference_params.pop(
"tool_choice", None
)
bedrock_tool_config = ToolConfigBlock(
tools=bedrock_tools,
toolChoice=inference_params.get("tool_choice", None),
)
if tool_choice_values is not None:
bedrock_tool_config["toolChoice"] = tool_choice_values
data: RequestObject = {
_data: RequestObject = {
"messages": bedrock_messages,
"additionalModelRequestFields": additional_request_params,
"system": system_content_blocks,
}
if bedrock_tool_config is not None:
data["toolConfig"] = bedrock_tool_config
_data["toolConfig"] = bedrock_tool_config
data = json.dumps(_data)
## COMPLETION CALL
headers = {"Content-Type": "application/json"}
@ -1441,8 +1623,18 @@ class BedrockConverseLLM(BaseLLM):
)
### ROUTING (ASYNC, STREAMING, SYNC)
### COMPLETION
if client is None or isinstance(client, AsyncHTTPHandler):
_params = {}
if timeout is not None:
if isinstance(timeout, float) or isinstance(timeout, int):
timeout = httpx.Timeout(timeout)
_params["timeout"] = timeout
client = HTTPHandler(**_params) # type: ignore
else:
client = client
try:
response = self.client.post(url=prepped.url, headers=prepped.headers, data=data) # type: ignore
response = client.post(url=prepped.url, headers=prepped.headers, data=data) # type: ignore
response.raise_for_status()
except httpx.HTTPStatusError as err:
error_code = err.response.status_code
@ -1450,6 +1642,20 @@ class BedrockConverseLLM(BaseLLM):
except httpx.TimeoutException as e:
raise BedrockError(status_code=408, message="Timeout error occurred.")
return self.process_response(
model=model,
response=response,
model_response=model_response,
stream=stream,
logging_obj=logging_obj,
optional_params=optional_params,
api_key="",
data=data,
messages=messages,
print_verbose=print_verbose,
encoding=encoding,
)
def get_response_stream_shape():
from botocore.model import ServiceModel

View file

@ -156,12 +156,13 @@ class HTTPHandler:
self,
url: str,
data: Optional[Union[dict, str]] = None,
json: Optional[Union[dict, str]] = None,
params: Optional[dict] = None,
headers: Optional[dict] = None,
stream: bool = False,
):
req = self.client.build_request(
"POST", url, data=data, params=params, headers=headers # type: ignore
"POST", url, data=data, json=json, params=params, headers=headers # type: ignore
)
response = self.client.send(req, stream=stream)
return response

View file

@ -1617,6 +1617,7 @@ from litellm.types.llms.bedrock import (
ToolInputSchemaBlock as BedrockToolInputSchemaBlock,
ToolSpecBlock as BedrockToolSpecBlock,
ToolBlock as BedrockToolBlock,
ToolChoiceValuesBlock as BedrockToolChoiceValuesBlock,
)
@ -1814,7 +1815,7 @@ def _convert_to_bedrock_tool_call_result(
tool_result_content_block = BedrockToolResultContentBlock(text=content)
tool_result = BedrockToolResultBlock(
content=tool_result_content_block,
content=[tool_result_content_block],
toolUseId=id,
)
content_block = BedrockContentBlock(toolResult=tool_result)

View file

@ -79,7 +79,7 @@ from .llms.anthropic import AnthropicChatCompletion
from .llms.anthropic_text import AnthropicTextCompletion
from .llms.huggingface_restapi import Huggingface
from .llms.predibase import PredibaseChatCompletion
from .llms.bedrock_httpx import BedrockLLM
from .llms.bedrock_httpx import BedrockLLM, BedrockConverseLLM
from .llms.vertex_httpx import VertexLLM
from .llms.triton import TritonChatCompletion
from .llms.prompt_templates.factory import (
@ -121,7 +121,7 @@ azure_text_completions = AzureTextCompletion()
huggingface = Huggingface()
predibase_chat_completions = PredibaseChatCompletion()
triton_chat_completions = TritonChatCompletion()
bedrock_chat_completion = BedrockLLM()
bedrock_chat_completion = BedrockConverseLLM()
vertex_chat_completion = VertexLLM()
####### COMPLETION ENDPOINTS ################

View file

@ -302,10 +302,7 @@ def test_completion_claude_3():
@pytest.mark.parametrize(
"model",
[
# "anthropic/claude-3-opus-20240229",
"cohere.command-r-plus-v1:0"
],
["anthropic/claude-3-opus-20240229", "anthropic.claude-3-sonnet-20240229-v1:0"],
)
def test_completion_claude_3_function_call(model):
litellm.set_verbose = True
@ -345,6 +342,7 @@ def test_completion_claude_3_function_call(model):
"type": "function",
"function": {"name": "get_current_weather"},
},
drop_params=True,
)
# Add any assertions, here to check response args
@ -375,6 +373,7 @@ def test_completion_claude_3_function_call(model):
messages=messages,
tools=tools,
tool_choice="auto",
drop_params=True,
)
print(second_response)
except Exception as e:

View file

@ -15,6 +15,7 @@ from litellm.llms.prompt_templates.factory import (
claude_2_1_pt,
llama_2_chat_pt,
prompt_factory,
_bedrock_tools_pt,
)
@ -128,3 +129,29 @@ def test_anthropic_messages_pt():
# codellama_prompt_format()
def test_bedrock_tool_calling_pt():
tools = [
{
"type": "function",
"function": {
"name": "get_current_weather",
"description": "Get the current weather in a given location",
"parameters": {
"type": "object",
"properties": {
"location": {
"type": "string",
"description": "The city and state, e.g. San Francisco, CA",
},
"unit": {"type": "string", "enum": ["celsius", "fahrenheit"]},
},
"required": ["location"],
},
},
}
]
converted_tools = _bedrock_tools_pt(tools=tools)
print(converted_tools)
assert False

View file

@ -31,7 +31,7 @@ class ToolResultContentBlock(TypedDict, total=False):
class ToolResultBlock(TypedDict, total=False):
content: Required[ToolResultContentBlock]
content: Required[List[ToolResultContentBlock]]
toolUseId: Required[str]
status: Literal["success", "error"]
@ -54,6 +54,30 @@ class MessageBlock(TypedDict):
role: Literal["user", "assistant"]
class ConverseMetricsBlock(TypedDict):
latencyMs: float # time in ms
class ConverseResponseOutputBlock(TypedDict):
message: Optional[MessageBlock]
class ConverseTokenUsageBlock(TypedDict):
inputTokens: int
outputTokens: int
totalTokens: int
class ConverseResponseBlock(TypedDict):
additionalModelResponseFields: dict
metrics: ConverseMetricsBlock
output: ConverseResponseOutputBlock
stopReason: (
str # end_turn | tool_use | max_tokens | stop_sequence | content_filtered
)
usage: ConverseTokenUsageBlock
class ToolInputSchemaBlock(TypedDict):
json: Optional[dict]
@ -72,9 +96,15 @@ class SpecificToolChoiceBlock(TypedDict):
name: str
class ToolChoiceValuesBlock(TypedDict, total=False):
any: dict
auto: dict
tool: SpecificToolChoiceBlock
class ToolConfigBlock(TypedDict, total=False):
tools: Required[List[ToolBlock]]
toolChoice: Union[str, SpecificToolChoiceBlock]
toolChoice: Union[str, ToolChoiceValuesBlock]
class RequestObject(TypedDict, total=False):

View file

@ -293,3 +293,20 @@ class ListBatchRequest(TypedDict, total=False):
extra_headers: Optional[Dict[str, str]]
extra_body: Optional[Dict[str, str]]
timeout: Optional[float]
class ChatCompletionToolCallFunctionChunk(TypedDict):
name: str
arguments: str
class ChatCompletionToolCallChunk(TypedDict):
id: str
type: Literal["function"]
function: ChatCompletionToolCallFunctionChunk
class ChatCompletionResponseMessage(TypedDict, total=False):
content: Optional[str]
tool_calls: List[ChatCompletionToolCallChunk]
role: Literal["assistant"]

View file

@ -5618,84 +5618,13 @@ def get_optional_params(
supported_params = get_supported_openai_params(
model=model, custom_llm_provider=custom_llm_provider
)
if "ai21" in model:
_check_valid_arg(supported_params=supported_params)
# params "maxTokens":200,"temperature":0,"topP":250,"stop_sequences":[],
# https://us-west-2.console.aws.amazon.com/bedrock/home?region=us-west-2#/providers?model=j2-ultra
if max_tokens is not None:
optional_params["maxTokens"] = max_tokens
if temperature is not None:
optional_params["temperature"] = temperature
if top_p is not None:
optional_params["topP"] = top_p
if stream:
optional_params["stream"] = stream
elif "anthropic" in model:
_check_valid_arg(supported_params=supported_params)
# anthropic params on bedrock
# \"max_tokens_to_sample\":300,\"temperature\":0.5,\"top_p\":1,\"stop_sequences\":[\"\\\\n\\\\nHuman:\"]}"
if model.startswith("anthropic.claude-3"):
optional_params = (
litellm.AmazonAnthropicClaude3Config().map_openai_params(
non_default_params=non_default_params,
optional_params=optional_params,
)
)
else:
optional_params = litellm.AmazonAnthropicConfig().map_openai_params(
non_default_params=non_default_params,
optional_params=optional_params,
)
elif "amazon" in model: # amazon titan llms
_check_valid_arg(supported_params=supported_params)
# see https://us-west-2.console.aws.amazon.com/bedrock/home?region=us-west-2#/providers?model=titan-large
if max_tokens is not None:
optional_params["maxTokenCount"] = max_tokens
if temperature is not None:
optional_params["temperature"] = temperature
if stop is not None:
filtered_stop = _map_and_modify_arg(
{"stop": stop}, provider="bedrock", model=model
)
optional_params["stopSequences"] = filtered_stop["stop"]
if top_p is not None:
optional_params["topP"] = top_p
if stream:
optional_params["stream"] = stream
elif "meta" in model: # amazon / meta llms
_check_valid_arg(supported_params=supported_params)
# see https://us-west-2.console.aws.amazon.com/bedrock/home?region=us-west-2#/providers?model=titan-large
if max_tokens is not None:
optional_params["max_gen_len"] = max_tokens
if temperature is not None:
optional_params["temperature"] = temperature
if top_p is not None:
optional_params["top_p"] = top_p
if stream:
optional_params["stream"] = stream
elif "cohere" in model: # cohere models on bedrock
_check_valid_arg(supported_params=supported_params)
# handle cohere params
if stream:
optional_params["stream"] = stream
if temperature is not None:
optional_params["temperature"] = temperature
if max_tokens is not None:
optional_params["max_tokens"] = max_tokens
elif "mistral" in model:
_check_valid_arg(supported_params=supported_params)
# mistral params on bedrock
# \"max_tokens\":400,\"temperature\":0.7,\"top_p\":0.7,\"stop\":[\"\\\\n\\\\nHuman:\"]}"
if max_tokens is not None:
optional_params["max_tokens"] = max_tokens
if temperature is not None:
optional_params["temperature"] = temperature
if top_p is not None:
optional_params["top_p"] = top_p
if stop is not None:
optional_params["stop"] = stop
if stream is not None:
optional_params["stream"] = stream
_check_valid_arg(supported_params=supported_params)
optional_params = litellm.AmazonConverseConfig().map_openai_params(
model=model,
non_default_params=non_default_params,
optional_params=optional_params,
drop_params=drop_params,
)
elif custom_llm_provider == "aleph_alpha":
supported_params = [
"max_tokens",

3
ruff.toml Normal file
View file

@ -0,0 +1,3 @@
ignore = ["F405"]
extend-select = ["E501"]
line-length = 120