forked from phoenix/litellm-mirror
Merge pull request #5026 from BerriAI/litellm_fix_whisper_caching
[Fix] Whisper Caching - Use correct cache keys for checking request in cache
This commit is contained in:
commit
7ec1f241fc
4 changed files with 58 additions and 10 deletions
|
@ -1883,16 +1883,15 @@ class Cache:
|
||||||
caching_group or model_group or kwargs[param]
|
caching_group or model_group or kwargs[param]
|
||||||
) # use caching_group, if set then model_group if it exists, else use kwargs["model"]
|
) # use caching_group, if set then model_group if it exists, else use kwargs["model"]
|
||||||
elif param == "file":
|
elif param == "file":
|
||||||
metadata_file_name = kwargs.get("metadata", {}).get(
|
file = kwargs.get("file")
|
||||||
"file_name", None
|
metadata = kwargs.get("metadata", {})
|
||||||
|
litellm_params = kwargs.get("litellm_params", {})
|
||||||
|
|
||||||
|
param_value = (
|
||||||
|
getattr(file, "name", None)
|
||||||
|
or metadata.get("file_name")
|
||||||
|
or litellm_params.get("file_name")
|
||||||
)
|
)
|
||||||
litellm_params_file_name = kwargs.get("litellm_params", {}).get(
|
|
||||||
"file_name", None
|
|
||||||
)
|
|
||||||
if metadata_file_name is not None:
|
|
||||||
param_value = metadata_file_name
|
|
||||||
elif litellm_params_file_name is not None:
|
|
||||||
param_value = litellm_params_file_name
|
|
||||||
else:
|
else:
|
||||||
if kwargs[param] is None:
|
if kwargs[param] is None:
|
||||||
continue # ignore None params
|
continue # ignore None params
|
||||||
|
|
BIN
litellm/tests/eagle.wav
Normal file
BIN
litellm/tests/eagle.wav
Normal file
Binary file not shown.
|
@ -26,6 +26,10 @@ file_path = os.path.join(pwd, "gettysburg.wav")
|
||||||
|
|
||||||
audio_file = open(file_path, "rb")
|
audio_file = open(file_path, "rb")
|
||||||
|
|
||||||
|
|
||||||
|
file2_path = os.path.join(pwd, "eagle.wav")
|
||||||
|
audio_file2 = open(file2_path, "rb")
|
||||||
|
|
||||||
load_dotenv()
|
load_dotenv()
|
||||||
|
|
||||||
sys.path.insert(
|
sys.path.insert(
|
||||||
|
@ -148,3 +152,46 @@ async def test_transcription_on_router():
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
traceback.print_exc()
|
traceback.print_exc()
|
||||||
pytest.fail(f"Error occurred: {e}")
|
pytest.fail(f"Error occurred: {e}")
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio()
|
||||||
|
async def test_transcription_caching():
|
||||||
|
import litellm
|
||||||
|
from litellm.caching import Cache
|
||||||
|
|
||||||
|
litellm.set_verbose = True
|
||||||
|
litellm.cache = Cache()
|
||||||
|
|
||||||
|
# make raw llm api call
|
||||||
|
|
||||||
|
response_1 = await litellm.atranscription(
|
||||||
|
model="whisper-1",
|
||||||
|
file=audio_file,
|
||||||
|
)
|
||||||
|
|
||||||
|
await asyncio.sleep(5)
|
||||||
|
|
||||||
|
# cache hit
|
||||||
|
|
||||||
|
response_2 = await litellm.atranscription(
|
||||||
|
model="whisper-1",
|
||||||
|
file=audio_file,
|
||||||
|
)
|
||||||
|
|
||||||
|
print("response_1", response_1)
|
||||||
|
print("response_2", response_2)
|
||||||
|
print("response2 hidden params", response_2._hidden_params)
|
||||||
|
assert response_2._hidden_params["cache_hit"] is True
|
||||||
|
|
||||||
|
# cache miss
|
||||||
|
|
||||||
|
response_3 = await litellm.atranscription(
|
||||||
|
model="whisper-1",
|
||||||
|
file=audio_file2,
|
||||||
|
)
|
||||||
|
print("response_3", response_3)
|
||||||
|
print("response3 hidden params", response_3._hidden_params)
|
||||||
|
assert response_3._hidden_params.get("cache_hit") is not True
|
||||||
|
assert response_3.text != response_2.text
|
||||||
|
|
||||||
|
litellm.cache = None
|
||||||
|
|
|
@ -553,7 +553,8 @@ def function_setup(
|
||||||
or call_type == CallTypes.transcription.value
|
or call_type == CallTypes.transcription.value
|
||||||
):
|
):
|
||||||
_file_name: BinaryIO = args[1] if len(args) > 1 else kwargs["file"]
|
_file_name: BinaryIO = args[1] if len(args) > 1 else kwargs["file"]
|
||||||
messages = "audio_file"
|
file_name = getattr(_file_name, "name", "audio_file")
|
||||||
|
messages = file_name
|
||||||
elif (
|
elif (
|
||||||
call_type == CallTypes.aspeech.value or call_type == CallTypes.speech.value
|
call_type == CallTypes.aspeech.value or call_type == CallTypes.speech.value
|
||||||
):
|
):
|
||||||
|
@ -1213,6 +1214,7 @@ def client(original_function):
|
||||||
hidden_params = {
|
hidden_params = {
|
||||||
"model": "whisper-1",
|
"model": "whisper-1",
|
||||||
"custom_llm_provider": custom_llm_provider,
|
"custom_llm_provider": custom_llm_provider,
|
||||||
|
"cache_hit": True,
|
||||||
}
|
}
|
||||||
cached_result = convert_to_model_response_object(
|
cached_result = convert_to_model_response_object(
|
||||||
response_object=cached_result,
|
response_object=cached_result,
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue