diff --git a/docs/my-website/docs/providers/ai21.md b/docs/my-website/docs/providers/ai21.md index c0987b3126..f331fcff9a 100644 --- a/docs/my-website/docs/providers/ai21.md +++ b/docs/my-website/docs/providers/ai21.md @@ -1,8 +1,17 @@ +import Tabs from '@theme/Tabs'; +import TabItem from '@theme/TabItem'; + # AI21 -LiteLLM supports j2-light, j2-mid and j2-ultra from [AI21](https://www.ai21.com/studio/pricing). +LiteLLM supports j2-light, j2-mid and j2-ultra from [AI21](https://www.ai21.com/studio/pricing) -They're available to use without a waitlist. + +:::tip + +**We support ALL AI21 models, just set `model=ai21/` as a prefix when sending litellm requests**. +**See all litellm supported AI21 models [here](https://models.litellm.ai)** + +::: ### API KEYS ```python @@ -10,6 +19,7 @@ import os os.environ["AI21_API_KEY"] = "your-api-key" ``` +## **LiteLLM Python SDK Usage** ### Sample Usage ```python @@ -23,10 +33,177 @@ messages = [{"role": "user", "content": "Write me a poem about the blue sky"}] completion(model="j2-light", messages=messages) ``` -### AI21 Models + + +## **LiteLLM Proxy Server Usage** + +Here's how to call a ai21 model with the LiteLLM Proxy Server + +1. Modify the config.yaml + + ```yaml + model_list: + - model_name: my-model + litellm_params: + model: ai21/ # add ai21/ prefix to route as ai21 provider + api_key: api-key # api key to send your model + ``` + + +2. Start the proxy + + ```bash + $ litellm --config /path/to/config.yaml + ``` + +3. Send Request to LiteLLM Proxy Server + + + + + + ```python + import openai + client = openai.OpenAI( + api_key="sk-1234", # pass litellm proxy key, if you're using virtual keys + base_url="http://0.0.0.0:4000" # litellm-proxy-base url + ) + + response = client.chat.completions.create( + model="my-model", + messages = [ + { + "role": "user", + "content": "what llm are you" + } + ], + ) + + print(response) + ``` + + + + + ```shell + curl --location 'http://0.0.0.0:4000/chat/completions' \ + --header 'Authorization: Bearer sk-1234' \ + --header 'Content-Type: application/json' \ + --data '{ + "model": "my-model", + "messages": [ + { + "role": "user", + "content": "what llm are you" + } + ], + }' + ``` + + + + +## Supported OpenAI Parameters + + +| [param](../completion/input) | type | AI21 equivalent | +|-------|-------------|------------------| +| `tools` | **Optional[list]** | `tools` | +| `response_format` | **Optional[dict]** | `response_format` | +| `max_tokens` | **Optional[int]** | `max_tokens` | +| `temperature` | **Optional[float]** | `temperature` | +| `top_p` | **Optional[float]** | `top_p` | +| `stop` | **Optional[Union[str, list]]** | `stop` | +| `n` | **Optional[int]** | `n` | +| `stream` | **Optional[bool]** | `stream` | +| `seed` | **Optional[int]** | `seed` | +| `tool_choice` | **Optional[str]** | `tool_choice` | +| `user` | **Optional[str]** | `user` | + +## Supported AI21 Parameters + + +| param | type | [AI21 equivalent](https://docs.ai21.com/reference/jamba-15-api-ref#request-parameters) | +|-----------|------|-------------| +| `documents` | **Optional[List[Dict]]** | `documents` | + + +## Passing AI21 Specific Parameters - `documents` + +LiteLLM allows you to pass all AI21 specific parameters to the `litellm.completion` function. Here is an example of how to pass the `documents` parameter to the `litellm.completion` function. + + + + + +```python +response = await litellm.acompletion( + model="jamba-1.5-large", + messages=[{"role": "user", "content": "what does the document say"}], + documents = [ + { + "content": "hello world", + "metadata": { + "source": "google", + "author": "ishaan" + } + } + ] +) + +``` + + + + +```python +import openai +client = openai.OpenAI( + api_key="sk-1234", # pass litellm proxy key, if you're using virtual keys + base_url="http://0.0.0.0:4000" # litellm-proxy-base url +) + +response = client.chat.completions.create( + model="my-model", + messages = [ + { + "role": "user", + "content": "what llm are you" + } + ], + extra_body = { + "documents": [ + { + "content": "hello world", + "metadata": { + "source": "google", + "author": "ishaan" + } + } + ] + } +) + +print(response) + +``` + + + + +:::tip + +**We support ALL AI21 models, just set `model=ai21/` as a prefix when sending litellm requests** +**See all litellm supported AI21 models [here](https://models.litellm.ai)** +::: + +## AI21 Models | Model Name | Function Call | Required OS Variables | |------------------|--------------------------------------------|--------------------------------------| +| jamba-1.5-mini | `completion('jamba-1.5-mini', messages)` | `os.environ['AI21_API_KEY']` | +| jamba-1.5-large | `completion('jamba-1.5-large', messages)` | `os.environ['AI21_API_KEY']` | | j2-light | `completion('j2-light', messages)` | `os.environ['AI21_API_KEY']` | | j2-mid | `completion('j2-mid', messages)` | `os.environ['AI21_API_KEY']` | -| j2-ultra | `completion('j2-ultra', messages)` | `os.environ['AI21_API_KEY']` | \ No newline at end of file +| j2-ultra | `completion('j2-ultra', messages)` | `os.environ['AI21_API_KEY']` | + diff --git a/litellm/__init__.py b/litellm/__init__.py index 496bb9db72..e220c89920 100644 --- a/litellm/__init__.py +++ b/litellm/__init__.py @@ -364,6 +364,7 @@ vertex_llama3_models: List = [] vertex_ai_ai21_models: List = [] vertex_mistral_models: List = [] ai21_models: List = [] +ai21_chat_models: List = [] nlp_cloud_models: List = [] aleph_alpha_models: List = [] bedrock_models: List = [] @@ -416,7 +417,10 @@ for key, value in model_cost.items(): key = key.replace("vertex_ai/", "") vertex_ai_ai21_models.append(key) elif value.get("litellm_provider") == "ai21": - ai21_models.append(key) + if value.get("mode") == "chat": + ai21_chat_models.append(key) + else: + ai21_models.append(key) elif value.get("litellm_provider") == "nlp_cloud": nlp_cloud_models.append(key) elif value.get("litellm_provider") == "aleph_alpha": @@ -456,6 +460,7 @@ openai_compatible_providers: List = [ "groq", "nvidia_nim", "cerebras", + "ai21_chat", "volcengine", "codestral", "deepseek", @@ -644,6 +649,7 @@ model_list = ( + vertex_chat_models + vertex_text_models + ai21_models + + ai21_chat_models + together_ai_models + baseten_models + aleph_alpha_models @@ -695,6 +701,7 @@ provider_list: List = [ "groq", "nvidia_nim", "cerebras", + "ai21_chat", "volcengine", "codestral", "text-completion-codestral", @@ -853,7 +860,8 @@ from .llms.predibase import PredibaseConfig from .llms.replicate import ReplicateConfig from .llms.cohere.completion import CohereConfig from .llms.clarifai import ClarifaiConfig -from .llms.ai21 import AI21Config +from .llms.AI21.completion import AI21Config +from .llms.AI21.chat import AI21ChatConfig from .llms.together_ai import TogetherAIConfig from .llms.cloudflare import CloudflareConfig from .llms.palm import PalmConfig @@ -919,6 +927,7 @@ from .llms.openai import ( ) from .llms.nvidia_nim import NvidiaNimConfig from .llms.cerebras.chat import CerebrasConfig +from .llms.AI21.chat import AI21ChatConfig from .llms.fireworks_ai import FireworksAIConfig from .llms.volcengine import VolcEngineConfig from .llms.text_completion_codestral import MistralTextCompletionConfig diff --git a/litellm/llms/AI21/chat.py b/litellm/llms/AI21/chat.py new file mode 100644 index 0000000000..4eabaaa875 --- /dev/null +++ b/litellm/llms/AI21/chat.py @@ -0,0 +1,95 @@ +""" +AI21 Chat Completions API + +this is OpenAI compatible - no translation needed / occurs +""" + +import types +from typing import Optional, Union + + +class AI21ChatConfig: + """ + Reference: https://docs.ai21.com/reference/jamba-15-api-ref#request-parameters + + Below are the parameters: + """ + + tools: Optional[list] = None + response_format: Optional[dict] = None + documents: Optional[list] = None + max_tokens: Optional[int] = None + temperature: Optional[float] = None + top_p: Optional[float] = None + stop: Optional[Union[str, list]] = None + n: Optional[int] = None + stream: Optional[bool] = None + seed: Optional[int] = None + tool_choice: Optional[str] = None + user: Optional[str] = None + + def __init__( + self, + tools: Optional[list] = None, + response_format: Optional[dict] = None, + max_tokens: Optional[int] = None, + temperature: Optional[float] = None, + top_p: Optional[float] = None, + stop: Optional[Union[str, list]] = None, + n: Optional[int] = None, + stream: Optional[bool] = None, + seed: Optional[int] = None, + tool_choice: Optional[str] = None, + user: Optional[str] = None, + ) -> None: + locals_ = locals().copy() + for key, value in locals_.items(): + if key != "self" and value is not None: + setattr(self.__class__, key, value) + + @classmethod + def get_config(cls): + return { + k: v + for k, v in cls.__dict__.items() + if not k.startswith("__") + and not isinstance( + v, + ( + types.FunctionType, + types.BuiltinFunctionType, + classmethod, + staticmethod, + ), + ) + and v is not None + } + + def get_supported_openai_params(self, model: str) -> list: + """ + Get the supported OpenAI params for the given model + + """ + + return [ + "tools", + "response_format", + "max_tokens", + "temperature", + "top_p", + "stop", + "n", + "stream", + "seed", + "tool_choice", + "user", + ] + + def map_openai_params( + self, model: str, non_default_params: dict, optional_params: dict + ) -> dict: + supported_openai_params = self.get_supported_openai_params(model=model) + for param, value in non_default_params.items(): + if param in supported_openai_params: + optional_params[param] = value + return optional_params diff --git a/litellm/llms/ai21.py b/litellm/llms/AI21/completion.py similarity index 100% rename from litellm/llms/ai21.py rename to litellm/llms/AI21/completion.py diff --git a/litellm/main.py b/litellm/main.py index b5d73e4396..4e59d3baa3 100644 --- a/litellm/main.py +++ b/litellm/main.py @@ -75,7 +75,6 @@ from litellm.utils import ( from ._logging import verbose_logger from .caching import disable_cache, enable_cache, update_cache from .llms import ( - ai21, aleph_alpha, baseten, clarifai, @@ -91,6 +90,7 @@ from .llms import ( replicate, vllm, ) +from .llms.AI21 import completion as ai21 from .llms.anthropic.chat import AnthropicChatCompletion from .llms.anthropic.completion import AnthropicTextCompletion from .llms.azure import AzureChatCompletion, _check_dynamic_azure_params @@ -387,6 +387,7 @@ async def acompletion( or custom_llm_provider == "groq" or custom_llm_provider == "nvidia_nim" or custom_llm_provider == "cerebras" + or custom_llm_provider == "ai21_chat" or custom_llm_provider == "volcengine" or custom_llm_provider == "codestral" or custom_llm_provider == "text-completion-codestral" @@ -1293,6 +1294,7 @@ def completion( or custom_llm_provider == "groq" or custom_llm_provider == "nvidia_nim" or custom_llm_provider == "cerebras" + or custom_llm_provider == "ai21_chat" or custom_llm_provider == "volcengine" or custom_llm_provider == "codestral" or custom_llm_provider == "deepseek" @@ -3143,6 +3145,7 @@ async def aembedding(*args, **kwargs) -> EmbeddingResponse: or custom_llm_provider == "groq" or custom_llm_provider == "nvidia_nim" or custom_llm_provider == "cerebras" + or custom_llm_provider == "ai21_chat" or custom_llm_provider == "volcengine" or custom_llm_provider == "deepseek" or custom_llm_provider == "fireworks_ai" @@ -3807,6 +3810,7 @@ async def atext_completion( or custom_llm_provider == "groq" or custom_llm_provider == "nvidia_nim" or custom_llm_provider == "cerebras" + or custom_llm_provider == "ai21_chat" or custom_llm_provider == "volcengine" or custom_llm_provider == "text-completion-codestral" or custom_llm_provider == "deepseek" diff --git a/litellm/tests/test_completion.py b/litellm/tests/test_completion.py index 5dd7681ced..5fce984520 100644 --- a/litellm/tests/test_completion.py +++ b/litellm/tests/test_completion.py @@ -4481,3 +4481,23 @@ async def test_dynamic_azure_params(stream, sync_mode): except Exception as e: traceback.print_stack() raise e + + +@pytest.mark.asyncio() +@pytest.mark.flaky(retries=3, delay=1) +async def test_completion_ai21_chat(): + litellm.set_verbose = True + response = await litellm.acompletion( + model="jamba-1.5-large", + user="ishaan", + tool_choice="auto", + seed=123, + messages=[{"role": "user", "content": "what does the document say"}], + documents=[ + { + "content": "hello world", + "metadata": {"source": "google", "author": "ishaan"}, + } + ], + ) + pass diff --git a/litellm/tests/test_get_llm_provider.py b/litellm/tests/test_get_llm_provider.py index 5e1c1f4fec..ebf4debd5b 100644 --- a/litellm/tests/test_get_llm_provider.py +++ b/litellm/tests/test_get_llm_provider.py @@ -68,3 +68,28 @@ def test_get_llm_provider_deepseek_custom_api_base(): assert api_base == "MY-FAKE-BASE" os.environ.pop("DEEPSEEK_API_BASE") + + +def test_get_llm_provider_ai21_chat(): + model, custom_llm_provider, dynamic_api_key, api_base = litellm.get_llm_provider( + model="jamba-1.5-large", + ) + assert custom_llm_provider == "ai21_chat" + assert model == "jamba-1.5-large" + assert api_base == "https://api.ai21.com/studio/v1" + + +def test_get_llm_provider_ai21_chat_test2(): + """ + if user prefix with ai21/ but calls jamba-1.5-large then it should be ai21_chat provider + """ + model, custom_llm_provider, dynamic_api_key, api_base = litellm.get_llm_provider( + model="ai21/jamba-1.5-large", + ) + + print("model=", model) + print("custom_llm_provider=", custom_llm_provider) + print("api_base=", api_base) + assert custom_llm_provider == "ai21_chat" + assert model == "jamba-1.5-large" + assert api_base == "https://api.ai21.com/studio/v1" diff --git a/litellm/tests/test_streaming.py b/litellm/tests/test_streaming.py index 1b8b4e0852..1b155bea85 100644 --- a/litellm/tests/test_streaming.py +++ b/litellm/tests/test_streaming.py @@ -586,6 +586,37 @@ async def test_completion_predibase_streaming(sync_mode): pytest.fail(f"Error occurred: {e}") +@pytest.mark.asyncio() +@pytest.mark.flaky(retries=3, delay=1) +async def test_completion_ai21_stream(): + litellm.set_verbose = True + response = await litellm.acompletion( + model="ai21_chat/jamba-1.5-large", + user="ishaan", + stream=True, + seed=123, + messages=[{"role": "user", "content": "hi my name is ishaan"}], + ) + complete_response = "" + idx = 0 + async for init_chunk in response: + chunk, finished = streaming_format_tests(idx, init_chunk) + complete_response += chunk + custom_llm_provider = init_chunk._hidden_params["custom_llm_provider"] + print(f"custom_llm_provider: {custom_llm_provider}") + assert custom_llm_provider == "ai21_chat" + idx += 1 + if finished: + assert isinstance(init_chunk.choices[0], litellm.utils.StreamingChoices) + break + if complete_response.strip() == "": + raise Exception("Empty response received") + + print(f"complete_response: {complete_response}") + + pass + + def test_completion_azure_function_calling_stream(): try: litellm.set_verbose = False diff --git a/litellm/utils.py b/litellm/utils.py index 8dd18f450b..6cd873dd05 100644 --- a/litellm/utils.py +++ b/litellm/utils.py @@ -2887,6 +2887,7 @@ def get_optional_params( and custom_llm_provider != "groq" and custom_llm_provider != "nvidia_nim" and custom_llm_provider != "cerebras" + and custom_llm_provider != "ai21_chat" and custom_llm_provider != "volcengine" and custom_llm_provider != "deepseek" and custom_llm_provider != "codestral" @@ -3656,6 +3657,16 @@ def get_optional_params( optional_params=optional_params, model=model, ) + elif custom_llm_provider == "ai21_chat": + supported_params = get_supported_openai_params( + model=model, custom_llm_provider=custom_llm_provider + ) + _check_valid_arg(supported_params=supported_params) + optional_params = litellm.AI21ChatConfig().map_openai_params( + non_default_params=non_default_params, + optional_params=optional_params, + model=model, + ) elif custom_llm_provider == "fireworks_ai": supported_params = get_supported_openai_params( model=model, custom_llm_provider=custom_llm_provider @@ -4283,6 +4294,8 @@ def get_supported_openai_params( return litellm.NvidiaNimConfig().get_supported_openai_params(model=model) elif custom_llm_provider == "cerebras": return litellm.CerebrasConfig().get_supported_openai_params(model=model) + elif custom_llm_provider == "ai21_chat": + return litellm.AI21ChatConfig().get_supported_openai_params(model=model) elif custom_llm_provider == "volcengine": return litellm.VolcEngineConfig().get_supported_openai_params(model=model) elif custom_llm_provider == "groq": @@ -4671,6 +4684,7 @@ def get_llm_provider( ): custom_llm_provider = model.split("/", 1)[0] model = model.split("/", 1)[1] + if custom_llm_provider == "perplexity": # perplexity is openai compatible, we just need to set this to custom_openai and have the api_base be https://api.perplexity.ai api_base = api_base or get_secret("PERPLEXITY_API_BASE") or "https://api.perplexity.ai" # type: ignore @@ -4717,6 +4731,16 @@ def get_llm_provider( or "https://api.cerebras.ai/v1" ) # type: ignore dynamic_api_key = api_key or get_secret("CEREBRAS_API_KEY") + elif (custom_llm_provider == "ai21_chat") or ( + custom_llm_provider == "ai21" and model in litellm.ai21_chat_models + ): + api_base = ( + api_base + or get_secret("AI21_API_BASE") + or "https://api.ai21.com/studio/v1" + ) # type: ignore + dynamic_api_key = api_key or get_secret("AI21_API_KEY") + custom_llm_provider = "ai21_chat" elif custom_llm_provider == "volcengine": # volcengine is openai compatible, we just need to set this to custom_openai and have the api_base be https://api.endpoints.anyscale.com/v1 api_base = ( @@ -4870,6 +4894,9 @@ def get_llm_provider( elif endpoint == "https://api.cerebras.ai/v1": custom_llm_provider = "cerebras" dynamic_api_key = get_secret("CEREBRAS_API_KEY") + elif endpoint == "https://api.ai21.com/studio/v1": + custom_llm_provider = "ai21_chat" + dynamic_api_key = get_secret("AI21_API_KEY") elif endpoint == "https://codestral.mistral.ai/v1": custom_llm_provider = "codestral" dynamic_api_key = get_secret("CODESTRAL_API_KEY") @@ -4953,6 +4980,14 @@ def get_llm_provider( ## ai21 elif model in litellm.ai21_models: custom_llm_provider = "ai21" + elif model in litellm.ai21_chat_models: + custom_llm_provider = "ai21_chat" + api_base = ( + api_base + or get_secret("AI21_API_BASE") + or "https://api.ai21.com/studio/v1" + ) # type: ignore + dynamic_api_key = api_key or get_secret("AI21_API_KEY") ## aleph_alpha elif model in litellm.aleph_alpha_models: custom_llm_provider = "aleph_alpha" @@ -5800,6 +5835,11 @@ def validate_environment( keys_in_environment = True else: missing_keys.append("CEREBRAS_API_KEY") + elif custom_llm_provider == "ai21_chat": + if "AI21_API_KEY" in os.environ: + keys_in_environment = True + else: + missing_keys.append("AI21_API_KEY") elif custom_llm_provider == "volcengine": if "VOLCENGINE_API_KEY" in os.environ: keys_in_environment = True @@ -6211,7 +6251,10 @@ def convert_to_model_response_object( if "model" in response_object: if model_response_object.model is None: model_response_object.model = response_object["model"] - elif "/" in model_response_object.model: + elif ( + "/" in model_response_object.model + and response_object["model"] is not None + ): openai_compatible_provider = model_response_object.model.split("/")[ 0 ]