mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-26 03:04:13 +00:00
(Feat) - Add support for structured output on bedrock/nova
models + add util litellm.supports_tool_choice
(#8264)
* fix supports_tool_choice * TestBedrockNovaJson * use supports_tool_choice * fix supports_tool_choice * add supports_tool_choice param * script to add fields to model cost map * test_supports_tool_choice * test_supports_tool_choice * fix supports tool choice check * test_supports_tool_choice_simple_tests * fix supports_tool_choice check * fix supports_tool_choice bedrock * test_supports_tool_choice * test_supports_tool_choice * fix bedrock/eu-west-3/mistral.mistral-large-2402-v1:0 * ci/cd run again * test_supports_tool_choice_simple_tests * TestGoogleAIStudioGemini temp - remove to run ci/cd * test_aaalangfuse_logging_metadata * TestGoogleAIStudioGemini * test_check_provider_match * remove add param to map
This commit is contained in:
parent
c743475aba
commit
3a6349d871
10 changed files with 226 additions and 9 deletions
|
@ -227,3 +227,7 @@ class BaseConfig(ABC):
|
||||||
json_mode: Optional[bool] = False,
|
json_mode: Optional[bool] = False,
|
||||||
) -> Any:
|
) -> Any:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
@property
|
||||||
|
def custom_llm_provider(self) -> Optional[str]:
|
||||||
|
return None
|
||||||
|
|
|
@ -68,6 +68,10 @@ class AmazonConverseConfig(BaseConfig):
|
||||||
if key != "self" and value is not None:
|
if key != "self" and value is not None:
|
||||||
setattr(self.__class__, key, value)
|
setattr(self.__class__, key, value)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def custom_llm_provider(self) -> Optional[str]:
|
||||||
|
return "bedrock_converse"
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_config(cls):
|
def get_config(cls):
|
||||||
return {
|
return {
|
||||||
|
@ -112,7 +116,9 @@ class AmazonConverseConfig(BaseConfig):
|
||||||
):
|
):
|
||||||
supported_params.append("tools")
|
supported_params.append("tools")
|
||||||
|
|
||||||
if base_model.startswith("anthropic") or base_model.startswith("mistral"):
|
if litellm.utils.supports_tool_choice(
|
||||||
|
model=model, custom_llm_provider=self.custom_llm_provider
|
||||||
|
):
|
||||||
# only anthropic and mistral support tool choice config. otherwise (E.g. cohere) will fail the call - https://docs.aws.amazon.com/bedrock/latest/APIReference/API_runtime_ToolChoice.html
|
# only anthropic and mistral support tool choice config. otherwise (E.g. cohere) will fail the call - https://docs.aws.amazon.com/bedrock/latest/APIReference/API_runtime_ToolChoice.html
|
||||||
supported_params.append("tool_choice")
|
supported_params.append("tool_choice")
|
||||||
|
|
||||||
|
@ -224,6 +230,9 @@ class AmazonConverseConfig(BaseConfig):
|
||||||
schema_name=schema_name if schema_name != "" else "json_tool_call",
|
schema_name=schema_name if schema_name != "" else "json_tool_call",
|
||||||
)
|
)
|
||||||
optional_params["tools"] = [_tool]
|
optional_params["tools"] = [_tool]
|
||||||
|
if litellm.utils.supports_tool_choice(
|
||||||
|
model=model, custom_llm_provider=self.custom_llm_provider
|
||||||
|
):
|
||||||
optional_params["tool_choice"] = ToolChoiceValuesBlock(
|
optional_params["tool_choice"] = ToolChoiceValuesBlock(
|
||||||
tool=SpecificToolChoiceBlock(
|
tool=SpecificToolChoiceBlock(
|
||||||
name=schema_name if schema_name != "" else "json_tool_call"
|
name=schema_name if schema_name != "" else "json_tool_call"
|
||||||
|
|
|
@ -5996,7 +5996,8 @@
|
||||||
"output_cost_per_token": 0.0000312,
|
"output_cost_per_token": 0.0000312,
|
||||||
"litellm_provider": "bedrock",
|
"litellm_provider": "bedrock",
|
||||||
"mode": "chat",
|
"mode": "chat",
|
||||||
"supports_function_calling": true
|
"supports_function_calling": true,
|
||||||
|
"supports_tool_choice": true
|
||||||
},
|
},
|
||||||
"amazon.nova-micro-v1:0": {
|
"amazon.nova-micro-v1:0": {
|
||||||
"max_tokens": 4096,
|
"max_tokens": 4096,
|
||||||
|
|
|
@ -81,6 +81,7 @@ class ProviderSpecificModelInfo(TypedDict, total=False):
|
||||||
supports_response_schema: Optional[bool]
|
supports_response_schema: Optional[bool]
|
||||||
supports_vision: Optional[bool]
|
supports_vision: Optional[bool]
|
||||||
supports_function_calling: Optional[bool]
|
supports_function_calling: Optional[bool]
|
||||||
|
supports_tool_choice: Optional[bool]
|
||||||
supports_assistant_prefill: Optional[bool]
|
supports_assistant_prefill: Optional[bool]
|
||||||
supports_prompt_caching: Optional[bool]
|
supports_prompt_caching: Optional[bool]
|
||||||
supports_audio_input: Optional[bool]
|
supports_audio_input: Optional[bool]
|
||||||
|
|
|
@ -1938,6 +1938,15 @@ def supports_function_calling(
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def supports_tool_choice(model: str, custom_llm_provider: Optional[str] = None) -> bool:
|
||||||
|
"""
|
||||||
|
Check if the given model supports `tool_choice` and return a boolean value.
|
||||||
|
"""
|
||||||
|
return _supports_factory(
|
||||||
|
model=model, custom_llm_provider=custom_llm_provider, key="supports_tool_choice"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def _supports_factory(model: str, custom_llm_provider: Optional[str], key: str) -> bool:
|
def _supports_factory(model: str, custom_llm_provider: Optional[str], key: str) -> bool:
|
||||||
"""
|
"""
|
||||||
Check if the given model supports function calling and return a boolean value.
|
Check if the given model supports function calling and return a boolean value.
|
||||||
|
@ -4018,7 +4027,7 @@ def _check_provider_match(model_info: dict, custom_llm_provider: Optional[str])
|
||||||
"litellm_provider"
|
"litellm_provider"
|
||||||
].startswith("fireworks_ai"):
|
].startswith("fireworks_ai"):
|
||||||
return True
|
return True
|
||||||
elif custom_llm_provider == "bedrock" and model_info[
|
elif custom_llm_provider.startswith("bedrock") and model_info[
|
||||||
"litellm_provider"
|
"litellm_provider"
|
||||||
].startswith("bedrock"):
|
].startswith("bedrock"):
|
||||||
return True
|
return True
|
||||||
|
@ -4189,6 +4198,7 @@ def _get_model_info_helper( # noqa: PLR0915
|
||||||
supports_system_messages=None,
|
supports_system_messages=None,
|
||||||
supports_response_schema=None,
|
supports_response_schema=None,
|
||||||
supports_function_calling=None,
|
supports_function_calling=None,
|
||||||
|
supports_tool_choice=None,
|
||||||
supports_assistant_prefill=None,
|
supports_assistant_prefill=None,
|
||||||
supports_prompt_caching=None,
|
supports_prompt_caching=None,
|
||||||
supports_pdf_input=None,
|
supports_pdf_input=None,
|
||||||
|
@ -4333,6 +4343,7 @@ def _get_model_info_helper( # noqa: PLR0915
|
||||||
supports_function_calling=_model_info.get(
|
supports_function_calling=_model_info.get(
|
||||||
"supports_function_calling", False
|
"supports_function_calling", False
|
||||||
),
|
),
|
||||||
|
supports_tool_choice=_model_info.get("supports_tool_choice", False),
|
||||||
supports_assistant_prefill=_model_info.get(
|
supports_assistant_prefill=_model_info.get(
|
||||||
"supports_assistant_prefill", False
|
"supports_assistant_prefill", False
|
||||||
),
|
),
|
||||||
|
@ -4411,6 +4422,7 @@ def get_model_info(model: str, custom_llm_provider: Optional[str] = None) -> Mod
|
||||||
supports_response_schema: Optional[bool]
|
supports_response_schema: Optional[bool]
|
||||||
supports_vision: Optional[bool]
|
supports_vision: Optional[bool]
|
||||||
supports_function_calling: Optional[bool]
|
supports_function_calling: Optional[bool]
|
||||||
|
supports_tool_choice: Optional[bool]
|
||||||
supports_prompt_caching: Optional[bool]
|
supports_prompt_caching: Optional[bool]
|
||||||
supports_audio_input: Optional[bool]
|
supports_audio_input: Optional[bool]
|
||||||
supports_audio_output: Optional[bool]
|
supports_audio_output: Optional[bool]
|
||||||
|
|
|
@ -5996,7 +5996,8 @@
|
||||||
"output_cost_per_token": 0.0000312,
|
"output_cost_per_token": 0.0000312,
|
||||||
"litellm_provider": "bedrock",
|
"litellm_provider": "bedrock",
|
||||||
"mode": "chat",
|
"mode": "chat",
|
||||||
"supports_function_calling": true
|
"supports_function_calling": true,
|
||||||
|
"supports_tool_choice": true
|
||||||
},
|
},
|
||||||
"amazon.nova-micro-v1:0": {
|
"amazon.nova-micro-v1:0": {
|
||||||
"max_tokens": 4096,
|
"max_tokens": 4096,
|
||||||
|
|
160
tests/litellm_utils_tests/test_supports_tool_choice.py
Normal file
160
tests/litellm_utils_tests/test_supports_tool_choice.py
Normal file
|
@ -0,0 +1,160 @@
|
||||||
|
import json
|
||||||
|
import os
|
||||||
|
import sys
|
||||||
|
from unittest.mock import patch
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
# Add parent directory to system path
|
||||||
|
sys.path.insert(0, os.path.abspath("../.."))
|
||||||
|
|
||||||
|
import litellm
|
||||||
|
from litellm.utils import get_llm_provider, ProviderConfigManager, _check_provider_match
|
||||||
|
from litellm import LlmProviders
|
||||||
|
|
||||||
|
|
||||||
|
def test_supports_tool_choice_simple_tests():
|
||||||
|
"""
|
||||||
|
simple sanity checks
|
||||||
|
"""
|
||||||
|
assert litellm.utils.supports_tool_choice(model="gpt-4o") == True
|
||||||
|
assert (
|
||||||
|
litellm.utils.supports_tool_choice(
|
||||||
|
model="bedrock/anthropic.claude-3-sonnet-20240229-v1:0"
|
||||||
|
)
|
||||||
|
== True
|
||||||
|
)
|
||||||
|
assert (
|
||||||
|
litellm.utils.supports_tool_choice(
|
||||||
|
model="anthropic.claude-3-sonnet-20240229-v1:0"
|
||||||
|
)
|
||||||
|
is True
|
||||||
|
)
|
||||||
|
|
||||||
|
assert (
|
||||||
|
litellm.utils.supports_tool_choice(
|
||||||
|
model="anthropic.claude-3-sonnet-20240229-v1:0",
|
||||||
|
custom_llm_provider="bedrock_converse",
|
||||||
|
)
|
||||||
|
is True
|
||||||
|
)
|
||||||
|
|
||||||
|
assert (
|
||||||
|
litellm.utils.supports_tool_choice(model="us.amazon.nova-micro-v1:0") is False
|
||||||
|
)
|
||||||
|
assert (
|
||||||
|
litellm.utils.supports_tool_choice(model="bedrock/us.amazon.nova-micro-v1:0")
|
||||||
|
is False
|
||||||
|
)
|
||||||
|
assert (
|
||||||
|
litellm.utils.supports_tool_choice(
|
||||||
|
model="us.amazon.nova-micro-v1:0", custom_llm_provider="bedrock_converse"
|
||||||
|
)
|
||||||
|
is False
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_check_provider_match():
|
||||||
|
"""
|
||||||
|
Test the _check_provider_match function for various provider scenarios
|
||||||
|
"""
|
||||||
|
# Test bedrock and bedrock_converse cases
|
||||||
|
model_info = {"litellm_provider": "bedrock"}
|
||||||
|
assert litellm.utils._check_provider_match(model_info, "bedrock") is True
|
||||||
|
assert litellm.utils._check_provider_match(model_info, "bedrock_converse") is True
|
||||||
|
|
||||||
|
# Test bedrock_converse provider
|
||||||
|
model_info = {"litellm_provider": "bedrock_converse"}
|
||||||
|
assert litellm.utils._check_provider_match(model_info, "bedrock") is True
|
||||||
|
assert litellm.utils._check_provider_match(model_info, "bedrock_converse") is True
|
||||||
|
|
||||||
|
# Test non-matching provider
|
||||||
|
model_info = {"litellm_provider": "bedrock"}
|
||||||
|
assert litellm.utils._check_provider_match(model_info, "openai") is False
|
||||||
|
|
||||||
|
|
||||||
|
# Models that should be skipped during testing
|
||||||
|
OLD_PROVIDERS = ["aleph_alpha", "palm"]
|
||||||
|
SKIP_MODELS = ["azure/mistral", "azure/command-r", "jamba", "deepinfra", "mistral."]
|
||||||
|
|
||||||
|
# Bedrock models to block - organized by type
|
||||||
|
BEDROCK_REGIONS = ["ap-northeast-1", "eu-central-1", "us-east-1", "us-west-2"]
|
||||||
|
BEDROCK_COMMITMENTS = ["1-month-commitment", "6-month-commitment"]
|
||||||
|
BEDROCK_MODELS = {
|
||||||
|
"anthropic.claude-v1",
|
||||||
|
"anthropic.claude-v2",
|
||||||
|
"anthropic.claude-v2:1",
|
||||||
|
"anthropic.claude-instant-v1",
|
||||||
|
}
|
||||||
|
|
||||||
|
# Generate block_list dynamically
|
||||||
|
block_list = set()
|
||||||
|
for region in BEDROCK_REGIONS:
|
||||||
|
for commitment in BEDROCK_COMMITMENTS:
|
||||||
|
for model in BEDROCK_MODELS:
|
||||||
|
block_list.add(f"bedrock/{region}/{commitment}/{model}")
|
||||||
|
block_list.add(f"bedrock/{region}/{model}")
|
||||||
|
|
||||||
|
# Add Cohere models
|
||||||
|
for commitment in BEDROCK_COMMITMENTS:
|
||||||
|
block_list.add(f"bedrock/*/{commitment}/cohere.command-text-v14")
|
||||||
|
block_list.add(f"bedrock/*/{commitment}/cohere.command-light-text-v14")
|
||||||
|
|
||||||
|
print("block_list", block_list)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_supports_tool_choice():
|
||||||
|
"""
|
||||||
|
Test that litellm.utils.supports_tool_choice() returns the correct value
|
||||||
|
for all models in model_prices_and_context_window.json.
|
||||||
|
|
||||||
|
The test:
|
||||||
|
1. Loads model pricing data
|
||||||
|
2. Iterates through each model
|
||||||
|
3. Checks if tool_choice support matches the model's supported parameters
|
||||||
|
"""
|
||||||
|
# Load model prices
|
||||||
|
litellm._turn_on_debug()
|
||||||
|
with open("./model_prices_and_context_window.json", "r") as f:
|
||||||
|
model_prices = json.load(f)
|
||||||
|
litellm.model_cost = model_prices
|
||||||
|
config_manager = ProviderConfigManager()
|
||||||
|
|
||||||
|
for model_name, model_info in model_prices.items():
|
||||||
|
print(f"testing model: {model_name}")
|
||||||
|
|
||||||
|
# Skip certain models
|
||||||
|
if (
|
||||||
|
model_name == "sample_spec"
|
||||||
|
or model_info.get("mode") != "chat"
|
||||||
|
or any(skip in model_name for skip in SKIP_MODELS)
|
||||||
|
or any(provider in model_name for provider in OLD_PROVIDERS)
|
||||||
|
or model_info["litellm_provider"] in OLD_PROVIDERS
|
||||||
|
or model_name in block_list
|
||||||
|
):
|
||||||
|
continue
|
||||||
|
|
||||||
|
try:
|
||||||
|
model, provider, _, _ = get_llm_provider(model=model_name)
|
||||||
|
except Exception as e:
|
||||||
|
print(f"\033[91mERROR for {model_name}: {e}\033[0m")
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Get provider config and supported params
|
||||||
|
print("LLM provider", provider)
|
||||||
|
provider_enum = LlmProviders(provider)
|
||||||
|
config = config_manager.get_provider_chat_config(model, provider_enum)
|
||||||
|
supported_params = config.get_supported_openai_params(model)
|
||||||
|
print("supported_params", supported_params)
|
||||||
|
|
||||||
|
# Check tool_choice support
|
||||||
|
supports_tool_choice_result = litellm.utils.supports_tool_choice(
|
||||||
|
model=model_name, custom_llm_provider=provider
|
||||||
|
)
|
||||||
|
tool_choice_in_params = "tool_choice" in supported_params
|
||||||
|
|
||||||
|
assert supports_tool_choice_result == tool_choice_in_params, (
|
||||||
|
f"Tool choice support mismatch for {model_name}:\n"
|
||||||
|
f"supports_tool_choice() returned: {supports_tool_choice_result}\n"
|
||||||
|
f"tool_choice in supported params: {tool_choice_in_params}"
|
||||||
|
)
|
|
@ -768,6 +768,7 @@ def test_bedrock_system_prompt(system, model):
|
||||||
def test_bedrock_claude_3_tool_calling():
|
def test_bedrock_claude_3_tool_calling():
|
||||||
try:
|
try:
|
||||||
litellm.set_verbose = True
|
litellm.set_verbose = True
|
||||||
|
litellm._turn_on_debug()
|
||||||
tools = [
|
tools = [
|
||||||
{
|
{
|
||||||
"type": "function",
|
"type": "function",
|
||||||
|
|
28
tests/llm_translation/test_bedrock_nova_json.py
Normal file
28
tests/llm_translation/test_bedrock_nova_json.py
Normal file
|
@ -0,0 +1,28 @@
|
||||||
|
from base_llm_unit_tests import BaseLLMChatTest
|
||||||
|
import pytest
|
||||||
|
import sys
|
||||||
|
import os
|
||||||
|
|
||||||
|
sys.path.insert(
|
||||||
|
0, os.path.abspath("../..")
|
||||||
|
) # Adds the parent directory to the system path
|
||||||
|
import litellm
|
||||||
|
|
||||||
|
|
||||||
|
class TestBedrockNovaJson(BaseLLMChatTest):
|
||||||
|
def get_base_completion_call_args(self) -> dict:
|
||||||
|
litellm._turn_on_debug()
|
||||||
|
return {
|
||||||
|
"model": "bedrock/converse/us.amazon.nova-micro-v1:0",
|
||||||
|
}
|
||||||
|
|
||||||
|
def test_tool_call_no_arguments(self, tool_call_no_arguments):
|
||||||
|
"""Test that tool calls with no arguments is translated correctly. Relevant issue: https://github.com/BerriAI/litellm/issues/6833"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@pytest.fixture(autouse=True)
|
||||||
|
def skip_non_json_tests(self, request):
|
||||||
|
if not "json" in request.function.__name__.lower():
|
||||||
|
pytest.skip(
|
||||||
|
f"Skipping non-JSON test: {request.function.__name__} does not contain 'json'"
|
||||||
|
)
|
|
@ -8,7 +8,7 @@ from typing import List, Dict, Any
|
||||||
|
|
||||||
sys.path.insert(
|
sys.path.insert(
|
||||||
0, os.path.abspath("../..")
|
0, os.path.abspath("../..")
|
||||||
) # Adds the parent directory to the system path
|
) # Adds the parent directory to the system-path
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
import litellm
|
import litellm
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue