diff --git a/litellm/llms/base_llm/chat/transformation.py b/litellm/llms/base_llm/chat/transformation.py index ca9b6f92fa..76da38481d 100644 --- a/litellm/llms/base_llm/chat/transformation.py +++ b/litellm/llms/base_llm/chat/transformation.py @@ -227,3 +227,7 @@ class BaseConfig(ABC): json_mode: Optional[bool] = False, ) -> Any: pass + + @property + def custom_llm_provider(self) -> Optional[str]: + return None diff --git a/litellm/llms/bedrock/chat/converse_transformation.py b/litellm/llms/bedrock/chat/converse_transformation.py index cab4a413d1..60527381f6 100644 --- a/litellm/llms/bedrock/chat/converse_transformation.py +++ b/litellm/llms/bedrock/chat/converse_transformation.py @@ -68,6 +68,10 @@ class AmazonConverseConfig(BaseConfig): if key != "self" and value is not None: setattr(self.__class__, key, value) + @property + def custom_llm_provider(self) -> Optional[str]: + return "bedrock_converse" + @classmethod def get_config(cls): return { @@ -112,7 +116,9 @@ class AmazonConverseConfig(BaseConfig): ): supported_params.append("tools") - if base_model.startswith("anthropic") or base_model.startswith("mistral"): + if litellm.utils.supports_tool_choice( + model=model, custom_llm_provider=self.custom_llm_provider + ): # only anthropic and mistral support tool choice config. otherwise (E.g. cohere) will fail the call - https://docs.aws.amazon.com/bedrock/latest/APIReference/API_runtime_ToolChoice.html supported_params.append("tool_choice") @@ -224,11 +230,14 @@ class AmazonConverseConfig(BaseConfig): schema_name=schema_name if schema_name != "" else "json_tool_call", ) optional_params["tools"] = [_tool] - optional_params["tool_choice"] = ToolChoiceValuesBlock( - tool=SpecificToolChoiceBlock( - name=schema_name if schema_name != "" else "json_tool_call" + if litellm.utils.supports_tool_choice( + model=model, custom_llm_provider=self.custom_llm_provider + ): + optional_params["tool_choice"] = ToolChoiceValuesBlock( + tool=SpecificToolChoiceBlock( + name=schema_name if schema_name != "" else "json_tool_call" + ) ) - ) optional_params["json_mode"] = True if non_default_params.get("stream", False) is True: optional_params["fake_stream"] = True diff --git a/litellm/model_prices_and_context_window_backup.json b/litellm/model_prices_and_context_window_backup.json index 4f58423d12..987ef948a5 100644 --- a/litellm/model_prices_and_context_window_backup.json +++ b/litellm/model_prices_and_context_window_backup.json @@ -5996,7 +5996,8 @@ "output_cost_per_token": 0.0000312, "litellm_provider": "bedrock", "mode": "chat", - "supports_function_calling": true + "supports_function_calling": true, + "supports_tool_choice": true }, "amazon.nova-micro-v1:0": { "max_tokens": 4096, diff --git a/litellm/types/utils.py b/litellm/types/utils.py index 16dbbd0912..a2e5448dd1 100644 --- a/litellm/types/utils.py +++ b/litellm/types/utils.py @@ -81,6 +81,7 @@ class ProviderSpecificModelInfo(TypedDict, total=False): supports_response_schema: Optional[bool] supports_vision: Optional[bool] supports_function_calling: Optional[bool] + supports_tool_choice: Optional[bool] supports_assistant_prefill: Optional[bool] supports_prompt_caching: Optional[bool] supports_audio_input: Optional[bool] diff --git a/litellm/utils.py b/litellm/utils.py index 443bfb5afd..08383d4d6a 100644 --- a/litellm/utils.py +++ b/litellm/utils.py @@ -1938,6 +1938,15 @@ def supports_function_calling( ) +def supports_tool_choice(model: str, custom_llm_provider: Optional[str] = None) -> bool: + """ + Check if the given model supports `tool_choice` and return a boolean value. + """ + return _supports_factory( + model=model, custom_llm_provider=custom_llm_provider, key="supports_tool_choice" + ) + + def _supports_factory(model: str, custom_llm_provider: Optional[str], key: str) -> bool: """ Check if the given model supports function calling and return a boolean value. @@ -4018,7 +4027,7 @@ def _check_provider_match(model_info: dict, custom_llm_provider: Optional[str]) "litellm_provider" ].startswith("fireworks_ai"): return True - elif custom_llm_provider == "bedrock" and model_info[ + elif custom_llm_provider.startswith("bedrock") and model_info[ "litellm_provider" ].startswith("bedrock"): return True @@ -4189,6 +4198,7 @@ def _get_model_info_helper( # noqa: PLR0915 supports_system_messages=None, supports_response_schema=None, supports_function_calling=None, + supports_tool_choice=None, supports_assistant_prefill=None, supports_prompt_caching=None, supports_pdf_input=None, @@ -4333,6 +4343,7 @@ def _get_model_info_helper( # noqa: PLR0915 supports_function_calling=_model_info.get( "supports_function_calling", False ), + supports_tool_choice=_model_info.get("supports_tool_choice", False), supports_assistant_prefill=_model_info.get( "supports_assistant_prefill", False ), @@ -4411,6 +4422,7 @@ def get_model_info(model: str, custom_llm_provider: Optional[str] = None) -> Mod supports_response_schema: Optional[bool] supports_vision: Optional[bool] supports_function_calling: Optional[bool] + supports_tool_choice: Optional[bool] supports_prompt_caching: Optional[bool] supports_audio_input: Optional[bool] supports_audio_output: Optional[bool] diff --git a/model_prices_and_context_window.json b/model_prices_and_context_window.json index 4f58423d12..987ef948a5 100644 --- a/model_prices_and_context_window.json +++ b/model_prices_and_context_window.json @@ -5996,7 +5996,8 @@ "output_cost_per_token": 0.0000312, "litellm_provider": "bedrock", "mode": "chat", - "supports_function_calling": true + "supports_function_calling": true, + "supports_tool_choice": true }, "amazon.nova-micro-v1:0": { "max_tokens": 4096, diff --git a/tests/litellm_utils_tests/test_supports_tool_choice.py b/tests/litellm_utils_tests/test_supports_tool_choice.py new file mode 100644 index 0000000000..98d8727ce0 --- /dev/null +++ b/tests/litellm_utils_tests/test_supports_tool_choice.py @@ -0,0 +1,160 @@ +import json +import os +import sys +from unittest.mock import patch +import pytest + +# Add parent directory to system path +sys.path.insert(0, os.path.abspath("../..")) + +import litellm +from litellm.utils import get_llm_provider, ProviderConfigManager, _check_provider_match +from litellm import LlmProviders + + +def test_supports_tool_choice_simple_tests(): + """ + simple sanity checks + """ + assert litellm.utils.supports_tool_choice(model="gpt-4o") == True + assert ( + litellm.utils.supports_tool_choice( + model="bedrock/anthropic.claude-3-sonnet-20240229-v1:0" + ) + == True + ) + assert ( + litellm.utils.supports_tool_choice( + model="anthropic.claude-3-sonnet-20240229-v1:0" + ) + is True + ) + + assert ( + litellm.utils.supports_tool_choice( + model="anthropic.claude-3-sonnet-20240229-v1:0", + custom_llm_provider="bedrock_converse", + ) + is True + ) + + assert ( + litellm.utils.supports_tool_choice(model="us.amazon.nova-micro-v1:0") is False + ) + assert ( + litellm.utils.supports_tool_choice(model="bedrock/us.amazon.nova-micro-v1:0") + is False + ) + assert ( + litellm.utils.supports_tool_choice( + model="us.amazon.nova-micro-v1:0", custom_llm_provider="bedrock_converse" + ) + is False + ) + + +def test_check_provider_match(): + """ + Test the _check_provider_match function for various provider scenarios + """ + # Test bedrock and bedrock_converse cases + model_info = {"litellm_provider": "bedrock"} + assert litellm.utils._check_provider_match(model_info, "bedrock") is True + assert litellm.utils._check_provider_match(model_info, "bedrock_converse") is True + + # Test bedrock_converse provider + model_info = {"litellm_provider": "bedrock_converse"} + assert litellm.utils._check_provider_match(model_info, "bedrock") is True + assert litellm.utils._check_provider_match(model_info, "bedrock_converse") is True + + # Test non-matching provider + model_info = {"litellm_provider": "bedrock"} + assert litellm.utils._check_provider_match(model_info, "openai") is False + + +# Models that should be skipped during testing +OLD_PROVIDERS = ["aleph_alpha", "palm"] +SKIP_MODELS = ["azure/mistral", "azure/command-r", "jamba", "deepinfra", "mistral."] + +# Bedrock models to block - organized by type +BEDROCK_REGIONS = ["ap-northeast-1", "eu-central-1", "us-east-1", "us-west-2"] +BEDROCK_COMMITMENTS = ["1-month-commitment", "6-month-commitment"] +BEDROCK_MODELS = { + "anthropic.claude-v1", + "anthropic.claude-v2", + "anthropic.claude-v2:1", + "anthropic.claude-instant-v1", +} + +# Generate block_list dynamically +block_list = set() +for region in BEDROCK_REGIONS: + for commitment in BEDROCK_COMMITMENTS: + for model in BEDROCK_MODELS: + block_list.add(f"bedrock/{region}/{commitment}/{model}") + block_list.add(f"bedrock/{region}/{model}") + +# Add Cohere models +for commitment in BEDROCK_COMMITMENTS: + block_list.add(f"bedrock/*/{commitment}/cohere.command-text-v14") + block_list.add(f"bedrock/*/{commitment}/cohere.command-light-text-v14") + +print("block_list", block_list) + + +@pytest.mark.asyncio +async def test_supports_tool_choice(): + """ + Test that litellm.utils.supports_tool_choice() returns the correct value + for all models in model_prices_and_context_window.json. + + The test: + 1. Loads model pricing data + 2. Iterates through each model + 3. Checks if tool_choice support matches the model's supported parameters + """ + # Load model prices + litellm._turn_on_debug() + with open("./model_prices_and_context_window.json", "r") as f: + model_prices = json.load(f) + litellm.model_cost = model_prices + config_manager = ProviderConfigManager() + + for model_name, model_info in model_prices.items(): + print(f"testing model: {model_name}") + + # Skip certain models + if ( + model_name == "sample_spec" + or model_info.get("mode") != "chat" + or any(skip in model_name for skip in SKIP_MODELS) + or any(provider in model_name for provider in OLD_PROVIDERS) + or model_info["litellm_provider"] in OLD_PROVIDERS + or model_name in block_list + ): + continue + + try: + model, provider, _, _ = get_llm_provider(model=model_name) + except Exception as e: + print(f"\033[91mERROR for {model_name}: {e}\033[0m") + continue + + # Get provider config and supported params + print("LLM provider", provider) + provider_enum = LlmProviders(provider) + config = config_manager.get_provider_chat_config(model, provider_enum) + supported_params = config.get_supported_openai_params(model) + print("supported_params", supported_params) + + # Check tool_choice support + supports_tool_choice_result = litellm.utils.supports_tool_choice( + model=model_name, custom_llm_provider=provider + ) + tool_choice_in_params = "tool_choice" in supported_params + + assert supports_tool_choice_result == tool_choice_in_params, ( + f"Tool choice support mismatch for {model_name}:\n" + f"supports_tool_choice() returned: {supports_tool_choice_result}\n" + f"tool_choice in supported params: {tool_choice_in_params}" + ) diff --git a/tests/llm_translation/test_bedrock_completion.py b/tests/llm_translation/test_bedrock_completion.py index f9f6bdef58..508a80a915 100644 --- a/tests/llm_translation/test_bedrock_completion.py +++ b/tests/llm_translation/test_bedrock_completion.py @@ -768,6 +768,7 @@ def test_bedrock_system_prompt(system, model): def test_bedrock_claude_3_tool_calling(): try: litellm.set_verbose = True + litellm._turn_on_debug() tools = [ { "type": "function", diff --git a/tests/llm_translation/test_bedrock_nova_json.py b/tests/llm_translation/test_bedrock_nova_json.py new file mode 100644 index 0000000000..6fd7a3d3c3 --- /dev/null +++ b/tests/llm_translation/test_bedrock_nova_json.py @@ -0,0 +1,28 @@ +from base_llm_unit_tests import BaseLLMChatTest +import pytest +import sys +import os + +sys.path.insert( + 0, os.path.abspath("../..") +) # Adds the parent directory to the system path +import litellm + + +class TestBedrockNovaJson(BaseLLMChatTest): + def get_base_completion_call_args(self) -> dict: + litellm._turn_on_debug() + return { + "model": "bedrock/converse/us.amazon.nova-micro-v1:0", + } + + def test_tool_call_no_arguments(self, tool_call_no_arguments): + """Test that tool calls with no arguments is translated correctly. Relevant issue: https://github.com/BerriAI/litellm/issues/6833""" + pass + + @pytest.fixture(autouse=True) + def skip_non_json_tests(self, request): + if not "json" in request.function.__name__.lower(): + pytest.skip( + f"Skipping non-JSON test: {request.function.__name__} does not contain 'json'" + ) diff --git a/tests/local_testing/test_get_model_info.py b/tests/local_testing/test_get_model_info.py index edf9183ad2..6747efce3a 100644 --- a/tests/local_testing/test_get_model_info.py +++ b/tests/local_testing/test_get_model_info.py @@ -8,7 +8,7 @@ from typing import List, Dict, Any sys.path.insert( 0, os.path.abspath("../..") -) # Adds the parent directory to the system path +) # Adds the parent directory to the system-path import pytest import litellm