diff --git a/litellm/llms/bedrock/common_utils.py b/litellm/llms/bedrock/common_utils.py index 54be359897..4677a579ed 100644 --- a/litellm/llms/bedrock/common_utils.py +++ b/litellm/llms/bedrock/common_utils.py @@ -336,13 +336,7 @@ class BedrockModelInfo(BaseLLMModelInfo): return model @staticmethod - def get_base_model(model: str) -> str: - """ - Get the base model from the given model name. - - Handle model names like - "us.meta.llama3-2-11b-instruct-v1:0" -> "meta.llama3-2-11b-instruct-v1" - AND "meta.llama3-2-11b-instruct-v1:0" -> "meta.llama3-2-11b-instruct-v1" - """ + def get_non_litellm_routing_model_name(model: str) -> str: if model.startswith("bedrock/"): model = model.split("/", 1)[1] @@ -352,6 +346,18 @@ class BedrockModelInfo(BaseLLMModelInfo): if model.startswith("invoke/"): model = model.split("/", 1)[1] + return model + + @staticmethod + def get_base_model(model: str) -> str: + """ + Get the base model from the given model name. + + Handle model names like - "us.meta.llama3-2-11b-instruct-v1:0" -> "meta.llama3-2-11b-instruct-v1" + AND "meta.llama3-2-11b-instruct-v1:0" -> "meta.llama3-2-11b-instruct-v1" + """ + + model = BedrockModelInfo.get_non_litellm_routing_model_name(model=model) model = BedrockModelInfo.extract_model_name_from_arn(model) potential_region = model.split(".", 1)[0] @@ -386,12 +392,16 @@ class BedrockModelInfo(BaseLLMModelInfo): Get the bedrock route for the given model. """ base_model = BedrockModelInfo.get_base_model(model) + alt_model = BedrockModelInfo.get_non_litellm_routing_model_name(model=model) if "invoke/" in model: return "invoke" elif "converse_like" in model: return "converse_like" elif "converse/" in model: return "converse" - elif base_model in litellm.bedrock_converse_models: + elif ( + base_model in litellm.bedrock_converse_models + or alt_model in litellm.bedrock_converse_models + ): return "converse" return "invoke" diff --git a/tests/litellm/llms/bedrock/test_bedrock_common_utils.py b/tests/litellm/llms/bedrock/test_bedrock_common_utils.py new file mode 100644 index 0000000000..f66ed21cf7 --- /dev/null +++ b/tests/litellm/llms/bedrock/test_bedrock_common_utils.py @@ -0,0 +1,21 @@ +import json +import os +import sys + +import pytest +from fastapi.testclient import TestClient + +sys.path.insert( + 0, os.path.abspath("../../../..") +) # Adds the parent directory to the system path + + +from litellm.llms.bedrock.common_utils import BedrockModelInfo + + +def test_deepseek_cris(): + bedrock_model_info = BedrockModelInfo + bedrock_route = bedrock_model_info.get_bedrock_route( + model="bedrock/us.deepseek.r1-v1:0" + ) + assert bedrock_route == "converse"