(Feat) - Add /bedrock/invoke support for all Anthropic models (#8383)

* use anthropic transformation for bedrock/invoke

* use anthropic transforms for bedrock invoke claude

* TestBedrockInvokeClaudeJson

* add AmazonAnthropicClaudeStreamDecoder

* pass bedrock_invoke_provider to make_call

* fix _get_base_bedrock_model

* fix get_bedrock_route

* fix bedrock routing

* fixes for bedrock invoke

* test_all_model_configs

* fix AWSEventStreamDecoder linting

* fix code qa

* test_bedrock_get_base_model

* test_get_model_info_bedrock_models

* test_bedrock_base_model_helper

* test_bedrock_route_detection
This commit is contained in:
Ishaan Jaff 2025-02-07 22:41:11 -08:00 committed by GitHub
parent 1dd3713f1a
commit b242c66a3b
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
15 changed files with 386 additions and 262 deletions

View file

@ -86,10 +86,10 @@ from litellm.litellm_core_utils.llm_request_utils import _ensure_extra_body_is_s
from litellm.litellm_core_utils.llm_response_utils.convert_dict_to_response import (
LiteLLMResponseObjectHandler,
_handle_invalid_parallel_tool_calls,
_parse_content_for_reasoning,
convert_to_model_response_object,
convert_to_streaming_response,
convert_to_streaming_response_async,
_parse_content_for_reasoning,
)
from litellm.litellm_core_utils.llm_response_utils.get_api_base import get_api_base
from litellm.litellm_core_utils.llm_response_utils.get_formatted_prompt import (
@ -111,6 +111,7 @@ from litellm.litellm_core_utils.token_counter import (
calculate_img_tokens,
get_modified_max_tokens,
)
from litellm.llms.bedrock.common_utils import BedrockModelInfo
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, HTTPHandler
from litellm.router_utils.get_retry_from_policy import (
get_num_retries_from_retry_policy,
@ -3189,8 +3190,8 @@ def get_optional_params( # noqa: PLR0915
),
)
elif custom_llm_provider == "bedrock":
base_model = litellm.AmazonConverseConfig()._get_base_model(model)
if base_model in litellm.bedrock_converse_models:
bedrock_route = BedrockModelInfo.get_bedrock_route(model)
if bedrock_route == "converse" or bedrock_route == "converse_like":
optional_params = litellm.AmazonConverseConfig().map_openai_params(
model=model,
non_default_params=non_default_params,
@ -3203,15 +3204,20 @@ def get_optional_params( # noqa: PLR0915
messages=messages,
)
elif "anthropic" in model:
if "aws_bedrock_client" in passed_params: # deprecated boto3.invoke route.
if model.startswith("anthropic.claude-3"):
optional_params = (
litellm.AmazonAnthropicClaude3Config().map_openai_params(
non_default_params=non_default_params,
optional_params=optional_params,
)
elif "anthropic" in model and bedrock_route == "invoke":
if model.startswith("anthropic.claude-3"):
optional_params = (
litellm.AmazonAnthropicClaude3Config().map_openai_params(
non_default_params=non_default_params,
optional_params=optional_params,
model=model,
drop_params=(
drop_params
if drop_params is not None and isinstance(drop_params, bool)
else False
),
)
)
else:
optional_params = litellm.AmazonAnthropicConfig().map_openai_params(
non_default_params=non_default_params,
@ -3972,8 +3978,16 @@ def _strip_stable_vertex_version(model_name) -> str:
return re.sub(r"-\d+$", "", model_name)
def _strip_bedrock_region(model_name) -> str:
return litellm.AmazonConverseConfig()._get_base_model(model_name)
def _get_base_bedrock_model(model_name) -> str:
"""
Get the base model from the given model name.
Handle model names like - "us.meta.llama3-2-11b-instruct-v1:0" -> "meta.llama3-2-11b-instruct-v1"
AND "meta.llama3-2-11b-instruct-v1:0" -> "meta.llama3-2-11b-instruct-v1"
"""
from litellm.llms.bedrock.common_utils import BedrockModelInfo
return BedrockModelInfo.get_base_model(model_name)
def _strip_openai_finetune_model_name(model_name: str) -> str:
@ -3994,8 +4008,8 @@ def _strip_openai_finetune_model_name(model_name: str) -> str:
def _strip_model_name(model: str, custom_llm_provider: Optional[str]) -> str:
if custom_llm_provider and custom_llm_provider == "bedrock":
strip_bedrock_region = _strip_bedrock_region(model_name=model)
return strip_bedrock_region
stripped_bedrock_model = _get_base_bedrock_model(model_name=model)
return stripped_bedrock_model
elif custom_llm_provider and (
custom_llm_provider == "vertex_ai" or custom_llm_provider == "gemini"
):
@ -6066,24 +6080,23 @@ class ProviderConfigManager:
elif litellm.LlmProviders.PETALS == provider:
return litellm.PetalsConfig()
elif litellm.LlmProviders.BEDROCK == provider:
base_model = litellm.AmazonConverseConfig()._get_base_model(model)
bedrock_provider = litellm.BedrockLLM.get_bedrock_invoke_provider(model)
if (
base_model in litellm.bedrock_converse_models
or "converse_like" in model
):
bedrock_route = BedrockModelInfo.get_bedrock_route(model)
bedrock_invoke_provider = litellm.BedrockLLM.get_bedrock_invoke_provider(
model
)
if bedrock_route == "converse" or bedrock_route == "converse_like":
return litellm.AmazonConverseConfig()
elif bedrock_provider == "amazon": # amazon titan llms
elif bedrock_invoke_provider == "amazon": # amazon titan llms
return litellm.AmazonTitanConfig()
elif (
bedrock_provider == "meta" or bedrock_provider == "llama"
bedrock_invoke_provider == "meta" or bedrock_invoke_provider == "llama"
): # amazon / meta llms
return litellm.AmazonLlamaConfig()
elif bedrock_provider == "ai21": # ai21 llms
elif bedrock_invoke_provider == "ai21": # ai21 llms
return litellm.AmazonAI21Config()
elif bedrock_provider == "cohere": # cohere models on bedrock
elif bedrock_invoke_provider == "cohere": # cohere models on bedrock
return litellm.AmazonCohereConfig()
elif bedrock_provider == "mistral": # mistral models on bedrock
elif bedrock_invoke_provider == "mistral": # mistral models on bedrock
return litellm.AmazonMistralConfig()
else:
return litellm.AmazonInvokeConfig()