diff --git a/litellm/utils.py b/litellm/utils.py index efd48e8ab6..0c990ddea0 100644 --- a/litellm/utils.py +++ b/litellm/utils.py @@ -2869,6 +2869,7 @@ def get_optional_params( and custom_llm_provider != "groq" and custom_llm_provider != "nvidia_nim" and custom_llm_provider != "cerebras" + and custom_llm_provider != "ai21_chat" and custom_llm_provider != "volcengine" and custom_llm_provider != "deepseek" and custom_llm_provider != "codestral" @@ -3638,6 +3639,16 @@ def get_optional_params( optional_params=optional_params, model=model, ) + elif custom_llm_provider == "ai21_chat": + supported_params = get_supported_openai_params( + model=model, custom_llm_provider=custom_llm_provider + ) + _check_valid_arg(supported_params=supported_params) + optional_params = litellm.AI21ChatConfig().map_openai_params( + non_default_params=non_default_params, + optional_params=optional_params, + model=model, + ) elif custom_llm_provider == "fireworks_ai": supported_params = get_supported_openai_params( model=model, custom_llm_provider=custom_llm_provider @@ -4265,6 +4276,8 @@ def get_supported_openai_params( return litellm.NvidiaNimConfig().get_supported_openai_params(model=model) elif custom_llm_provider == "cerebras": return litellm.CerebrasConfig().get_supported_openai_params(model=model) + elif custom_llm_provider == "ai21_chat": + return litellm.AI21ChatConfig().get_supported_openai_params(model=model) elif custom_llm_provider == "volcengine": return litellm.VolcEngineConfig().get_supported_openai_params(model=model) elif custom_llm_provider == "groq": @@ -4699,6 +4712,13 @@ def get_llm_provider( or "https://api.cerebras.ai/v1" ) # type: ignore dynamic_api_key = api_key or get_secret("CEREBRAS_API_KEY") + elif custom_llm_provider == "ai21_chat": + api_base = ( + api_base + or get_secret("AI21_API_BASE") + or "https://api.ai21.com/studio/v1" + ) # type: ignore + dynamic_api_key = api_key or get_secret("AI21_API_KEY") elif custom_llm_provider == "volcengine": # volcengine is openai compatible, we just need to set this to custom_openai and have the api_base be https://api.endpoints.anyscale.com/v1 api_base = ( @@ -4852,6 +4872,9 @@ def get_llm_provider( elif endpoint == "https://api.cerebras.ai/v1": custom_llm_provider = "cerebras" dynamic_api_key = get_secret("CEREBRAS_API_KEY") + elif endpoint == "https://api.ai21.com/studio/v1": + custom_llm_provider = "ai21_chat" + dynamic_api_key = get_secret("AI21_API_KEY") elif endpoint == "https://codestral.mistral.ai/v1": custom_llm_provider = "codestral" dynamic_api_key = get_secret("CODESTRAL_API_KEY") @@ -5782,6 +5805,11 @@ def validate_environment( keys_in_environment = True else: missing_keys.append("CEREBRAS_API_KEY") + elif custom_llm_provider == "ai21_chat": + if "AI21_API_KEY" in os.environ: + keys_in_environment = True + else: + missing_keys.append("AI21_API_KEY") elif custom_llm_provider == "volcengine": if "VOLCENGINE_API_KEY" in os.environ: keys_in_environment = True @@ -6193,7 +6221,10 @@ def convert_to_model_response_object( if "model" in response_object: if model_response_object.model is None: model_response_object.model = response_object["model"] - elif "/" in model_response_object.model: + elif ( + "/" in model_response_object.model + and response_object["model"] is not None + ): openai_compatible_provider = model_response_object.model.split("/")[ 0 ]