fix(common_utils.py): handle cris only model

Fixes https://github.com/BerriAI/litellm/issues/9161#issuecomment-2734905153
This commit is contained in:
Krrish Dholakia 2025-03-18 23:35:43 -07:00
parent 5a327da78e
commit db3a65d52a
2 changed files with 39 additions and 8 deletions

View file

@ -336,13 +336,7 @@ class BedrockModelInfo(BaseLLMModelInfo):
return model return model
@staticmethod @staticmethod
def get_base_model(model: str) -> str: def get_non_litellm_routing_model_name(model: str) -> str:
"""
Get the base model from the given model name.
Handle model names like - "us.meta.llama3-2-11b-instruct-v1:0" -> "meta.llama3-2-11b-instruct-v1"
AND "meta.llama3-2-11b-instruct-v1:0" -> "meta.llama3-2-11b-instruct-v1"
"""
if model.startswith("bedrock/"): if model.startswith("bedrock/"):
model = model.split("/", 1)[1] model = model.split("/", 1)[1]
@ -352,6 +346,18 @@ class BedrockModelInfo(BaseLLMModelInfo):
if model.startswith("invoke/"): if model.startswith("invoke/"):
model = model.split("/", 1)[1] model = model.split("/", 1)[1]
return model
@staticmethod
def get_base_model(model: str) -> str:
"""
Get the base model from the given model name.
Handle model names like - "us.meta.llama3-2-11b-instruct-v1:0" -> "meta.llama3-2-11b-instruct-v1"
AND "meta.llama3-2-11b-instruct-v1:0" -> "meta.llama3-2-11b-instruct-v1"
"""
model = BedrockModelInfo.get_non_litellm_routing_model_name(model=model)
model = BedrockModelInfo.extract_model_name_from_arn(model) model = BedrockModelInfo.extract_model_name_from_arn(model)
potential_region = model.split(".", 1)[0] potential_region = model.split(".", 1)[0]
@ -386,12 +392,16 @@ class BedrockModelInfo(BaseLLMModelInfo):
Get the bedrock route for the given model. Get the bedrock route for the given model.
""" """
base_model = BedrockModelInfo.get_base_model(model) base_model = BedrockModelInfo.get_base_model(model)
alt_model = BedrockModelInfo.get_non_litellm_routing_model_name(model=model)
if "invoke/" in model: if "invoke/" in model:
return "invoke" return "invoke"
elif "converse_like" in model: elif "converse_like" in model:
return "converse_like" return "converse_like"
elif "converse/" in model: elif "converse/" in model:
return "converse" return "converse"
elif base_model in litellm.bedrock_converse_models: elif (
base_model in litellm.bedrock_converse_models
or alt_model in litellm.bedrock_converse_models
):
return "converse" return "converse"
return "invoke" return "invoke"

View file

@ -0,0 +1,21 @@
import json
import os
import sys
import pytest
from fastapi.testclient import TestClient
sys.path.insert(
0, os.path.abspath("../../../..")
) # Adds the parent directory to the system path
from litellm.llms.bedrock.common_utils import BedrockModelInfo
def test_deepseek_cris():
bedrock_model_info = BedrockModelInfo
bedrock_route = bedrock_model_info.get_bedrock_route(
model="bedrock/us.deepseek.r1-v1:0"
)
assert bedrock_route == "converse"