mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-26 11:14:04 +00:00
Litellm dev 01 01 2025 p3 (#7503)
* fix(utils.py): add new validate tool choice helper function Prevents https://github.com/BerriAI/litellm/issues/7483 * fix(main.py): add tool choice validation on .completion() prevents user error like - https://github.com/BerriAI/litellm/issues/7483 * fix(utils.py): fix return val of tool choice validation logic
This commit is contained in:
parent
98cba7ba3f
commit
b3611ace41
3 changed files with 51 additions and 0 deletions
|
@ -88,6 +88,7 @@ from litellm.utils import (
|
||||||
supports_httpx_timeout,
|
supports_httpx_timeout,
|
||||||
token_counter,
|
token_counter,
|
||||||
validate_chat_completion_messages,
|
validate_chat_completion_messages,
|
||||||
|
validate_chat_completion_tool_choice,
|
||||||
)
|
)
|
||||||
|
|
||||||
from ._logging import verbose_logger
|
from ._logging import verbose_logger
|
||||||
|
@ -847,6 +848,8 @@ def completion( # type: ignore # noqa: PLR0915
|
||||||
raise ValueError("model param not passed in.")
|
raise ValueError("model param not passed in.")
|
||||||
# validate messages
|
# validate messages
|
||||||
messages = validate_chat_completion_messages(messages=messages)
|
messages = validate_chat_completion_messages(messages=messages)
|
||||||
|
# validate tool_choice
|
||||||
|
tool_choice = validate_chat_completion_tool_choice(tool_choice=tool_choice)
|
||||||
######### unpacking kwargs #####################
|
######### unpacking kwargs #####################
|
||||||
args = locals()
|
args = locals()
|
||||||
api_base = kwargs.get("api_base", None)
|
api_base = kwargs.get("api_base", None)
|
||||||
|
|
|
@ -6058,6 +6058,34 @@ def validate_chat_completion_user_messages(messages: List[AllMessageValues]):
|
||||||
return messages
|
return messages
|
||||||
|
|
||||||
|
|
||||||
|
def validate_chat_completion_tool_choice(
|
||||||
|
tool_choice: Optional[Union[dict, str]]
|
||||||
|
) -> Optional[Union[dict, str]]:
|
||||||
|
"""
|
||||||
|
Confirm the tool choice is passed in the OpenAI format.
|
||||||
|
|
||||||
|
Prevents user errors like: https://github.com/BerriAI/litellm/issues/7483
|
||||||
|
"""
|
||||||
|
from litellm.types.llms.openai import (
|
||||||
|
ChatCompletionToolChoiceObjectParam,
|
||||||
|
ChatCompletionToolChoiceStringValues,
|
||||||
|
)
|
||||||
|
|
||||||
|
if tool_choice is None:
|
||||||
|
return tool_choice
|
||||||
|
elif isinstance(tool_choice, str):
|
||||||
|
return tool_choice
|
||||||
|
elif isinstance(tool_choice, dict):
|
||||||
|
if tool_choice.get("type") is None or tool_choice.get("function") is None:
|
||||||
|
raise Exception(
|
||||||
|
f"Invalid tool choice, tool_choice={tool_choice}. Please ensure tool_choice follows the OpenAI spec"
|
||||||
|
)
|
||||||
|
return tool_choice
|
||||||
|
raise Exception(
|
||||||
|
f"Invalid tool choice, tool_choice={tool_choice}. Got={type(tool_choice)}. Expecting str, or dict. Please ensure tool_choice follows the OpenAI tool_choice spec"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class ProviderConfigManager:
|
class ProviderConfigManager:
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def get_provider_chat_config( # noqa: PLR0915
|
def get_provider_chat_config( # noqa: PLR0915
|
||||||
|
|
|
@ -1085,6 +1085,26 @@ def test_validate_chat_completion_user_messages(messages, expected_bool):
|
||||||
validate_chat_completion_user_messages(messages=messages)
|
validate_chat_completion_user_messages(messages=messages)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"tool_choice, expected_bool",
|
||||||
|
[
|
||||||
|
({"type": "function", "function": {"name": "get_current_weather"}}, True),
|
||||||
|
({"type": "tool", "name": "get_current_weather"}, False),
|
||||||
|
(None, True),
|
||||||
|
("auto", True),
|
||||||
|
("required", True),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
def test_validate_chat_completion_tool_choice(tool_choice, expected_bool):
|
||||||
|
from litellm.utils import validate_chat_completion_tool_choice
|
||||||
|
|
||||||
|
if expected_bool:
|
||||||
|
validate_chat_completion_tool_choice(tool_choice=tool_choice)
|
||||||
|
else:
|
||||||
|
with pytest.raises(Exception):
|
||||||
|
validate_chat_completion_tool_choice(tool_choice=tool_choice)
|
||||||
|
|
||||||
|
|
||||||
def test_models_by_provider():
|
def test_models_by_provider():
|
||||||
"""
|
"""
|
||||||
Make sure all providers from model map are in the valid providers list
|
Make sure all providers from model map are in the valid providers list
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue