diff --git a/docs/my-website/docs/providers/ai21.md b/docs/my-website/docs/providers/ai21.md index c0987b3126..294db3fcbc 100644 --- a/docs/my-website/docs/providers/ai21.md +++ b/docs/my-website/docs/providers/ai21.md @@ -1,8 +1,17 @@ +import Tabs from '@theme/Tabs'; +import TabItem from '@theme/TabItem'; + # AI21 -LiteLLM supports j2-light, j2-mid and j2-ultra from [AI21](https://www.ai21.com/studio/pricing). +LiteLLM supports j2-light, j2-mid and j2-ultra from [AI21](https://www.ai21.com/studio/pricing) -They're available to use without a waitlist. + +:::tip + +**We support ALL AI21 models, just set `model=ai21/` as a prefix when sending litellm requests**. +**See all litellm supported AI21 models [here](https://models.litellm.ai)** + +::: ### API KEYS ```python @@ -10,6 +19,7 @@ import os os.environ["AI21_API_KEY"] = "your-api-key" ``` +## **LiteLLM Python SDK Usage** ### Sample Usage ```python @@ -23,10 +33,89 @@ messages = [{"role": "user", "content": "Write me a poem about the blue sky"}] completion(model="j2-light", messages=messages) ``` -### AI21 Models + + +## **LiteLLM Proxy Server Usage** + +Here's how to call a ai21 model with the LiteLLM Proxy Server + +1. Modify the config.yaml + + ```yaml + model_list: + - model_name: my-model + litellm_params: + model: ai21/ # add ai21/ prefix to route as ai21 provider + api_key: api-key # api key to send your model + ``` + + +2. Start the proxy + + ```bash + $ litellm --config /path/to/config.yaml + ``` + +3. Send Request to LiteLLM Proxy Server + + + + + + ```python + import openai + client = openai.OpenAI( + api_key="sk-1234", # pass litellm proxy key, if you're using virtual keys + base_url="http://0.0.0.0:4000" # litellm-proxy-base url + ) + + response = client.chat.completions.create( + model="my-model", + messages = [ + { + "role": "user", + "content": "what llm are you" + } + ], + ) + + print(response) + ``` + + + + + ```shell + curl --location 'http://0.0.0.0:4000/chat/completions' \ + --header 'Authorization: Bearer sk-1234' \ + --header 'Content-Type: application/json' \ + --data '{ + "model": "my-model", + "messages": [ + { + "role": "user", + "content": "what llm are you" + } + ], + }' + ``` + + + + + +:::tip + +**We support ALL AI21 models, just set `model=ai21/` as a prefix when sending litellm requests** +**See all litellm supported AI21 models [here](https://models.litellm.ai)** +::: + +## AI21 Models | Model Name | Function Call | Required OS Variables | |------------------|--------------------------------------------|--------------------------------------| +| jamba-1.5-mini | `completion('jamba-1.5-mini', messages)` | `os.environ['AI21_API_KEY']` | +| jamba-1.5-large | `completion('jamba-1.5-large', messages)` | `os.environ['AI21_API_KEY']` | | j2-light | `completion('j2-light', messages)` | `os.environ['AI21_API_KEY']` | | j2-mid | `completion('j2-mid', messages)` | `os.environ['AI21_API_KEY']` | | j2-ultra | `completion('j2-ultra', messages)` | `os.environ['AI21_API_KEY']` | \ No newline at end of file diff --git a/litellm/tests/test_get_llm_provider.py b/litellm/tests/test_get_llm_provider.py index 921420f801..ebf4debd5b 100644 --- a/litellm/tests/test_get_llm_provider.py +++ b/litellm/tests/test_get_llm_provider.py @@ -77,3 +77,19 @@ def test_get_llm_provider_ai21_chat(): assert custom_llm_provider == "ai21_chat" assert model == "jamba-1.5-large" assert api_base == "https://api.ai21.com/studio/v1" + + +def test_get_llm_provider_ai21_chat_test2(): + """ + if user prefix with ai21/ but calls jamba-1.5-large then it should be ai21_chat provider + """ + model, custom_llm_provider, dynamic_api_key, api_base = litellm.get_llm_provider( + model="ai21/jamba-1.5-large", + ) + + print("model=", model) + print("custom_llm_provider=", custom_llm_provider) + print("api_base=", api_base) + assert custom_llm_provider == "ai21_chat" + assert model == "jamba-1.5-large" + assert api_base == "https://api.ai21.com/studio/v1" diff --git a/litellm/utils.py b/litellm/utils.py index 85744c2dfb..c32a2736a6 100644 --- a/litellm/utils.py +++ b/litellm/utils.py @@ -4666,6 +4666,7 @@ def get_llm_provider( ): custom_llm_provider = model.split("/", 1)[0] model = model.split("/", 1)[1] + if custom_llm_provider == "perplexity": # perplexity is openai compatible, we just need to set this to custom_openai and have the api_base be https://api.perplexity.ai api_base = api_base or get_secret("PERPLEXITY_API_BASE") or "https://api.perplexity.ai" # type: ignore @@ -4712,13 +4713,16 @@ def get_llm_provider( or "https://api.cerebras.ai/v1" ) # type: ignore dynamic_api_key = api_key or get_secret("CEREBRAS_API_KEY") - elif custom_llm_provider == "ai21_chat": + elif (custom_llm_provider == "ai21_chat") or ( + custom_llm_provider == "ai21" and model in litellm.ai21_chat_models + ): api_base = ( api_base or get_secret("AI21_API_BASE") or "https://api.ai21.com/studio/v1" ) # type: ignore dynamic_api_key = api_key or get_secret("AI21_API_KEY") + custom_llm_provider = "ai21_chat" elif custom_llm_provider == "volcengine": # volcengine is openai compatible, we just need to set this to custom_openai and have the api_base be https://api.endpoints.anyscale.com/v1 api_base = (