From a995a0b172bf4cfa03fa5d0e4e6028aee01c33ec Mon Sep 17 00:00:00 2001 From: Krrish Dholakia Date: Thu, 6 Jun 2024 20:12:41 -0700 Subject: [PATCH] fix(bedrock_httpx.py): working claude 3 function calling --- .pre-commit-config.yaml | 16 +- litellm/llms/bedrock_httpx.py | 230 ++++++++++++++++++++-- litellm/llms/custom_httpx/http_handler.py | 3 +- litellm/llms/prompt_templates/factory.py | 3 +- litellm/main.py | 4 +- litellm/tests/test_completion.py | 7 +- litellm/tests/test_prompt_factory.py | 27 +++ litellm/types/llms/bedrock.py | 34 +++- litellm/types/llms/openai.py | 17 ++ litellm/utils.py | 85 +------- ruff.toml | 3 + 11 files changed, 321 insertions(+), 108 deletions(-) create mode 100644 ruff.toml diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index cc41d85f1..e8bb1ff66 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -16,11 +16,11 @@ repos: name: Check if files match entry: python3 ci_cd/check_files_match.py language: system -- repo: local - hooks: - - id: mypy - name: mypy - entry: python3 -m mypy --ignore-missing-imports - language: system - types: [python] - files: ^litellm/ \ No newline at end of file +# - repo: local +# hooks: +# - id: mypy +# name: mypy +# entry: python3 -m mypy --ignore-missing-imports +# language: system +# types: [python] +# files: ^litellm/ \ No newline at end of file diff --git a/litellm/llms/bedrock_httpx.py b/litellm/llms/bedrock_httpx.py index e21265006..ce6a93174 100644 --- a/litellm/llms/bedrock_httpx.py +++ b/litellm/llms/bedrock_httpx.py @@ -47,6 +47,11 @@ import httpx # type: ignore from .bedrock import BedrockError, convert_messages_to_prompt, ModelResponseIterator from litellm.types.llms.bedrock import * import urllib.parse +from litellm.types.llms.openai import ( + ChatCompletionResponseMessage, + ChatCompletionToolCallChunk, + ChatCompletionToolCallFunctionChunk, +) class AmazonCohereChatConfig: @@ -1004,12 +1009,12 @@ class BedrockLLM(BaseLLM): if isinstance(timeout, float) or isinstance(timeout, int): timeout = httpx.Timeout(timeout) _params["timeout"] = timeout - self.client = AsyncHTTPHandler(**_params) # type: ignore + client = AsyncHTTPHandler(**_params) # type: ignore else: - self.client = client # type: ignore + client = client # type: ignore try: - response = await self.client.post(api_base, headers=headers, data=data) # type: ignore + response = await client.post(api_base, headers=headers, data=data) # type: ignore response.raise_for_status() except httpx.HTTPStatusError as err: error_code = err.response.status_code @@ -1125,11 +1130,55 @@ class AmazonConverseConfig: "tool_choice", ] + def map_tool_choice_values( + self, model: str, tool_choice: Union[str, dict], drop_params: bool + ) -> Optional[ToolChoiceValuesBlock]: + if not model.startswith("anthropic") and not model.startswith("mistral"): + # only anthropic and mistral support tool choice config. otherwise (E.g. cohere) will fail the call - https://docs.aws.amazon.com/bedrock/latest/APIReference/API_runtime_ToolChoice.html + if drop_params == True or litellm.drop_params == True: + return None + else: + raise litellm.utils.UnsupportedParamsError( + message="Only Anthropic and Mistral on Bedrock support 'tool_choice'. To drop it from the call, set `litellm.drop_params = True.`", + status_code=400, + ) + if tool_choice == "none": + if litellm.drop_params is True or drop_params is True: + return None + else: + raise litellm.utils.UnsupportedParamsError( + message="Bedrock doesn't support tool_choice={}. To drop it from the call, set `litellm.drop_params = True.".format( + tool_choice + ), + status_code=400, + ) + elif tool_choice == "required": + return ToolChoiceValuesBlock(any={}) + elif tool_choice == "auto": + return ToolChoiceValuesBlock(auto={}) + elif isinstance(tool_choice, dict): + # only supported for anthropic + mistral models - https://docs.aws.amazon.com/bedrock/latest/APIReference/API_runtime_ToolChoice.html + specific_tool = SpecificToolChoiceBlock( + name=tool_choice.get("function", {}).get("name", "") + ) + return ToolChoiceValuesBlock(tool=specific_tool) + else: + raise litellm.utils.UnsupportedParamsError( + message="Bedrock doesn't support tool_choice={}. Supported tool_choice values=['auto', 'required', json object]. To drop it from the call, set `litellm.drop_params = True.".format( + tool_choice + ), + status_code=400, + ) + def get_supported_image_types(self) -> List[str]: return ["png", "jpeg", "gif", "webp"] def map_openai_params( - self, non_default_params: dict, optional_params: dict + self, + model: str, + non_default_params: dict, + optional_params: dict, + drop_params: bool, ) -> dict: for param, value in non_default_params.items(): if param == "max_tokens": @@ -1144,6 +1193,14 @@ class AmazonConverseConfig: optional_params["temperature"] = value if param == "top_p": optional_params["topP"] = value + if param == "tools": + optional_params["tools"] = value + if param == "tool_choice": + _tool_choice_value = self.map_tool_choice_values( + model=model, tool_choice=value, drop_params=drop_params + ) + if _tool_choice_value is not None: + optional_params["tool_choice"] = _tool_choice_value return optional_params @@ -1151,6 +1208,124 @@ class BedrockConverseLLM(BaseLLM): def __init__(self) -> None: super().__init__() + def process_response( + self, + model: str, + response: Union[requests.Response, httpx.Response], + model_response: ModelResponse, + stream: bool, + logging_obj: Logging, + optional_params: dict, + api_key: str, + data: Union[dict, str], + messages: List, + print_verbose, + encoding, + ) -> Union[ModelResponse, CustomStreamWrapper]: + + ## LOGGING + logging_obj.post_call( + input=messages, + api_key=api_key, + original_response=response.text, + additional_args={"complete_input_dict": data}, + ) + print_verbose(f"raw model_response: {response.text}") + + ## RESPONSE OBJECT + try: + completion_response = ConverseResponseBlock(**response.json()) # type: ignore + except Exception as e: + raise BedrockError( + message="Received={}, Error converting to valid response block={}. File an issue if litellm error - https://github.com/BerriAI/litellm/issues".format( + response.text, str(e) + ), + status_code=422, + ) + + """ + Bedrock Response Object has optional message block + + completion_response["output"].get("message", None) + + A message block looks like this (Example 1): + "output": { + "message": { + "role": "assistant", + "content": [ + { + "text": "Is there anything else you'd like to talk about? Perhaps I can help with some economic questions or provide some information about economic concepts?" + } + ] + } + }, + (Example 2): + "output": { + "message": { + "role": "assistant", + "content": [ + { + "toolUse": { + "toolUseId": "tooluse_hbTgdi0CSLq_hM4P8csZJA", + "name": "top_song", + "input": { + "sign": "WZPZ" + } + } + } + ] + } + } + + """ + message: Optional[MessageBlock] = completion_response["output"]["message"] + chat_completion_message: ChatCompletionResponseMessage = {"role": "assistant"} + content_str = "" + tools: List[ChatCompletionToolCallChunk] = [] + if message is not None: + for content in message["content"]: + """ + - Content is either a tool response or text + """ + if "text" in content: + content_str += content["text"] + if "toolUse" in content: + _function_chunk = ChatCompletionToolCallFunctionChunk( + name=content["toolUse"]["name"], + arguments=json.dumps(content["toolUse"]["input"]), + ) + _tool_response_chunk = ChatCompletionToolCallChunk( + id=content["toolUse"]["toolUseId"], + type="function", + function=_function_chunk, + ) + tools.append(_tool_response_chunk) + chat_completion_message["content"] = content_str + chat_completion_message["tool_calls"] = tools + + ## CALCULATING USAGE - bedrock returns usage in the headers + input_tokens = completion_response["usage"]["inputTokens"] + output_tokens = completion_response["usage"]["outputTokens"] + total_tokens = completion_response["usage"]["totalTokens"] + + model_response.choices = [ + litellm.Choices( + finish_reason=map_finish_reason(completion_response["stopReason"]), + index=0, + message=litellm.Message(**chat_completion_message), + ) + ] + model_response["created"] = int(time.time()) + model_response["model"] = model + usage = Usage( + prompt_tokens=input_tokens, + completion_tokens=output_tokens, + total_tokens=total_tokens, + ) + setattr(model_response, "usage", usage) + + return model_response + def encode_model_id(self, model_id: str) -> str: """ Double encode the model ID to ensure it matches the expected double-encoded format. @@ -1387,11 +1562,14 @@ class BedrockConverseLLM(BaseLLM): additional_request_keys = [] additional_request_params = {} supported_converse_params = AmazonConverseConfig().get_config().keys() - + supported_tool_call_params = ["tools", "tool_choice"] ## TRANSFORMATION ## # send all model-specific params in 'additional_request_params' for k, v in inference_params.items(): - if k not in supported_converse_params: + if ( + k not in supported_converse_params + and k not in supported_tool_call_params + ): additional_request_params[k] = v additional_request_keys.append(k) for key in additional_request_keys: @@ -1401,23 +1579,27 @@ class BedrockConverseLLM(BaseLLM): messages=messages ) bedrock_tools: List[ToolBlock] = _bedrock_tools_pt( - inference_params.get("tools", []) + inference_params.pop("tools", []) ) bedrock_tool_config: Optional[ToolConfigBlock] = None if len(bedrock_tools) > 0: + tool_choice_values: ToolChoiceValuesBlock = inference_params.pop( + "tool_choice", None + ) bedrock_tool_config = ToolConfigBlock( tools=bedrock_tools, - toolChoice=inference_params.get("tool_choice", None), ) + if tool_choice_values is not None: + bedrock_tool_config["toolChoice"] = tool_choice_values - data: RequestObject = { + _data: RequestObject = { "messages": bedrock_messages, "additionalModelRequestFields": additional_request_params, "system": system_content_blocks, } if bedrock_tool_config is not None: - data["toolConfig"] = bedrock_tool_config - + _data["toolConfig"] = bedrock_tool_config + data = json.dumps(_data) ## COMPLETION CALL headers = {"Content-Type": "application/json"} @@ -1441,8 +1623,18 @@ class BedrockConverseLLM(BaseLLM): ) ### ROUTING (ASYNC, STREAMING, SYNC) + ### COMPLETION + if client is None or isinstance(client, AsyncHTTPHandler): + _params = {} + if timeout is not None: + if isinstance(timeout, float) or isinstance(timeout, int): + timeout = httpx.Timeout(timeout) + _params["timeout"] = timeout + client = HTTPHandler(**_params) # type: ignore + else: + client = client try: - response = self.client.post(url=prepped.url, headers=prepped.headers, data=data) # type: ignore + response = client.post(url=prepped.url, headers=prepped.headers, data=data) # type: ignore response.raise_for_status() except httpx.HTTPStatusError as err: error_code = err.response.status_code @@ -1450,6 +1642,20 @@ class BedrockConverseLLM(BaseLLM): except httpx.TimeoutException as e: raise BedrockError(status_code=408, message="Timeout error occurred.") + return self.process_response( + model=model, + response=response, + model_response=model_response, + stream=stream, + logging_obj=logging_obj, + optional_params=optional_params, + api_key="", + data=data, + messages=messages, + print_verbose=print_verbose, + encoding=encoding, + ) + def get_response_stream_shape(): from botocore.model import ServiceModel diff --git a/litellm/llms/custom_httpx/http_handler.py b/litellm/llms/custom_httpx/http_handler.py index b91aaee2a..5ec9c79bb 100644 --- a/litellm/llms/custom_httpx/http_handler.py +++ b/litellm/llms/custom_httpx/http_handler.py @@ -156,12 +156,13 @@ class HTTPHandler: self, url: str, data: Optional[Union[dict, str]] = None, + json: Optional[Union[dict, str]] = None, params: Optional[dict] = None, headers: Optional[dict] = None, stream: bool = False, ): req = self.client.build_request( - "POST", url, data=data, params=params, headers=headers # type: ignore + "POST", url, data=data, json=json, params=params, headers=headers # type: ignore ) response = self.client.send(req, stream=stream) return response diff --git a/litellm/llms/prompt_templates/factory.py b/litellm/llms/prompt_templates/factory.py index d5ef69687..ddd0e1909 100644 --- a/litellm/llms/prompt_templates/factory.py +++ b/litellm/llms/prompt_templates/factory.py @@ -1617,6 +1617,7 @@ from litellm.types.llms.bedrock import ( ToolInputSchemaBlock as BedrockToolInputSchemaBlock, ToolSpecBlock as BedrockToolSpecBlock, ToolBlock as BedrockToolBlock, + ToolChoiceValuesBlock as BedrockToolChoiceValuesBlock, ) @@ -1814,7 +1815,7 @@ def _convert_to_bedrock_tool_call_result( tool_result_content_block = BedrockToolResultContentBlock(text=content) tool_result = BedrockToolResultBlock( - content=tool_result_content_block, + content=[tool_result_content_block], toolUseId=id, ) content_block = BedrockContentBlock(toolResult=tool_result) diff --git a/litellm/main.py b/litellm/main.py index f76d6c521..c95b419ba 100644 --- a/litellm/main.py +++ b/litellm/main.py @@ -79,7 +79,7 @@ from .llms.anthropic import AnthropicChatCompletion from .llms.anthropic_text import AnthropicTextCompletion from .llms.huggingface_restapi import Huggingface from .llms.predibase import PredibaseChatCompletion -from .llms.bedrock_httpx import BedrockLLM +from .llms.bedrock_httpx import BedrockLLM, BedrockConverseLLM from .llms.vertex_httpx import VertexLLM from .llms.triton import TritonChatCompletion from .llms.prompt_templates.factory import ( @@ -121,7 +121,7 @@ azure_text_completions = AzureTextCompletion() huggingface = Huggingface() predibase_chat_completions = PredibaseChatCompletion() triton_chat_completions = TritonChatCompletion() -bedrock_chat_completion = BedrockLLM() +bedrock_chat_completion = BedrockConverseLLM() vertex_chat_completion = VertexLLM() ####### COMPLETION ENDPOINTS ################ diff --git a/litellm/tests/test_completion.py b/litellm/tests/test_completion.py index 1befa1392..bcbe4944c 100644 --- a/litellm/tests/test_completion.py +++ b/litellm/tests/test_completion.py @@ -302,10 +302,7 @@ def test_completion_claude_3(): @pytest.mark.parametrize( "model", - [ - # "anthropic/claude-3-opus-20240229", - "cohere.command-r-plus-v1:0" - ], + ["anthropic/claude-3-opus-20240229", "anthropic.claude-3-sonnet-20240229-v1:0"], ) def test_completion_claude_3_function_call(model): litellm.set_verbose = True @@ -345,6 +342,7 @@ def test_completion_claude_3_function_call(model): "type": "function", "function": {"name": "get_current_weather"}, }, + drop_params=True, ) # Add any assertions, here to check response args @@ -375,6 +373,7 @@ def test_completion_claude_3_function_call(model): messages=messages, tools=tools, tool_choice="auto", + drop_params=True, ) print(second_response) except Exception as e: diff --git a/litellm/tests/test_prompt_factory.py b/litellm/tests/test_prompt_factory.py index 2fc04ec52..9f112a0b1 100644 --- a/litellm/tests/test_prompt_factory.py +++ b/litellm/tests/test_prompt_factory.py @@ -15,6 +15,7 @@ from litellm.llms.prompt_templates.factory import ( claude_2_1_pt, llama_2_chat_pt, prompt_factory, + _bedrock_tools_pt, ) @@ -128,3 +129,29 @@ def test_anthropic_messages_pt(): # codellama_prompt_format() +def test_bedrock_tool_calling_pt(): + tools = [ + { + "type": "function", + "function": { + "name": "get_current_weather", + "description": "Get the current weather in a given location", + "parameters": { + "type": "object", + "properties": { + "location": { + "type": "string", + "description": "The city and state, e.g. San Francisco, CA", + }, + "unit": {"type": "string", "enum": ["celsius", "fahrenheit"]}, + }, + "required": ["location"], + }, + }, + } + ] + converted_tools = _bedrock_tools_pt(tools=tools) + + print(converted_tools) + + assert False diff --git a/litellm/types/llms/bedrock.py b/litellm/types/llms/bedrock.py index 9333ea1b9..647dc1d7b 100644 --- a/litellm/types/llms/bedrock.py +++ b/litellm/types/llms/bedrock.py @@ -31,7 +31,7 @@ class ToolResultContentBlock(TypedDict, total=False): class ToolResultBlock(TypedDict, total=False): - content: Required[ToolResultContentBlock] + content: Required[List[ToolResultContentBlock]] toolUseId: Required[str] status: Literal["success", "error"] @@ -54,6 +54,30 @@ class MessageBlock(TypedDict): role: Literal["user", "assistant"] +class ConverseMetricsBlock(TypedDict): + latencyMs: float # time in ms + + +class ConverseResponseOutputBlock(TypedDict): + message: Optional[MessageBlock] + + +class ConverseTokenUsageBlock(TypedDict): + inputTokens: int + outputTokens: int + totalTokens: int + + +class ConverseResponseBlock(TypedDict): + additionalModelResponseFields: dict + metrics: ConverseMetricsBlock + output: ConverseResponseOutputBlock + stopReason: ( + str # end_turn | tool_use | max_tokens | stop_sequence | content_filtered + ) + usage: ConverseTokenUsageBlock + + class ToolInputSchemaBlock(TypedDict): json: Optional[dict] @@ -72,9 +96,15 @@ class SpecificToolChoiceBlock(TypedDict): name: str +class ToolChoiceValuesBlock(TypedDict, total=False): + any: dict + auto: dict + tool: SpecificToolChoiceBlock + + class ToolConfigBlock(TypedDict, total=False): tools: Required[List[ToolBlock]] - toolChoice: Union[str, SpecificToolChoiceBlock] + toolChoice: Union[str, ToolChoiceValuesBlock] class RequestObject(TypedDict, total=False): diff --git a/litellm/types/llms/openai.py b/litellm/types/llms/openai.py index bc0c82434..7861e394c 100644 --- a/litellm/types/llms/openai.py +++ b/litellm/types/llms/openai.py @@ -293,3 +293,20 @@ class ListBatchRequest(TypedDict, total=False): extra_headers: Optional[Dict[str, str]] extra_body: Optional[Dict[str, str]] timeout: Optional[float] + + +class ChatCompletionToolCallFunctionChunk(TypedDict): + name: str + arguments: str + + +class ChatCompletionToolCallChunk(TypedDict): + id: str + type: Literal["function"] + function: ChatCompletionToolCallFunctionChunk + + +class ChatCompletionResponseMessage(TypedDict, total=False): + content: Optional[str] + tool_calls: List[ChatCompletionToolCallChunk] + role: Literal["assistant"] diff --git a/litellm/utils.py b/litellm/utils.py index 65a34058b..6db5f540c 100644 --- a/litellm/utils.py +++ b/litellm/utils.py @@ -5618,84 +5618,13 @@ def get_optional_params( supported_params = get_supported_openai_params( model=model, custom_llm_provider=custom_llm_provider ) - if "ai21" in model: - _check_valid_arg(supported_params=supported_params) - # params "maxTokens":200,"temperature":0,"topP":250,"stop_sequences":[], - # https://us-west-2.console.aws.amazon.com/bedrock/home?region=us-west-2#/providers?model=j2-ultra - if max_tokens is not None: - optional_params["maxTokens"] = max_tokens - if temperature is not None: - optional_params["temperature"] = temperature - if top_p is not None: - optional_params["topP"] = top_p - if stream: - optional_params["stream"] = stream - elif "anthropic" in model: - _check_valid_arg(supported_params=supported_params) - # anthropic params on bedrock - # \"max_tokens_to_sample\":300,\"temperature\":0.5,\"top_p\":1,\"stop_sequences\":[\"\\\\n\\\\nHuman:\"]}" - if model.startswith("anthropic.claude-3"): - optional_params = ( - litellm.AmazonAnthropicClaude3Config().map_openai_params( - non_default_params=non_default_params, - optional_params=optional_params, - ) - ) - else: - optional_params = litellm.AmazonAnthropicConfig().map_openai_params( - non_default_params=non_default_params, - optional_params=optional_params, - ) - elif "amazon" in model: # amazon titan llms - _check_valid_arg(supported_params=supported_params) - # see https://us-west-2.console.aws.amazon.com/bedrock/home?region=us-west-2#/providers?model=titan-large - if max_tokens is not None: - optional_params["maxTokenCount"] = max_tokens - if temperature is not None: - optional_params["temperature"] = temperature - if stop is not None: - filtered_stop = _map_and_modify_arg( - {"stop": stop}, provider="bedrock", model=model - ) - optional_params["stopSequences"] = filtered_stop["stop"] - if top_p is not None: - optional_params["topP"] = top_p - if stream: - optional_params["stream"] = stream - elif "meta" in model: # amazon / meta llms - _check_valid_arg(supported_params=supported_params) - # see https://us-west-2.console.aws.amazon.com/bedrock/home?region=us-west-2#/providers?model=titan-large - if max_tokens is not None: - optional_params["max_gen_len"] = max_tokens - if temperature is not None: - optional_params["temperature"] = temperature - if top_p is not None: - optional_params["top_p"] = top_p - if stream: - optional_params["stream"] = stream - elif "cohere" in model: # cohere models on bedrock - _check_valid_arg(supported_params=supported_params) - # handle cohere params - if stream: - optional_params["stream"] = stream - if temperature is not None: - optional_params["temperature"] = temperature - if max_tokens is not None: - optional_params["max_tokens"] = max_tokens - elif "mistral" in model: - _check_valid_arg(supported_params=supported_params) - # mistral params on bedrock - # \"max_tokens\":400,\"temperature\":0.7,\"top_p\":0.7,\"stop\":[\"\\\\n\\\\nHuman:\"]}" - if max_tokens is not None: - optional_params["max_tokens"] = max_tokens - if temperature is not None: - optional_params["temperature"] = temperature - if top_p is not None: - optional_params["top_p"] = top_p - if stop is not None: - optional_params["stop"] = stop - if stream is not None: - optional_params["stream"] = stream + _check_valid_arg(supported_params=supported_params) + optional_params = litellm.AmazonConverseConfig().map_openai_params( + model=model, + non_default_params=non_default_params, + optional_params=optional_params, + drop_params=drop_params, + ) elif custom_llm_provider == "aleph_alpha": supported_params = [ "max_tokens", diff --git a/ruff.toml b/ruff.toml new file mode 100644 index 000000000..dfb323c1b --- /dev/null +++ b/ruff.toml @@ -0,0 +1,3 @@ +ignore = ["F405"] +extend-select = ["E501"] +line-length = 120