mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-26 03:04:13 +00:00
fix(common_utils.py): handle cris only model
Fixes https://github.com/BerriAI/litellm/issues/9161#issuecomment-2734905153
This commit is contained in:
parent
c101fe9b5d
commit
9adad381b4
2 changed files with 39 additions and 8 deletions
|
@ -336,13 +336,7 @@ class BedrockModelInfo(BaseLLMModelInfo):
|
||||||
return model
|
return model
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def get_base_model(model: str) -> str:
|
def get_non_litellm_routing_model_name(model: str) -> str:
|
||||||
"""
|
|
||||||
Get the base model from the given model name.
|
|
||||||
|
|
||||||
Handle model names like - "us.meta.llama3-2-11b-instruct-v1:0" -> "meta.llama3-2-11b-instruct-v1"
|
|
||||||
AND "meta.llama3-2-11b-instruct-v1:0" -> "meta.llama3-2-11b-instruct-v1"
|
|
||||||
"""
|
|
||||||
if model.startswith("bedrock/"):
|
if model.startswith("bedrock/"):
|
||||||
model = model.split("/", 1)[1]
|
model = model.split("/", 1)[1]
|
||||||
|
|
||||||
|
@ -352,6 +346,18 @@ class BedrockModelInfo(BaseLLMModelInfo):
|
||||||
if model.startswith("invoke/"):
|
if model.startswith("invoke/"):
|
||||||
model = model.split("/", 1)[1]
|
model = model.split("/", 1)[1]
|
||||||
|
|
||||||
|
return model
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def get_base_model(model: str) -> str:
|
||||||
|
"""
|
||||||
|
Get the base model from the given model name.
|
||||||
|
|
||||||
|
Handle model names like - "us.meta.llama3-2-11b-instruct-v1:0" -> "meta.llama3-2-11b-instruct-v1"
|
||||||
|
AND "meta.llama3-2-11b-instruct-v1:0" -> "meta.llama3-2-11b-instruct-v1"
|
||||||
|
"""
|
||||||
|
|
||||||
|
model = BedrockModelInfo.get_non_litellm_routing_model_name(model=model)
|
||||||
model = BedrockModelInfo.extract_model_name_from_arn(model)
|
model = BedrockModelInfo.extract_model_name_from_arn(model)
|
||||||
|
|
||||||
potential_region = model.split(".", 1)[0]
|
potential_region = model.split(".", 1)[0]
|
||||||
|
@ -386,12 +392,16 @@ class BedrockModelInfo(BaseLLMModelInfo):
|
||||||
Get the bedrock route for the given model.
|
Get the bedrock route for the given model.
|
||||||
"""
|
"""
|
||||||
base_model = BedrockModelInfo.get_base_model(model)
|
base_model = BedrockModelInfo.get_base_model(model)
|
||||||
|
alt_model = BedrockModelInfo.get_non_litellm_routing_model_name(model=model)
|
||||||
if "invoke/" in model:
|
if "invoke/" in model:
|
||||||
return "invoke"
|
return "invoke"
|
||||||
elif "converse_like" in model:
|
elif "converse_like" in model:
|
||||||
return "converse_like"
|
return "converse_like"
|
||||||
elif "converse/" in model:
|
elif "converse/" in model:
|
||||||
return "converse"
|
return "converse"
|
||||||
elif base_model in litellm.bedrock_converse_models:
|
elif (
|
||||||
|
base_model in litellm.bedrock_converse_models
|
||||||
|
or alt_model in litellm.bedrock_converse_models
|
||||||
|
):
|
||||||
return "converse"
|
return "converse"
|
||||||
return "invoke"
|
return "invoke"
|
||||||
|
|
21
tests/litellm/llms/bedrock/test_bedrock_common_utils.py
Normal file
21
tests/litellm/llms/bedrock/test_bedrock_common_utils.py
Normal file
|
@ -0,0 +1,21 @@
|
||||||
|
import json
|
||||||
|
import os
|
||||||
|
import sys
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from fastapi.testclient import TestClient
|
||||||
|
|
||||||
|
sys.path.insert(
|
||||||
|
0, os.path.abspath("../../../..")
|
||||||
|
) # Adds the parent directory to the system path
|
||||||
|
|
||||||
|
|
||||||
|
from litellm.llms.bedrock.common_utils import BedrockModelInfo
|
||||||
|
|
||||||
|
|
||||||
|
def test_deepseek_cris():
|
||||||
|
bedrock_model_info = BedrockModelInfo
|
||||||
|
bedrock_route = bedrock_model_info.get_bedrock_route(
|
||||||
|
model="bedrock/us.deepseek.r1-v1:0"
|
||||||
|
)
|
||||||
|
assert bedrock_route == "converse"
|
Loading…
Add table
Add a link
Reference in a new issue