From 7cb95bcc965beece5eb8bec1fb9d7cef248d992e Mon Sep 17 00:00:00 2001 From: Ishaan Jaff Date: Mon, 21 Apr 2025 22:39:40 -0700 Subject: [PATCH] [Bug Fix] caching does not account for thinking or reasoning_effort config (#10140) * _get_litellm_supported_chat_completion_kwargs * test caching with thinking --- .../litellm_core_utils/model_param_helper.py | 17 +++++- tests/local_testing/test_caching.py | 61 +++++++++++++++++++ 2 files changed, 75 insertions(+), 3 deletions(-) diff --git a/litellm/litellm_core_utils/model_param_helper.py b/litellm/litellm_core_utils/model_param_helper.py index c96b4a3f5b..91f2f1341c 100644 --- a/litellm/litellm_core_utils/model_param_helper.py +++ b/litellm/litellm_core_utils/model_param_helper.py @@ -75,6 +75,10 @@ class ModelParamHelper: combined_kwargs = combined_kwargs.difference(exclude_kwargs) return combined_kwargs + @staticmethod + def get_litellm_provider_specific_params_for_chat_params() -> Set[str]: + return set(["thinking"]) + @staticmethod def _get_litellm_supported_chat_completion_kwargs() -> Set[str]: """ @@ -82,11 +86,18 @@ class ModelParamHelper: This follows the OpenAI API Spec """ - all_chat_completion_kwargs = set( + non_streaming_params: Set[str] = set( getattr(CompletionCreateParamsNonStreaming, "__annotations__", {}).keys() - ).union( - set(getattr(CompletionCreateParamsStreaming, "__annotations__", {}).keys()) ) + streaming_params: Set[str] = set( + getattr(CompletionCreateParamsStreaming, "__annotations__", {}).keys() + ) + litellm_provider_specific_params: Set[str] = ( + ModelParamHelper.get_litellm_provider_specific_params_for_chat_params() + ) + all_chat_completion_kwargs: Set[str] = non_streaming_params.union( + streaming_params + ).union(litellm_provider_specific_params) return all_chat_completion_kwargs @staticmethod diff --git a/tests/local_testing/test_caching.py b/tests/local_testing/test_caching.py index df0b625d7d..8c12f3fd9b 100644 --- a/tests/local_testing/test_caching.py +++ b/tests/local_testing/test_caching.py @@ -2608,3 +2608,64 @@ def test_caching_with_reasoning_content(): print(f"response 2: {response_2.model_dump_json(indent=4)}") assert response_2._hidden_params["cache_hit"] == True assert response_2.choices[0].message.reasoning_content is not None + + +def test_caching_reasoning_args_miss(): # test in memory cache + try: + #litellm._turn_on_debug() + litellm.set_verbose = True + litellm.cache = Cache( + ) + response1 = completion(model="claude-3-7-sonnet-latest", messages=messages, caching=True, reasoning_effort="low", mock_response="My response") + response2 = completion(model="claude-3-7-sonnet-latest", messages=messages, caching=True, mock_response="My response") + print(f"response1: {response1}") + print(f"response2: {response2}") + assert response1.id != response2.id + except Exception as e: + print(f"error occurred: {traceback.format_exc()}") + pytest.fail(f"Error occurred: {e}") + +def test_caching_reasoning_args_hit(): # test in memory cache + try: + #litellm._turn_on_debug() + litellm.set_verbose = True + litellm.cache = Cache( + ) + response1 = completion(model="claude-3-7-sonnet-latest", messages=messages, caching=True, reasoning_effort="low", mock_response="My response") + response2 = completion(model="claude-3-7-sonnet-latest", messages=messages, caching=True, reasoning_effort="low", mock_response="My response") + print(f"response1: {response1}") + print(f"response2: {response2}") + assert response1.id == response2.id + except Exception as e: + print(f"error occurred: {traceback.format_exc()}") + pytest.fail(f"Error occurred: {e}") + +def test_caching_thinking_args_miss(): # test in memory cache + try: + #litellm._turn_on_debug() + litellm.set_verbose = True + litellm.cache = Cache( + ) + response1 = completion(model="claude-3-7-sonnet-latest", messages=messages, caching=True, thinking={"type": "enabled", "budget_tokens": 1024}, mock_response="My response") + response2 = completion(model="claude-3-7-sonnet-latest", messages=messages, caching=True, mock_response="My response") + print(f"response1: {response1}") + print(f"response2: {response2}") + assert response1.id != response2.id + except Exception as e: + print(f"error occurred: {traceback.format_exc()}") + pytest.fail(f"Error occurred: {e}") + +def test_caching_thinking_args_hit(): # test in memory cache + try: + #litellm._turn_on_debug() + litellm.set_verbose = True + litellm.cache = Cache( + ) + response1 = completion(model="claude-3-7-sonnet-latest", messages=messages, caching=True, thinking={"type": "enabled", "budget_tokens": 1024}, mock_response="My response" ) + response2 = completion(model="claude-3-7-sonnet-latest", messages=messages, caching=True, thinking={"type": "enabled", "budget_tokens": 1024}, mock_response="My response") + print(f"response1: {response1}") + print(f"response2: {response2}") + assert response1.id == response2.id + except Exception as e: + print(f"error occurred: {traceback.format_exc()}") + pytest.fail(f"Error occurred: {e}") \ No newline at end of file