fix(utils.py): fix conditional check

This commit is contained in:
Krrish Dholakia 2024-02-03 18:58:58 -08:00
parent 9ab59045a3
commit c2f674ebe0

View file

@ -2512,12 +2512,19 @@ def client(original_function):
) )
) )
cached_result = await asyncio.gather(*tasks) cached_result = await asyncio.gather(*tasks)
## check if cached result is None ##
if cached_result is not None and isinstance(
cached_result, list
):
if len(cached_result) == 1 and cached_result[0] is None:
cached_result = None
else: else:
preset_cache_key = litellm.cache.get_cache_key(*args, **kwargs) preset_cache_key = litellm.cache.get_cache_key(*args, **kwargs)
kwargs[ kwargs[
"preset_cache_key" "preset_cache_key"
] = preset_cache_key # for streaming calls, we need to pass the preset_cache_key ] = preset_cache_key # for streaming calls, we need to pass the preset_cache_key
cached_result = litellm.cache.get_cache(*args, **kwargs) cached_result = litellm.cache.get_cache(*args, **kwargs)
if cached_result is not None and not isinstance( if cached_result is not None and not isinstance(
cached_result, list cached_result, list
): ):
@ -2611,7 +2618,6 @@ def client(original_function):
non_null_list.append((idx, cr)) non_null_list.append((idx, cr))
original_kwargs_input = kwargs["input"] original_kwargs_input = kwargs["input"]
kwargs["input"] = remaining_list kwargs["input"] = remaining_list
if len(non_null_list) > 0: if len(non_null_list) > 0:
print_verbose( print_verbose(
f"EMBEDDING CACHE HIT! - {len(non_null_list)}" f"EMBEDDING CACHE HIT! - {len(non_null_list)}"
@ -2628,7 +2634,6 @@ def client(original_function):
idx, cr = val # (idx, cr) tuple idx, cr = val # (idx, cr) tuple
if cr is not None: if cr is not None:
final_embedding_cached_response.data[idx] = cr final_embedding_cached_response.data[idx] = cr
if len(remaining_list) == 0: if len(remaining_list) == 0:
# LOG SUCCESS # LOG SUCCESS
cache_hit = True cache_hit = True
@ -2769,7 +2774,8 @@ def client(original_function):
result._response_ms = ( result._response_ms = (
end_time - start_time end_time - start_time
).total_seconds() * 1000 # return response latency in ms like openai ).total_seconds() * 1000 # return response latency in ms like openai
elif (
if (
isinstance(result, EmbeddingResponse) isinstance(result, EmbeddingResponse)
and final_embedding_cached_response is not None and final_embedding_cached_response is not None
): ):
@ -2783,6 +2789,10 @@ def client(original_function):
final_data_list.append(item) final_data_list.append(item)
final_embedding_cached_response.data = final_data_list final_embedding_cached_response.data = final_data_list
final_embedding_cached_response._hidden_params["cache_hit"] = True
final_embedding_cached_response._response_ms = (
end_time - start_time
).total_seconds() * 1000
return final_embedding_cached_response return final_embedding_cached_response
return result return result