[Bug Fix] caching does not account for thinking or reasoning_effort config (#10140)

* _get_litellm_supported_chat_completion_kwargs

* test caching with thinking
This commit is contained in:
Ishaan Jaff 2025-04-21 22:39:40 -07:00 committed by GitHub
parent 104e4cb1bc
commit 7cb95bcc96
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
2 changed files with 75 additions and 3 deletions

View file

@ -75,6 +75,10 @@ class ModelParamHelper:
combined_kwargs = combined_kwargs.difference(exclude_kwargs)
return combined_kwargs
@staticmethod
def get_litellm_provider_specific_params_for_chat_params() -> Set[str]:
return set(["thinking"])
@staticmethod
def _get_litellm_supported_chat_completion_kwargs() -> Set[str]:
"""
@ -82,11 +86,18 @@ class ModelParamHelper:
This follows the OpenAI API Spec
"""
all_chat_completion_kwargs = set(
non_streaming_params: Set[str] = set(
getattr(CompletionCreateParamsNonStreaming, "__annotations__", {}).keys()
).union(
set(getattr(CompletionCreateParamsStreaming, "__annotations__", {}).keys())
)
streaming_params: Set[str] = set(
getattr(CompletionCreateParamsStreaming, "__annotations__", {}).keys()
)
litellm_provider_specific_params: Set[str] = (
ModelParamHelper.get_litellm_provider_specific_params_for_chat_params()
)
all_chat_completion_kwargs: Set[str] = non_streaming_params.union(
streaming_params
).union(litellm_provider_specific_params)
return all_chat_completion_kwargs
@staticmethod

View file

@ -2608,3 +2608,64 @@ def test_caching_with_reasoning_content():
print(f"response 2: {response_2.model_dump_json(indent=4)}")
assert response_2._hidden_params["cache_hit"] == True
assert response_2.choices[0].message.reasoning_content is not None
def test_caching_reasoning_args_miss(): # test in memory cache
try:
#litellm._turn_on_debug()
litellm.set_verbose = True
litellm.cache = Cache(
)
response1 = completion(model="claude-3-7-sonnet-latest", messages=messages, caching=True, reasoning_effort="low", mock_response="My response")
response2 = completion(model="claude-3-7-sonnet-latest", messages=messages, caching=True, mock_response="My response")
print(f"response1: {response1}")
print(f"response2: {response2}")
assert response1.id != response2.id
except Exception as e:
print(f"error occurred: {traceback.format_exc()}")
pytest.fail(f"Error occurred: {e}")
def test_caching_reasoning_args_hit(): # test in memory cache
try:
#litellm._turn_on_debug()
litellm.set_verbose = True
litellm.cache = Cache(
)
response1 = completion(model="claude-3-7-sonnet-latest", messages=messages, caching=True, reasoning_effort="low", mock_response="My response")
response2 = completion(model="claude-3-7-sonnet-latest", messages=messages, caching=True, reasoning_effort="low", mock_response="My response")
print(f"response1: {response1}")
print(f"response2: {response2}")
assert response1.id == response2.id
except Exception as e:
print(f"error occurred: {traceback.format_exc()}")
pytest.fail(f"Error occurred: {e}")
def test_caching_thinking_args_miss(): # test in memory cache
try:
#litellm._turn_on_debug()
litellm.set_verbose = True
litellm.cache = Cache(
)
response1 = completion(model="claude-3-7-sonnet-latest", messages=messages, caching=True, thinking={"type": "enabled", "budget_tokens": 1024}, mock_response="My response")
response2 = completion(model="claude-3-7-sonnet-latest", messages=messages, caching=True, mock_response="My response")
print(f"response1: {response1}")
print(f"response2: {response2}")
assert response1.id != response2.id
except Exception as e:
print(f"error occurred: {traceback.format_exc()}")
pytest.fail(f"Error occurred: {e}")
def test_caching_thinking_args_hit(): # test in memory cache
try:
#litellm._turn_on_debug()
litellm.set_verbose = True
litellm.cache = Cache(
)
response1 = completion(model="claude-3-7-sonnet-latest", messages=messages, caching=True, thinking={"type": "enabled", "budget_tokens": 1024}, mock_response="My response" )
response2 = completion(model="claude-3-7-sonnet-latest", messages=messages, caching=True, thinking={"type": "enabled", "budget_tokens": 1024}, mock_response="My response")
print(f"response1: {response1}")
print(f"response2: {response2}")
assert response1.id == response2.id
except Exception as e:
print(f"error occurred: {traceback.format_exc()}")
pytest.fail(f"Error occurred: {e}")