diff --git a/docs/my-website/docs/completion/supported.md b/docs/my-website/docs/completion/supported.md index d0b10811d4..16e0cbe6b6 100644 --- a/docs/my-website/docs/completion/supported.md +++ b/docs/my-website/docs/completion/supported.md @@ -69,6 +69,16 @@ Here are some examples of supported models: | [google/flan-t5-xxl](https://huggingface.co/google/flan-t5-xxl) | `completion(model="google/flan-t5-xxl", messages=messages, custom_llm_provider="huggingface")` | `os.environ['HF_TOKEN']` | | [google/flan-t5-large](https://huggingface.co/google/flan-t5-large) | `completion(model="google/flan-t5-large", messages=messages, custom_llm_provider="huggingface")` | `os.environ['HF_TOKEN']` | +### AI21 Models + +| Model Name | Function Call | Required OS Variables | +|------------------|--------------------------------------------|--------------------------------------| +| j2-light | `completion('j2-light', messages)` | `os.environ['AI21_API_KEY']` | +| j2-mid | `completion('j2-mid', messages)` | `os.environ['AI21_API_KEY']` | + +| j2-ultra | `completion('j2-ultra', messages)` | `os.environ['AI21_API_KEY']` | + + ### Cohere Models | Model Name | Function Call | Required OS Variables | diff --git a/litellm/__init__.py b/litellm/__init__.py index 026afcf142..e3ce5fbec9 100644 --- a/litellm/__init__.py +++ b/litellm/__init__.py @@ -124,7 +124,14 @@ huggingface_models = [ "meta-llama/Llama-2-70b-chat", ] # these have been tested on extensively. But by default all text2text-generation and text-generation models are supported by liteLLM. - https://docs.litellm.ai/docs/completion/supported -model_list = open_ai_chat_completion_models + open_ai_text_completion_models + cohere_models + anthropic_models + replicate_models + openrouter_models + huggingface_models + vertex_chat_models + vertex_text_models +ai21_models = [ + "j2-ultra", + "j2-mid", + "j2-light" +] + +model_list = open_ai_chat_completion_models + open_ai_text_completion_models + cohere_models + anthropic_models + replicate_models + openrouter_models + huggingface_models + vertex_chat_models + vertex_text_models + ai21_models + ####### EMBEDDING MODELS ################### open_ai_embedding_models = [ diff --git a/litellm/main.py b/litellm/main.py index d17bebd0d1..a695309efc 100644 --- a/litellm/main.py +++ b/litellm/main.py @@ -394,6 +394,29 @@ def completion( ## LOGGING logging(model=model, input=prompt, custom_llm_provider=custom_llm_provider, additional_args={"max_tokens": max_tokens, "original_response": completion_response}, logger_fn=logger_fn) + ## RESPONSE OBJECT + model_response["choices"][0]["message"]["content"] = completion_response + model_response["created"] = time.time() + model_response["model"] = model + response = model_response + elif model in litellm.ai21_models: + install_and_import("ai21") + import ai21 + ai21.api_key = get_secret("AI21_API_KEY") + + prompt = " ".join([message["content"] for message in messages]) + ## LOGGING + logging(model=model, input=prompt, custom_llm_provider=custom_llm_provider, logger_fn=logger_fn) + + ai21_response = ai21.Completion.execute( + model=model, + prompt=prompt, + ) + completion_response = ai21_response['completions'][0]['data']['text'] + + ## LOGGING + logging(model=model, input=prompt, custom_llm_provider=custom_llm_provider, additional_args={"max_tokens": max_tokens, "original_response": completion_response}, logger_fn=logger_fn) + ## RESPONSE OBJECT model_response["choices"][0]["message"]["content"] = completion_response model_response["created"] = time.time() diff --git a/litellm/tests/test_completion.py b/litellm/tests/test_completion.py index f639e327db..8cc1297f05 100644 --- a/litellm/tests/test_completion.py +++ b/litellm/tests/test_completion.py @@ -213,6 +213,16 @@ def test_completion_together_ai_stream(): except Exception as e: pytest.fail(f"Error occurred: {e}") +#### Test A121 ################### +def test_completion_ai21(): + model_name = "j2-light" + try: + response = completion(model=model_name, messages=messages) + # Add any assertions here to check the response + print(response) + except Exception as e: + pytest.fail(f"Error occurred: {e}") + # test config file with completion # # def test_completion_openai_config(): # try: