(Feat) - Add /bedrock/invoke support for all Anthropic models (#8383)

* use anthropic transformation for bedrock/invoke

* use anthropic transforms for bedrock invoke claude

* TestBedrockInvokeClaudeJson

* add AmazonAnthropicClaudeStreamDecoder

* pass bedrock_invoke_provider to make_call

* fix _get_base_bedrock_model

* fix get_bedrock_route

* fix bedrock routing

* fixes for bedrock invoke

* test_all_model_configs

* fix AWSEventStreamDecoder linting

* fix code qa

* test_bedrock_get_base_model

* test_get_model_info_bedrock_models

* test_bedrock_base_model_helper

* test_bedrock_route_detection
This commit is contained in:
Ishaan Jaff 2025-02-07 22:41:11 -08:00 committed by GitHub
parent 1dd3713f1a
commit b242c66a3b
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
15 changed files with 386 additions and 262 deletions

View file

@ -34,6 +34,17 @@ class BaseLLMModelInfo(ABC):
def get_api_base(api_base: Optional[str] = None) -> Optional[str]:
pass
@staticmethod
@abstractmethod
def get_base_model(model: str) -> Optional[str]:
"""
Returns the base model name from the given model name.
Some providers like bedrock - can receive model=`invoke/anthropic.claude-3-opus-20240229-v1:0` or `converse/anthropic.claude-3-opus-20240229-v1:0`
This function will return `anthropic.claude-3-opus-20240229-v1:0`
"""
pass
def _dict_to_response_format_helper(
response_format: dict, ref_template: Optional[str] = None

View file

@ -33,14 +33,7 @@ from litellm.types.llms.openai import (
from litellm.types.utils import ModelResponse, Usage
from litellm.utils import add_dummy_tool, has_tool_call_blocks
from ..common_utils import (
AmazonBedrockGlobalConfig,
BedrockError,
get_bedrock_tool_name,
)
global_config = AmazonBedrockGlobalConfig()
all_global_regions = global_config.get_all_regions()
from ..common_utils import BedrockError, BedrockModelInfo, get_bedrock_tool_name
class AmazonConverseConfig(BaseConfig):
@ -104,7 +97,7 @@ class AmazonConverseConfig(BaseConfig):
]
## Filter out 'cross-region' from model name
base_model = self._get_base_model(model)
base_model = BedrockModelInfo.get_base_model(model)
if (
base_model.startswith("anthropic")
@ -343,7 +336,7 @@ class AmazonConverseConfig(BaseConfig):
return InferenceConfig(**inference_params)
def _handle_top_k_value(self, model: str, inference_params: dict) -> dict:
base_model = self._get_base_model(model)
base_model = BedrockModelInfo.get_base_model(model)
val_top_k = None
if "topK" in inference_params:
@ -352,10 +345,10 @@ class AmazonConverseConfig(BaseConfig):
val_top_k = inference_params.pop("top_k")
if val_top_k:
if (base_model.startswith("anthropic")):
if base_model.startswith("anthropic"):
return {"top_k": val_top_k}
if base_model.startswith("amazon.nova"):
return {'inferenceConfig': {"topK": val_top_k}}
return {"inferenceConfig": {"topK": val_top_k}}
return {}
@ -393,15 +386,25 @@ class AmazonConverseConfig(BaseConfig):
) + ["top_k"]
supported_tool_call_params = ["tools", "tool_choice"]
supported_guardrail_params = ["guardrailConfig"]
total_supported_params = supported_converse_params + supported_tool_call_params + supported_guardrail_params
total_supported_params = (
supported_converse_params
+ supported_tool_call_params
+ supported_guardrail_params
)
inference_params.pop("json_mode", None) # used for handling json_schema
# keep supported params in 'inference_params', and set all model-specific params in 'additional_request_params'
additional_request_params = {k: v for k, v in inference_params.items() if k not in total_supported_params}
inference_params = {k: v for k, v in inference_params.items() if k in total_supported_params}
additional_request_params = {
k: v for k, v in inference_params.items() if k not in total_supported_params
}
inference_params = {
k: v for k, v in inference_params.items() if k in total_supported_params
}
# Only set the topK value in for models that support it
additional_request_params.update(self._handle_top_k_value(model, inference_params))
additional_request_params.update(
self._handle_top_k_value(model, inference_params)
)
bedrock_tools: List[ToolBlock] = _bedrock_tools_pt(
inference_params.pop("tools", [])
@ -679,41 +682,6 @@ class AmazonConverseConfig(BaseConfig):
return model_response
def _supported_cross_region_inference_region(self) -> List[str]:
"""
Abbreviations of regions AWS Bedrock supports for cross region inference
"""
return ["us", "eu", "apac"]
def _get_base_model(self, model: str) -> str:
"""
Get the base model from the given model name.
Handle model names like - "us.meta.llama3-2-11b-instruct-v1:0" -> "meta.llama3-2-11b-instruct-v1"
AND "meta.llama3-2-11b-instruct-v1:0" -> "meta.llama3-2-11b-instruct-v1"
"""
if model.startswith("bedrock/"):
model = model.split("/", 1)[1]
if model.startswith("converse/"):
model = model.split("/", 1)[1]
potential_region = model.split(".", 1)[0]
alt_potential_region = model.split("/", 1)[
0
] # in model cost map we store regional information like `/us-west-2/bedrock-model`
if potential_region in self._supported_cross_region_inference_region():
return model.split(".", 1)[1]
elif (
alt_potential_region in all_global_regions and len(model.split("/", 1)) > 1
):
return model.split("/", 1)[1]
return model
def get_error_class(
self, error_message: str, status_code: int, headers: Union[dict, httpx.Headers]
) -> BaseLLMException:

View file

@ -40,6 +40,9 @@ from litellm.litellm_core_utils.prompt_templates.factory import (
parse_xml_params,
prompt_factory,
)
from litellm.llms.anthropic.chat.handler import (
ModelResponseIterator as AnthropicModelResponseIterator,
)
from litellm.llms.custom_httpx.http_handler import (
AsyncHTTPHandler,
HTTPHandler,
@ -177,6 +180,7 @@ async def make_call(
logging_obj: Logging,
fake_stream: bool = False,
json_mode: Optional[bool] = False,
bedrock_invoke_provider: Optional[litellm.BEDROCK_INVOKE_PROVIDERS_LITERAL] = None,
):
try:
if client is None:
@ -214,6 +218,14 @@ async def make_call(
completion_stream: Any = MockResponseIterator(
model_response=model_response, json_mode=json_mode
)
elif bedrock_invoke_provider == "anthropic":
decoder: AWSEventStreamDecoder = AmazonAnthropicClaudeStreamDecoder(
model=model,
sync_stream=False,
)
completion_stream = decoder.aiter_bytes(
response.aiter_bytes(chunk_size=1024)
)
else:
decoder = AWSEventStreamDecoder(model=model)
completion_stream = decoder.aiter_bytes(
@ -248,6 +260,7 @@ def make_sync_call(
logging_obj: Logging,
fake_stream: bool = False,
json_mode: Optional[bool] = False,
bedrock_invoke_provider: Optional[litellm.BEDROCK_INVOKE_PROVIDERS_LITERAL] = None,
):
try:
if client is None:
@ -283,6 +296,12 @@ def make_sync_call(
completion_stream: Any = MockResponseIterator(
model_response=model_response, json_mode=json_mode
)
elif bedrock_invoke_provider == "anthropic":
decoder: AWSEventStreamDecoder = AmazonAnthropicClaudeStreamDecoder(
model=model,
sync_stream=True,
)
completion_stream = decoder.iter_bytes(response.iter_bytes(chunk_size=1024))
else:
decoder = AWSEventStreamDecoder(model=model)
completion_stream = decoder.iter_bytes(response.iter_bytes(chunk_size=1024))
@ -1323,7 +1342,7 @@ class AWSEventStreamDecoder:
text = chunk_data.get("completions")[0].get("data").get("text") # type: ignore
is_finished = True
finish_reason = "stop"
######## bedrock.anthropic mappings ###############
######## converse bedrock.anthropic mappings ###############
elif (
"contentBlockIndex" in chunk_data
or "stopReason" in chunk_data
@ -1429,6 +1448,27 @@ class AWSEventStreamDecoder:
return chunk.decode() # type: ignore[no-any-return]
class AmazonAnthropicClaudeStreamDecoder(AWSEventStreamDecoder):
def __init__(
self,
model: str,
sync_stream: bool,
) -> None:
"""
Child class of AWSEventStreamDecoder that handles the streaming response from the Anthropic family of models
The only difference between AWSEventStreamDecoder and AmazonAnthropicClaudeStreamDecoder is the `chunk_parser` method
"""
super().__init__(model=model)
self.anthropic_model_response_iterator = AnthropicModelResponseIterator(
streaming_response=None,
sync_stream=sync_stream,
)
def _chunk_parser(self, chunk_data: dict) -> GChunk:
return self.anthropic_model_response_iterator.chunk_parser(chunk=chunk_data)
class MockResponseIterator: # for returning ai21 streaming responses
def __init__(self, model_response, json_mode: Optional[bool] = False):
self.model_response = model_response

View file

@ -1,61 +1,34 @@
import types
from typing import List, Optional
from typing import TYPE_CHECKING, Any, List, Optional
import httpx
import litellm
from litellm.llms.bedrock.chat.invoke_transformations.base_invoke_transformation import (
AmazonInvokeConfig,
)
from litellm.types.llms.openai import AllMessageValues
from litellm.types.utils import ModelResponse
if TYPE_CHECKING:
from litellm.litellm_core_utils.litellm_logging import Logging as _LiteLLMLoggingObj
LiteLLMLoggingObj = _LiteLLMLoggingObj
else:
LiteLLMLoggingObj = Any
class AmazonAnthropicClaude3Config:
class AmazonAnthropicClaude3Config(AmazonInvokeConfig):
"""
Reference:
https://us-west-2.console.aws.amazon.com/bedrock/home?region=us-west-2#/providers?model=claude
https://docs.anthropic.com/claude/docs/models-overview#model-comparison
Supported Params for the Amazon / Anthropic Claude 3 models:
- `max_tokens` Required (integer) max tokens. Default is 4096
- `anthropic_version` Required (string) version of anthropic for bedrock - e.g. "bedrock-2023-05-31"
- `system` Optional (string) the system prompt, conversion from openai format to this is handled in factory.py
- `temperature` Optional (float) The amount of randomness injected into the response
- `top_p` Optional (float) Use nucleus sampling.
- `top_k` Optional (int) Only sample from the top K options for each subsequent token
- `stop_sequences` Optional (List[str]) Custom text sequences that cause the model to stop generating
"""
max_tokens: Optional[int] = 4096 # Opus, Sonnet, and Haiku default
anthropic_version: Optional[str] = "bedrock-2023-05-31"
system: Optional[str] = None
temperature: Optional[float] = None
top_p: Optional[float] = None
top_k: Optional[int] = None
stop_sequences: Optional[List[str]] = None
anthropic_version: str = "bedrock-2023-05-31"
def __init__(
self,
max_tokens: Optional[int] = None,
anthropic_version: Optional[str] = None,
) -> None:
locals_ = locals().copy()
for key, value in locals_.items():
if key != "self" and value is not None:
setattr(self.__class__, key, value)
@classmethod
def get_config(cls):
return {
k: v
for k, v in cls.__dict__.items()
if not k.startswith("__")
and not isinstance(
v,
(
types.FunctionType,
types.BuiltinFunctionType,
classmethod,
staticmethod,
),
)
and v is not None
}
def get_supported_openai_params(self):
def get_supported_openai_params(self, model: str):
return [
"max_tokens",
"max_completion_tokens",
@ -68,7 +41,13 @@ class AmazonAnthropicClaude3Config:
"extra_headers",
]
def map_openai_params(self, non_default_params: dict, optional_params: dict):
def map_openai_params(
self,
non_default_params: dict,
optional_params: dict,
model: str,
drop_params: bool,
):
for param, value in non_default_params.items():
if param == "max_tokens" or param == "max_completion_tokens":
optional_params["max_tokens"] = value
@ -83,3 +62,53 @@ class AmazonAnthropicClaude3Config:
if param == "top_p":
optional_params["top_p"] = value
return optional_params
def transform_request(
self,
model: str,
messages: List[AllMessageValues],
optional_params: dict,
litellm_params: dict,
headers: dict,
) -> dict:
_anthropic_request = litellm.AnthropicConfig().transform_request(
model=model,
messages=messages,
optional_params=optional_params,
litellm_params=litellm_params,
headers=headers,
)
_anthropic_request.pop("model", None)
if "anthropic_version" not in _anthropic_request:
_anthropic_request["anthropic_version"] = self.anthropic_version
return _anthropic_request
def transform_response(
self,
model: str,
raw_response: httpx.Response,
model_response: ModelResponse,
logging_obj: LiteLLMLoggingObj,
request_data: dict,
messages: List[AllMessageValues],
optional_params: dict,
litellm_params: dict,
encoding: Any,
api_key: Optional[str] = None,
json_mode: Optional[bool] = None,
) -> ModelResponse:
return litellm.AnthropicConfig().transform_response(
model=model,
raw_response=raw_response,
model_response=model_response,
logging_obj=logging_obj,
request_data=request_data,
messages=messages,
optional_params=optional_params,
litellm_params=litellm_params,
encoding=encoding,
api_key=api_key,
json_mode=json_mode,
)

View file

@ -2,7 +2,6 @@ import copy
import json
import time
import urllib.parse
import uuid
from functools import partial
from typing import TYPE_CHECKING, Any, List, Optional, Tuple, Union, cast, get_args
@ -13,11 +12,7 @@ from litellm.litellm_core_utils.core_helpers import map_finish_reason
from litellm.litellm_core_utils.logging_utils import track_llm_api_timing
from litellm.litellm_core_utils.prompt_templates.factory import (
cohere_message_pt,
construct_tool_use_system_prompt,
contains_tag,
custom_prompt,
extract_between_tags,
parse_xml_params,
prompt_factory,
)
from litellm.llms.base_llm.chat.transformation import BaseConfig, BaseLLMException
@ -194,7 +189,6 @@ class AmazonInvokeConfig(BaseConfig, BaseAWSLLM):
for k, v in inference_params.items()
if k not in self.aws_authentication_params
}
json_schemas: dict = {}
request_data: dict = {}
if provider == "cohere":
if model.startswith("cohere.command-r"):
@ -223,57 +217,13 @@ class AmazonInvokeConfig(BaseConfig, BaseAWSLLM):
)
request_data = {"prompt": prompt, **inference_params}
elif provider == "anthropic":
if model.startswith("anthropic.claude-3"):
# Separate system prompt from rest of message
system_prompt_idx: list[int] = []
system_messages: list[str] = []
for idx, message in enumerate(messages):
if message["role"] == "system" and isinstance(
message["content"], str
):
system_messages.append(message["content"])
system_prompt_idx.append(idx)
if len(system_prompt_idx) > 0:
inference_params["system"] = "\n".join(system_messages)
messages = [
i for j, i in enumerate(messages) if j not in system_prompt_idx
]
# Format rest of message according to anthropic guidelines
messages = prompt_factory(
model=model, messages=messages, custom_llm_provider="anthropic_xml"
) # type: ignore
## LOAD CONFIG
config = litellm.AmazonAnthropicClaude3Config.get_config()
for k, v in config.items():
if (
k not in inference_params
): # completion(top_k=3) > anthropic_config(top_k=3) <- allows for dynamic variables to be passed in
inference_params[k] = v
## Handle Tool Calling
if "tools" in inference_params:
_is_function_call = True
for tool in inference_params["tools"]:
json_schemas[tool["function"]["name"]] = tool["function"].get(
"parameters", None
return litellm.AmazonAnthropicClaude3Config().transform_request(
model=model,
messages=messages,
optional_params=optional_params,
litellm_params=litellm_params,
headers=headers,
)
tool_calling_system_prompt = construct_tool_use_system_prompt(
tools=inference_params["tools"]
)
inference_params["system"] = (
inference_params.get("system", "\n")
+ tool_calling_system_prompt
) # add the anthropic tool calling prompt to the system prompt
inference_params.pop("tools")
request_data = {"messages": messages, **inference_params}
else:
## LOAD CONFIG
config = litellm.AmazonAnthropicConfig.get_config()
for k, v in config.items():
if (
k not in inference_params
): # completion(top_k=3) > anthropic_config(top_k=3) <- allows for dynamic variables to be passed in
inference_params[k] = v
request_data = {"prompt": prompt, **inference_params}
elif provider == "ai21":
## LOAD CONFIG
config = litellm.AmazonAI21Config.get_config()
@ -359,66 +309,19 @@ class AmazonInvokeConfig(BaseConfig, BaseAWSLLM):
completion_response["generations"][0]["finish_reason"]
)
elif provider == "anthropic":
if model.startswith("anthropic.claude-3"):
json_schemas: dict = {}
_is_function_call = False
## Handle Tool Calling
if "tools" in optional_params:
_is_function_call = True
for tool in optional_params["tools"]:
json_schemas[tool["function"]["name"]] = tool[
"function"
].get("parameters", None)
outputText = completion_response.get("content")[0].get("text", None)
if outputText is not None and contains_tag(
"invoke", outputText
): # OUTPUT PARSE FUNCTION CALL
function_name = extract_between_tags("tool_name", outputText)[0]
function_arguments_str = extract_between_tags(
"invoke", outputText
)[0].strip()
function_arguments_str = (
f"<invoke>{function_arguments_str}</invoke>"
return litellm.AmazonAnthropicClaude3Config().transform_response(
model=model,
raw_response=raw_response,
model_response=model_response,
logging_obj=logging_obj,
request_data=request_data,
messages=messages,
optional_params=optional_params,
litellm_params=litellm_params,
encoding=encoding,
api_key=api_key,
json_mode=json_mode,
)
function_arguments = parse_xml_params(
function_arguments_str,
json_schema=json_schemas.get(
function_name, None
), # check if we have a json schema for this function name)
)
_message = litellm.Message(
tool_calls=[
{
"id": f"call_{uuid.uuid4()}",
"type": "function",
"function": {
"name": function_name,
"arguments": json.dumps(function_arguments),
},
}
],
content=None,
)
model_response.choices[0].message = _message # type: ignore
model_response._hidden_params["original_response"] = (
outputText # allow user to access raw anthropic tool calling response
)
model_response.choices[0].finish_reason = map_finish_reason(
completion_response.get("stop_reason", "")
)
_usage = litellm.Usage(
prompt_tokens=completion_response["usage"]["input_tokens"],
completion_tokens=completion_response["usage"]["output_tokens"],
total_tokens=completion_response["usage"]["input_tokens"]
+ completion_response["usage"]["output_tokens"],
)
setattr(model_response, "usage", _usage)
else:
outputText = completion_response["completion"]
model_response.choices[0].finish_reason = completion_response[
"stop_reason"
]
elif provider == "ai21":
outputText = (
completion_response.get("completions")[0].get("data").get("text")
@ -536,6 +439,7 @@ class AmazonInvokeConfig(BaseConfig, BaseAWSLLM):
messages=messages,
logging_obj=logging_obj,
fake_stream=True if "ai21" in api_base else False,
bedrock_invoke_provider=self.get_bedrock_invoke_provider(model),
),
model=model,
custom_llm_provider="bedrock",
@ -569,6 +473,7 @@ class AmazonInvokeConfig(BaseConfig, BaseAWSLLM):
messages=messages,
logging_obj=logging_obj,
fake_stream=True if "ai21" in api_base else False,
bedrock_invoke_provider=self.get_bedrock_invoke_provider(model),
),
model=model,
custom_llm_provider="bedrock",
@ -594,10 +499,14 @@ class AmazonInvokeConfig(BaseConfig, BaseAWSLLM):
"""
Helper function to get the bedrock provider from the model
handles 2 scenarions:
1. model=anthropic.claude-3-5-sonnet-20240620-v1:0 -> Returns `anthropic`
2. model=llama/arn:aws:bedrock:us-east-1:086734376398:imported-model/r4c4kewx2s0n -> Returns `llama`
handles 3 scenarions:
1. model=invoke/anthropic.claude-3-5-sonnet-20240620-v1:0 -> Returns `anthropic`
2. model=anthropic.claude-3-5-sonnet-20240620-v1:0 -> Returns `anthropic`
3. model=llama/arn:aws:bedrock:us-east-1:086734376398:imported-model/r4c4kewx2s0n -> Returns `llama`
"""
if model.startswith("invoke/"):
model = model.replace("invoke/", "", 1)
_split_model = model.split(".")[0]
if _split_model in get_args(litellm.BEDROCK_INVOKE_PROVIDERS_LITERAL):
return cast(litellm.BEDROCK_INVOKE_PROVIDERS_LITERAL, _split_model)
@ -640,9 +549,9 @@ class AmazonInvokeConfig(BaseConfig, BaseAWSLLM):
else:
modelId = model
modelId = modelId.replace("invoke/", "", 1)
if provider == "llama" and "llama/" in modelId:
modelId = self._get_model_id_for_llama_like_model(modelId)
return modelId
def _get_aws_region_name(self, optional_params: dict) -> str:

View file

@ -3,11 +3,12 @@ Common utilities used across bedrock chat/embedding/image generation
"""
import os
from typing import List, Optional, Union
from typing import List, Literal, Optional, Union
import httpx
import litellm
from litellm.llms.base_llm.base_utils import BaseLLMModelInfo
from litellm.llms.base_llm.chat.transformation import BaseLLMException
from litellm.secret_managers.main import get_secret
@ -310,3 +311,68 @@ def get_bedrock_tool_name(response_tool_name: str) -> str:
response_tool_name
]
return response_tool_name
class BedrockModelInfo(BaseLLMModelInfo):
global_config = AmazonBedrockGlobalConfig()
all_global_regions = global_config.get_all_regions()
@staticmethod
def get_base_model(model: str) -> str:
"""
Get the base model from the given model name.
Handle model names like - "us.meta.llama3-2-11b-instruct-v1:0" -> "meta.llama3-2-11b-instruct-v1"
AND "meta.llama3-2-11b-instruct-v1:0" -> "meta.llama3-2-11b-instruct-v1"
"""
if model.startswith("bedrock/"):
model = model.split("/", 1)[1]
if model.startswith("converse/"):
model = model.split("/", 1)[1]
if model.startswith("invoke/"):
model = model.split("/", 1)[1]
potential_region = model.split(".", 1)[0]
alt_potential_region = model.split("/", 1)[
0
] # in model cost map we store regional information like `/us-west-2/bedrock-model`
if (
potential_region
in BedrockModelInfo._supported_cross_region_inference_region()
):
return model.split(".", 1)[1]
elif (
alt_potential_region in BedrockModelInfo.all_global_regions
and len(model.split("/", 1)) > 1
):
return model.split("/", 1)[1]
return model
@staticmethod
def _supported_cross_region_inference_region() -> List[str]:
"""
Abbreviations of regions AWS Bedrock supports for cross region inference
"""
return ["us", "eu", "apac"]
@staticmethod
def get_bedrock_route(model: str) -> Literal["converse", "invoke", "converse_like"]:
"""
Get the bedrock route for the given model.
"""
base_model = BedrockModelInfo.get_base_model(model)
if "invoke/" in model:
return "invoke"
elif "converse_like" in model:
return "converse_like"
elif "converse/" in model:
return "converse"
elif base_model in litellm.bedrock_converse_models:
return "converse"
return "invoke"

View file

@ -344,6 +344,10 @@ class OpenAIGPTConfig(BaseLLMModelInfo, BaseConfig):
or "https://api.openai.com/v1"
)
@staticmethod
def get_base_model(model: str) -> str:
return model
def get_model_response_iterator(
self,
streaming_response: Union[Iterator[str], AsyncIterator[str], ModelResponse],

View file

@ -29,3 +29,7 @@ class TopazModelInfo(BaseLLMModelInfo):
return (
api_base or get_secret_str("TOPAZ_API_BASE") or "https://api.topazlabs.com"
)
@staticmethod
def get_base_model(model: str) -> str:
return model

View file

@ -68,6 +68,7 @@ from litellm.litellm_core_utils.prompt_templates.common_utils import (
get_content_from_model_response,
)
from litellm.llms.base_llm.chat.transformation import BaseConfig
from litellm.llms.bedrock.common_utils import BedrockModelInfo
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, HTTPHandler
from litellm.realtime_api.main import _realtime_health_check
from litellm.secret_managers.main import get_secret_str
@ -2628,11 +2629,8 @@ def completion( # type: ignore # noqa: PLR0915
aws_bedrock_client.meta.region_name
)
base_model = litellm.AmazonConverseConfig()._get_base_model(model)
if base_model in litellm.bedrock_converse_models or model.startswith(
"converse/"
):
bedrock_route = BedrockModelInfo.get_bedrock_route(model)
if bedrock_route == "converse":
model = model.replace("converse/", "")
response = bedrock_converse_chat_completion.completion(
model=model,
@ -2651,7 +2649,7 @@ def completion( # type: ignore # noqa: PLR0915
client=client,
api_base=api_base,
)
elif "converse_like" in model:
elif bedrock_route == "converse_like":
model = model.replace("converse_like/", "")
response = base_llm_http_handler.completion(
model=model,

View file

@ -86,10 +86,10 @@ from litellm.litellm_core_utils.llm_request_utils import _ensure_extra_body_is_s
from litellm.litellm_core_utils.llm_response_utils.convert_dict_to_response import (
LiteLLMResponseObjectHandler,
_handle_invalid_parallel_tool_calls,
_parse_content_for_reasoning,
convert_to_model_response_object,
convert_to_streaming_response,
convert_to_streaming_response_async,
_parse_content_for_reasoning,
)
from litellm.litellm_core_utils.llm_response_utils.get_api_base import get_api_base
from litellm.litellm_core_utils.llm_response_utils.get_formatted_prompt import (
@ -111,6 +111,7 @@ from litellm.litellm_core_utils.token_counter import (
calculate_img_tokens,
get_modified_max_tokens,
)
from litellm.llms.bedrock.common_utils import BedrockModelInfo
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, HTTPHandler
from litellm.router_utils.get_retry_from_policy import (
get_num_retries_from_retry_policy,
@ -3189,8 +3190,8 @@ def get_optional_params( # noqa: PLR0915
),
)
elif custom_llm_provider == "bedrock":
base_model = litellm.AmazonConverseConfig()._get_base_model(model)
if base_model in litellm.bedrock_converse_models:
bedrock_route = BedrockModelInfo.get_bedrock_route(model)
if bedrock_route == "converse" or bedrock_route == "converse_like":
optional_params = litellm.AmazonConverseConfig().map_openai_params(
model=model,
non_default_params=non_default_params,
@ -3203,13 +3204,18 @@ def get_optional_params( # noqa: PLR0915
messages=messages,
)
elif "anthropic" in model:
if "aws_bedrock_client" in passed_params: # deprecated boto3.invoke route.
elif "anthropic" in model and bedrock_route == "invoke":
if model.startswith("anthropic.claude-3"):
optional_params = (
litellm.AmazonAnthropicClaude3Config().map_openai_params(
non_default_params=non_default_params,
optional_params=optional_params,
model=model,
drop_params=(
drop_params
if drop_params is not None and isinstance(drop_params, bool)
else False
),
)
)
else:
@ -3972,8 +3978,16 @@ def _strip_stable_vertex_version(model_name) -> str:
return re.sub(r"-\d+$", "", model_name)
def _strip_bedrock_region(model_name) -> str:
return litellm.AmazonConverseConfig()._get_base_model(model_name)
def _get_base_bedrock_model(model_name) -> str:
"""
Get the base model from the given model name.
Handle model names like - "us.meta.llama3-2-11b-instruct-v1:0" -> "meta.llama3-2-11b-instruct-v1"
AND "meta.llama3-2-11b-instruct-v1:0" -> "meta.llama3-2-11b-instruct-v1"
"""
from litellm.llms.bedrock.common_utils import BedrockModelInfo
return BedrockModelInfo.get_base_model(model_name)
def _strip_openai_finetune_model_name(model_name: str) -> str:
@ -3994,8 +4008,8 @@ def _strip_openai_finetune_model_name(model_name: str) -> str:
def _strip_model_name(model: str, custom_llm_provider: Optional[str]) -> str:
if custom_llm_provider and custom_llm_provider == "bedrock":
strip_bedrock_region = _strip_bedrock_region(model_name=model)
return strip_bedrock_region
stripped_bedrock_model = _get_base_bedrock_model(model_name=model)
return stripped_bedrock_model
elif custom_llm_provider and (
custom_llm_provider == "vertex_ai" or custom_llm_provider == "gemini"
):
@ -6066,24 +6080,23 @@ class ProviderConfigManager:
elif litellm.LlmProviders.PETALS == provider:
return litellm.PetalsConfig()
elif litellm.LlmProviders.BEDROCK == provider:
base_model = litellm.AmazonConverseConfig()._get_base_model(model)
bedrock_provider = litellm.BedrockLLM.get_bedrock_invoke_provider(model)
if (
base_model in litellm.bedrock_converse_models
or "converse_like" in model
):
bedrock_route = BedrockModelInfo.get_bedrock_route(model)
bedrock_invoke_provider = litellm.BedrockLLM.get_bedrock_invoke_provider(
model
)
if bedrock_route == "converse" or bedrock_route == "converse_like":
return litellm.AmazonConverseConfig()
elif bedrock_provider == "amazon": # amazon titan llms
elif bedrock_invoke_provider == "amazon": # amazon titan llms
return litellm.AmazonTitanConfig()
elif (
bedrock_provider == "meta" or bedrock_provider == "llama"
bedrock_invoke_provider == "meta" or bedrock_invoke_provider == "llama"
): # amazon / meta llms
return litellm.AmazonLlamaConfig()
elif bedrock_provider == "ai21": # ai21 llms
elif bedrock_invoke_provider == "ai21": # ai21 llms
return litellm.AmazonAI21Config()
elif bedrock_provider == "cohere": # cohere models on bedrock
elif bedrock_invoke_provider == "cohere": # cohere models on bedrock
return litellm.AmazonCohereConfig()
elif bedrock_provider == "mistral": # mistral models on bedrock
elif bedrock_invoke_provider == "mistral": # mistral models on bedrock
return litellm.AmazonMistralConfig()
else:
return litellm.AmazonInvokeConfig()

View file

@ -1265,7 +1265,9 @@ def test_bedrock_cross_region_inference(model):
],
)
def test_bedrock_get_base_model(model, expected_base_model):
assert litellm.AmazonConverseConfig()._get_base_model(model) == expected_base_model
from litellm.llms.bedrock.common_utils import BedrockModelInfo
assert BedrockModelInfo.get_base_model(model) == expected_base_model
from litellm.litellm_core_utils.prompt_templates.factory import (
@ -1982,9 +1984,49 @@ def test_bedrock_mapped_converse_models():
def test_bedrock_base_model_helper():
from litellm.llms.bedrock.common_utils import BedrockModelInfo
model = "us.amazon.nova-pro-v1:0"
litellm.AmazonConverseConfig()._get_base_model(model)
assert model == "us.amazon.nova-pro-v1:0"
base_model = BedrockModelInfo.get_base_model(model)
assert base_model == "amazon.nova-pro-v1:0"
assert (
BedrockModelInfo.get_base_model(
"invoke/anthropic.claude-3-5-sonnet-20241022-v2:0"
)
== "anthropic.claude-3-5-sonnet-20241022-v2:0"
)
@pytest.mark.parametrize(
"model,expected_route",
[
# Test explicit route prefixes
("invoke/anthropic.claude-3-sonnet-20240229-v1:0", "invoke"),
("converse/anthropic.claude-3-sonnet-20240229-v1:0", "converse"),
("converse_like/anthropic.claude-3-sonnet-20240229-v1:0", "converse_like"),
# Test models in BEDROCK_CONVERSE_MODELS list
("anthropic.claude-3-5-haiku-20241022-v1:0", "converse"),
("anthropic.claude-v2", "converse"),
("meta.llama3-70b-instruct-v1:0", "converse"),
("mistral.mistral-large-2407-v1:0", "converse"),
# Test models with region prefixes
("us.anthropic.claude-3-sonnet-20240229-v1:0", "converse"),
("us.meta.llama3-70b-instruct-v1:0", "converse"),
# Test default case (should return "invoke")
("amazon.titan-text-express-v1", "invoke"),
("cohere.command-text-v14", "invoke"),
("cohere.command-r-v1:0", "invoke"),
],
)
def test_bedrock_route_detection(model, expected_route):
"""Test all scenarios for BedrockModelInfo.get_bedrock_route"""
from litellm.llms.bedrock.common_utils import BedrockModelInfo
route = BedrockModelInfo.get_bedrock_route(model)
assert (
route == expected_route
), f"Expected route '{expected_route}' for model '{model}', but got '{route}'"
@pytest.mark.parametrize(

View file

@ -0,0 +1,28 @@
from base_llm_unit_tests import BaseLLMChatTest
import pytest
import sys
import os
sys.path.insert(
0, os.path.abspath("../..")
) # Adds the parent directory to the system path
import litellm
class TestBedrockInvokeClaudeJson(BaseLLMChatTest):
def get_base_completion_call_args(self) -> dict:
litellm._turn_on_debug()
return {
"model": "bedrock/invoke/anthropic.claude-3-5-sonnet-20240620-v1:0",
}
def test_tool_call_no_arguments(self, tool_call_no_arguments):
"""Test that tool calls with no arguments is translated correctly. Relevant issue: https://github.com/BerriAI/litellm/issues/6833"""
pass
@pytest.fixture(autouse=True)
def skip_non_json_tests(self, request):
if not "json" in request.function.__name__.lower():
pytest.skip(
f"Skipping non-JSON test: {request.function.__name__} does not contain 'json'"
)

View file

@ -305,12 +305,16 @@ def test_all_model_configs():
assert (
"max_completion_tokens"
in AmazonAnthropicClaude3Config().get_supported_openai_params()
in AmazonAnthropicClaude3Config().get_supported_openai_params(
model="anthropic.claude-3-sonnet-20240229-v1:0"
)
)
assert AmazonAnthropicClaude3Config().map_openai_params(
non_default_params={"max_completion_tokens": 10},
optional_params={},
model="anthropic.claude-3-sonnet-20240229-v1:0",
drop_params=False,
) == {"max_tokens": 10}
assert (

View file

@ -208,3 +208,11 @@ def test_nova_bedrock_converse():
)
assert custom_llm_provider == "bedrock"
assert model == "amazon.nova-micro-v1:0"
def test_bedrock_invoke_anthropic():
model, custom_llm_provider, dynamic_api_key, api_base = litellm.get_llm_provider(
model="bedrock/invoke/anthropic.claude-3-5-sonnet-20240620-v1:0",
)
assert custom_llm_provider == "bedrock"
assert model == "invoke/anthropic.claude-3-5-sonnet-20240620-v1:0"

View file

@ -321,7 +321,7 @@ def test_get_model_info_bedrock_models():
"""
Check for drift in base model info for bedrock models and regional model info for bedrock models.
"""
from litellm import AmazonConverseConfig
from litellm.llms.bedrock.common_utils import BedrockModelInfo
os.environ["LITELLM_LOCAL_MODEL_COST_MAP"] = "True"
litellm.model_cost = litellm.get_model_cost_map(url="")
@ -337,7 +337,7 @@ def test_get_model_info_bedrock_models():
if any(commitment in k for commitment in potential_commitments):
for commitment in potential_commitments:
k = k.replace(f"{commitment}/", "")
base_model = AmazonConverseConfig()._get_base_model(k)
base_model = BedrockModelInfo.get_base_model(k)
base_model_info = litellm.model_cost[base_model]
for base_model_key, base_model_value in base_model_info.items():
if base_model_key.startswith("supports_"):