(Feat) - Add support for structured output on bedrock/nova models + add util litellm.supports_tool_choice (#8264)

* fix supports_tool_choice

* TestBedrockNovaJson

* use supports_tool_choice

* fix supports_tool_choice

* add supports_tool_choice param

* script to add fields to model cost map

* test_supports_tool_choice

* test_supports_tool_choice

* fix supports tool choice check

* test_supports_tool_choice_simple_tests

* fix supports_tool_choice check

* fix supports_tool_choice bedrock

* test_supports_tool_choice

* test_supports_tool_choice

* fix bedrock/eu-west-3/mistral.mistral-large-2402-v1:0

* ci/cd run again

* test_supports_tool_choice_simple_tests

* TestGoogleAIStudioGemini temp - remove to run ci/cd

* test_aaalangfuse_logging_metadata

* TestGoogleAIStudioGemini

* test_check_provider_match

* remove add param to map
This commit is contained in:
Ishaan Jaff 2025-02-04 21:47:16 -08:00 committed by GitHub
parent c743475aba
commit 3a6349d871
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
10 changed files with 226 additions and 9 deletions

View file

@ -227,3 +227,7 @@ class BaseConfig(ABC):
json_mode: Optional[bool] = False,
) -> Any:
pass
@property
def custom_llm_provider(self) -> Optional[str]:
return None

View file

@ -68,6 +68,10 @@ class AmazonConverseConfig(BaseConfig):
if key != "self" and value is not None:
setattr(self.__class__, key, value)
@property
def custom_llm_provider(self) -> Optional[str]:
return "bedrock_converse"
@classmethod
def get_config(cls):
return {
@ -112,7 +116,9 @@ class AmazonConverseConfig(BaseConfig):
):
supported_params.append("tools")
if base_model.startswith("anthropic") or base_model.startswith("mistral"):
if litellm.utils.supports_tool_choice(
model=model, custom_llm_provider=self.custom_llm_provider
):
# only anthropic and mistral support tool choice config. otherwise (E.g. cohere) will fail the call - https://docs.aws.amazon.com/bedrock/latest/APIReference/API_runtime_ToolChoice.html
supported_params.append("tool_choice")
@ -224,6 +230,9 @@ class AmazonConverseConfig(BaseConfig):
schema_name=schema_name if schema_name != "" else "json_tool_call",
)
optional_params["tools"] = [_tool]
if litellm.utils.supports_tool_choice(
model=model, custom_llm_provider=self.custom_llm_provider
):
optional_params["tool_choice"] = ToolChoiceValuesBlock(
tool=SpecificToolChoiceBlock(
name=schema_name if schema_name != "" else "json_tool_call"

View file

@ -5996,7 +5996,8 @@
"output_cost_per_token": 0.0000312,
"litellm_provider": "bedrock",
"mode": "chat",
"supports_function_calling": true
"supports_function_calling": true,
"supports_tool_choice": true
},
"amazon.nova-micro-v1:0": {
"max_tokens": 4096,

View file

@ -81,6 +81,7 @@ class ProviderSpecificModelInfo(TypedDict, total=False):
supports_response_schema: Optional[bool]
supports_vision: Optional[bool]
supports_function_calling: Optional[bool]
supports_tool_choice: Optional[bool]
supports_assistant_prefill: Optional[bool]
supports_prompt_caching: Optional[bool]
supports_audio_input: Optional[bool]

View file

@ -1938,6 +1938,15 @@ def supports_function_calling(
)
def supports_tool_choice(model: str, custom_llm_provider: Optional[str] = None) -> bool:
"""
Check if the given model supports `tool_choice` and return a boolean value.
"""
return _supports_factory(
model=model, custom_llm_provider=custom_llm_provider, key="supports_tool_choice"
)
def _supports_factory(model: str, custom_llm_provider: Optional[str], key: str) -> bool:
"""
Check if the given model supports function calling and return a boolean value.
@ -4018,7 +4027,7 @@ def _check_provider_match(model_info: dict, custom_llm_provider: Optional[str])
"litellm_provider"
].startswith("fireworks_ai"):
return True
elif custom_llm_provider == "bedrock" and model_info[
elif custom_llm_provider.startswith("bedrock") and model_info[
"litellm_provider"
].startswith("bedrock"):
return True
@ -4189,6 +4198,7 @@ def _get_model_info_helper( # noqa: PLR0915
supports_system_messages=None,
supports_response_schema=None,
supports_function_calling=None,
supports_tool_choice=None,
supports_assistant_prefill=None,
supports_prompt_caching=None,
supports_pdf_input=None,
@ -4333,6 +4343,7 @@ def _get_model_info_helper( # noqa: PLR0915
supports_function_calling=_model_info.get(
"supports_function_calling", False
),
supports_tool_choice=_model_info.get("supports_tool_choice", False),
supports_assistant_prefill=_model_info.get(
"supports_assistant_prefill", False
),
@ -4411,6 +4422,7 @@ def get_model_info(model: str, custom_llm_provider: Optional[str] = None) -> Mod
supports_response_schema: Optional[bool]
supports_vision: Optional[bool]
supports_function_calling: Optional[bool]
supports_tool_choice: Optional[bool]
supports_prompt_caching: Optional[bool]
supports_audio_input: Optional[bool]
supports_audio_output: Optional[bool]

View file

@ -5996,7 +5996,8 @@
"output_cost_per_token": 0.0000312,
"litellm_provider": "bedrock",
"mode": "chat",
"supports_function_calling": true
"supports_function_calling": true,
"supports_tool_choice": true
},
"amazon.nova-micro-v1:0": {
"max_tokens": 4096,

View file

@ -0,0 +1,160 @@
import json
import os
import sys
from unittest.mock import patch
import pytest
# Add parent directory to system path
sys.path.insert(0, os.path.abspath("../.."))
import litellm
from litellm.utils import get_llm_provider, ProviderConfigManager, _check_provider_match
from litellm import LlmProviders
def test_supports_tool_choice_simple_tests():
"""
simple sanity checks
"""
assert litellm.utils.supports_tool_choice(model="gpt-4o") == True
assert (
litellm.utils.supports_tool_choice(
model="bedrock/anthropic.claude-3-sonnet-20240229-v1:0"
)
== True
)
assert (
litellm.utils.supports_tool_choice(
model="anthropic.claude-3-sonnet-20240229-v1:0"
)
is True
)
assert (
litellm.utils.supports_tool_choice(
model="anthropic.claude-3-sonnet-20240229-v1:0",
custom_llm_provider="bedrock_converse",
)
is True
)
assert (
litellm.utils.supports_tool_choice(model="us.amazon.nova-micro-v1:0") is False
)
assert (
litellm.utils.supports_tool_choice(model="bedrock/us.amazon.nova-micro-v1:0")
is False
)
assert (
litellm.utils.supports_tool_choice(
model="us.amazon.nova-micro-v1:0", custom_llm_provider="bedrock_converse"
)
is False
)
def test_check_provider_match():
"""
Test the _check_provider_match function for various provider scenarios
"""
# Test bedrock and bedrock_converse cases
model_info = {"litellm_provider": "bedrock"}
assert litellm.utils._check_provider_match(model_info, "bedrock") is True
assert litellm.utils._check_provider_match(model_info, "bedrock_converse") is True
# Test bedrock_converse provider
model_info = {"litellm_provider": "bedrock_converse"}
assert litellm.utils._check_provider_match(model_info, "bedrock") is True
assert litellm.utils._check_provider_match(model_info, "bedrock_converse") is True
# Test non-matching provider
model_info = {"litellm_provider": "bedrock"}
assert litellm.utils._check_provider_match(model_info, "openai") is False
# Models that should be skipped during testing
OLD_PROVIDERS = ["aleph_alpha", "palm"]
SKIP_MODELS = ["azure/mistral", "azure/command-r", "jamba", "deepinfra", "mistral."]
# Bedrock models to block - organized by type
BEDROCK_REGIONS = ["ap-northeast-1", "eu-central-1", "us-east-1", "us-west-2"]
BEDROCK_COMMITMENTS = ["1-month-commitment", "6-month-commitment"]
BEDROCK_MODELS = {
"anthropic.claude-v1",
"anthropic.claude-v2",
"anthropic.claude-v2:1",
"anthropic.claude-instant-v1",
}
# Generate block_list dynamically
block_list = set()
for region in BEDROCK_REGIONS:
for commitment in BEDROCK_COMMITMENTS:
for model in BEDROCK_MODELS:
block_list.add(f"bedrock/{region}/{commitment}/{model}")
block_list.add(f"bedrock/{region}/{model}")
# Add Cohere models
for commitment in BEDROCK_COMMITMENTS:
block_list.add(f"bedrock/*/{commitment}/cohere.command-text-v14")
block_list.add(f"bedrock/*/{commitment}/cohere.command-light-text-v14")
print("block_list", block_list)
@pytest.mark.asyncio
async def test_supports_tool_choice():
"""
Test that litellm.utils.supports_tool_choice() returns the correct value
for all models in model_prices_and_context_window.json.
The test:
1. Loads model pricing data
2. Iterates through each model
3. Checks if tool_choice support matches the model's supported parameters
"""
# Load model prices
litellm._turn_on_debug()
with open("./model_prices_and_context_window.json", "r") as f:
model_prices = json.load(f)
litellm.model_cost = model_prices
config_manager = ProviderConfigManager()
for model_name, model_info in model_prices.items():
print(f"testing model: {model_name}")
# Skip certain models
if (
model_name == "sample_spec"
or model_info.get("mode") != "chat"
or any(skip in model_name for skip in SKIP_MODELS)
or any(provider in model_name for provider in OLD_PROVIDERS)
or model_info["litellm_provider"] in OLD_PROVIDERS
or model_name in block_list
):
continue
try:
model, provider, _, _ = get_llm_provider(model=model_name)
except Exception as e:
print(f"\033[91mERROR for {model_name}: {e}\033[0m")
continue
# Get provider config and supported params
print("LLM provider", provider)
provider_enum = LlmProviders(provider)
config = config_manager.get_provider_chat_config(model, provider_enum)
supported_params = config.get_supported_openai_params(model)
print("supported_params", supported_params)
# Check tool_choice support
supports_tool_choice_result = litellm.utils.supports_tool_choice(
model=model_name, custom_llm_provider=provider
)
tool_choice_in_params = "tool_choice" in supported_params
assert supports_tool_choice_result == tool_choice_in_params, (
f"Tool choice support mismatch for {model_name}:\n"
f"supports_tool_choice() returned: {supports_tool_choice_result}\n"
f"tool_choice in supported params: {tool_choice_in_params}"
)

View file

@ -768,6 +768,7 @@ def test_bedrock_system_prompt(system, model):
def test_bedrock_claude_3_tool_calling():
try:
litellm.set_verbose = True
litellm._turn_on_debug()
tools = [
{
"type": "function",

View file

@ -0,0 +1,28 @@
from base_llm_unit_tests import BaseLLMChatTest
import pytest
import sys
import os
sys.path.insert(
0, os.path.abspath("../..")
) # Adds the parent directory to the system path
import litellm
class TestBedrockNovaJson(BaseLLMChatTest):
def get_base_completion_call_args(self) -> dict:
litellm._turn_on_debug()
return {
"model": "bedrock/converse/us.amazon.nova-micro-v1:0",
}
def test_tool_call_no_arguments(self, tool_call_no_arguments):
"""Test that tool calls with no arguments is translated correctly. Relevant issue: https://github.com/BerriAI/litellm/issues/6833"""
pass
@pytest.fixture(autouse=True)
def skip_non_json_tests(self, request):
if not "json" in request.function.__name__.lower():
pytest.skip(
f"Skipping non-JSON test: {request.function.__name__} does not contain 'json'"
)

View file

@ -8,7 +8,7 @@ from typing import List, Dict, Any
sys.path.insert(
0, os.path.abspath("../..")
) # Adds the parent directory to the system path
) # Adds the parent directory to the system-path
import pytest
import litellm