add ai21 model test

This commit is contained in:
Ishaan Jaff 2024-09-02 12:14:13 -07:00
parent ae25c5695f
commit 263e283126
4 changed files with 24 additions and 2 deletions

View file

@ -363,6 +363,7 @@ vertex_llama3_models: List = []
vertex_ai_ai21_models: List = [] vertex_ai_ai21_models: List = []
vertex_mistral_models: List = [] vertex_mistral_models: List = []
ai21_models: List = [] ai21_models: List = []
ai21_chat_models: List = []
nlp_cloud_models: List = [] nlp_cloud_models: List = []
aleph_alpha_models: List = [] aleph_alpha_models: List = []
bedrock_models: List = [] bedrock_models: List = []
@ -415,7 +416,10 @@ for key, value in model_cost.items():
key = key.replace("vertex_ai/", "") key = key.replace("vertex_ai/", "")
vertex_ai_ai21_models.append(key) vertex_ai_ai21_models.append(key)
elif value.get("litellm_provider") == "ai21": elif value.get("litellm_provider") == "ai21":
ai21_models.append(key) if value.get("mode") == "chat":
ai21_chat_models.append(key)
else:
ai21_models.append(key)
elif value.get("litellm_provider") == "nlp_cloud": elif value.get("litellm_provider") == "nlp_cloud":
nlp_cloud_models.append(key) nlp_cloud_models.append(key)
elif value.get("litellm_provider") == "aleph_alpha": elif value.get("litellm_provider") == "aleph_alpha":
@ -644,6 +648,7 @@ model_list = (
+ vertex_chat_models + vertex_chat_models
+ vertex_text_models + vertex_text_models
+ ai21_models + ai21_models
+ ai21_chat_models
+ together_ai_models + together_ai_models
+ baseten_models + baseten_models
+ aleph_alpha_models + aleph_alpha_models

View file

@ -4481,7 +4481,7 @@ async def test_dynamic_azure_params(stream, sync_mode):
async def test_completion_ai21(): async def test_completion_ai21():
litellm.set_verbose = True litellm.set_verbose = True
response = await litellm.acompletion( response = await litellm.acompletion(
model="ai21_chat/jamba-1.5-large", model="jamba-1.5-large",
user="ishaan", user="ishaan",
tool_choice="auto", tool_choice="auto",
seed=123, seed=123,

View file

@ -68,3 +68,12 @@ def test_get_llm_provider_deepseek_custom_api_base():
assert api_base == "MY-FAKE-BASE" assert api_base == "MY-FAKE-BASE"
os.environ.pop("DEEPSEEK_API_BASE") os.environ.pop("DEEPSEEK_API_BASE")
def test_get_llm_provider_ai21_chat():
model, custom_llm_provider, dynamic_api_key, api_base = litellm.get_llm_provider(
model="jamba-1.5-large",
)
assert custom_llm_provider == "ai21_chat"
assert model == "jamba-1.5-large"
assert api_base == "https://api.ai21.com/studio/v1"

View file

@ -4958,6 +4958,14 @@ def get_llm_provider(
## ai21 ## ai21
elif model in litellm.ai21_models: elif model in litellm.ai21_models:
custom_llm_provider = "ai21" custom_llm_provider = "ai21"
elif model in litellm.ai21_chat_models:
custom_llm_provider = "ai21_chat"
api_base = (
api_base
or get_secret("AI21_API_BASE")
or "https://api.ai21.com/studio/v1"
) # type: ignore
dynamic_api_key = api_key or get_secret("AI21_API_KEY")
## aleph_alpha ## aleph_alpha
elif model in litellm.aleph_alpha_models: elif model in litellm.aleph_alpha_models:
custom_llm_provider = "aleph_alpha" custom_llm_provider = "aleph_alpha"