Litellm dev 01 01 2025 p3 (#7503)

* fix(utils.py): add new validate tool choice helper function

Prevents https://github.com/BerriAI/litellm/issues/7483

* fix(main.py): add tool choice validation on .completion()

prevents user error like - https://github.com/BerriAI/litellm/issues/7483

* fix(utils.py): fix return val of tool choice validation logic
This commit is contained in:
Krish Dholakia 2025-01-01 22:12:15 -08:00 committed by GitHub
parent 98cba7ba3f
commit b3611ace41
3 changed files with 51 additions and 0 deletions

View file

@ -88,6 +88,7 @@ from litellm.utils import (
supports_httpx_timeout,
token_counter,
validate_chat_completion_messages,
validate_chat_completion_tool_choice,
)
from ._logging import verbose_logger
@ -847,6 +848,8 @@ def completion( # type: ignore # noqa: PLR0915
raise ValueError("model param not passed in.")
# validate messages
messages = validate_chat_completion_messages(messages=messages)
# validate tool_choice
tool_choice = validate_chat_completion_tool_choice(tool_choice=tool_choice)
######### unpacking kwargs #####################
args = locals()
api_base = kwargs.get("api_base", None)

View file

@ -6058,6 +6058,34 @@ def validate_chat_completion_user_messages(messages: List[AllMessageValues]):
return messages
def validate_chat_completion_tool_choice(
tool_choice: Optional[Union[dict, str]]
) -> Optional[Union[dict, str]]:
"""
Confirm the tool choice is passed in the OpenAI format.
Prevents user errors like: https://github.com/BerriAI/litellm/issues/7483
"""
from litellm.types.llms.openai import (
ChatCompletionToolChoiceObjectParam,
ChatCompletionToolChoiceStringValues,
)
if tool_choice is None:
return tool_choice
elif isinstance(tool_choice, str):
return tool_choice
elif isinstance(tool_choice, dict):
if tool_choice.get("type") is None or tool_choice.get("function") is None:
raise Exception(
f"Invalid tool choice, tool_choice={tool_choice}. Please ensure tool_choice follows the OpenAI spec"
)
return tool_choice
raise Exception(
f"Invalid tool choice, tool_choice={tool_choice}. Got={type(tool_choice)}. Expecting str, or dict. Please ensure tool_choice follows the OpenAI tool_choice spec"
)
class ProviderConfigManager:
@staticmethod
def get_provider_chat_config( # noqa: PLR0915

View file

@ -1085,6 +1085,26 @@ def test_validate_chat_completion_user_messages(messages, expected_bool):
validate_chat_completion_user_messages(messages=messages)
@pytest.mark.parametrize(
"tool_choice, expected_bool",
[
({"type": "function", "function": {"name": "get_current_weather"}}, True),
({"type": "tool", "name": "get_current_weather"}, False),
(None, True),
("auto", True),
("required", True),
],
)
def test_validate_chat_completion_tool_choice(tool_choice, expected_bool):
from litellm.utils import validate_chat_completion_tool_choice
if expected_bool:
validate_chat_completion_tool_choice(tool_choice=tool_choice)
else:
with pytest.raises(Exception):
validate_chat_completion_tool_choice(tool_choice=tool_choice)
def test_models_by_provider():
"""
Make sure all providers from model map are in the valid providers list