mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-26 03:04:13 +00:00
(Feat) - Add support for structured output on bedrock/nova
models + add util litellm.supports_tool_choice
(#8264)
* fix supports_tool_choice * TestBedrockNovaJson * use supports_tool_choice * fix supports_tool_choice * add supports_tool_choice param * script to add fields to model cost map * test_supports_tool_choice * test_supports_tool_choice * fix supports tool choice check * test_supports_tool_choice_simple_tests * fix supports_tool_choice check * fix supports_tool_choice bedrock * test_supports_tool_choice * test_supports_tool_choice * fix bedrock/eu-west-3/mistral.mistral-large-2402-v1:0 * ci/cd run again * test_supports_tool_choice_simple_tests * TestGoogleAIStudioGemini temp - remove to run ci/cd * test_aaalangfuse_logging_metadata * TestGoogleAIStudioGemini * test_check_provider_match * remove add param to map
This commit is contained in:
parent
c743475aba
commit
3a6349d871
10 changed files with 226 additions and 9 deletions
|
@ -227,3 +227,7 @@ class BaseConfig(ABC):
|
|||
json_mode: Optional[bool] = False,
|
||||
) -> Any:
|
||||
pass
|
||||
|
||||
@property
|
||||
def custom_llm_provider(self) -> Optional[str]:
|
||||
return None
|
||||
|
|
|
@ -68,6 +68,10 @@ class AmazonConverseConfig(BaseConfig):
|
|||
if key != "self" and value is not None:
|
||||
setattr(self.__class__, key, value)
|
||||
|
||||
@property
|
||||
def custom_llm_provider(self) -> Optional[str]:
|
||||
return "bedrock_converse"
|
||||
|
||||
@classmethod
|
||||
def get_config(cls):
|
||||
return {
|
||||
|
@ -112,7 +116,9 @@ class AmazonConverseConfig(BaseConfig):
|
|||
):
|
||||
supported_params.append("tools")
|
||||
|
||||
if base_model.startswith("anthropic") or base_model.startswith("mistral"):
|
||||
if litellm.utils.supports_tool_choice(
|
||||
model=model, custom_llm_provider=self.custom_llm_provider
|
||||
):
|
||||
# only anthropic and mistral support tool choice config. otherwise (E.g. cohere) will fail the call - https://docs.aws.amazon.com/bedrock/latest/APIReference/API_runtime_ToolChoice.html
|
||||
supported_params.append("tool_choice")
|
||||
|
||||
|
@ -224,6 +230,9 @@ class AmazonConverseConfig(BaseConfig):
|
|||
schema_name=schema_name if schema_name != "" else "json_tool_call",
|
||||
)
|
||||
optional_params["tools"] = [_tool]
|
||||
if litellm.utils.supports_tool_choice(
|
||||
model=model, custom_llm_provider=self.custom_llm_provider
|
||||
):
|
||||
optional_params["tool_choice"] = ToolChoiceValuesBlock(
|
||||
tool=SpecificToolChoiceBlock(
|
||||
name=schema_name if schema_name != "" else "json_tool_call"
|
||||
|
|
|
@ -5996,7 +5996,8 @@
|
|||
"output_cost_per_token": 0.0000312,
|
||||
"litellm_provider": "bedrock",
|
||||
"mode": "chat",
|
||||
"supports_function_calling": true
|
||||
"supports_function_calling": true,
|
||||
"supports_tool_choice": true
|
||||
},
|
||||
"amazon.nova-micro-v1:0": {
|
||||
"max_tokens": 4096,
|
||||
|
|
|
@ -81,6 +81,7 @@ class ProviderSpecificModelInfo(TypedDict, total=False):
|
|||
supports_response_schema: Optional[bool]
|
||||
supports_vision: Optional[bool]
|
||||
supports_function_calling: Optional[bool]
|
||||
supports_tool_choice: Optional[bool]
|
||||
supports_assistant_prefill: Optional[bool]
|
||||
supports_prompt_caching: Optional[bool]
|
||||
supports_audio_input: Optional[bool]
|
||||
|
|
|
@ -1938,6 +1938,15 @@ def supports_function_calling(
|
|||
)
|
||||
|
||||
|
||||
def supports_tool_choice(model: str, custom_llm_provider: Optional[str] = None) -> bool:
|
||||
"""
|
||||
Check if the given model supports `tool_choice` and return a boolean value.
|
||||
"""
|
||||
return _supports_factory(
|
||||
model=model, custom_llm_provider=custom_llm_provider, key="supports_tool_choice"
|
||||
)
|
||||
|
||||
|
||||
def _supports_factory(model: str, custom_llm_provider: Optional[str], key: str) -> bool:
|
||||
"""
|
||||
Check if the given model supports function calling and return a boolean value.
|
||||
|
@ -4018,7 +4027,7 @@ def _check_provider_match(model_info: dict, custom_llm_provider: Optional[str])
|
|||
"litellm_provider"
|
||||
].startswith("fireworks_ai"):
|
||||
return True
|
||||
elif custom_llm_provider == "bedrock" and model_info[
|
||||
elif custom_llm_provider.startswith("bedrock") and model_info[
|
||||
"litellm_provider"
|
||||
].startswith("bedrock"):
|
||||
return True
|
||||
|
@ -4189,6 +4198,7 @@ def _get_model_info_helper( # noqa: PLR0915
|
|||
supports_system_messages=None,
|
||||
supports_response_schema=None,
|
||||
supports_function_calling=None,
|
||||
supports_tool_choice=None,
|
||||
supports_assistant_prefill=None,
|
||||
supports_prompt_caching=None,
|
||||
supports_pdf_input=None,
|
||||
|
@ -4333,6 +4343,7 @@ def _get_model_info_helper( # noqa: PLR0915
|
|||
supports_function_calling=_model_info.get(
|
||||
"supports_function_calling", False
|
||||
),
|
||||
supports_tool_choice=_model_info.get("supports_tool_choice", False),
|
||||
supports_assistant_prefill=_model_info.get(
|
||||
"supports_assistant_prefill", False
|
||||
),
|
||||
|
@ -4411,6 +4422,7 @@ def get_model_info(model: str, custom_llm_provider: Optional[str] = None) -> Mod
|
|||
supports_response_schema: Optional[bool]
|
||||
supports_vision: Optional[bool]
|
||||
supports_function_calling: Optional[bool]
|
||||
supports_tool_choice: Optional[bool]
|
||||
supports_prompt_caching: Optional[bool]
|
||||
supports_audio_input: Optional[bool]
|
||||
supports_audio_output: Optional[bool]
|
||||
|
|
|
@ -5996,7 +5996,8 @@
|
|||
"output_cost_per_token": 0.0000312,
|
||||
"litellm_provider": "bedrock",
|
||||
"mode": "chat",
|
||||
"supports_function_calling": true
|
||||
"supports_function_calling": true,
|
||||
"supports_tool_choice": true
|
||||
},
|
||||
"amazon.nova-micro-v1:0": {
|
||||
"max_tokens": 4096,
|
||||
|
|
160
tests/litellm_utils_tests/test_supports_tool_choice.py
Normal file
160
tests/litellm_utils_tests/test_supports_tool_choice.py
Normal file
|
@ -0,0 +1,160 @@
|
|||
import json
|
||||
import os
|
||||
import sys
|
||||
from unittest.mock import patch
|
||||
import pytest
|
||||
|
||||
# Add parent directory to system path
|
||||
sys.path.insert(0, os.path.abspath("../.."))
|
||||
|
||||
import litellm
|
||||
from litellm.utils import get_llm_provider, ProviderConfigManager, _check_provider_match
|
||||
from litellm import LlmProviders
|
||||
|
||||
|
||||
def test_supports_tool_choice_simple_tests():
|
||||
"""
|
||||
simple sanity checks
|
||||
"""
|
||||
assert litellm.utils.supports_tool_choice(model="gpt-4o") == True
|
||||
assert (
|
||||
litellm.utils.supports_tool_choice(
|
||||
model="bedrock/anthropic.claude-3-sonnet-20240229-v1:0"
|
||||
)
|
||||
== True
|
||||
)
|
||||
assert (
|
||||
litellm.utils.supports_tool_choice(
|
||||
model="anthropic.claude-3-sonnet-20240229-v1:0"
|
||||
)
|
||||
is True
|
||||
)
|
||||
|
||||
assert (
|
||||
litellm.utils.supports_tool_choice(
|
||||
model="anthropic.claude-3-sonnet-20240229-v1:0",
|
||||
custom_llm_provider="bedrock_converse",
|
||||
)
|
||||
is True
|
||||
)
|
||||
|
||||
assert (
|
||||
litellm.utils.supports_tool_choice(model="us.amazon.nova-micro-v1:0") is False
|
||||
)
|
||||
assert (
|
||||
litellm.utils.supports_tool_choice(model="bedrock/us.amazon.nova-micro-v1:0")
|
||||
is False
|
||||
)
|
||||
assert (
|
||||
litellm.utils.supports_tool_choice(
|
||||
model="us.amazon.nova-micro-v1:0", custom_llm_provider="bedrock_converse"
|
||||
)
|
||||
is False
|
||||
)
|
||||
|
||||
|
||||
def test_check_provider_match():
|
||||
"""
|
||||
Test the _check_provider_match function for various provider scenarios
|
||||
"""
|
||||
# Test bedrock and bedrock_converse cases
|
||||
model_info = {"litellm_provider": "bedrock"}
|
||||
assert litellm.utils._check_provider_match(model_info, "bedrock") is True
|
||||
assert litellm.utils._check_provider_match(model_info, "bedrock_converse") is True
|
||||
|
||||
# Test bedrock_converse provider
|
||||
model_info = {"litellm_provider": "bedrock_converse"}
|
||||
assert litellm.utils._check_provider_match(model_info, "bedrock") is True
|
||||
assert litellm.utils._check_provider_match(model_info, "bedrock_converse") is True
|
||||
|
||||
# Test non-matching provider
|
||||
model_info = {"litellm_provider": "bedrock"}
|
||||
assert litellm.utils._check_provider_match(model_info, "openai") is False
|
||||
|
||||
|
||||
# Models that should be skipped during testing
|
||||
OLD_PROVIDERS = ["aleph_alpha", "palm"]
|
||||
SKIP_MODELS = ["azure/mistral", "azure/command-r", "jamba", "deepinfra", "mistral."]
|
||||
|
||||
# Bedrock models to block - organized by type
|
||||
BEDROCK_REGIONS = ["ap-northeast-1", "eu-central-1", "us-east-1", "us-west-2"]
|
||||
BEDROCK_COMMITMENTS = ["1-month-commitment", "6-month-commitment"]
|
||||
BEDROCK_MODELS = {
|
||||
"anthropic.claude-v1",
|
||||
"anthropic.claude-v2",
|
||||
"anthropic.claude-v2:1",
|
||||
"anthropic.claude-instant-v1",
|
||||
}
|
||||
|
||||
# Generate block_list dynamically
|
||||
block_list = set()
|
||||
for region in BEDROCK_REGIONS:
|
||||
for commitment in BEDROCK_COMMITMENTS:
|
||||
for model in BEDROCK_MODELS:
|
||||
block_list.add(f"bedrock/{region}/{commitment}/{model}")
|
||||
block_list.add(f"bedrock/{region}/{model}")
|
||||
|
||||
# Add Cohere models
|
||||
for commitment in BEDROCK_COMMITMENTS:
|
||||
block_list.add(f"bedrock/*/{commitment}/cohere.command-text-v14")
|
||||
block_list.add(f"bedrock/*/{commitment}/cohere.command-light-text-v14")
|
||||
|
||||
print("block_list", block_list)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_supports_tool_choice():
|
||||
"""
|
||||
Test that litellm.utils.supports_tool_choice() returns the correct value
|
||||
for all models in model_prices_and_context_window.json.
|
||||
|
||||
The test:
|
||||
1. Loads model pricing data
|
||||
2. Iterates through each model
|
||||
3. Checks if tool_choice support matches the model's supported parameters
|
||||
"""
|
||||
# Load model prices
|
||||
litellm._turn_on_debug()
|
||||
with open("./model_prices_and_context_window.json", "r") as f:
|
||||
model_prices = json.load(f)
|
||||
litellm.model_cost = model_prices
|
||||
config_manager = ProviderConfigManager()
|
||||
|
||||
for model_name, model_info in model_prices.items():
|
||||
print(f"testing model: {model_name}")
|
||||
|
||||
# Skip certain models
|
||||
if (
|
||||
model_name == "sample_spec"
|
||||
or model_info.get("mode") != "chat"
|
||||
or any(skip in model_name for skip in SKIP_MODELS)
|
||||
or any(provider in model_name for provider in OLD_PROVIDERS)
|
||||
or model_info["litellm_provider"] in OLD_PROVIDERS
|
||||
or model_name in block_list
|
||||
):
|
||||
continue
|
||||
|
||||
try:
|
||||
model, provider, _, _ = get_llm_provider(model=model_name)
|
||||
except Exception as e:
|
||||
print(f"\033[91mERROR for {model_name}: {e}\033[0m")
|
||||
continue
|
||||
|
||||
# Get provider config and supported params
|
||||
print("LLM provider", provider)
|
||||
provider_enum = LlmProviders(provider)
|
||||
config = config_manager.get_provider_chat_config(model, provider_enum)
|
||||
supported_params = config.get_supported_openai_params(model)
|
||||
print("supported_params", supported_params)
|
||||
|
||||
# Check tool_choice support
|
||||
supports_tool_choice_result = litellm.utils.supports_tool_choice(
|
||||
model=model_name, custom_llm_provider=provider
|
||||
)
|
||||
tool_choice_in_params = "tool_choice" in supported_params
|
||||
|
||||
assert supports_tool_choice_result == tool_choice_in_params, (
|
||||
f"Tool choice support mismatch for {model_name}:\n"
|
||||
f"supports_tool_choice() returned: {supports_tool_choice_result}\n"
|
||||
f"tool_choice in supported params: {tool_choice_in_params}"
|
||||
)
|
|
@ -768,6 +768,7 @@ def test_bedrock_system_prompt(system, model):
|
|||
def test_bedrock_claude_3_tool_calling():
|
||||
try:
|
||||
litellm.set_verbose = True
|
||||
litellm._turn_on_debug()
|
||||
tools = [
|
||||
{
|
||||
"type": "function",
|
||||
|
|
28
tests/llm_translation/test_bedrock_nova_json.py
Normal file
28
tests/llm_translation/test_bedrock_nova_json.py
Normal file
|
@ -0,0 +1,28 @@
|
|||
from base_llm_unit_tests import BaseLLMChatTest
|
||||
import pytest
|
||||
import sys
|
||||
import os
|
||||
|
||||
sys.path.insert(
|
||||
0, os.path.abspath("../..")
|
||||
) # Adds the parent directory to the system path
|
||||
import litellm
|
||||
|
||||
|
||||
class TestBedrockNovaJson(BaseLLMChatTest):
|
||||
def get_base_completion_call_args(self) -> dict:
|
||||
litellm._turn_on_debug()
|
||||
return {
|
||||
"model": "bedrock/converse/us.amazon.nova-micro-v1:0",
|
||||
}
|
||||
|
||||
def test_tool_call_no_arguments(self, tool_call_no_arguments):
|
||||
"""Test that tool calls with no arguments is translated correctly. Relevant issue: https://github.com/BerriAI/litellm/issues/6833"""
|
||||
pass
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def skip_non_json_tests(self, request):
|
||||
if not "json" in request.function.__name__.lower():
|
||||
pytest.skip(
|
||||
f"Skipping non-JSON test: {request.function.__name__} does not contain 'json'"
|
||||
)
|
|
@ -8,7 +8,7 @@ from typing import List, Dict, Any
|
|||
|
||||
sys.path.insert(
|
||||
0, os.path.abspath("../..")
|
||||
) # Adds the parent directory to the system path
|
||||
) # Adds the parent directory to the system-path
|
||||
import pytest
|
||||
|
||||
import litellm
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue