(Feat) - Add /bedrock/invoke support for all Anthropic models (#8383)

* use anthropic transformation for bedrock/invoke

* use anthropic transforms for bedrock invoke claude

* TestBedrockInvokeClaudeJson

* add AmazonAnthropicClaudeStreamDecoder

* pass bedrock_invoke_provider to make_call

* fix _get_base_bedrock_model

* fix get_bedrock_route

* fix bedrock routing

* fixes for bedrock invoke

* test_all_model_configs

* fix AWSEventStreamDecoder linting

* fix code qa

* test_bedrock_get_base_model

* test_get_model_info_bedrock_models

* test_bedrock_base_model_helper

* test_bedrock_route_detection
This commit is contained in:
Ishaan Jaff 2025-02-07 22:41:11 -08:00 committed by GitHub
parent 1dd3713f1a
commit b242c66a3b
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
15 changed files with 386 additions and 262 deletions

View file

@ -34,6 +34,17 @@ class BaseLLMModelInfo(ABC):
def get_api_base(api_base: Optional[str] = None) -> Optional[str]: def get_api_base(api_base: Optional[str] = None) -> Optional[str]:
pass pass
@staticmethod
@abstractmethod
def get_base_model(model: str) -> Optional[str]:
"""
Returns the base model name from the given model name.
Some providers like bedrock - can receive model=`invoke/anthropic.claude-3-opus-20240229-v1:0` or `converse/anthropic.claude-3-opus-20240229-v1:0`
This function will return `anthropic.claude-3-opus-20240229-v1:0`
"""
pass
def _dict_to_response_format_helper( def _dict_to_response_format_helper(
response_format: dict, ref_template: Optional[str] = None response_format: dict, ref_template: Optional[str] = None

View file

@ -33,14 +33,7 @@ from litellm.types.llms.openai import (
from litellm.types.utils import ModelResponse, Usage from litellm.types.utils import ModelResponse, Usage
from litellm.utils import add_dummy_tool, has_tool_call_blocks from litellm.utils import add_dummy_tool, has_tool_call_blocks
from ..common_utils import ( from ..common_utils import BedrockError, BedrockModelInfo, get_bedrock_tool_name
AmazonBedrockGlobalConfig,
BedrockError,
get_bedrock_tool_name,
)
global_config = AmazonBedrockGlobalConfig()
all_global_regions = global_config.get_all_regions()
class AmazonConverseConfig(BaseConfig): class AmazonConverseConfig(BaseConfig):
@ -104,7 +97,7 @@ class AmazonConverseConfig(BaseConfig):
] ]
## Filter out 'cross-region' from model name ## Filter out 'cross-region' from model name
base_model = self._get_base_model(model) base_model = BedrockModelInfo.get_base_model(model)
if ( if (
base_model.startswith("anthropic") base_model.startswith("anthropic")
@ -341,9 +334,9 @@ class AmazonConverseConfig(BaseConfig):
if "top_k" in inference_params: if "top_k" in inference_params:
inference_params["topK"] = inference_params.pop("top_k") inference_params["topK"] = inference_params.pop("top_k")
return InferenceConfig(**inference_params) return InferenceConfig(**inference_params)
def _handle_top_k_value(self, model: str, inference_params: dict) -> dict: def _handle_top_k_value(self, model: str, inference_params: dict) -> dict:
base_model = self._get_base_model(model) base_model = BedrockModelInfo.get_base_model(model)
val_top_k = None val_top_k = None
if "topK" in inference_params: if "topK" in inference_params:
@ -352,11 +345,11 @@ class AmazonConverseConfig(BaseConfig):
val_top_k = inference_params.pop("top_k") val_top_k = inference_params.pop("top_k")
if val_top_k: if val_top_k:
if (base_model.startswith("anthropic")): if base_model.startswith("anthropic"):
return {"top_k": val_top_k} return {"top_k": val_top_k}
if base_model.startswith("amazon.nova"): if base_model.startswith("amazon.nova"):
return {'inferenceConfig': {"topK": val_top_k}} return {"inferenceConfig": {"topK": val_top_k}}
return {} return {}
def _transform_request_helper( def _transform_request_helper(
@ -393,15 +386,25 @@ class AmazonConverseConfig(BaseConfig):
) + ["top_k"] ) + ["top_k"]
supported_tool_call_params = ["tools", "tool_choice"] supported_tool_call_params = ["tools", "tool_choice"]
supported_guardrail_params = ["guardrailConfig"] supported_guardrail_params = ["guardrailConfig"]
total_supported_params = supported_converse_params + supported_tool_call_params + supported_guardrail_params total_supported_params = (
supported_converse_params
+ supported_tool_call_params
+ supported_guardrail_params
)
inference_params.pop("json_mode", None) # used for handling json_schema inference_params.pop("json_mode", None) # used for handling json_schema
# keep supported params in 'inference_params', and set all model-specific params in 'additional_request_params' # keep supported params in 'inference_params', and set all model-specific params in 'additional_request_params'
additional_request_params = {k: v for k, v in inference_params.items() if k not in total_supported_params} additional_request_params = {
inference_params = {k: v for k, v in inference_params.items() if k in total_supported_params} k: v for k, v in inference_params.items() if k not in total_supported_params
}
inference_params = {
k: v for k, v in inference_params.items() if k in total_supported_params
}
# Only set the topK value in for models that support it # Only set the topK value in for models that support it
additional_request_params.update(self._handle_top_k_value(model, inference_params)) additional_request_params.update(
self._handle_top_k_value(model, inference_params)
)
bedrock_tools: List[ToolBlock] = _bedrock_tools_pt( bedrock_tools: List[ToolBlock] = _bedrock_tools_pt(
inference_params.pop("tools", []) inference_params.pop("tools", [])
@ -679,41 +682,6 @@ class AmazonConverseConfig(BaseConfig):
return model_response return model_response
def _supported_cross_region_inference_region(self) -> List[str]:
"""
Abbreviations of regions AWS Bedrock supports for cross region inference
"""
return ["us", "eu", "apac"]
def _get_base_model(self, model: str) -> str:
"""
Get the base model from the given model name.
Handle model names like - "us.meta.llama3-2-11b-instruct-v1:0" -> "meta.llama3-2-11b-instruct-v1"
AND "meta.llama3-2-11b-instruct-v1:0" -> "meta.llama3-2-11b-instruct-v1"
"""
if model.startswith("bedrock/"):
model = model.split("/", 1)[1]
if model.startswith("converse/"):
model = model.split("/", 1)[1]
potential_region = model.split(".", 1)[0]
alt_potential_region = model.split("/", 1)[
0
] # in model cost map we store regional information like `/us-west-2/bedrock-model`
if potential_region in self._supported_cross_region_inference_region():
return model.split(".", 1)[1]
elif (
alt_potential_region in all_global_regions and len(model.split("/", 1)) > 1
):
return model.split("/", 1)[1]
return model
def get_error_class( def get_error_class(
self, error_message: str, status_code: int, headers: Union[dict, httpx.Headers] self, error_message: str, status_code: int, headers: Union[dict, httpx.Headers]
) -> BaseLLMException: ) -> BaseLLMException:

View file

@ -40,6 +40,9 @@ from litellm.litellm_core_utils.prompt_templates.factory import (
parse_xml_params, parse_xml_params,
prompt_factory, prompt_factory,
) )
from litellm.llms.anthropic.chat.handler import (
ModelResponseIterator as AnthropicModelResponseIterator,
)
from litellm.llms.custom_httpx.http_handler import ( from litellm.llms.custom_httpx.http_handler import (
AsyncHTTPHandler, AsyncHTTPHandler,
HTTPHandler, HTTPHandler,
@ -177,6 +180,7 @@ async def make_call(
logging_obj: Logging, logging_obj: Logging,
fake_stream: bool = False, fake_stream: bool = False,
json_mode: Optional[bool] = False, json_mode: Optional[bool] = False,
bedrock_invoke_provider: Optional[litellm.BEDROCK_INVOKE_PROVIDERS_LITERAL] = None,
): ):
try: try:
if client is None: if client is None:
@ -214,6 +218,14 @@ async def make_call(
completion_stream: Any = MockResponseIterator( completion_stream: Any = MockResponseIterator(
model_response=model_response, json_mode=json_mode model_response=model_response, json_mode=json_mode
) )
elif bedrock_invoke_provider == "anthropic":
decoder: AWSEventStreamDecoder = AmazonAnthropicClaudeStreamDecoder(
model=model,
sync_stream=False,
)
completion_stream = decoder.aiter_bytes(
response.aiter_bytes(chunk_size=1024)
)
else: else:
decoder = AWSEventStreamDecoder(model=model) decoder = AWSEventStreamDecoder(model=model)
completion_stream = decoder.aiter_bytes( completion_stream = decoder.aiter_bytes(
@ -248,6 +260,7 @@ def make_sync_call(
logging_obj: Logging, logging_obj: Logging,
fake_stream: bool = False, fake_stream: bool = False,
json_mode: Optional[bool] = False, json_mode: Optional[bool] = False,
bedrock_invoke_provider: Optional[litellm.BEDROCK_INVOKE_PROVIDERS_LITERAL] = None,
): ):
try: try:
if client is None: if client is None:
@ -283,6 +296,12 @@ def make_sync_call(
completion_stream: Any = MockResponseIterator( completion_stream: Any = MockResponseIterator(
model_response=model_response, json_mode=json_mode model_response=model_response, json_mode=json_mode
) )
elif bedrock_invoke_provider == "anthropic":
decoder: AWSEventStreamDecoder = AmazonAnthropicClaudeStreamDecoder(
model=model,
sync_stream=True,
)
completion_stream = decoder.iter_bytes(response.iter_bytes(chunk_size=1024))
else: else:
decoder = AWSEventStreamDecoder(model=model) decoder = AWSEventStreamDecoder(model=model)
completion_stream = decoder.iter_bytes(response.iter_bytes(chunk_size=1024)) completion_stream = decoder.iter_bytes(response.iter_bytes(chunk_size=1024))
@ -1323,7 +1342,7 @@ class AWSEventStreamDecoder:
text = chunk_data.get("completions")[0].get("data").get("text") # type: ignore text = chunk_data.get("completions")[0].get("data").get("text") # type: ignore
is_finished = True is_finished = True
finish_reason = "stop" finish_reason = "stop"
######## bedrock.anthropic mappings ############### ######## converse bedrock.anthropic mappings ###############
elif ( elif (
"contentBlockIndex" in chunk_data "contentBlockIndex" in chunk_data
or "stopReason" in chunk_data or "stopReason" in chunk_data
@ -1429,6 +1448,27 @@ class AWSEventStreamDecoder:
return chunk.decode() # type: ignore[no-any-return] return chunk.decode() # type: ignore[no-any-return]
class AmazonAnthropicClaudeStreamDecoder(AWSEventStreamDecoder):
def __init__(
self,
model: str,
sync_stream: bool,
) -> None:
"""
Child class of AWSEventStreamDecoder that handles the streaming response from the Anthropic family of models
The only difference between AWSEventStreamDecoder and AmazonAnthropicClaudeStreamDecoder is the `chunk_parser` method
"""
super().__init__(model=model)
self.anthropic_model_response_iterator = AnthropicModelResponseIterator(
streaming_response=None,
sync_stream=sync_stream,
)
def _chunk_parser(self, chunk_data: dict) -> GChunk:
return self.anthropic_model_response_iterator.chunk_parser(chunk=chunk_data)
class MockResponseIterator: # for returning ai21 streaming responses class MockResponseIterator: # for returning ai21 streaming responses
def __init__(self, model_response, json_mode: Optional[bool] = False): def __init__(self, model_response, json_mode: Optional[bool] = False):
self.model_response = model_response self.model_response = model_response

View file

@ -1,61 +1,34 @@
import types from typing import TYPE_CHECKING, Any, List, Optional
from typing import List, Optional
import httpx
import litellm
from litellm.llms.bedrock.chat.invoke_transformations.base_invoke_transformation import (
AmazonInvokeConfig,
)
from litellm.types.llms.openai import AllMessageValues
from litellm.types.utils import ModelResponse
if TYPE_CHECKING:
from litellm.litellm_core_utils.litellm_logging import Logging as _LiteLLMLoggingObj
LiteLLMLoggingObj = _LiteLLMLoggingObj
else:
LiteLLMLoggingObj = Any
class AmazonAnthropicClaude3Config: class AmazonAnthropicClaude3Config(AmazonInvokeConfig):
""" """
Reference: Reference:
https://us-west-2.console.aws.amazon.com/bedrock/home?region=us-west-2#/providers?model=claude https://us-west-2.console.aws.amazon.com/bedrock/home?region=us-west-2#/providers?model=claude
https://docs.anthropic.com/claude/docs/models-overview#model-comparison https://docs.anthropic.com/claude/docs/models-overview#model-comparison
Supported Params for the Amazon / Anthropic Claude 3 models: Supported Params for the Amazon / Anthropic Claude 3 models:
- `max_tokens` Required (integer) max tokens. Default is 4096
- `anthropic_version` Required (string) version of anthropic for bedrock - e.g. "bedrock-2023-05-31"
- `system` Optional (string) the system prompt, conversion from openai format to this is handled in factory.py
- `temperature` Optional (float) The amount of randomness injected into the response
- `top_p` Optional (float) Use nucleus sampling.
- `top_k` Optional (int) Only sample from the top K options for each subsequent token
- `stop_sequences` Optional (List[str]) Custom text sequences that cause the model to stop generating
""" """
max_tokens: Optional[int] = 4096 # Opus, Sonnet, and Haiku default anthropic_version: str = "bedrock-2023-05-31"
anthropic_version: Optional[str] = "bedrock-2023-05-31"
system: Optional[str] = None
temperature: Optional[float] = None
top_p: Optional[float] = None
top_k: Optional[int] = None
stop_sequences: Optional[List[str]] = None
def __init__( def get_supported_openai_params(self, model: str):
self,
max_tokens: Optional[int] = None,
anthropic_version: Optional[str] = None,
) -> None:
locals_ = locals().copy()
for key, value in locals_.items():
if key != "self" and value is not None:
setattr(self.__class__, key, value)
@classmethod
def get_config(cls):
return {
k: v
for k, v in cls.__dict__.items()
if not k.startswith("__")
and not isinstance(
v,
(
types.FunctionType,
types.BuiltinFunctionType,
classmethod,
staticmethod,
),
)
and v is not None
}
def get_supported_openai_params(self):
return [ return [
"max_tokens", "max_tokens",
"max_completion_tokens", "max_completion_tokens",
@ -68,7 +41,13 @@ class AmazonAnthropicClaude3Config:
"extra_headers", "extra_headers",
] ]
def map_openai_params(self, non_default_params: dict, optional_params: dict): def map_openai_params(
self,
non_default_params: dict,
optional_params: dict,
model: str,
drop_params: bool,
):
for param, value in non_default_params.items(): for param, value in non_default_params.items():
if param == "max_tokens" or param == "max_completion_tokens": if param == "max_tokens" or param == "max_completion_tokens":
optional_params["max_tokens"] = value optional_params["max_tokens"] = value
@ -83,3 +62,53 @@ class AmazonAnthropicClaude3Config:
if param == "top_p": if param == "top_p":
optional_params["top_p"] = value optional_params["top_p"] = value
return optional_params return optional_params
def transform_request(
self,
model: str,
messages: List[AllMessageValues],
optional_params: dict,
litellm_params: dict,
headers: dict,
) -> dict:
_anthropic_request = litellm.AnthropicConfig().transform_request(
model=model,
messages=messages,
optional_params=optional_params,
litellm_params=litellm_params,
headers=headers,
)
_anthropic_request.pop("model", None)
if "anthropic_version" not in _anthropic_request:
_anthropic_request["anthropic_version"] = self.anthropic_version
return _anthropic_request
def transform_response(
self,
model: str,
raw_response: httpx.Response,
model_response: ModelResponse,
logging_obj: LiteLLMLoggingObj,
request_data: dict,
messages: List[AllMessageValues],
optional_params: dict,
litellm_params: dict,
encoding: Any,
api_key: Optional[str] = None,
json_mode: Optional[bool] = None,
) -> ModelResponse:
return litellm.AnthropicConfig().transform_response(
model=model,
raw_response=raw_response,
model_response=model_response,
logging_obj=logging_obj,
request_data=request_data,
messages=messages,
optional_params=optional_params,
litellm_params=litellm_params,
encoding=encoding,
api_key=api_key,
json_mode=json_mode,
)

View file

@ -2,7 +2,6 @@ import copy
import json import json
import time import time
import urllib.parse import urllib.parse
import uuid
from functools import partial from functools import partial
from typing import TYPE_CHECKING, Any, List, Optional, Tuple, Union, cast, get_args from typing import TYPE_CHECKING, Any, List, Optional, Tuple, Union, cast, get_args
@ -13,11 +12,7 @@ from litellm.litellm_core_utils.core_helpers import map_finish_reason
from litellm.litellm_core_utils.logging_utils import track_llm_api_timing from litellm.litellm_core_utils.logging_utils import track_llm_api_timing
from litellm.litellm_core_utils.prompt_templates.factory import ( from litellm.litellm_core_utils.prompt_templates.factory import (
cohere_message_pt, cohere_message_pt,
construct_tool_use_system_prompt,
contains_tag,
custom_prompt, custom_prompt,
extract_between_tags,
parse_xml_params,
prompt_factory, prompt_factory,
) )
from litellm.llms.base_llm.chat.transformation import BaseConfig, BaseLLMException from litellm.llms.base_llm.chat.transformation import BaseConfig, BaseLLMException
@ -194,7 +189,6 @@ class AmazonInvokeConfig(BaseConfig, BaseAWSLLM):
for k, v in inference_params.items() for k, v in inference_params.items()
if k not in self.aws_authentication_params if k not in self.aws_authentication_params
} }
json_schemas: dict = {}
request_data: dict = {} request_data: dict = {}
if provider == "cohere": if provider == "cohere":
if model.startswith("cohere.command-r"): if model.startswith("cohere.command-r"):
@ -223,57 +217,13 @@ class AmazonInvokeConfig(BaseConfig, BaseAWSLLM):
) )
request_data = {"prompt": prompt, **inference_params} request_data = {"prompt": prompt, **inference_params}
elif provider == "anthropic": elif provider == "anthropic":
if model.startswith("anthropic.claude-3"): return litellm.AmazonAnthropicClaude3Config().transform_request(
# Separate system prompt from rest of message model=model,
system_prompt_idx: list[int] = [] messages=messages,
system_messages: list[str] = [] optional_params=optional_params,
for idx, message in enumerate(messages): litellm_params=litellm_params,
if message["role"] == "system" and isinstance( headers=headers,
message["content"], str )
):
system_messages.append(message["content"])
system_prompt_idx.append(idx)
if len(system_prompt_idx) > 0:
inference_params["system"] = "\n".join(system_messages)
messages = [
i for j, i in enumerate(messages) if j not in system_prompt_idx
]
# Format rest of message according to anthropic guidelines
messages = prompt_factory(
model=model, messages=messages, custom_llm_provider="anthropic_xml"
) # type: ignore
## LOAD CONFIG
config = litellm.AmazonAnthropicClaude3Config.get_config()
for k, v in config.items():
if (
k not in inference_params
): # completion(top_k=3) > anthropic_config(top_k=3) <- allows for dynamic variables to be passed in
inference_params[k] = v
## Handle Tool Calling
if "tools" in inference_params:
_is_function_call = True
for tool in inference_params["tools"]:
json_schemas[tool["function"]["name"]] = tool["function"].get(
"parameters", None
)
tool_calling_system_prompt = construct_tool_use_system_prompt(
tools=inference_params["tools"]
)
inference_params["system"] = (
inference_params.get("system", "\n")
+ tool_calling_system_prompt
) # add the anthropic tool calling prompt to the system prompt
inference_params.pop("tools")
request_data = {"messages": messages, **inference_params}
else:
## LOAD CONFIG
config = litellm.AmazonAnthropicConfig.get_config()
for k, v in config.items():
if (
k not in inference_params
): # completion(top_k=3) > anthropic_config(top_k=3) <- allows for dynamic variables to be passed in
inference_params[k] = v
request_data = {"prompt": prompt, **inference_params}
elif provider == "ai21": elif provider == "ai21":
## LOAD CONFIG ## LOAD CONFIG
config = litellm.AmazonAI21Config.get_config() config = litellm.AmazonAI21Config.get_config()
@ -359,66 +309,19 @@ class AmazonInvokeConfig(BaseConfig, BaseAWSLLM):
completion_response["generations"][0]["finish_reason"] completion_response["generations"][0]["finish_reason"]
) )
elif provider == "anthropic": elif provider == "anthropic":
if model.startswith("anthropic.claude-3"): return litellm.AmazonAnthropicClaude3Config().transform_response(
json_schemas: dict = {} model=model,
_is_function_call = False raw_response=raw_response,
## Handle Tool Calling model_response=model_response,
if "tools" in optional_params: logging_obj=logging_obj,
_is_function_call = True request_data=request_data,
for tool in optional_params["tools"]: messages=messages,
json_schemas[tool["function"]["name"]] = tool[ optional_params=optional_params,
"function" litellm_params=litellm_params,
].get("parameters", None) encoding=encoding,
outputText = completion_response.get("content")[0].get("text", None) api_key=api_key,
if outputText is not None and contains_tag( json_mode=json_mode,
"invoke", outputText )
): # OUTPUT PARSE FUNCTION CALL
function_name = extract_between_tags("tool_name", outputText)[0]
function_arguments_str = extract_between_tags(
"invoke", outputText
)[0].strip()
function_arguments_str = (
f"<invoke>{function_arguments_str}</invoke>"
)
function_arguments = parse_xml_params(
function_arguments_str,
json_schema=json_schemas.get(
function_name, None
), # check if we have a json schema for this function name)
)
_message = litellm.Message(
tool_calls=[
{
"id": f"call_{uuid.uuid4()}",
"type": "function",
"function": {
"name": function_name,
"arguments": json.dumps(function_arguments),
},
}
],
content=None,
)
model_response.choices[0].message = _message # type: ignore
model_response._hidden_params["original_response"] = (
outputText # allow user to access raw anthropic tool calling response
)
model_response.choices[0].finish_reason = map_finish_reason(
completion_response.get("stop_reason", "")
)
_usage = litellm.Usage(
prompt_tokens=completion_response["usage"]["input_tokens"],
completion_tokens=completion_response["usage"]["output_tokens"],
total_tokens=completion_response["usage"]["input_tokens"]
+ completion_response["usage"]["output_tokens"],
)
setattr(model_response, "usage", _usage)
else:
outputText = completion_response["completion"]
model_response.choices[0].finish_reason = completion_response[
"stop_reason"
]
elif provider == "ai21": elif provider == "ai21":
outputText = ( outputText = (
completion_response.get("completions")[0].get("data").get("text") completion_response.get("completions")[0].get("data").get("text")
@ -536,6 +439,7 @@ class AmazonInvokeConfig(BaseConfig, BaseAWSLLM):
messages=messages, messages=messages,
logging_obj=logging_obj, logging_obj=logging_obj,
fake_stream=True if "ai21" in api_base else False, fake_stream=True if "ai21" in api_base else False,
bedrock_invoke_provider=self.get_bedrock_invoke_provider(model),
), ),
model=model, model=model,
custom_llm_provider="bedrock", custom_llm_provider="bedrock",
@ -569,6 +473,7 @@ class AmazonInvokeConfig(BaseConfig, BaseAWSLLM):
messages=messages, messages=messages,
logging_obj=logging_obj, logging_obj=logging_obj,
fake_stream=True if "ai21" in api_base else False, fake_stream=True if "ai21" in api_base else False,
bedrock_invoke_provider=self.get_bedrock_invoke_provider(model),
), ),
model=model, model=model,
custom_llm_provider="bedrock", custom_llm_provider="bedrock",
@ -594,10 +499,14 @@ class AmazonInvokeConfig(BaseConfig, BaseAWSLLM):
""" """
Helper function to get the bedrock provider from the model Helper function to get the bedrock provider from the model
handles 2 scenarions: handles 3 scenarions:
1. model=anthropic.claude-3-5-sonnet-20240620-v1:0 -> Returns `anthropic` 1. model=invoke/anthropic.claude-3-5-sonnet-20240620-v1:0 -> Returns `anthropic`
2. model=llama/arn:aws:bedrock:us-east-1:086734376398:imported-model/r4c4kewx2s0n -> Returns `llama` 2. model=anthropic.claude-3-5-sonnet-20240620-v1:0 -> Returns `anthropic`
3. model=llama/arn:aws:bedrock:us-east-1:086734376398:imported-model/r4c4kewx2s0n -> Returns `llama`
""" """
if model.startswith("invoke/"):
model = model.replace("invoke/", "", 1)
_split_model = model.split(".")[0] _split_model = model.split(".")[0]
if _split_model in get_args(litellm.BEDROCK_INVOKE_PROVIDERS_LITERAL): if _split_model in get_args(litellm.BEDROCK_INVOKE_PROVIDERS_LITERAL):
return cast(litellm.BEDROCK_INVOKE_PROVIDERS_LITERAL, _split_model) return cast(litellm.BEDROCK_INVOKE_PROVIDERS_LITERAL, _split_model)
@ -640,9 +549,9 @@ class AmazonInvokeConfig(BaseConfig, BaseAWSLLM):
else: else:
modelId = model modelId = model
modelId = modelId.replace("invoke/", "", 1)
if provider == "llama" and "llama/" in modelId: if provider == "llama" and "llama/" in modelId:
modelId = self._get_model_id_for_llama_like_model(modelId) modelId = self._get_model_id_for_llama_like_model(modelId)
return modelId return modelId
def _get_aws_region_name(self, optional_params: dict) -> str: def _get_aws_region_name(self, optional_params: dict) -> str:

View file

@ -3,11 +3,12 @@ Common utilities used across bedrock chat/embedding/image generation
""" """
import os import os
from typing import List, Optional, Union from typing import List, Literal, Optional, Union
import httpx import httpx
import litellm import litellm
from litellm.llms.base_llm.base_utils import BaseLLMModelInfo
from litellm.llms.base_llm.chat.transformation import BaseLLMException from litellm.llms.base_llm.chat.transformation import BaseLLMException
from litellm.secret_managers.main import get_secret from litellm.secret_managers.main import get_secret
@ -310,3 +311,68 @@ def get_bedrock_tool_name(response_tool_name: str) -> str:
response_tool_name response_tool_name
] ]
return response_tool_name return response_tool_name
class BedrockModelInfo(BaseLLMModelInfo):
global_config = AmazonBedrockGlobalConfig()
all_global_regions = global_config.get_all_regions()
@staticmethod
def get_base_model(model: str) -> str:
"""
Get the base model from the given model name.
Handle model names like - "us.meta.llama3-2-11b-instruct-v1:0" -> "meta.llama3-2-11b-instruct-v1"
AND "meta.llama3-2-11b-instruct-v1:0" -> "meta.llama3-2-11b-instruct-v1"
"""
if model.startswith("bedrock/"):
model = model.split("/", 1)[1]
if model.startswith("converse/"):
model = model.split("/", 1)[1]
if model.startswith("invoke/"):
model = model.split("/", 1)[1]
potential_region = model.split(".", 1)[0]
alt_potential_region = model.split("/", 1)[
0
] # in model cost map we store regional information like `/us-west-2/bedrock-model`
if (
potential_region
in BedrockModelInfo._supported_cross_region_inference_region()
):
return model.split(".", 1)[1]
elif (
alt_potential_region in BedrockModelInfo.all_global_regions
and len(model.split("/", 1)) > 1
):
return model.split("/", 1)[1]
return model
@staticmethod
def _supported_cross_region_inference_region() -> List[str]:
"""
Abbreviations of regions AWS Bedrock supports for cross region inference
"""
return ["us", "eu", "apac"]
@staticmethod
def get_bedrock_route(model: str) -> Literal["converse", "invoke", "converse_like"]:
"""
Get the bedrock route for the given model.
"""
base_model = BedrockModelInfo.get_base_model(model)
if "invoke/" in model:
return "invoke"
elif "converse_like" in model:
return "converse_like"
elif "converse/" in model:
return "converse"
elif base_model in litellm.bedrock_converse_models:
return "converse"
return "invoke"

View file

@ -344,6 +344,10 @@ class OpenAIGPTConfig(BaseLLMModelInfo, BaseConfig):
or "https://api.openai.com/v1" or "https://api.openai.com/v1"
) )
@staticmethod
def get_base_model(model: str) -> str:
return model
def get_model_response_iterator( def get_model_response_iterator(
self, self,
streaming_response: Union[Iterator[str], AsyncIterator[str], ModelResponse], streaming_response: Union[Iterator[str], AsyncIterator[str], ModelResponse],

View file

@ -29,3 +29,7 @@ class TopazModelInfo(BaseLLMModelInfo):
return ( return (
api_base or get_secret_str("TOPAZ_API_BASE") or "https://api.topazlabs.com" api_base or get_secret_str("TOPAZ_API_BASE") or "https://api.topazlabs.com"
) )
@staticmethod
def get_base_model(model: str) -> str:
return model

View file

@ -68,6 +68,7 @@ from litellm.litellm_core_utils.prompt_templates.common_utils import (
get_content_from_model_response, get_content_from_model_response,
) )
from litellm.llms.base_llm.chat.transformation import BaseConfig from litellm.llms.base_llm.chat.transformation import BaseConfig
from litellm.llms.bedrock.common_utils import BedrockModelInfo
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, HTTPHandler from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, HTTPHandler
from litellm.realtime_api.main import _realtime_health_check from litellm.realtime_api.main import _realtime_health_check
from litellm.secret_managers.main import get_secret_str from litellm.secret_managers.main import get_secret_str
@ -2628,11 +2629,8 @@ def completion( # type: ignore # noqa: PLR0915
aws_bedrock_client.meta.region_name aws_bedrock_client.meta.region_name
) )
base_model = litellm.AmazonConverseConfig()._get_base_model(model) bedrock_route = BedrockModelInfo.get_bedrock_route(model)
if bedrock_route == "converse":
if base_model in litellm.bedrock_converse_models or model.startswith(
"converse/"
):
model = model.replace("converse/", "") model = model.replace("converse/", "")
response = bedrock_converse_chat_completion.completion( response = bedrock_converse_chat_completion.completion(
model=model, model=model,
@ -2651,7 +2649,7 @@ def completion( # type: ignore # noqa: PLR0915
client=client, client=client,
api_base=api_base, api_base=api_base,
) )
elif "converse_like" in model: elif bedrock_route == "converse_like":
model = model.replace("converse_like/", "") model = model.replace("converse_like/", "")
response = base_llm_http_handler.completion( response = base_llm_http_handler.completion(
model=model, model=model,

View file

@ -86,10 +86,10 @@ from litellm.litellm_core_utils.llm_request_utils import _ensure_extra_body_is_s
from litellm.litellm_core_utils.llm_response_utils.convert_dict_to_response import ( from litellm.litellm_core_utils.llm_response_utils.convert_dict_to_response import (
LiteLLMResponseObjectHandler, LiteLLMResponseObjectHandler,
_handle_invalid_parallel_tool_calls, _handle_invalid_parallel_tool_calls,
_parse_content_for_reasoning,
convert_to_model_response_object, convert_to_model_response_object,
convert_to_streaming_response, convert_to_streaming_response,
convert_to_streaming_response_async, convert_to_streaming_response_async,
_parse_content_for_reasoning,
) )
from litellm.litellm_core_utils.llm_response_utils.get_api_base import get_api_base from litellm.litellm_core_utils.llm_response_utils.get_api_base import get_api_base
from litellm.litellm_core_utils.llm_response_utils.get_formatted_prompt import ( from litellm.litellm_core_utils.llm_response_utils.get_formatted_prompt import (
@ -111,6 +111,7 @@ from litellm.litellm_core_utils.token_counter import (
calculate_img_tokens, calculate_img_tokens,
get_modified_max_tokens, get_modified_max_tokens,
) )
from litellm.llms.bedrock.common_utils import BedrockModelInfo
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, HTTPHandler from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, HTTPHandler
from litellm.router_utils.get_retry_from_policy import ( from litellm.router_utils.get_retry_from_policy import (
get_num_retries_from_retry_policy, get_num_retries_from_retry_policy,
@ -3189,8 +3190,8 @@ def get_optional_params( # noqa: PLR0915
), ),
) )
elif custom_llm_provider == "bedrock": elif custom_llm_provider == "bedrock":
base_model = litellm.AmazonConverseConfig()._get_base_model(model) bedrock_route = BedrockModelInfo.get_bedrock_route(model)
if base_model in litellm.bedrock_converse_models: if bedrock_route == "converse" or bedrock_route == "converse_like":
optional_params = litellm.AmazonConverseConfig().map_openai_params( optional_params = litellm.AmazonConverseConfig().map_openai_params(
model=model, model=model,
non_default_params=non_default_params, non_default_params=non_default_params,
@ -3203,15 +3204,20 @@ def get_optional_params( # noqa: PLR0915
messages=messages, messages=messages,
) )
elif "anthropic" in model: elif "anthropic" in model and bedrock_route == "invoke":
if "aws_bedrock_client" in passed_params: # deprecated boto3.invoke route. if model.startswith("anthropic.claude-3"):
if model.startswith("anthropic.claude-3"): optional_params = (
optional_params = ( litellm.AmazonAnthropicClaude3Config().map_openai_params(
litellm.AmazonAnthropicClaude3Config().map_openai_params( non_default_params=non_default_params,
non_default_params=non_default_params, optional_params=optional_params,
optional_params=optional_params, model=model,
) drop_params=(
drop_params
if drop_params is not None and isinstance(drop_params, bool)
else False
),
) )
)
else: else:
optional_params = litellm.AmazonAnthropicConfig().map_openai_params( optional_params = litellm.AmazonAnthropicConfig().map_openai_params(
non_default_params=non_default_params, non_default_params=non_default_params,
@ -3972,8 +3978,16 @@ def _strip_stable_vertex_version(model_name) -> str:
return re.sub(r"-\d+$", "", model_name) return re.sub(r"-\d+$", "", model_name)
def _strip_bedrock_region(model_name) -> str: def _get_base_bedrock_model(model_name) -> str:
return litellm.AmazonConverseConfig()._get_base_model(model_name) """
Get the base model from the given model name.
Handle model names like - "us.meta.llama3-2-11b-instruct-v1:0" -> "meta.llama3-2-11b-instruct-v1"
AND "meta.llama3-2-11b-instruct-v1:0" -> "meta.llama3-2-11b-instruct-v1"
"""
from litellm.llms.bedrock.common_utils import BedrockModelInfo
return BedrockModelInfo.get_base_model(model_name)
def _strip_openai_finetune_model_name(model_name: str) -> str: def _strip_openai_finetune_model_name(model_name: str) -> str:
@ -3994,8 +4008,8 @@ def _strip_openai_finetune_model_name(model_name: str) -> str:
def _strip_model_name(model: str, custom_llm_provider: Optional[str]) -> str: def _strip_model_name(model: str, custom_llm_provider: Optional[str]) -> str:
if custom_llm_provider and custom_llm_provider == "bedrock": if custom_llm_provider and custom_llm_provider == "bedrock":
strip_bedrock_region = _strip_bedrock_region(model_name=model) stripped_bedrock_model = _get_base_bedrock_model(model_name=model)
return strip_bedrock_region return stripped_bedrock_model
elif custom_llm_provider and ( elif custom_llm_provider and (
custom_llm_provider == "vertex_ai" or custom_llm_provider == "gemini" custom_llm_provider == "vertex_ai" or custom_llm_provider == "gemini"
): ):
@ -6066,24 +6080,23 @@ class ProviderConfigManager:
elif litellm.LlmProviders.PETALS == provider: elif litellm.LlmProviders.PETALS == provider:
return litellm.PetalsConfig() return litellm.PetalsConfig()
elif litellm.LlmProviders.BEDROCK == provider: elif litellm.LlmProviders.BEDROCK == provider:
base_model = litellm.AmazonConverseConfig()._get_base_model(model) bedrock_route = BedrockModelInfo.get_bedrock_route(model)
bedrock_provider = litellm.BedrockLLM.get_bedrock_invoke_provider(model) bedrock_invoke_provider = litellm.BedrockLLM.get_bedrock_invoke_provider(
if ( model
base_model in litellm.bedrock_converse_models )
or "converse_like" in model if bedrock_route == "converse" or bedrock_route == "converse_like":
):
return litellm.AmazonConverseConfig() return litellm.AmazonConverseConfig()
elif bedrock_provider == "amazon": # amazon titan llms elif bedrock_invoke_provider == "amazon": # amazon titan llms
return litellm.AmazonTitanConfig() return litellm.AmazonTitanConfig()
elif ( elif (
bedrock_provider == "meta" or bedrock_provider == "llama" bedrock_invoke_provider == "meta" or bedrock_invoke_provider == "llama"
): # amazon / meta llms ): # amazon / meta llms
return litellm.AmazonLlamaConfig() return litellm.AmazonLlamaConfig()
elif bedrock_provider == "ai21": # ai21 llms elif bedrock_invoke_provider == "ai21": # ai21 llms
return litellm.AmazonAI21Config() return litellm.AmazonAI21Config()
elif bedrock_provider == "cohere": # cohere models on bedrock elif bedrock_invoke_provider == "cohere": # cohere models on bedrock
return litellm.AmazonCohereConfig() return litellm.AmazonCohereConfig()
elif bedrock_provider == "mistral": # mistral models on bedrock elif bedrock_invoke_provider == "mistral": # mistral models on bedrock
return litellm.AmazonMistralConfig() return litellm.AmazonMistralConfig()
else: else:
return litellm.AmazonInvokeConfig() return litellm.AmazonInvokeConfig()

View file

@ -1265,7 +1265,9 @@ def test_bedrock_cross_region_inference(model):
], ],
) )
def test_bedrock_get_base_model(model, expected_base_model): def test_bedrock_get_base_model(model, expected_base_model):
assert litellm.AmazonConverseConfig()._get_base_model(model) == expected_base_model from litellm.llms.bedrock.common_utils import BedrockModelInfo
assert BedrockModelInfo.get_base_model(model) == expected_base_model
from litellm.litellm_core_utils.prompt_templates.factory import ( from litellm.litellm_core_utils.prompt_templates.factory import (
@ -1982,9 +1984,49 @@ def test_bedrock_mapped_converse_models():
def test_bedrock_base_model_helper(): def test_bedrock_base_model_helper():
from litellm.llms.bedrock.common_utils import BedrockModelInfo
model = "us.amazon.nova-pro-v1:0" model = "us.amazon.nova-pro-v1:0"
litellm.AmazonConverseConfig()._get_base_model(model) base_model = BedrockModelInfo.get_base_model(model)
assert model == "us.amazon.nova-pro-v1:0" assert base_model == "amazon.nova-pro-v1:0"
assert (
BedrockModelInfo.get_base_model(
"invoke/anthropic.claude-3-5-sonnet-20241022-v2:0"
)
== "anthropic.claude-3-5-sonnet-20241022-v2:0"
)
@pytest.mark.parametrize(
"model,expected_route",
[
# Test explicit route prefixes
("invoke/anthropic.claude-3-sonnet-20240229-v1:0", "invoke"),
("converse/anthropic.claude-3-sonnet-20240229-v1:0", "converse"),
("converse_like/anthropic.claude-3-sonnet-20240229-v1:0", "converse_like"),
# Test models in BEDROCK_CONVERSE_MODELS list
("anthropic.claude-3-5-haiku-20241022-v1:0", "converse"),
("anthropic.claude-v2", "converse"),
("meta.llama3-70b-instruct-v1:0", "converse"),
("mistral.mistral-large-2407-v1:0", "converse"),
# Test models with region prefixes
("us.anthropic.claude-3-sonnet-20240229-v1:0", "converse"),
("us.meta.llama3-70b-instruct-v1:0", "converse"),
# Test default case (should return "invoke")
("amazon.titan-text-express-v1", "invoke"),
("cohere.command-text-v14", "invoke"),
("cohere.command-r-v1:0", "invoke"),
],
)
def test_bedrock_route_detection(model, expected_route):
"""Test all scenarios for BedrockModelInfo.get_bedrock_route"""
from litellm.llms.bedrock.common_utils import BedrockModelInfo
route = BedrockModelInfo.get_bedrock_route(model)
assert (
route == expected_route
), f"Expected route '{expected_route}' for model '{model}', but got '{route}'"
@pytest.mark.parametrize( @pytest.mark.parametrize(

View file

@ -0,0 +1,28 @@
from base_llm_unit_tests import BaseLLMChatTest
import pytest
import sys
import os
sys.path.insert(
0, os.path.abspath("../..")
) # Adds the parent directory to the system path
import litellm
class TestBedrockInvokeClaudeJson(BaseLLMChatTest):
def get_base_completion_call_args(self) -> dict:
litellm._turn_on_debug()
return {
"model": "bedrock/invoke/anthropic.claude-3-5-sonnet-20240620-v1:0",
}
def test_tool_call_no_arguments(self, tool_call_no_arguments):
"""Test that tool calls with no arguments is translated correctly. Relevant issue: https://github.com/BerriAI/litellm/issues/6833"""
pass
@pytest.fixture(autouse=True)
def skip_non_json_tests(self, request):
if not "json" in request.function.__name__.lower():
pytest.skip(
f"Skipping non-JSON test: {request.function.__name__} does not contain 'json'"
)

View file

@ -305,12 +305,16 @@ def test_all_model_configs():
assert ( assert (
"max_completion_tokens" "max_completion_tokens"
in AmazonAnthropicClaude3Config().get_supported_openai_params() in AmazonAnthropicClaude3Config().get_supported_openai_params(
model="anthropic.claude-3-sonnet-20240229-v1:0"
)
) )
assert AmazonAnthropicClaude3Config().map_openai_params( assert AmazonAnthropicClaude3Config().map_openai_params(
non_default_params={"max_completion_tokens": 10}, non_default_params={"max_completion_tokens": 10},
optional_params={}, optional_params={},
model="anthropic.claude-3-sonnet-20240229-v1:0",
drop_params=False,
) == {"max_tokens": 10} ) == {"max_tokens": 10}
assert ( assert (

View file

@ -208,3 +208,11 @@ def test_nova_bedrock_converse():
) )
assert custom_llm_provider == "bedrock" assert custom_llm_provider == "bedrock"
assert model == "amazon.nova-micro-v1:0" assert model == "amazon.nova-micro-v1:0"
def test_bedrock_invoke_anthropic():
model, custom_llm_provider, dynamic_api_key, api_base = litellm.get_llm_provider(
model="bedrock/invoke/anthropic.claude-3-5-sonnet-20240620-v1:0",
)
assert custom_llm_provider == "bedrock"
assert model == "invoke/anthropic.claude-3-5-sonnet-20240620-v1:0"

View file

@ -321,7 +321,7 @@ def test_get_model_info_bedrock_models():
""" """
Check for drift in base model info for bedrock models and regional model info for bedrock models. Check for drift in base model info for bedrock models and regional model info for bedrock models.
""" """
from litellm import AmazonConverseConfig from litellm.llms.bedrock.common_utils import BedrockModelInfo
os.environ["LITELLM_LOCAL_MODEL_COST_MAP"] = "True" os.environ["LITELLM_LOCAL_MODEL_COST_MAP"] = "True"
litellm.model_cost = litellm.get_model_cost_map(url="") litellm.model_cost = litellm.get_model_cost_map(url="")
@ -337,7 +337,7 @@ def test_get_model_info_bedrock_models():
if any(commitment in k for commitment in potential_commitments): if any(commitment in k for commitment in potential_commitments):
for commitment in potential_commitments: for commitment in potential_commitments:
k = k.replace(f"{commitment}/", "") k = k.replace(f"{commitment}/", "")
base_model = AmazonConverseConfig()._get_base_model(k) base_model = BedrockModelInfo.get_base_model(k)
base_model_info = litellm.model_cost[base_model] base_model_info = litellm.model_cost[base_model]
for base_model_key, base_model_value in base_model_info.items(): for base_model_key, base_model_value in base_model_info.items():
if base_model_key.startswith("supports_"): if base_model_key.startswith("supports_"):