(feat) add XAI ChatCompletion Support (#6373)

* init commit for XAI

* add full logic for xai chat completion

* test_completion_xai

* docs xAI

* add xai/grok-beta

* test_xai_chat_config_get_openai_compatible_provider_info

* test_xai_chat_config_map_openai_params

* add xai streaming test
This commit is contained in:
Ishaan Jaff 2024-11-01 20:37:09 +05:30 committed by GitHub
parent 9545b0e5cd
commit 5652c375b3
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
9 changed files with 400 additions and 0 deletions

View file

@ -0,0 +1,146 @@
import Tabs from '@theme/Tabs';
import TabItem from '@theme/TabItem';
# XAI
https://docs.x.ai/docs
:::tip
**We support ALL XAI models, just set `model=xai/<any-model-on-xai>` as a prefix when sending litellm requests**
:::
## API Key
```python
# env variable
os.environ['XAI_API_KEY']
```
## Sample Usage
```python
from litellm import completion
import os
os.environ['XAI_API_KEY'] = ""
response = completion(
model="xai/grok-beta",
messages=[
{
"role": "user",
"content": "What's the weather like in Boston today in Fahrenheit?",
}
],
max_tokens=10,
response_format={ "type": "json_object" },
seed=123,
stop=["\n\n"],
temperature=0.2,
top_p=0.9,
tool_choice="auto",
tools=[],
user="user",
)
print(response)
```
## Sample Usage - Streaming
```python
from litellm import completion
import os
os.environ['XAI_API_KEY'] = ""
response = completion(
model="xai/grok-beta",
messages=[
{
"role": "user",
"content": "What's the weather like in Boston today in Fahrenheit?",
}
],
stream=True,
max_tokens=10,
response_format={ "type": "json_object" },
seed=123,
stop=["\n\n"],
temperature=0.2,
top_p=0.9,
tool_choice="auto",
tools=[],
user="user",
)
for chunk in response:
print(chunk)
```
## Usage with LiteLLM Proxy Server
Here's how to call a XAI model with the LiteLLM Proxy Server
1. Modify the config.yaml
```yaml
model_list:
- model_name: my-model
litellm_params:
model: xai/<your-model-name> # add xai/ prefix to route as XAI provider
api_key: api-key # api key to send your model
```
2. Start the proxy
```bash
$ litellm --config /path/to/config.yaml
```
3. Send Request to LiteLLM Proxy Server
<Tabs>
<TabItem value="openai" label="OpenAI Python v1.0.0+">
```python
import openai
client = openai.OpenAI(
api_key="sk-1234", # pass litellm proxy key, if you're using virtual keys
base_url="http://0.0.0.0:4000" # litellm-proxy-base url
)
response = client.chat.completions.create(
model="my-model",
messages = [
{
"role": "user",
"content": "what llm are you"
}
],
)
print(response)
```
</TabItem>
<TabItem value="curl" label="curl">
```shell
curl --location 'http://0.0.0.0:4000/chat/completions' \
--header 'Authorization: Bearer sk-1234' \
--header 'Content-Type: application/json' \
--data '{
"model": "my-model",
"messages": [
{
"role": "user",
"content": "what llm are you"
}
],
}'
```
</TabItem>
</Tabs>

View file

@ -155,6 +155,7 @@ const sidebars = {
"providers/watsonx",
"providers/predibase",
"providers/nvidia_nim",
"providers/xai",
"providers/cerebras",
"providers/volcano",
"providers/triton-inference-server",

View file

@ -490,6 +490,7 @@ openai_compatible_endpoints: List = [
"app.empower.dev/api/v1",
"inference.friendli.ai/v1",
"api.sambanova.ai/v1",
"api.x.ai/v1",
]
# this is maintained for Exception Mapping
@ -507,6 +508,7 @@ openai_compatible_providers: List = [
"deepinfra",
"perplexity",
"xinference",
"xai",
"together_ai",
"fireworks_ai",
"empower",
@ -717,6 +719,7 @@ class LlmProviders(str, Enum):
OPENAI = "openai"
OPENAI_LIKE = "openai_like" # embedding only
JINA_AI = "jina_ai"
XAI = "xai"
CUSTOM_OPENAI = "custom_openai"
TEXT_COMPLETION_OPENAI = "text-completion-openai"
COHERE = "cohere"
@ -1021,6 +1024,7 @@ from .llms.fireworks_ai.embed.fireworks_ai_transformation import (
FireworksAIEmbeddingConfig,
)
from .llms.jina_ai.embedding.transformation import JinaAIEmbeddingConfig
from .llms.xai.chat.xai_transformation import XAIChatConfig
from .llms.volcengine import VolcEngineConfig
from .llms.text_completion_codestral import MistralTextCompletionConfig
from .llms.AzureOpenAI.azure import (

View file

@ -480,6 +480,13 @@ def _get_openai_compatible_provider_info( # noqa: PLR0915
) = litellm.JinaAIEmbeddingConfig()._get_openai_compatible_provider_info(
api_base, api_key
)
elif custom_llm_provider == "xai":
(
api_base,
dynamic_api_key,
) = litellm.XAIChatConfig()._get_openai_compatible_provider_info(
api_base, api_key
)
elif custom_llm_provider == "voyage":
# voyage is openai compatible, we just need to set this to custom_openai and have the api_base be https://api.voyageai.com/v1
api_base = (

View file

@ -0,0 +1,56 @@
import types
from typing import Literal, Optional, Tuple, Union
from litellm.secret_managers.main import get_secret_str
from ...OpenAI.chat.gpt_transformation import OpenAIGPTConfig
XAI_API_BASE = "https://api.x.ai/v1"
class XAIChatConfig(OpenAIGPTConfig):
def _get_openai_compatible_provider_info(
self, api_base: Optional[str], api_key: Optional[str]
) -> Tuple[Optional[str], Optional[str]]:
api_base = api_base or get_secret_str("XAI_API_BASE") or XAI_API_BASE # type: ignore
dynamic_api_key = api_key or get_secret_str("XAI_API_KEY")
return api_base, dynamic_api_key
def get_supported_openai_params(self, model: str) -> list:
return [
"frequency_penalty",
"logit_bias",
"logprobs",
"max_tokens",
"messages",
"model",
"n",
"presence_penalty",
"response_format",
"seed",
"stop",
"stream",
"stream_options",
"temperature",
"tool_choice",
"tools",
"top_logprobs",
"top_p",
"user",
]
def map_openai_params(
self,
non_default_params: dict,
optional_params: dict,
model: str,
drop_params: bool = False,
) -> dict:
supported_openai_params = self.get_supported_openai_params(model=model)
for param, value in non_default_params.items():
if param == "max_completion_tokens":
optional_params["max_tokens"] = value
elif param in supported_openai_params:
if value is not None:
optional_params[param] = value
return optional_params

View file

@ -1502,6 +1502,17 @@
"mode": "completion",
"source": "https://docs.mistral.ai/capabilities/code_generation/"
},
"xai/grok-beta": {
"max_tokens": 131072,
"max_input_tokens": 131072,
"max_output_tokens": 131072,
"input_cost_per_token": 0.000005,
"output_cost_per_token": 0.000015,
"litellm_provider": "xai",
"mode": "chat",
"supports_function_calling": true,
"supports_vision": true
},
"deepseek-coder": {
"max_tokens": 4096,
"max_input_tokens": 128000,

View file

@ -2680,6 +2680,7 @@ def get_optional_params( # noqa: PLR0915
and custom_llm_provider != "groq"
and custom_llm_provider != "nvidia_nim"
and custom_llm_provider != "cerebras"
and custom_llm_provider != "xai"
and custom_llm_provider != "ai21_chat"
and custom_llm_provider != "volcengine"
and custom_llm_provider != "deepseek"
@ -3456,6 +3457,16 @@ def get_optional_params( # noqa: PLR0915
optional_params=optional_params,
model=model,
)
elif custom_llm_provider == "xai":
supported_params = get_supported_openai_params(
model=model, custom_llm_provider=custom_llm_provider
)
_check_valid_arg(supported_params=supported_params)
optional_params = litellm.XAIChatConfig().map_openai_params(
model=model,
non_default_params=non_default_params,
optional_params=optional_params,
)
elif custom_llm_provider == "ai21_chat":
supported_params = get_supported_openai_params(
model=model, custom_llm_provider=custom_llm_provider
@ -4184,6 +4195,8 @@ def get_supported_openai_params( # noqa: PLR0915
return litellm.nvidiaNimEmbeddingConfig.get_supported_openai_params()
elif custom_llm_provider == "cerebras":
return litellm.CerebrasConfig().get_supported_openai_params(model=model)
elif custom_llm_provider == "xai":
return litellm.XAIChatConfig().get_supported_openai_params(model=model)
elif custom_llm_provider == "ai21_chat":
return litellm.AI21ChatConfig().get_supported_openai_params(model=model)
elif custom_llm_provider == "volcengine":
@ -5344,6 +5357,11 @@ def validate_environment( # noqa: PLR0915
keys_in_environment = True
else:
missing_keys.append("CEREBRAS_API_KEY")
elif custom_llm_provider == "xai":
if "XAI_API_KEY" in os.environ:
keys_in_environment = True
else:
missing_keys.append("XAI_API_KEY")
elif custom_llm_provider == "ai21_chat":
if "AI21_API_KEY" in os.environ:
keys_in_environment = True

View file

@ -1502,6 +1502,17 @@
"mode": "completion",
"source": "https://docs.mistral.ai/capabilities/code_generation/"
},
"xai/grok-beta": {
"max_tokens": 131072,
"max_input_tokens": 131072,
"max_output_tokens": 131072,
"input_cost_per_token": 0.000005,
"output_cost_per_token": 0.000015,
"litellm_provider": "xai",
"mode": "chat",
"supports_function_calling": true,
"supports_vision": true
},
"deepseek-coder": {
"max_tokens": 4096,
"max_input_tokens": 128000,

View file

@ -0,0 +1,146 @@
import json
import os
import sys
from datetime import datetime
from unittest.mock import AsyncMock
sys.path.insert(
0, os.path.abspath("../..")
) # Adds the parent directory to the system path
import httpx
import pytest
from respx import MockRouter
import litellm
from litellm import Choices, Message, ModelResponse, EmbeddingResponse, Usage
from litellm import completion
from unittest.mock import patch
from litellm.llms.xai.chat.xai_transformation import XAIChatConfig, XAI_API_BASE
def test_xai_chat_config_get_openai_compatible_provider_info():
config = XAIChatConfig()
# Test with default values
api_base, api_key = config._get_openai_compatible_provider_info(
api_base=None, api_key=None
)
assert api_base == XAI_API_BASE
assert api_key == os.environ.get("XAI_API_KEY")
# Test with custom API key
custom_api_key = "test_api_key"
api_base, api_key = config._get_openai_compatible_provider_info(
api_base=None, api_key=custom_api_key
)
assert api_base == XAI_API_BASE
assert api_key == custom_api_key
# Test with custom environment variables for api_base and api_key
with patch.dict(
"os.environ",
{"XAI_API_BASE": "https://env.x.ai/v1", "XAI_API_KEY": "env_api_key"},
):
api_base, api_key = config._get_openai_compatible_provider_info(None, None)
assert api_base == "https://env.x.ai/v1"
assert api_key == "env_api_key"
def test_xai_chat_config_map_openai_params():
"""
XAI is OpenAI compatible*
Does not support all OpenAI parameters:
- max_completion_tokens -> max_tokens
"""
config = XAIChatConfig()
# Test mapping of parameters
non_default_params = {
"max_completion_tokens": 100,
"frequency_penalty": 0.5,
"logit_bias": {"50256": -100},
"logprobs": 5,
"messages": [{"role": "user", "content": "Hello"}],
"model": "xai/grok-beta",
"n": 2,
"presence_penalty": 0.2,
"response_format": {"type": "json_object"},
"seed": 42,
"stop": ["END"],
"stream": True,
"stream_options": {},
"temperature": 0.7,
"tool_choice": "auto",
"tools": [{"type": "function", "function": {"name": "get_weather"}}],
"top_logprobs": 3,
"top_p": 0.9,
"user": "test_user",
"unsupported_param": "value",
}
optional_params = {}
model = "xai/grok-beta"
result = config.map_openai_params(non_default_params, optional_params, model)
# Assert all supported parameters are present in the result
assert result["max_tokens"] == 100 # max_completion_tokens -> max_tokens
assert result["frequency_penalty"] == 0.5
assert result["logit_bias"] == {"50256": -100}
assert result["logprobs"] == 5
assert result["messages"] == [{"role": "user", "content": "Hello"}]
assert result["model"] == "xai/grok-beta"
assert result["n"] == 2
assert result["presence_penalty"] == 0.2
assert result["response_format"] == {"type": "json_object"}
assert result["seed"] == 42
assert result["stop"] == ["END"]
assert result["stream"] is True
assert result["stream_options"] == {}
assert result["temperature"] == 0.7
assert result["tool_choice"] == "auto"
assert result["tools"] == [
{"type": "function", "function": {"name": "get_weather"}}
]
assert result["top_logprobs"] == 3
assert result["top_p"] == 0.9
assert result["user"] == "test_user"
# Assert unsupported parameter is not in the result
assert "unsupported_param" not in result
@pytest.mark.parametrize("stream", [False, True])
def test_completion_xai(stream):
try:
litellm.set_verbose = True
messages = [
{"role": "system", "content": "You're a good bot"},
{
"role": "user",
"content": "Hey",
},
]
response = completion(
model="xai/grok-beta",
messages=messages,
stream=stream,
)
print(response)
if stream is True:
for chunk in response:
print(chunk)
assert chunk is not None
assert isinstance(chunk, litellm.ModelResponse)
assert isinstance(chunk.choices[0], litellm.utils.StreamingChoices)
else:
assert response is not None
assert isinstance(response, litellm.ModelResponse)
assert response.choices[0].message.content is not None
except Exception as e:
pytest.fail(f"Error occurred: {e}")