Merge pull request #5026 from BerriAI/litellm_fix_whisper_caching

[Fix] Whisper Caching - Use correct cache keys for checking request in cache
This commit is contained in:
Ishaan Jaff 2024-08-02 17:26:28 -07:00 committed by GitHub
commit 7ec1f241fc
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
4 changed files with 58 additions and 10 deletions

View file

@ -1883,16 +1883,15 @@ class Cache:
caching_group or model_group or kwargs[param] caching_group or model_group or kwargs[param]
) # use caching_group, if set then model_group if it exists, else use kwargs["model"] ) # use caching_group, if set then model_group if it exists, else use kwargs["model"]
elif param == "file": elif param == "file":
metadata_file_name = kwargs.get("metadata", {}).get( file = kwargs.get("file")
"file_name", None metadata = kwargs.get("metadata", {})
litellm_params = kwargs.get("litellm_params", {})
param_value = (
getattr(file, "name", None)
or metadata.get("file_name")
or litellm_params.get("file_name")
) )
litellm_params_file_name = kwargs.get("litellm_params", {}).get(
"file_name", None
)
if metadata_file_name is not None:
param_value = metadata_file_name
elif litellm_params_file_name is not None:
param_value = litellm_params_file_name
else: else:
if kwargs[param] is None: if kwargs[param] is None:
continue # ignore None params continue # ignore None params

BIN
litellm/tests/eagle.wav Normal file

Binary file not shown.

View file

@ -26,6 +26,10 @@ file_path = os.path.join(pwd, "gettysburg.wav")
audio_file = open(file_path, "rb") audio_file = open(file_path, "rb")
file2_path = os.path.join(pwd, "eagle.wav")
audio_file2 = open(file2_path, "rb")
load_dotenv() load_dotenv()
sys.path.insert( sys.path.insert(
@ -148,3 +152,46 @@ async def test_transcription_on_router():
except Exception as e: except Exception as e:
traceback.print_exc() traceback.print_exc()
pytest.fail(f"Error occurred: {e}") pytest.fail(f"Error occurred: {e}")
@pytest.mark.asyncio()
async def test_transcription_caching():
import litellm
from litellm.caching import Cache
litellm.set_verbose = True
litellm.cache = Cache()
# make raw llm api call
response_1 = await litellm.atranscription(
model="whisper-1",
file=audio_file,
)
await asyncio.sleep(5)
# cache hit
response_2 = await litellm.atranscription(
model="whisper-1",
file=audio_file,
)
print("response_1", response_1)
print("response_2", response_2)
print("response2 hidden params", response_2._hidden_params)
assert response_2._hidden_params["cache_hit"] is True
# cache miss
response_3 = await litellm.atranscription(
model="whisper-1",
file=audio_file2,
)
print("response_3", response_3)
print("response3 hidden params", response_3._hidden_params)
assert response_3._hidden_params.get("cache_hit") is not True
assert response_3.text != response_2.text
litellm.cache = None

View file

@ -553,7 +553,8 @@ def function_setup(
or call_type == CallTypes.transcription.value or call_type == CallTypes.transcription.value
): ):
_file_name: BinaryIO = args[1] if len(args) > 1 else kwargs["file"] _file_name: BinaryIO = args[1] if len(args) > 1 else kwargs["file"]
messages = "audio_file" file_name = getattr(_file_name, "name", "audio_file")
messages = file_name
elif ( elif (
call_type == CallTypes.aspeech.value or call_type == CallTypes.speech.value call_type == CallTypes.aspeech.value or call_type == CallTypes.speech.value
): ):
@ -1213,6 +1214,7 @@ def client(original_function):
hidden_params = { hidden_params = {
"model": "whisper-1", "model": "whisper-1",
"custom_llm_provider": custom_llm_provider, "custom_llm_provider": custom_llm_provider,
"cache_hit": True,
} }
cached_result = convert_to_model_response_object( cached_result = convert_to_model_response_object(
response_object=cached_result, response_object=cached_result,