diff --git a/docs/my-website/docs/reasoning_content.md b/docs/my-website/docs/reasoning_content.md index b384eb92ac..86ef58bd68 100644 --- a/docs/my-website/docs/reasoning_content.md +++ b/docs/my-website/docs/reasoning_content.md @@ -15,6 +15,7 @@ Supported Providers: - Bedrock (Anthropic + Deepseek) (`bedrock/`) - Vertex AI (Anthropic) (`vertexai/`) - OpenRouter (`openrouter/`) +- XAI (`xai/`) LiteLLM will standardize the `reasoning_content` in the response and `thinking_blocks` in the assistant message. diff --git a/litellm/llms/xai/chat/transformation.py b/litellm/llms/xai/chat/transformation.py index 614509020e..804abe30f0 100644 --- a/litellm/llms/xai/chat/transformation.py +++ b/litellm/llms/xai/chat/transformation.py @@ -1,5 +1,7 @@ from typing import List, Optional, Tuple +import litellm +from litellm._logging import verbose_logger from litellm.litellm_core_utils.prompt_templates.common_utils import ( strip_name_from_messages, ) @@ -12,6 +14,10 @@ XAI_API_BASE = "https://api.x.ai/v1" class XAIChatConfig(OpenAIGPTConfig): + @property + def custom_llm_provider(self) -> Optional[str]: + return "xai" + def _get_openai_compatible_provider_info( self, api_base: Optional[str], api_key: Optional[str] ) -> Tuple[Optional[str], Optional[str]]: @@ -20,7 +26,7 @@ class XAIChatConfig(OpenAIGPTConfig): return api_base, dynamic_api_key def get_supported_openai_params(self, model: str) -> list: - return [ + base_openai_params = [ "frequency_penalty", "logit_bias", "logprobs", @@ -39,6 +45,15 @@ class XAIChatConfig(OpenAIGPTConfig): "top_p", "user", ] + try: + if litellm.supports_reasoning( + model=model, custom_llm_provider=self.custom_llm_provider + ): + base_openai_params.append("reasoning_effort") + except Exception as e: + verbose_logger.debug(f"Error checking if model supports reasoning: {e}") + + return base_openai_params def map_openai_params( self, diff --git a/tests/llm_translation/base_llm_unit_tests.py b/tests/llm_translation/base_llm_unit_tests.py index 9e21260410..bd3627f7d4 100644 --- a/tests/llm_translation/base_llm_unit_tests.py +++ b/tests/llm_translation/base_llm_unit_tests.py @@ -23,7 +23,7 @@ from litellm.utils import ( ) from litellm.main import stream_chunk_builder from typing import Union - +from litellm.types.utils import Usage, ModelResponse # test_example.py from abc import ABC, abstractmethod from openai import OpenAI @@ -1398,4 +1398,77 @@ class BaseAnthropicChatTest(ABC): ) assert optional_params["thinking"] == {"type": "enabled", "budget_tokens": 4096} - assert "reasoning_effort" not in optional_params \ No newline at end of file + assert "reasoning_effort" not in optional_params + + +class BaseReasoningLLMTests(ABC): + """ + Base class for testing reasoning llms + + - test that the responses contain reasoning_content + - test that the usage contains reasoning_tokens + """ + @abstractmethod + def get_base_completion_call_args(self) -> dict: + """Must return the base completion call args""" + pass + + @property + def completion_function(self): + return litellm.completion + + + def test_non_streaming_reasoning_effort(self): + """ + Base test for non-streaming reasoning effort + + - Assert that `reasoning_content` is not None from response message + - Assert that `reasoning_tokens` is greater than 0 from usage + """ + litellm._turn_on_debug() + base_completion_call_args = self.get_base_completion_call_args() + response: ModelResponse = self.completion_function(**base_completion_call_args, reasoning_effort="low") + + # user gets `reasoning_content` in the response message + assert response.choices[0].message.reasoning_content is not None + assert isinstance(response.choices[0].message.reasoning_content, str) + + # user get `reasoning_tokens` + assert response.usage.completion_tokens_details.reasoning_tokens > 0 + + + def test_streaming_reasoning_effort(self): + """ + Base test for streaming reasoning effort + + - Assert that `reasoning_content` is not None from streaming response + - Assert that `reasoning_tokens` is greater than 0 from usage + """ + #litellm._turn_on_debug() + base_completion_call_args = self.get_base_completion_call_args() + response: CustomStreamWrapper = self.completion_function( + **base_completion_call_args, + reasoning_effort="low", + stream=True, + stream_options={ + "include_usage": True + } + ) + + resoning_content: str = "" + usage: Usage = None + for chunk in response: + print(chunk) + if hasattr(chunk.choices[0].delta, "reasoning_content"): + resoning_content += chunk.choices[0].delta.reasoning_content + if hasattr(chunk, "usage"): + usage = chunk.usage + + assert resoning_content is not None + assert len(resoning_content) > 0 + + print(f"usage: {usage}") + assert usage.completion_tokens_details.reasoning_tokens > 0 + + + \ No newline at end of file diff --git a/tests/llm_translation/test_xai.py b/tests/llm_translation/test_xai.py index afe4a3c0b9..419ace5686 100644 --- a/tests/llm_translation/test_xai.py +++ b/tests/llm_translation/test_xai.py @@ -18,6 +18,7 @@ from litellm import Choices, Message, ModelResponse, EmbeddingResponse, Usage from litellm import completion from unittest.mock import patch from litellm.llms.xai.chat.transformation import XAIChatConfig, XAI_API_BASE +from base_llm_unit_tests import BaseReasoningLLMTests def test_xai_chat_config_get_openai_compatible_provider_info(): @@ -162,22 +163,9 @@ def test_xai_message_name_filtering(): assert response.choices[0].message.content is not None -def test_xai_reasoning_effort(): - litellm._turn_on_debug() - messages = [ - { - "role": "system", - "content": "*I press the green button*", - "name": "example_user" - }, - {"role": "user", "content": "Hello", "name": "John"}, - {"role": "assistant", "content": "Hello", "name": "Jane"}, - ] - response = completion( - model="xai/grok-3", - messages=messages, - reasoning_effort="high", - stream=True, - ) - for chunk in response: - print(chunk) +class TestXAIReasoningEffort(BaseReasoningLLMTests): + def get_base_completion_call_args(self): + return { + "model": "xai/grok-3-mini-beta", + "messages": [{"role": "user", "content": "Hello"}], + }