diff --git a/litellm/__init__.py b/litellm/__init__.py index 5e57d7678..6c3c44c21 100644 --- a/litellm/__init__.py +++ b/litellm/__init__.py @@ -363,6 +363,7 @@ vertex_llama3_models: List = [] vertex_ai_ai21_models: List = [] vertex_mistral_models: List = [] ai21_models: List = [] +ai21_chat_models: List = [] nlp_cloud_models: List = [] aleph_alpha_models: List = [] bedrock_models: List = [] @@ -415,7 +416,10 @@ for key, value in model_cost.items(): key = key.replace("vertex_ai/", "") vertex_ai_ai21_models.append(key) elif value.get("litellm_provider") == "ai21": - ai21_models.append(key) + if value.get("mode") == "chat": + ai21_chat_models.append(key) + else: + ai21_models.append(key) elif value.get("litellm_provider") == "nlp_cloud": nlp_cloud_models.append(key) elif value.get("litellm_provider") == "aleph_alpha": @@ -644,6 +648,7 @@ model_list = ( + vertex_chat_models + vertex_text_models + ai21_models + + ai21_chat_models + together_ai_models + baseten_models + aleph_alpha_models diff --git a/litellm/tests/test_completion.py b/litellm/tests/test_completion.py index b010bd602..f223a7e2b 100644 --- a/litellm/tests/test_completion.py +++ b/litellm/tests/test_completion.py @@ -4481,7 +4481,7 @@ async def test_dynamic_azure_params(stream, sync_mode): async def test_completion_ai21(): litellm.set_verbose = True response = await litellm.acompletion( - model="ai21_chat/jamba-1.5-large", + model="jamba-1.5-large", user="ishaan", tool_choice="auto", seed=123, diff --git a/litellm/tests/test_get_llm_provider.py b/litellm/tests/test_get_llm_provider.py index 5e1c1f4fe..921420f80 100644 --- a/litellm/tests/test_get_llm_provider.py +++ b/litellm/tests/test_get_llm_provider.py @@ -68,3 +68,12 @@ def test_get_llm_provider_deepseek_custom_api_base(): assert api_base == "MY-FAKE-BASE" os.environ.pop("DEEPSEEK_API_BASE") + + +def test_get_llm_provider_ai21_chat(): + model, custom_llm_provider, dynamic_api_key, api_base = litellm.get_llm_provider( + model="jamba-1.5-large", + ) + assert custom_llm_provider == "ai21_chat" + assert model == "jamba-1.5-large" + assert api_base == "https://api.ai21.com/studio/v1" diff --git a/litellm/utils.py b/litellm/utils.py index 0c990ddea..85744c2df 100644 --- a/litellm/utils.py +++ b/litellm/utils.py @@ -4958,6 +4958,14 @@ def get_llm_provider( ## ai21 elif model in litellm.ai21_models: custom_llm_provider = "ai21" + elif model in litellm.ai21_chat_models: + custom_llm_provider = "ai21_chat" + api_base = ( + api_base + or get_secret("AI21_API_BASE") + or "https://api.ai21.com/studio/v1" + ) # type: ignore + dynamic_api_key = api_key or get_secret("AI21_API_KEY") ## aleph_alpha elif model in litellm.aleph_alpha_models: custom_llm_provider = "aleph_alpha"