mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-26 03:04:13 +00:00
add ai21 model test
This commit is contained in:
parent
ae25c5695f
commit
263e283126
4 changed files with 24 additions and 2 deletions
|
@ -363,6 +363,7 @@ vertex_llama3_models: List = []
|
||||||
vertex_ai_ai21_models: List = []
|
vertex_ai_ai21_models: List = []
|
||||||
vertex_mistral_models: List = []
|
vertex_mistral_models: List = []
|
||||||
ai21_models: List = []
|
ai21_models: List = []
|
||||||
|
ai21_chat_models: List = []
|
||||||
nlp_cloud_models: List = []
|
nlp_cloud_models: List = []
|
||||||
aleph_alpha_models: List = []
|
aleph_alpha_models: List = []
|
||||||
bedrock_models: List = []
|
bedrock_models: List = []
|
||||||
|
@ -415,7 +416,10 @@ for key, value in model_cost.items():
|
||||||
key = key.replace("vertex_ai/", "")
|
key = key.replace("vertex_ai/", "")
|
||||||
vertex_ai_ai21_models.append(key)
|
vertex_ai_ai21_models.append(key)
|
||||||
elif value.get("litellm_provider") == "ai21":
|
elif value.get("litellm_provider") == "ai21":
|
||||||
ai21_models.append(key)
|
if value.get("mode") == "chat":
|
||||||
|
ai21_chat_models.append(key)
|
||||||
|
else:
|
||||||
|
ai21_models.append(key)
|
||||||
elif value.get("litellm_provider") == "nlp_cloud":
|
elif value.get("litellm_provider") == "nlp_cloud":
|
||||||
nlp_cloud_models.append(key)
|
nlp_cloud_models.append(key)
|
||||||
elif value.get("litellm_provider") == "aleph_alpha":
|
elif value.get("litellm_provider") == "aleph_alpha":
|
||||||
|
@ -644,6 +648,7 @@ model_list = (
|
||||||
+ vertex_chat_models
|
+ vertex_chat_models
|
||||||
+ vertex_text_models
|
+ vertex_text_models
|
||||||
+ ai21_models
|
+ ai21_models
|
||||||
|
+ ai21_chat_models
|
||||||
+ together_ai_models
|
+ together_ai_models
|
||||||
+ baseten_models
|
+ baseten_models
|
||||||
+ aleph_alpha_models
|
+ aleph_alpha_models
|
||||||
|
|
|
@ -4481,7 +4481,7 @@ async def test_dynamic_azure_params(stream, sync_mode):
|
||||||
async def test_completion_ai21():
|
async def test_completion_ai21():
|
||||||
litellm.set_verbose = True
|
litellm.set_verbose = True
|
||||||
response = await litellm.acompletion(
|
response = await litellm.acompletion(
|
||||||
model="ai21_chat/jamba-1.5-large",
|
model="jamba-1.5-large",
|
||||||
user="ishaan",
|
user="ishaan",
|
||||||
tool_choice="auto",
|
tool_choice="auto",
|
||||||
seed=123,
|
seed=123,
|
||||||
|
|
|
@ -68,3 +68,12 @@ def test_get_llm_provider_deepseek_custom_api_base():
|
||||||
assert api_base == "MY-FAKE-BASE"
|
assert api_base == "MY-FAKE-BASE"
|
||||||
|
|
||||||
os.environ.pop("DEEPSEEK_API_BASE")
|
os.environ.pop("DEEPSEEK_API_BASE")
|
||||||
|
|
||||||
|
|
||||||
|
def test_get_llm_provider_ai21_chat():
|
||||||
|
model, custom_llm_provider, dynamic_api_key, api_base = litellm.get_llm_provider(
|
||||||
|
model="jamba-1.5-large",
|
||||||
|
)
|
||||||
|
assert custom_llm_provider == "ai21_chat"
|
||||||
|
assert model == "jamba-1.5-large"
|
||||||
|
assert api_base == "https://api.ai21.com/studio/v1"
|
||||||
|
|
|
@ -4958,6 +4958,14 @@ def get_llm_provider(
|
||||||
## ai21
|
## ai21
|
||||||
elif model in litellm.ai21_models:
|
elif model in litellm.ai21_models:
|
||||||
custom_llm_provider = "ai21"
|
custom_llm_provider = "ai21"
|
||||||
|
elif model in litellm.ai21_chat_models:
|
||||||
|
custom_llm_provider = "ai21_chat"
|
||||||
|
api_base = (
|
||||||
|
api_base
|
||||||
|
or get_secret("AI21_API_BASE")
|
||||||
|
or "https://api.ai21.com/studio/v1"
|
||||||
|
) # type: ignore
|
||||||
|
dynamic_api_key = api_key or get_secret("AI21_API_KEY")
|
||||||
## aleph_alpha
|
## aleph_alpha
|
||||||
elif model in litellm.aleph_alpha_models:
|
elif model in litellm.aleph_alpha_models:
|
||||||
custom_llm_provider = "aleph_alpha"
|
custom_llm_provider = "aleph_alpha"
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue