diff --git a/docs/my-website/docs/providers/xai.md b/docs/my-website/docs/providers/xai.md new file mode 100644 index 000000000..131c02b3d --- /dev/null +++ b/docs/my-website/docs/providers/xai.md @@ -0,0 +1,146 @@ +import Tabs from '@theme/Tabs'; +import TabItem from '@theme/TabItem'; + +# XAI + +https://docs.x.ai/docs + +:::tip + +**We support ALL XAI models, just set `model=xai/` as a prefix when sending litellm requests** + +::: + +## API Key +```python +# env variable +os.environ['XAI_API_KEY'] +``` + +## Sample Usage +```python +from litellm import completion +import os + +os.environ['XAI_API_KEY'] = "" +response = completion( + model="xai/grok-beta", + messages=[ + { + "role": "user", + "content": "What's the weather like in Boston today in Fahrenheit?", + } + ], + max_tokens=10, + response_format={ "type": "json_object" }, + seed=123, + stop=["\n\n"], + temperature=0.2, + top_p=0.9, + tool_choice="auto", + tools=[], + user="user", +) +print(response) +``` + +## Sample Usage - Streaming +```python +from litellm import completion +import os + +os.environ['XAI_API_KEY'] = "" +response = completion( + model="xai/grok-beta", + messages=[ + { + "role": "user", + "content": "What's the weather like in Boston today in Fahrenheit?", + } + ], + stream=True, + max_tokens=10, + response_format={ "type": "json_object" }, + seed=123, + stop=["\n\n"], + temperature=0.2, + top_p=0.9, + tool_choice="auto", + tools=[], + user="user", +) + +for chunk in response: + print(chunk) +``` + + +## Usage with LiteLLM Proxy Server + +Here's how to call a XAI model with the LiteLLM Proxy Server + +1. Modify the config.yaml + + ```yaml + model_list: + - model_name: my-model + litellm_params: + model: xai/ # add xai/ prefix to route as XAI provider + api_key: api-key # api key to send your model + ``` + + +2. Start the proxy + + ```bash + $ litellm --config /path/to/config.yaml + ``` + +3. Send Request to LiteLLM Proxy Server + + + + + + ```python + import openai + client = openai.OpenAI( + api_key="sk-1234", # pass litellm proxy key, if you're using virtual keys + base_url="http://0.0.0.0:4000" # litellm-proxy-base url + ) + + response = client.chat.completions.create( + model="my-model", + messages = [ + { + "role": "user", + "content": "what llm are you" + } + ], + ) + + print(response) + ``` + + + + + ```shell + curl --location 'http://0.0.0.0:4000/chat/completions' \ + --header 'Authorization: Bearer sk-1234' \ + --header 'Content-Type: application/json' \ + --data '{ + "model": "my-model", + "messages": [ + { + "role": "user", + "content": "what llm are you" + } + ], + }' + ``` + + + + + diff --git a/docs/my-website/sidebars.js b/docs/my-website/sidebars.js index f96103142..2df917f56 100644 --- a/docs/my-website/sidebars.js +++ b/docs/my-website/sidebars.js @@ -155,6 +155,7 @@ const sidebars = { "providers/watsonx", "providers/predibase", "providers/nvidia_nim", + "providers/xai", "providers/cerebras", "providers/volcano", "providers/triton-inference-server", diff --git a/litellm/__init__.py b/litellm/__init__.py index b7a0bad8e..2e9aabe63 100644 --- a/litellm/__init__.py +++ b/litellm/__init__.py @@ -490,6 +490,7 @@ openai_compatible_endpoints: List = [ "app.empower.dev/api/v1", "inference.friendli.ai/v1", "api.sambanova.ai/v1", + "api.x.ai/v1", ] # this is maintained for Exception Mapping @@ -507,6 +508,7 @@ openai_compatible_providers: List = [ "deepinfra", "perplexity", "xinference", + "xai", "together_ai", "fireworks_ai", "empower", @@ -717,6 +719,7 @@ class LlmProviders(str, Enum): OPENAI = "openai" OPENAI_LIKE = "openai_like" # embedding only JINA_AI = "jina_ai" + XAI = "xai" CUSTOM_OPENAI = "custom_openai" TEXT_COMPLETION_OPENAI = "text-completion-openai" COHERE = "cohere" @@ -1021,6 +1024,7 @@ from .llms.fireworks_ai.embed.fireworks_ai_transformation import ( FireworksAIEmbeddingConfig, ) from .llms.jina_ai.embedding.transformation import JinaAIEmbeddingConfig +from .llms.xai.chat.xai_transformation import XAIChatConfig from .llms.volcengine import VolcEngineConfig from .llms.text_completion_codestral import MistralTextCompletionConfig from .llms.AzureOpenAI.azure import ( diff --git a/litellm/litellm_core_utils/get_llm_provider_logic.py b/litellm/litellm_core_utils/get_llm_provider_logic.py index 4b64fb828..4ad6c1651 100644 --- a/litellm/litellm_core_utils/get_llm_provider_logic.py +++ b/litellm/litellm_core_utils/get_llm_provider_logic.py @@ -480,6 +480,13 @@ def _get_openai_compatible_provider_info( # noqa: PLR0915 ) = litellm.JinaAIEmbeddingConfig()._get_openai_compatible_provider_info( api_base, api_key ) + elif custom_llm_provider == "xai": + ( + api_base, + dynamic_api_key, + ) = litellm.XAIChatConfig()._get_openai_compatible_provider_info( + api_base, api_key + ) elif custom_llm_provider == "voyage": # voyage is openai compatible, we just need to set this to custom_openai and have the api_base be https://api.voyageai.com/v1 api_base = ( diff --git a/litellm/llms/xai/chat/xai_transformation.py b/litellm/llms/xai/chat/xai_transformation.py new file mode 100644 index 000000000..3bd41ed90 --- /dev/null +++ b/litellm/llms/xai/chat/xai_transformation.py @@ -0,0 +1,56 @@ +import types +from typing import Literal, Optional, Tuple, Union + +from litellm.secret_managers.main import get_secret_str + +from ...OpenAI.chat.gpt_transformation import OpenAIGPTConfig + +XAI_API_BASE = "https://api.x.ai/v1" + + +class XAIChatConfig(OpenAIGPTConfig): + def _get_openai_compatible_provider_info( + self, api_base: Optional[str], api_key: Optional[str] + ) -> Tuple[Optional[str], Optional[str]]: + api_base = api_base or get_secret_str("XAI_API_BASE") or XAI_API_BASE # type: ignore + dynamic_api_key = api_key or get_secret_str("XAI_API_KEY") + return api_base, dynamic_api_key + + def get_supported_openai_params(self, model: str) -> list: + return [ + "frequency_penalty", + "logit_bias", + "logprobs", + "max_tokens", + "messages", + "model", + "n", + "presence_penalty", + "response_format", + "seed", + "stop", + "stream", + "stream_options", + "temperature", + "tool_choice", + "tools", + "top_logprobs", + "top_p", + "user", + ] + + def map_openai_params( + self, + non_default_params: dict, + optional_params: dict, + model: str, + drop_params: bool = False, + ) -> dict: + supported_openai_params = self.get_supported_openai_params(model=model) + for param, value in non_default_params.items(): + if param == "max_completion_tokens": + optional_params["max_tokens"] = value + elif param in supported_openai_params: + if value is not None: + optional_params[param] = value + return optional_params diff --git a/litellm/model_prices_and_context_window_backup.json b/litellm/model_prices_and_context_window_backup.json index 9578ed9ea..6bc873fc9 100644 --- a/litellm/model_prices_and_context_window_backup.json +++ b/litellm/model_prices_and_context_window_backup.json @@ -1502,6 +1502,17 @@ "mode": "completion", "source": "https://docs.mistral.ai/capabilities/code_generation/" }, + "xai/grok-beta": { + "max_tokens": 131072, + "max_input_tokens": 131072, + "max_output_tokens": 131072, + "input_cost_per_token": 0.000005, + "output_cost_per_token": 0.000015, + "litellm_provider": "xai", + "mode": "chat", + "supports_function_calling": true, + "supports_vision": true + }, "deepseek-coder": { "max_tokens": 4096, "max_input_tokens": 128000, diff --git a/litellm/utils.py b/litellm/utils.py index cdf77f1a5..11613e24d 100644 --- a/litellm/utils.py +++ b/litellm/utils.py @@ -2680,6 +2680,7 @@ def get_optional_params( # noqa: PLR0915 and custom_llm_provider != "groq" and custom_llm_provider != "nvidia_nim" and custom_llm_provider != "cerebras" + and custom_llm_provider != "xai" and custom_llm_provider != "ai21_chat" and custom_llm_provider != "volcengine" and custom_llm_provider != "deepseek" @@ -3456,6 +3457,16 @@ def get_optional_params( # noqa: PLR0915 optional_params=optional_params, model=model, ) + elif custom_llm_provider == "xai": + supported_params = get_supported_openai_params( + model=model, custom_llm_provider=custom_llm_provider + ) + _check_valid_arg(supported_params=supported_params) + optional_params = litellm.XAIChatConfig().map_openai_params( + model=model, + non_default_params=non_default_params, + optional_params=optional_params, + ) elif custom_llm_provider == "ai21_chat": supported_params = get_supported_openai_params( model=model, custom_llm_provider=custom_llm_provider @@ -4184,6 +4195,8 @@ def get_supported_openai_params( # noqa: PLR0915 return litellm.nvidiaNimEmbeddingConfig.get_supported_openai_params() elif custom_llm_provider == "cerebras": return litellm.CerebrasConfig().get_supported_openai_params(model=model) + elif custom_llm_provider == "xai": + return litellm.XAIChatConfig().get_supported_openai_params(model=model) elif custom_llm_provider == "ai21_chat": return litellm.AI21ChatConfig().get_supported_openai_params(model=model) elif custom_llm_provider == "volcengine": @@ -5344,6 +5357,11 @@ def validate_environment( # noqa: PLR0915 keys_in_environment = True else: missing_keys.append("CEREBRAS_API_KEY") + elif custom_llm_provider == "xai": + if "XAI_API_KEY" in os.environ: + keys_in_environment = True + else: + missing_keys.append("XAI_API_KEY") elif custom_llm_provider == "ai21_chat": if "AI21_API_KEY" in os.environ: keys_in_environment = True diff --git a/model_prices_and_context_window.json b/model_prices_and_context_window.json index 9578ed9ea..6bc873fc9 100644 --- a/model_prices_and_context_window.json +++ b/model_prices_and_context_window.json @@ -1502,6 +1502,17 @@ "mode": "completion", "source": "https://docs.mistral.ai/capabilities/code_generation/" }, + "xai/grok-beta": { + "max_tokens": 131072, + "max_input_tokens": 131072, + "max_output_tokens": 131072, + "input_cost_per_token": 0.000005, + "output_cost_per_token": 0.000015, + "litellm_provider": "xai", + "mode": "chat", + "supports_function_calling": true, + "supports_vision": true + }, "deepseek-coder": { "max_tokens": 4096, "max_input_tokens": 128000, diff --git a/tests/llm_translation/test_xai.py b/tests/llm_translation/test_xai.py new file mode 100644 index 000000000..3701d39ce --- /dev/null +++ b/tests/llm_translation/test_xai.py @@ -0,0 +1,146 @@ +import json +import os +import sys +from datetime import datetime +from unittest.mock import AsyncMock + +sys.path.insert( + 0, os.path.abspath("../..") +) # Adds the parent directory to the system path + + +import httpx +import pytest +from respx import MockRouter + +import litellm +from litellm import Choices, Message, ModelResponse, EmbeddingResponse, Usage +from litellm import completion +from unittest.mock import patch +from litellm.llms.xai.chat.xai_transformation import XAIChatConfig, XAI_API_BASE + + +def test_xai_chat_config_get_openai_compatible_provider_info(): + config = XAIChatConfig() + + # Test with default values + api_base, api_key = config._get_openai_compatible_provider_info( + api_base=None, api_key=None + ) + assert api_base == XAI_API_BASE + assert api_key == os.environ.get("XAI_API_KEY") + + # Test with custom API key + custom_api_key = "test_api_key" + api_base, api_key = config._get_openai_compatible_provider_info( + api_base=None, api_key=custom_api_key + ) + assert api_base == XAI_API_BASE + assert api_key == custom_api_key + + # Test with custom environment variables for api_base and api_key + with patch.dict( + "os.environ", + {"XAI_API_BASE": "https://env.x.ai/v1", "XAI_API_KEY": "env_api_key"}, + ): + api_base, api_key = config._get_openai_compatible_provider_info(None, None) + assert api_base == "https://env.x.ai/v1" + assert api_key == "env_api_key" + + +def test_xai_chat_config_map_openai_params(): + """ + XAI is OpenAI compatible* + + Does not support all OpenAI parameters: + - max_completion_tokens -> max_tokens + + """ + config = XAIChatConfig() + + # Test mapping of parameters + non_default_params = { + "max_completion_tokens": 100, + "frequency_penalty": 0.5, + "logit_bias": {"50256": -100}, + "logprobs": 5, + "messages": [{"role": "user", "content": "Hello"}], + "model": "xai/grok-beta", + "n": 2, + "presence_penalty": 0.2, + "response_format": {"type": "json_object"}, + "seed": 42, + "stop": ["END"], + "stream": True, + "stream_options": {}, + "temperature": 0.7, + "tool_choice": "auto", + "tools": [{"type": "function", "function": {"name": "get_weather"}}], + "top_logprobs": 3, + "top_p": 0.9, + "user": "test_user", + "unsupported_param": "value", + } + optional_params = {} + model = "xai/grok-beta" + + result = config.map_openai_params(non_default_params, optional_params, model) + + # Assert all supported parameters are present in the result + assert result["max_tokens"] == 100 # max_completion_tokens -> max_tokens + assert result["frequency_penalty"] == 0.5 + assert result["logit_bias"] == {"50256": -100} + assert result["logprobs"] == 5 + assert result["messages"] == [{"role": "user", "content": "Hello"}] + assert result["model"] == "xai/grok-beta" + assert result["n"] == 2 + assert result["presence_penalty"] == 0.2 + assert result["response_format"] == {"type": "json_object"} + assert result["seed"] == 42 + assert result["stop"] == ["END"] + assert result["stream"] is True + assert result["stream_options"] == {} + assert result["temperature"] == 0.7 + assert result["tool_choice"] == "auto" + assert result["tools"] == [ + {"type": "function", "function": {"name": "get_weather"}} + ] + assert result["top_logprobs"] == 3 + assert result["top_p"] == 0.9 + assert result["user"] == "test_user" + + # Assert unsupported parameter is not in the result + assert "unsupported_param" not in result + + +@pytest.mark.parametrize("stream", [False, True]) +def test_completion_xai(stream): + try: + litellm.set_verbose = True + messages = [ + {"role": "system", "content": "You're a good bot"}, + { + "role": "user", + "content": "Hey", + }, + ] + response = completion( + model="xai/grok-beta", + messages=messages, + stream=stream, + ) + print(response) + + if stream is True: + for chunk in response: + print(chunk) + assert chunk is not None + assert isinstance(chunk, litellm.ModelResponse) + assert isinstance(chunk.choices[0], litellm.utils.StreamingChoices) + + else: + assert response is not None + assert isinstance(response, litellm.ModelResponse) + assert response.choices[0].message.content is not None + except Exception as e: + pytest.fail(f"Error occurred: {e}")