diff --git a/litellm/caching.py b/litellm/caching.py index fa10095da2..c23c1641b0 100644 --- a/litellm/caching.py +++ b/litellm/caching.py @@ -1883,16 +1883,15 @@ class Cache: caching_group or model_group or kwargs[param] ) # use caching_group, if set then model_group if it exists, else use kwargs["model"] elif param == "file": - metadata_file_name = kwargs.get("metadata", {}).get( - "file_name", None + file = kwargs.get("file") + metadata = kwargs.get("metadata", {}) + litellm_params = kwargs.get("litellm_params", {}) + + param_value = ( + getattr(file, "name", None) + or metadata.get("file_name") + or litellm_params.get("file_name") ) - litellm_params_file_name = kwargs.get("litellm_params", {}).get( - "file_name", None - ) - if metadata_file_name is not None: - param_value = metadata_file_name - elif litellm_params_file_name is not None: - param_value = litellm_params_file_name else: if kwargs[param] is None: continue # ignore None params diff --git a/litellm/tests/eagle.wav b/litellm/tests/eagle.wav new file mode 100644 index 0000000000..1c23657859 Binary files /dev/null and b/litellm/tests/eagle.wav differ diff --git a/litellm/tests/test_whisper.py b/litellm/tests/test_whisper.py index 9d26d2d4e6..89ce0ba006 100644 --- a/litellm/tests/test_whisper.py +++ b/litellm/tests/test_whisper.py @@ -26,6 +26,10 @@ file_path = os.path.join(pwd, "gettysburg.wav") audio_file = open(file_path, "rb") + +file2_path = os.path.join(pwd, "eagle.wav") +audio_file2 = open(file2_path, "rb") + load_dotenv() sys.path.insert( @@ -148,3 +152,46 @@ async def test_transcription_on_router(): except Exception as e: traceback.print_exc() pytest.fail(f"Error occurred: {e}") + + +@pytest.mark.asyncio() +async def test_transcription_caching(): + import litellm + from litellm.caching import Cache + + litellm.set_verbose = True + litellm.cache = Cache() + + # make raw llm api call + + response_1 = await litellm.atranscription( + model="whisper-1", + file=audio_file, + ) + + await asyncio.sleep(5) + + # cache hit + + response_2 = await litellm.atranscription( + model="whisper-1", + file=audio_file, + ) + + print("response_1", response_1) + print("response_2", response_2) + print("response2 hidden params", response_2._hidden_params) + assert response_2._hidden_params["cache_hit"] is True + + # cache miss + + response_3 = await litellm.atranscription( + model="whisper-1", + file=audio_file2, + ) + print("response_3", response_3) + print("response3 hidden params", response_3._hidden_params) + assert response_3._hidden_params.get("cache_hit") is not True + assert response_3.text != response_2.text + + litellm.cache = None diff --git a/litellm/utils.py b/litellm/utils.py index 5768f8541c..916d31cfb4 100644 --- a/litellm/utils.py +++ b/litellm/utils.py @@ -553,7 +553,8 @@ def function_setup( or call_type == CallTypes.transcription.value ): _file_name: BinaryIO = args[1] if len(args) > 1 else kwargs["file"] - messages = "audio_file" + file_name = getattr(_file_name, "name", "audio_file") + messages = file_name elif ( call_type == CallTypes.aspeech.value or call_type == CallTypes.speech.value ): @@ -1213,6 +1214,7 @@ def client(original_function): hidden_params = { "model": "whisper-1", "custom_llm_provider": custom_llm_provider, + "cache_hit": True, } cached_result = convert_to_model_response_object( response_object=cached_result,