From aa1e13f65bc6578ed7ea1ac69ce5ed0ecb9c9591 Mon Sep 17 00:00:00 2001 From: Ishaan Jaff Date: Wed, 16 Apr 2025 09:01:26 -0700 Subject: [PATCH] add test_function_calling_with_tool_response to base llm tests --- .../prompt_templates/factory.py | 40 ++++++++++++------- litellm/types/llms/bedrock.py | 13 +++++- tests/llm_translation/base_llm_unit_tests.py | 2 + .../test_bedrock_completion.py | 22 ++++++++++ 4 files changed, 62 insertions(+), 15 deletions(-) diff --git a/litellm/litellm_core_utils/prompt_templates/factory.py b/litellm/litellm_core_utils/prompt_templates/factory.py index aa5dc0d49a..03e60d0f80 100644 --- a/litellm/litellm_core_utils/prompt_templates/factory.py +++ b/litellm/litellm_core_utils/prompt_templates/factory.py @@ -15,6 +15,12 @@ from litellm import verbose_logger from litellm.llms.custom_httpx.http_handler import HTTPHandler, get_async_httpx_client from litellm.types.llms.anthropic import * from litellm.types.llms.bedrock import MessageBlock as BedrockMessageBlock +from litellm.types.llms.bedrock import ToolBlock as BedrockToolBlock +from litellm.types.llms.bedrock import ( + ToolInputSchemaBlock as BedrockToolInputSchemaBlock, +) +from litellm.types.llms.bedrock import ToolJsonSchemaBlock as BedrockToolJsonSchemaBlock +from litellm.types.llms.bedrock import ToolSpecBlock as BedrockToolSpecBlock from litellm.types.llms.custom_http import httpxSpecialProvider from litellm.types.llms.ollama import OllamaVisionModelObject from litellm.types.llms.openai import ( @@ -1041,10 +1047,10 @@ def convert_to_gemini_tool_call_invoke( if tool_calls is not None: for tool in tool_calls: if "function" in tool: - gemini_function_call: Optional[ - VertexFunctionCall - ] = _gemini_tool_call_invoke_helper( - function_call_params=tool["function"] + gemini_function_call: Optional[VertexFunctionCall] = ( + _gemini_tool_call_invoke_helper( + function_call_params=tool["function"] + ) ) if gemini_function_call is not None: _parts_list.append( @@ -1139,7 +1145,7 @@ def convert_to_gemini_tool_call_result( def convert_to_anthropic_tool_result( - message: Union[ChatCompletionToolMessage, ChatCompletionFunctionMessage] + message: Union[ChatCompletionToolMessage, ChatCompletionFunctionMessage], ) -> AnthropicMessagesToolResultParam: """ OpenAI message with a tool result looks like: @@ -1449,9 +1455,9 @@ def anthropic_messages_pt( # noqa: PLR0915 ) if "cache_control" in _content_element: - _anthropic_content_element[ - "cache_control" - ] = _content_element["cache_control"] + _anthropic_content_element["cache_control"] = ( + _content_element["cache_control"] + ) user_content.append(_anthropic_content_element) elif m.get("type", "") == "text": m = cast(ChatCompletionTextObject, m) @@ -1502,9 +1508,9 @@ def anthropic_messages_pt( # noqa: PLR0915 ) if "cache_control" in _content_element: - _anthropic_content_text_element[ - "cache_control" - ] = _content_element["cache_control"] + _anthropic_content_text_element["cache_control"] = ( + _content_element["cache_control"] + ) user_content.append(_anthropic_content_text_element) @@ -2491,7 +2497,7 @@ def _convert_to_bedrock_tool_call_invoke( def _convert_to_bedrock_tool_call_result( - message: Union[ChatCompletionToolMessage, ChatCompletionFunctionMessage] + message: Union[ChatCompletionToolMessage, ChatCompletionFunctionMessage], ) -> BedrockContentBlock: """ OpenAI message with a tool result looks like: @@ -2664,7 +2670,7 @@ def get_user_message_block_or_continue_message( def return_assistant_continue_message( assistant_continue_message: Optional[ Union[str, ChatCompletionAssistantMessage] - ] = None + ] = None, ) -> ChatCompletionAssistantMessage: if assistant_continue_message and isinstance(assistant_continue_message, str): return ChatCompletionAssistantMessage( @@ -3462,7 +3468,13 @@ def _bedrock_tools_pt(tools: List) -> List[BedrockToolBlock]: for _, value in defs_copy.items(): unpack_defs(value, defs_copy) unpack_defs(parameters, defs_copy) - tool_input_schema = BedrockToolInputSchemaBlock(json=parameters) + tool_input_schema = BedrockToolInputSchemaBlock( + json=BedrockToolJsonSchemaBlock( + type=parameters.get("type", ""), + properties=parameters.get("properties", {}), + required=parameters.get("required", []), + ) + ) tool_spec = BedrockToolSpecBlock( inputSchema=tool_input_schema, name=name, description=description ) diff --git a/litellm/types/llms/bedrock.py b/litellm/types/llms/bedrock.py index fe3f2e1b5f..4acc7dfc08 100644 --- a/litellm/types/llms/bedrock.py +++ b/litellm/types/llms/bedrock.py @@ -125,8 +125,19 @@ class ConverseResponseBlock(TypedDict): usage: ConverseTokenUsageBlock +class ToolJsonArgsBlock(TypedDict, total=False): + type: str + description: str + + +class ToolJsonSchemaBlock(TypedDict, total=False): + type: Literal["object"] + properties: dict + required: List[str] + + class ToolInputSchemaBlock(TypedDict): - json: Optional[dict] + json: Optional[ToolJsonSchemaBlock] class ToolSpecBlock(TypedDict, total=False): diff --git a/tests/llm_translation/base_llm_unit_tests.py b/tests/llm_translation/base_llm_unit_tests.py index bd3627f7d4..89f0949425 100644 --- a/tests/llm_translation/base_llm_unit_tests.py +++ b/tests/llm_translation/base_llm_unit_tests.py @@ -1037,6 +1037,7 @@ class BaseLLMChatTest(ABC): def test_function_calling_with_tool_response(self): from litellm.utils import supports_function_calling from litellm import completion + litellm._turn_on_debug() os.environ["LITELLM_LOCAL_MODEL_COST_MAP"] = "True" litellm.model_cost = litellm.get_model_cost_map(url="") @@ -1056,6 +1057,7 @@ class BaseLLMChatTest(ABC): "name": "get_weather", "description": "Get the weather in a city", "parameters": { + "$id": "https://some/internal/name", "type": "object", "properties": { "city": { diff --git a/tests/llm_translation/test_bedrock_completion.py b/tests/llm_translation/test_bedrock_completion.py index d6e8ed4ff8..64808f53cf 100644 --- a/tests/llm_translation/test_bedrock_completion.py +++ b/tests/llm_translation/test_bedrock_completion.py @@ -2226,6 +2226,28 @@ class TestBedrockConverseChatNormal(BaseLLMChatTest): """ pass +class TestBedrockConverseNovaTestSuite(BaseLLMChatTest): + def get_base_completion_call_args(self) -> dict: + os.environ["LITELLM_LOCAL_MODEL_COST_MAP"] = "True" + litellm.model_cost = litellm.get_model_cost_map(url="") + litellm.add_known_models() + return { + "model": "bedrock/us.amazon.nova-lite-v1:0", + "aws_region_name": "us-east-1", + } + + def test_tool_call_no_arguments(self, tool_call_no_arguments): + """Test that tool calls with no arguments is translated correctly. Relevant issue: https://github.com/BerriAI/litellm/issues/6833""" + pass + + def test_multilingual_requests(self): + """ + Bedrock API raises a 400 BadRequest error when the request contains invalid utf-8 sequences. + + Todo: if litellm.modify_params is True ensure it's a valid utf-8 sequence + """ + pass + class TestBedrockRerank(BaseLLMRerankTest): def get_custom_llm_provider(self) -> litellm.LlmProviders: