Fix calling claude via invoke route + response_format support for claude on invoke route (#8908)

* fix(anthropic_claude3_transformation.py): fix amazon anthropic claude 3 tool calling transformation on invoke route

move to using anthropic config as base

* fix(utils.py): expose anthropic config via providerconfigmanager

* fix(llm_http_handler.py): support json mode on async completion calls

* fix(invoke_handler/make_call): support json mode for anthropic called via bedrock invoke

* fix(anthropic/): handle 'response_format: {"type": "text"}` + migrate amazon claude 3 invoke config to inherit from anthropic config

Prevents error when passing in 'response_format: {"type": "text"}

* test: fix test

* fix(utils.py): fix base invoke provider check

* fix(anthropic_claude3_transformation.py): don't pass 'stream' param

* fix: fix linting errors

* fix(converse_transformation.py): handle response_format type=text for converse
This commit is contained in:
Krish Dholakia 2025-02-28 17:56:26 -08:00 committed by GitHub
parent 8f86959c32
commit a65bfab697
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
18 changed files with 444 additions and 139 deletions

View file

@ -53,6 +53,7 @@ from litellm.constants import (
cohere_embedding_models,
bedrock_embedding_models,
known_tokenizer_config,
BEDROCK_INVOKE_PROVIDERS_LITERAL,
)
from litellm.types.guardrails import GuardrailItem
from litellm.proxy._types import (
@ -361,17 +362,7 @@ BEDROCK_CONVERSE_MODELS = [
"meta.llama3-2-11b-instruct-v1:0",
"meta.llama3-2-90b-instruct-v1:0",
]
BEDROCK_INVOKE_PROVIDERS_LITERAL = Literal[
"cohere",
"anthropic",
"mistral",
"amazon",
"meta",
"llama",
"ai21",
"nova",
"deepseek_r1",
]
####### COMPLETION MODELS ###################
open_ai_chat_completion_models: List = []
open_ai_text_completion_models: List = []

View file

@ -1,4 +1,4 @@
from typing import List
from typing import List, Literal
ROUTER_MAX_FALLBACKS = 5
DEFAULT_BATCH_SIZE = 512
@ -320,6 +320,17 @@ baseten_models: List = [
"31dxrj3",
] # FALCON 7B # WizardLM # Mosaic ML
BEDROCK_INVOKE_PROVIDERS_LITERAL = Literal[
"cohere",
"anthropic",
"mistral",
"amazon",
"meta",
"llama",
"ai21",
"nova",
"deepseek_r1",
]
open_ai_embedding_models: List = ["text-embedding-ada-002"]
cohere_embedding_models: List = [

View file

@ -308,7 +308,6 @@ class AnthropicConfig(BaseConfig):
model: str,
drop_params: bool,
) -> dict:
for param, value in non_default_params.items():
if param == "max_tokens":
optional_params["max_tokens"] = value
@ -342,6 +341,10 @@ class AnthropicConfig(BaseConfig):
optional_params["top_p"] = value
if param == "response_format" and isinstance(value, dict):
ignore_response_format_types = ["text"]
if value["type"] in ignore_response_format_types: # value is a no-op
continue
json_schema: Optional[dict] = None
if "response_schema" in value:
json_schema = value["response_schema"]

View file

@ -317,6 +317,7 @@ class BaseConfig(ABC):
data: dict,
messages: list,
client: Optional[AsyncHTTPHandler] = None,
json_mode: Optional[bool] = None,
) -> CustomStreamWrapper:
raise NotImplementedError
@ -330,6 +331,7 @@ class BaseConfig(ABC):
data: dict,
messages: list,
client: Optional[Union[HTTPHandler, AsyncHTTPHandler]] = None,
json_mode: Optional[bool] = None,
) -> CustomStreamWrapper:
raise NotImplementedError

View file

@ -2,13 +2,14 @@ import hashlib
import json
import os
from datetime import datetime
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, cast, get_args
import httpx
from pydantic import BaseModel
from litellm._logging import verbose_logger
from litellm.caching.caching import DualCache
from litellm.constants import BEDROCK_INVOKE_PROVIDERS_LITERAL
from litellm.litellm_core_utils.dd_tracing import tracer
from litellm.secret_managers.main import get_secret
@ -223,6 +224,60 @@ class BaseAWSLLM:
# Catch any unexpected errors and return None
return None
@staticmethod
def _get_provider_from_model_path(
model_path: str,
) -> Optional[BEDROCK_INVOKE_PROVIDERS_LITERAL]:
"""
Helper function to get the provider from a model path with format: provider/model-name
Args:
model_path (str): The model path (e.g., 'llama/arn:aws:bedrock:us-east-1:086734376398:imported-model/r4c4kewx2s0n' or 'anthropic/model-name')
Returns:
Optional[str]: The provider name, or None if no valid provider found
"""
parts = model_path.split("/")
if len(parts) >= 1:
provider = parts[0]
if provider in get_args(BEDROCK_INVOKE_PROVIDERS_LITERAL):
return cast(BEDROCK_INVOKE_PROVIDERS_LITERAL, provider)
return None
@staticmethod
def get_bedrock_invoke_provider(
model: str,
) -> Optional[BEDROCK_INVOKE_PROVIDERS_LITERAL]:
"""
Helper function to get the bedrock provider from the model
handles 3 scenarions:
1. model=invoke/anthropic.claude-3-5-sonnet-20240620-v1:0 -> Returns `anthropic`
2. model=anthropic.claude-3-5-sonnet-20240620-v1:0 -> Returns `anthropic`
3. model=llama/arn:aws:bedrock:us-east-1:086734376398:imported-model/r4c4kewx2s0n -> Returns `llama`
4. model=us.amazon.nova-pro-v1:0 -> Returns `nova`
"""
if model.startswith("invoke/"):
model = model.replace("invoke/", "", 1)
_split_model = model.split(".")[0]
if _split_model in get_args(BEDROCK_INVOKE_PROVIDERS_LITERAL):
return cast(BEDROCK_INVOKE_PROVIDERS_LITERAL, _split_model)
# If not a known provider, check for pattern with two slashes
provider = BaseAWSLLM._get_provider_from_model_path(model)
if provider is not None:
return provider
# check if provider == "nova"
if "nova" in model:
return "nova"
else:
for provider in get_args(BEDROCK_INVOKE_PROVIDERS_LITERAL):
if provider in model:
return provider
return None
def _get_aws_region_name(
self, optional_params: dict, model: Optional[str] = None
) -> str:

View file

@ -206,7 +206,12 @@ class AmazonConverseConfig(BaseConfig):
messages: Optional[List[AllMessageValues]] = None,
) -> dict:
for param, value in non_default_params.items():
if param == "response_format":
if param == "response_format" and isinstance(value, dict):
ignore_response_format_types = ["text"]
if value["type"] in ignore_response_format_types: # value is a no-op
continue
json_schema: Optional[dict] = None
schema_name: str = ""
if "response_schema" in value:

View file

@ -226,6 +226,7 @@ async def make_call(
decoder: AWSEventStreamDecoder = AmazonAnthropicClaudeStreamDecoder(
model=model,
sync_stream=False,
json_mode=json_mode,
)
completion_stream = decoder.aiter_bytes(
response.aiter_bytes(chunk_size=1024)
@ -311,6 +312,7 @@ def make_sync_call(
decoder: AWSEventStreamDecoder = AmazonAnthropicClaudeStreamDecoder(
model=model,
sync_stream=True,
json_mode=json_mode,
)
completion_stream = decoder.iter_bytes(response.iter_bytes(chunk_size=1024))
elif bedrock_invoke_provider == "deepseek_r1":
@ -1149,27 +1151,6 @@ class BedrockLLM(BaseAWSLLM):
)
return streaming_response
@staticmethod
def get_bedrock_invoke_provider(
model: str,
) -> Optional[litellm.BEDROCK_INVOKE_PROVIDERS_LITERAL]:
"""
Helper function to get the bedrock provider from the model
handles 2 scenarions:
1. model=anthropic.claude-3-5-sonnet-20240620-v1:0 -> Returns `anthropic`
2. model=llama/arn:aws:bedrock:us-east-1:086734376398:imported-model/r4c4kewx2s0n -> Returns `llama`
"""
_split_model = model.split(".")[0]
if _split_model in get_args(litellm.BEDROCK_INVOKE_PROVIDERS_LITERAL):
return cast(litellm.BEDROCK_INVOKE_PROVIDERS_LITERAL, _split_model)
# If not a known provider, check for pattern with two slashes
provider = BedrockLLM._get_provider_from_model_path(model)
if provider is not None:
return provider
return None
@staticmethod
def _get_provider_from_model_path(
model_path: str,
@ -1524,6 +1505,7 @@ class AmazonAnthropicClaudeStreamDecoder(AWSEventStreamDecoder):
self,
model: str,
sync_stream: bool,
json_mode: Optional[bool] = None,
) -> None:
"""
Child class of AWSEventStreamDecoder that handles the streaming response from the Anthropic family of models
@ -1534,6 +1516,7 @@ class AmazonAnthropicClaudeStreamDecoder(AWSEventStreamDecoder):
self.anthropic_model_response_iterator = AnthropicModelResponseIterator(
streaming_response=None,
sync_stream=sync_stream,
json_mode=json_mode,
)
def _chunk_parser(self, chunk_data: dict) -> ModelResponseStream:

View file

@ -3,8 +3,10 @@ from typing import Optional
import litellm
from .base_invoke_transformation import AmazonInvokeConfig
class AmazonAnthropicConfig:
class AmazonAnthropicConfig(AmazonInvokeConfig):
"""
Reference: https://us-west-2.console.aws.amazon.com/bedrock/home?region=us-west-2#/providers?model=claude
@ -57,9 +59,7 @@ class AmazonAnthropicConfig:
and v is not None
}
def get_supported_openai_params(
self,
):
def get_supported_openai_params(self, model: str):
return [
"max_tokens",
"max_completion_tokens",
@ -69,7 +69,13 @@ class AmazonAnthropicConfig:
"stream",
]
def map_openai_params(self, non_default_params: dict, optional_params: dict):
def map_openai_params(
self,
non_default_params: dict,
optional_params: dict,
model: str,
drop_params: bool,
):
for param, value in non_default_params.items():
if param == "max_tokens" or param == "max_completion_tokens":
optional_params["max_tokens_to_sample"] = value

View file

@ -2,7 +2,7 @@ from typing import TYPE_CHECKING, Any, List, Optional
import httpx
import litellm
from litellm.llms.anthropic.chat.transformation import AnthropicConfig
from litellm.llms.bedrock.chat.invoke_transformations.base_invoke_transformation import (
AmazonInvokeConfig,
)
@ -17,7 +17,7 @@ else:
LiteLLMLoggingObj = Any
class AmazonAnthropicClaude3Config(AmazonInvokeConfig):
class AmazonAnthropicClaude3Config(AmazonInvokeConfig, AnthropicConfig):
"""
Reference:
https://us-west-2.console.aws.amazon.com/bedrock/home?region=us-west-2#/providers?model=claude
@ -28,18 +28,8 @@ class AmazonAnthropicClaude3Config(AmazonInvokeConfig):
anthropic_version: str = "bedrock-2023-05-31"
def get_supported_openai_params(self, model: str):
return [
"max_tokens",
"max_completion_tokens",
"tools",
"tool_choice",
"stream",
"stop",
"temperature",
"top_p",
"extra_headers",
]
def get_supported_openai_params(self, model: str) -> List[str]:
return AnthropicConfig.get_supported_openai_params(self, model)
def map_openai_params(
self,
@ -47,21 +37,14 @@ class AmazonAnthropicClaude3Config(AmazonInvokeConfig):
optional_params: dict,
model: str,
drop_params: bool,
):
for param, value in non_default_params.items():
if param == "max_tokens" or param == "max_completion_tokens":
optional_params["max_tokens"] = value
if param == "tools":
optional_params["tools"] = value
if param == "stream":
optional_params["stream"] = value
if param == "stop":
optional_params["stop_sequences"] = value
if param == "temperature":
optional_params["temperature"] = value
if param == "top_p":
optional_params["top_p"] = value
return optional_params
) -> dict:
return AnthropicConfig.map_openai_params(
self,
non_default_params,
optional_params,
model,
drop_params,
)
def transform_request(
self,
@ -71,7 +54,8 @@ class AmazonAnthropicClaude3Config(AmazonInvokeConfig):
litellm_params: dict,
headers: dict,
) -> dict:
_anthropic_request = litellm.AnthropicConfig().transform_request(
_anthropic_request = AnthropicConfig.transform_request(
self,
model=model,
messages=messages,
optional_params=optional_params,
@ -80,6 +64,7 @@ class AmazonAnthropicClaude3Config(AmazonInvokeConfig):
)
_anthropic_request.pop("model", None)
_anthropic_request.pop("stream", None)
if "anthropic_version" not in _anthropic_request:
_anthropic_request["anthropic_version"] = self.anthropic_version
@ -99,7 +84,8 @@ class AmazonAnthropicClaude3Config(AmazonInvokeConfig):
api_key: Optional[str] = None,
json_mode: Optional[bool] = None,
) -> ModelResponse:
return litellm.AnthropicConfig().transform_response(
return AnthropicConfig.transform_response(
self,
model=model,
raw_response=raw_response,
model_response=model_response,

View file

@ -3,7 +3,7 @@ import json
import time
import urllib.parse
from functools import partial
from typing import TYPE_CHECKING, Any, List, Optional, Tuple, Union, cast, get_args
from typing import TYPE_CHECKING, Any, List, Optional, Tuple, Union
import httpx
@ -461,6 +461,7 @@ class AmazonInvokeConfig(BaseConfig, BaseAWSLLM):
data: dict,
messages: list,
client: Optional[AsyncHTTPHandler] = None,
json_mode: Optional[bool] = None,
) -> CustomStreamWrapper:
streaming_response = CustomStreamWrapper(
completion_stream=None,
@ -475,6 +476,7 @@ class AmazonInvokeConfig(BaseConfig, BaseAWSLLM):
logging_obj=logging_obj,
fake_stream=True if "ai21" in api_base else False,
bedrock_invoke_provider=self.get_bedrock_invoke_provider(model),
json_mode=json_mode,
),
model=model,
custom_llm_provider="bedrock",
@ -493,6 +495,7 @@ class AmazonInvokeConfig(BaseConfig, BaseAWSLLM):
data: dict,
messages: list,
client: Optional[Union[HTTPHandler, AsyncHTTPHandler]] = None,
json_mode: Optional[bool] = None,
) -> CustomStreamWrapper:
if client is None or isinstance(client, AsyncHTTPHandler):
client = _get_httpx_client(params={})
@ -509,6 +512,7 @@ class AmazonInvokeConfig(BaseConfig, BaseAWSLLM):
logging_obj=logging_obj,
fake_stream=True if "ai21" in api_base else False,
bedrock_invoke_provider=self.get_bedrock_invoke_provider(model),
json_mode=json_mode,
),
model=model,
custom_llm_provider="bedrock",
@ -527,56 +531,6 @@ class AmazonInvokeConfig(BaseConfig, BaseAWSLLM):
"""
return False
@staticmethod
def get_bedrock_invoke_provider(
model: str,
) -> Optional[litellm.BEDROCK_INVOKE_PROVIDERS_LITERAL]:
"""
Helper function to get the bedrock provider from the model
handles 3 scenarions:
1. model=invoke/anthropic.claude-3-5-sonnet-20240620-v1:0 -> Returns `anthropic`
2. model=anthropic.claude-3-5-sonnet-20240620-v1:0 -> Returns `anthropic`
3. model=llama/arn:aws:bedrock:us-east-1:086734376398:imported-model/r4c4kewx2s0n -> Returns `llama`
4. model=us.amazon.nova-pro-v1:0 -> Returns `nova`
"""
if model.startswith("invoke/"):
model = model.replace("invoke/", "", 1)
_split_model = model.split(".")[0]
if _split_model in get_args(litellm.BEDROCK_INVOKE_PROVIDERS_LITERAL):
return cast(litellm.BEDROCK_INVOKE_PROVIDERS_LITERAL, _split_model)
# If not a known provider, check for pattern with two slashes
provider = AmazonInvokeConfig._get_provider_from_model_path(model)
if provider is not None:
return provider
# check if provider == "nova"
if "nova" in model:
return "nova"
return None
@staticmethod
def _get_provider_from_model_path(
model_path: str,
) -> Optional[litellm.BEDROCK_INVOKE_PROVIDERS_LITERAL]:
"""
Helper function to get the provider from a model path with format: provider/model-name
Args:
model_path (str): The model path (e.g., 'llama/arn:aws:bedrock:us-east-1:086734376398:imported-model/r4c4kewx2s0n' or 'anthropic/model-name')
Returns:
Optional[str]: The provider name, or None if no valid provider found
"""
parts = model_path.split("/")
if len(parts) >= 1:
provider = parts[0]
if provider in get_args(litellm.BEDROCK_INVOKE_PROVIDERS_LITERAL):
return cast(litellm.BEDROCK_INVOKE_PROVIDERS_LITERAL, provider)
return None
def get_bedrock_model_id(
self,
optional_params: dict,

View file

@ -159,6 +159,7 @@ class BaseLLMHTTPHandler:
encoding: Any,
api_key: Optional[str] = None,
client: Optional[AsyncHTTPHandler] = None,
json_mode: bool = False,
):
if client is None:
async_httpx_client = get_async_httpx_client(
@ -190,6 +191,7 @@ class BaseLLMHTTPHandler:
optional_params=optional_params,
litellm_params=litellm_params,
encoding=encoding,
json_mode=json_mode,
)
def completion(
@ -211,6 +213,7 @@ class BaseLLMHTTPHandler:
headers: Optional[dict] = {},
client: Optional[Union[HTTPHandler, AsyncHTTPHandler]] = None,
):
json_mode: bool = optional_params.pop("json_mode", False)
provider_config = ProviderConfigManager.get_provider_chat_config(
model=model, provider=litellm.LlmProviders(custom_llm_provider)
@ -286,6 +289,7 @@ class BaseLLMHTTPHandler:
else None
),
litellm_params=litellm_params,
json_mode=json_mode,
)
else:
@ -309,6 +313,7 @@ class BaseLLMHTTPHandler:
if client is not None and isinstance(client, AsyncHTTPHandler)
else None
),
json_mode=json_mode,
)
if stream is True:
@ -327,6 +332,7 @@ class BaseLLMHTTPHandler:
data=data,
messages=messages,
client=client,
json_mode=json_mode,
)
completion_stream, headers = self.make_sync_call(
provider_config=provider_config,
@ -380,6 +386,7 @@ class BaseLLMHTTPHandler:
optional_params=optional_params,
litellm_params=litellm_params,
encoding=encoding,
json_mode=json_mode,
)
def make_sync_call(
@ -453,6 +460,7 @@ class BaseLLMHTTPHandler:
litellm_params: dict,
fake_stream: bool = False,
client: Optional[AsyncHTTPHandler] = None,
json_mode: Optional[bool] = None,
):
if provider_config.has_custom_stream_wrapper is True:
return provider_config.get_async_custom_stream_wrapper(
@ -464,6 +472,7 @@ class BaseLLMHTTPHandler:
data=data,
messages=messages,
client=client,
json_mode=json_mode,
)
completion_stream, _response_headers = await self.make_async_call_stream_helper(
@ -720,7 +729,7 @@ class BaseLLMHTTPHandler:
api_base: Optional[str] = None,
client: Optional[Union[HTTPHandler, AsyncHTTPHandler]] = None,
) -> RerankResponse:
# get config from model, custom llm provider
headers = provider_config.validate_environment(
api_key=api_key,

File diff suppressed because one or more lines are too long

View file

@ -16,7 +16,10 @@ model_list:
api_key: os.environ/COHERE_API_KEY
- model_name: bedrock-claude-3-7
litellm_params:
model: bedrock/us.anthropic.claude-3-7-sonnet-20250219-v1:0
model: bedrock/invoke/us.anthropic.claude-3-7-sonnet-20250219-v1:0
- model_name: bedrock-claude-3-5-sonnet
litellm_params:
model: bedrock/invoke/anthropic.claude-3-5-sonnet-20240620-v1:0
litellm_settings:
callbacks: ["langfuse"]

View file

@ -3240,6 +3240,7 @@ def get_optional_params( # noqa: PLR0915
)
elif custom_llm_provider == "bedrock":
bedrock_route = BedrockModelInfo.get_bedrock_route(model)
bedrock_base_model = BedrockModelInfo.get_base_model(model)
if bedrock_route == "converse" or bedrock_route == "converse_like":
optional_params = litellm.AmazonConverseConfig().map_openai_params(
model=model,
@ -3253,8 +3254,9 @@ def get_optional_params( # noqa: PLR0915
messages=messages,
)
elif "anthropic" in model and bedrock_route == "invoke":
if model.startswith("anthropic.claude-3"):
elif "anthropic" in bedrock_base_model and bedrock_route == "invoke":
if bedrock_base_model.startswith("anthropic.claude-3"):
optional_params = (
litellm.AmazonAnthropicClaude3Config().map_openai_params(
non_default_params=non_default_params,
@ -3267,10 +3269,17 @@ def get_optional_params( # noqa: PLR0915
),
)
)
else:
optional_params = litellm.AmazonAnthropicConfig().map_openai_params(
non_default_params=non_default_params,
optional_params=optional_params,
model=model,
drop_params=(
drop_params
if drop_params is not None and isinstance(drop_params, bool)
else False
),
)
elif provider_config is not None:
optional_params = provider_config.map_openai_params(
@ -6158,13 +6167,19 @@ class ProviderConfigManager:
elif litellm.LlmProviders.BEDROCK == provider:
bedrock_route = BedrockModelInfo.get_bedrock_route(model)
bedrock_invoke_provider = litellm.BedrockLLM.get_bedrock_invoke_provider(
model
model=model
)
base_model = BedrockModelInfo.get_base_model(model)
if bedrock_route == "converse" or bedrock_route == "converse_like":
return litellm.AmazonConverseConfig()
elif bedrock_invoke_provider == "amazon": # amazon titan llms
return litellm.AmazonTitanConfig()
elif bedrock_invoke_provider == "anthropic":
if base_model.startswith("anthropic.claude-3"):
return litellm.AmazonAnthropicClaude3Config()
else:
return litellm.AmazonAnthropicConfig()
elif (
bedrock_invoke_provider == "meta" or bedrock_invoke_provider == "llama"
): # amazon / meta llms

View file

@ -18,6 +18,7 @@ from litellm.utils import (
CustomStreamWrapper,
get_supported_openai_params,
get_optional_params,
ProviderConfigManager,
)
from typing import Union
@ -247,12 +248,45 @@ class BaseLLMChatTest(ABC):
response_format=response_format,
)
print(response)
print(f"response={response}")
# OpenAI guarantees that the JSON schema is returned in the content
# relevant issue: https://github.com/BerriAI/litellm/issues/6741
assert response.choices[0].message.content is not None
def test_response_format_type_text(self):
"""
Test that the response format type text does not lead to tool calls
"""
from litellm import LlmProviders
base_completion_call_args = self.get_base_completion_call_args()
litellm.set_verbose = True
_, provider, _, _ = litellm.get_llm_provider(
model=base_completion_call_args["model"]
)
provider_config = ProviderConfigManager.get_provider_chat_config(
base_completion_call_args["model"], LlmProviders(provider)
)
print(f"provider_config={provider_config}")
translated_params = provider_config.map_openai_params(
non_default_params={"response_format": {"type": "text"}},
optional_params={},
model=base_completion_call_args["model"],
drop_params=False,
)
assert "tool_choice" not in translated_params
assert (
"tools" not in translated_params
), f"Got tools={translated_params['tools']}, expected no tools"
print(f"translated_params={translated_params}")
@pytest.mark.flaky(retries=6, delay=1)
def test_json_response_pydantic_obj(self):
litellm.set_verbose = True

View file

@ -22,17 +22,9 @@ class TestBedrockInvokeClaudeJson(BaseLLMChatTest):
"""Test that tool calls with no arguments is translated correctly. Relevant issue: https://github.com/BerriAI/litellm/issues/6833"""
pass
@pytest.fixture(autouse=True)
def skip_non_json_tests(self, request):
if not "json" in request.function.__name__.lower():
pytest.skip(
f"Skipping non-JSON test: {request.function.__name__} does not contain 'json'"
)
class TestBedrockInvokeNovaJson(BaseLLMChatTest):
def get_base_completion_call_args(self) -> dict:
litellm._turn_on_debug()
return {
"model": "bedrock/invoke/us.amazon.nova-micro-v1:0",
}

View file

@ -319,12 +319,15 @@ def test_all_model_configs():
) == {"max_tokens": 10}
assert (
"max_completion_tokens" in AmazonAnthropicConfig().get_supported_openai_params()
"max_completion_tokens"
in AmazonAnthropicConfig().get_supported_openai_params(model="")
)
assert AmazonAnthropicConfig().map_openai_params(
non_default_params={"max_completion_tokens": 10},
optional_params={},
model="",
drop_params=False,
) == {"max_tokens_to_sample": 10}
from litellm.llms.databricks.chat.handler import DatabricksConfig

View file

@ -1114,3 +1114,257 @@ def test_anthropic_thinking_param(model, expected_thinking):
assert "thinking" in optional_params
else:
assert "thinking" not in optional_params
def test_bedrock_invoke_anthropic_max_tokens():
passed_params = {
"model": "invoke/us.anthropic.claude-3-5-sonnet-20240620-v1:0",
"functions": None,
"function_call": None,
"temperature": 0.8,
"top_p": None,
"n": 1,
"stream": False,
"stream_options": None,
"stop": None,
"max_tokens": None,
"max_completion_tokens": 1024,
"modalities": None,
"prediction": None,
"audio": None,
"presence_penalty": None,
"frequency_penalty": None,
"logit_bias": None,
"user": None,
"custom_llm_provider": "bedrock",
"response_format": {"type": "text"},
"seed": None,
"tools": [
{
"type": "function",
"function": {
"name": "generate_plan",
"description": "Generate a plan to execute the task using only the tools outlined in your context.",
"input_schema": {
"type": "object",
"properties": {
"steps": {
"type": "array",
"items": {
"type": "object",
"properties": {
"type": {
"type": "string",
"description": "The type of step to execute",
},
"tool_name": {
"type": "string",
"description": "The name of the tool to use for this step",
},
"tool_input": {
"type": "object",
"description": "The input to pass to the tool. Make sure this complies with the schema for the tool.",
},
"tool_output": {
"type": "object",
"description": "(Optional) The output from the tool if needed for future steps. Make sure this complies with the schema for the tool.",
},
},
"required": ["type"],
},
}
},
},
},
},
{
"type": "function",
"function": {
"name": "generate_wire_tool",
"description": "Create a wire transfer with complete wire instructions",
"input_schema": {
"type": "object",
"properties": {
"company_id": {
"type": "integer",
"description": "The ID of the company receiving the investment",
},
"investment_id": {
"type": "integer",
"description": "The ID of the investment memo",
},
"dollar_amount": {
"type": "number",
"description": "The amount to wire in USD",
},
"wiring_instructions": {
"type": "object",
"description": "Complete bank account and routing information for the wire",
"properties": {
"account_name": {
"type": "string",
"description": "Name on the bank account",
},
"address_1": {
"type": "string",
"description": "Primary address line",
},
"address_2": {
"type": "string",
"description": "Secondary address line (optional)",
},
"city": {"type": "string"},
"state": {"type": "string"},
"zip": {"type": "string"},
"country": {"type": "string", "default": "US"},
"bank_name": {"type": "string"},
"account_number": {"type": "string"},
"routing_number": {"type": "string"},
"account_type": {
"type": "string",
"enum": ["checking", "savings"],
"default": "checking",
},
"swift_code": {
"type": "string",
"description": "Required for international wires",
},
"iban": {
"type": "string",
"description": "Required for some international wires",
},
"bank_city": {"type": "string"},
"bank_state": {"type": "string"},
"bank_country": {"type": "string", "default": "US"},
"bank_to_bank_instructions": {
"type": "string",
"description": "Additional instructions for the bank (optional)",
},
"intermediary_bank_name": {
"type": "string",
"description": "Name of intermediary bank if required (optional)",
},
},
"required": [
"account_name",
"address_1",
"country",
"bank_name",
"account_number",
"routing_number",
"account_type",
"bank_country",
],
},
},
"required": [
"company_id",
"investment_id",
"dollar_amount",
"wiring_instructions",
],
},
},
},
{
"type": "function",
"function": {
"name": "search_companies",
"description": "Search for companies by name or other criteria to get their IDs",
"input_schema": {
"type": "object",
"properties": {
"query": {
"type": "string",
"description": "Name or part of name to search for",
},
"batch": {
"type": "string",
"description": 'Optional batch filter (e.g., "W21", "S22")',
},
"status": {
"type": "string",
"enum": [
"live",
"dead",
"adrift",
"exited",
"went_public",
"all",
],
"description": "Filter by company status",
"default": "live",
},
"limit": {
"type": "integer",
"description": "Maximum number of results to return",
"default": 10,
},
},
"required": ["query"],
},
"output_schema": {
"type": "object",
"properties": {
"status": {
"type": "string",
"description": "Success or error status",
},
"results": {
"type": "array",
"description": "List of companies matching the search criteria",
"items": {
"type": "object",
"properties": {
"id": {
"type": "integer",
"description": "Company ID to use in other API calls",
},
"name": {"type": "string"},
"batch": {"type": "string"},
"status": {"type": "string"},
"valuation": {"type": "string"},
"url": {"type": "string"},
"description": {"type": "string"},
"founders": {"type": "string"},
},
},
},
"results_count": {
"type": "integer",
"description": "Number of companies returned",
},
"total_matches": {
"type": "integer",
"description": "Total number of matches found",
},
},
},
},
},
],
"tool_choice": None,
"max_retries": 0,
"logprobs": None,
"top_logprobs": None,
"extra_headers": None,
"api_version": None,
"parallel_tool_calls": None,
"drop_params": True,
"reasoning_effort": None,
"additional_drop_params": None,
"messages": [
{
"role": "system",
"content": "You are an AI assistant that helps prepare a wire for a pro rata investment.",
},
{"role": "user", "content": [{"type": "text", "text": "hi"}]},
],
"thinking": None,
"kwargs": {},
}
optional_params = get_optional_params(**passed_params)
print(f"optional_params: {optional_params}")
assert "max_tokens_to_sample" not in optional_params
assert optional_params["max_tokens"] == 1024