diff --git a/litellm/main.py b/litellm/main.py index 9f08f9f26c..846eb9f2dd 100644 --- a/litellm/main.py +++ b/litellm/main.py @@ -88,6 +88,7 @@ from litellm.utils import ( supports_httpx_timeout, token_counter, validate_chat_completion_messages, + validate_chat_completion_tool_choice, ) from ._logging import verbose_logger @@ -847,6 +848,8 @@ def completion( # type: ignore # noqa: PLR0915 raise ValueError("model param not passed in.") # validate messages messages = validate_chat_completion_messages(messages=messages) + # validate tool_choice + tool_choice = validate_chat_completion_tool_choice(tool_choice=tool_choice) ######### unpacking kwargs ##################### args = locals() api_base = kwargs.get("api_base", None) diff --git a/litellm/utils.py b/litellm/utils.py index 7e9287dc89..1fd79cebae 100644 --- a/litellm/utils.py +++ b/litellm/utils.py @@ -6058,6 +6058,34 @@ def validate_chat_completion_user_messages(messages: List[AllMessageValues]): return messages +def validate_chat_completion_tool_choice( + tool_choice: Optional[Union[dict, str]] +) -> Optional[Union[dict, str]]: + """ + Confirm the tool choice is passed in the OpenAI format. + + Prevents user errors like: https://github.com/BerriAI/litellm/issues/7483 + """ + from litellm.types.llms.openai import ( + ChatCompletionToolChoiceObjectParam, + ChatCompletionToolChoiceStringValues, + ) + + if tool_choice is None: + return tool_choice + elif isinstance(tool_choice, str): + return tool_choice + elif isinstance(tool_choice, dict): + if tool_choice.get("type") is None or tool_choice.get("function") is None: + raise Exception( + f"Invalid tool choice, tool_choice={tool_choice}. Please ensure tool_choice follows the OpenAI spec" + ) + return tool_choice + raise Exception( + f"Invalid tool choice, tool_choice={tool_choice}. Got={type(tool_choice)}. Expecting str, or dict. Please ensure tool_choice follows the OpenAI tool_choice spec" + ) + + class ProviderConfigManager: @staticmethod def get_provider_chat_config( # noqa: PLR0915 diff --git a/tests/local_testing/test_utils.py b/tests/local_testing/test_utils.py index 76970a7435..d07ca29b1c 100644 --- a/tests/local_testing/test_utils.py +++ b/tests/local_testing/test_utils.py @@ -1085,6 +1085,26 @@ def test_validate_chat_completion_user_messages(messages, expected_bool): validate_chat_completion_user_messages(messages=messages) +@pytest.mark.parametrize( + "tool_choice, expected_bool", + [ + ({"type": "function", "function": {"name": "get_current_weather"}}, True), + ({"type": "tool", "name": "get_current_weather"}, False), + (None, True), + ("auto", True), + ("required", True), + ], +) +def test_validate_chat_completion_tool_choice(tool_choice, expected_bool): + from litellm.utils import validate_chat_completion_tool_choice + + if expected_bool: + validate_chat_completion_tool_choice(tool_choice=tool_choice) + else: + with pytest.raises(Exception): + validate_chat_completion_tool_choice(tool_choice=tool_choice) + + def test_models_by_provider(): """ Make sure all providers from model map are in the valid providers list