mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-25 18:54:30 +00:00
[Feat] Add reasoning_effort support for xai/grok-3-mini-beta
model family (#9932)
* add BaseReasoningEffortTests * BaseReasoningLLMTests * fix test rename * docs update thinking / reasoning content docs
This commit is contained in:
parent
f9ce754817
commit
57bc03b30b
4 changed files with 99 additions and 22 deletions
|
@ -15,6 +15,7 @@ Supported Providers:
|
||||||
- Bedrock (Anthropic + Deepseek) (`bedrock/`)
|
- Bedrock (Anthropic + Deepseek) (`bedrock/`)
|
||||||
- Vertex AI (Anthropic) (`vertexai/`)
|
- Vertex AI (Anthropic) (`vertexai/`)
|
||||||
- OpenRouter (`openrouter/`)
|
- OpenRouter (`openrouter/`)
|
||||||
|
- XAI (`xai/`)
|
||||||
|
|
||||||
LiteLLM will standardize the `reasoning_content` in the response and `thinking_blocks` in the assistant message.
|
LiteLLM will standardize the `reasoning_content` in the response and `thinking_blocks` in the assistant message.
|
||||||
|
|
||||||
|
|
|
@ -1,5 +1,7 @@
|
||||||
from typing import List, Optional, Tuple
|
from typing import List, Optional, Tuple
|
||||||
|
|
||||||
|
import litellm
|
||||||
|
from litellm._logging import verbose_logger
|
||||||
from litellm.litellm_core_utils.prompt_templates.common_utils import (
|
from litellm.litellm_core_utils.prompt_templates.common_utils import (
|
||||||
strip_name_from_messages,
|
strip_name_from_messages,
|
||||||
)
|
)
|
||||||
|
@ -12,6 +14,10 @@ XAI_API_BASE = "https://api.x.ai/v1"
|
||||||
|
|
||||||
|
|
||||||
class XAIChatConfig(OpenAIGPTConfig):
|
class XAIChatConfig(OpenAIGPTConfig):
|
||||||
|
@property
|
||||||
|
def custom_llm_provider(self) -> Optional[str]:
|
||||||
|
return "xai"
|
||||||
|
|
||||||
def _get_openai_compatible_provider_info(
|
def _get_openai_compatible_provider_info(
|
||||||
self, api_base: Optional[str], api_key: Optional[str]
|
self, api_base: Optional[str], api_key: Optional[str]
|
||||||
) -> Tuple[Optional[str], Optional[str]]:
|
) -> Tuple[Optional[str], Optional[str]]:
|
||||||
|
@ -20,7 +26,7 @@ class XAIChatConfig(OpenAIGPTConfig):
|
||||||
return api_base, dynamic_api_key
|
return api_base, dynamic_api_key
|
||||||
|
|
||||||
def get_supported_openai_params(self, model: str) -> list:
|
def get_supported_openai_params(self, model: str) -> list:
|
||||||
return [
|
base_openai_params = [
|
||||||
"frequency_penalty",
|
"frequency_penalty",
|
||||||
"logit_bias",
|
"logit_bias",
|
||||||
"logprobs",
|
"logprobs",
|
||||||
|
@ -39,6 +45,15 @@ class XAIChatConfig(OpenAIGPTConfig):
|
||||||
"top_p",
|
"top_p",
|
||||||
"user",
|
"user",
|
||||||
]
|
]
|
||||||
|
try:
|
||||||
|
if litellm.supports_reasoning(
|
||||||
|
model=model, custom_llm_provider=self.custom_llm_provider
|
||||||
|
):
|
||||||
|
base_openai_params.append("reasoning_effort")
|
||||||
|
except Exception as e:
|
||||||
|
verbose_logger.debug(f"Error checking if model supports reasoning: {e}")
|
||||||
|
|
||||||
|
return base_openai_params
|
||||||
|
|
||||||
def map_openai_params(
|
def map_openai_params(
|
||||||
self,
|
self,
|
||||||
|
|
|
@ -23,7 +23,7 @@ from litellm.utils import (
|
||||||
)
|
)
|
||||||
from litellm.main import stream_chunk_builder
|
from litellm.main import stream_chunk_builder
|
||||||
from typing import Union
|
from typing import Union
|
||||||
|
from litellm.types.utils import Usage, ModelResponse
|
||||||
# test_example.py
|
# test_example.py
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from openai import OpenAI
|
from openai import OpenAI
|
||||||
|
@ -1398,4 +1398,77 @@ class BaseAnthropicChatTest(ABC):
|
||||||
)
|
)
|
||||||
assert optional_params["thinking"] == {"type": "enabled", "budget_tokens": 4096}
|
assert optional_params["thinking"] == {"type": "enabled", "budget_tokens": 4096}
|
||||||
|
|
||||||
assert "reasoning_effort" not in optional_params
|
assert "reasoning_effort" not in optional_params
|
||||||
|
|
||||||
|
|
||||||
|
class BaseReasoningLLMTests(ABC):
|
||||||
|
"""
|
||||||
|
Base class for testing reasoning llms
|
||||||
|
|
||||||
|
- test that the responses contain reasoning_content
|
||||||
|
- test that the usage contains reasoning_tokens
|
||||||
|
"""
|
||||||
|
@abstractmethod
|
||||||
|
def get_base_completion_call_args(self) -> dict:
|
||||||
|
"""Must return the base completion call args"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@property
|
||||||
|
def completion_function(self):
|
||||||
|
return litellm.completion
|
||||||
|
|
||||||
|
|
||||||
|
def test_non_streaming_reasoning_effort(self):
|
||||||
|
"""
|
||||||
|
Base test for non-streaming reasoning effort
|
||||||
|
|
||||||
|
- Assert that `reasoning_content` is not None from response message
|
||||||
|
- Assert that `reasoning_tokens` is greater than 0 from usage
|
||||||
|
"""
|
||||||
|
litellm._turn_on_debug()
|
||||||
|
base_completion_call_args = self.get_base_completion_call_args()
|
||||||
|
response: ModelResponse = self.completion_function(**base_completion_call_args, reasoning_effort="low")
|
||||||
|
|
||||||
|
# user gets `reasoning_content` in the response message
|
||||||
|
assert response.choices[0].message.reasoning_content is not None
|
||||||
|
assert isinstance(response.choices[0].message.reasoning_content, str)
|
||||||
|
|
||||||
|
# user get `reasoning_tokens`
|
||||||
|
assert response.usage.completion_tokens_details.reasoning_tokens > 0
|
||||||
|
|
||||||
|
|
||||||
|
def test_streaming_reasoning_effort(self):
|
||||||
|
"""
|
||||||
|
Base test for streaming reasoning effort
|
||||||
|
|
||||||
|
- Assert that `reasoning_content` is not None from streaming response
|
||||||
|
- Assert that `reasoning_tokens` is greater than 0 from usage
|
||||||
|
"""
|
||||||
|
#litellm._turn_on_debug()
|
||||||
|
base_completion_call_args = self.get_base_completion_call_args()
|
||||||
|
response: CustomStreamWrapper = self.completion_function(
|
||||||
|
**base_completion_call_args,
|
||||||
|
reasoning_effort="low",
|
||||||
|
stream=True,
|
||||||
|
stream_options={
|
||||||
|
"include_usage": True
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
resoning_content: str = ""
|
||||||
|
usage: Usage = None
|
||||||
|
for chunk in response:
|
||||||
|
print(chunk)
|
||||||
|
if hasattr(chunk.choices[0].delta, "reasoning_content"):
|
||||||
|
resoning_content += chunk.choices[0].delta.reasoning_content
|
||||||
|
if hasattr(chunk, "usage"):
|
||||||
|
usage = chunk.usage
|
||||||
|
|
||||||
|
assert resoning_content is not None
|
||||||
|
assert len(resoning_content) > 0
|
||||||
|
|
||||||
|
print(f"usage: {usage}")
|
||||||
|
assert usage.completion_tokens_details.reasoning_tokens > 0
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -18,6 +18,7 @@ from litellm import Choices, Message, ModelResponse, EmbeddingResponse, Usage
|
||||||
from litellm import completion
|
from litellm import completion
|
||||||
from unittest.mock import patch
|
from unittest.mock import patch
|
||||||
from litellm.llms.xai.chat.transformation import XAIChatConfig, XAI_API_BASE
|
from litellm.llms.xai.chat.transformation import XAIChatConfig, XAI_API_BASE
|
||||||
|
from base_llm_unit_tests import BaseReasoningLLMTests
|
||||||
|
|
||||||
|
|
||||||
def test_xai_chat_config_get_openai_compatible_provider_info():
|
def test_xai_chat_config_get_openai_compatible_provider_info():
|
||||||
|
@ -162,22 +163,9 @@ def test_xai_message_name_filtering():
|
||||||
assert response.choices[0].message.content is not None
|
assert response.choices[0].message.content is not None
|
||||||
|
|
||||||
|
|
||||||
def test_xai_reasoning_effort():
|
class TestXAIReasoningEffort(BaseReasoningLLMTests):
|
||||||
litellm._turn_on_debug()
|
def get_base_completion_call_args(self):
|
||||||
messages = [
|
return {
|
||||||
{
|
"model": "xai/grok-3-mini-beta",
|
||||||
"role": "system",
|
"messages": [{"role": "user", "content": "Hello"}],
|
||||||
"content": "*I press the green button*",
|
}
|
||||||
"name": "example_user"
|
|
||||||
},
|
|
||||||
{"role": "user", "content": "Hello", "name": "John"},
|
|
||||||
{"role": "assistant", "content": "Hello", "name": "Jane"},
|
|
||||||
]
|
|
||||||
response = completion(
|
|
||||||
model="xai/grok-3",
|
|
||||||
messages=messages,
|
|
||||||
reasoning_effort="high",
|
|
||||||
stream=True,
|
|
||||||
)
|
|
||||||
for chunk in response:
|
|
||||||
print(chunk)
|
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue