Merge pull request #3250 from BerriAI/litellm_caching_no_cache_fix

fix(utils.py): fix 'no-cache': true when caching is turned on
This commit is contained in:
Krish Dholakia 2024-04-23 19:57:07 -07:00 committed by GitHub
commit 4acdde988f
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
5 changed files with 122 additions and 55 deletions

View file

@ -129,8 +129,6 @@ class PrometheusServicesLogger:
if self.mock_testing: if self.mock_testing:
self.mock_testing_success_calls += 1 self.mock_testing_success_calls += 1
print(f"payload call type: {payload.call_type}")
if payload.service.value in self.payload_to_prometheus_map: if payload.service.value in self.payload_to_prometheus_map:
prom_objects = self.payload_to_prometheus_map[payload.service.value] prom_objects = self.payload_to_prometheus_map[payload.service.value]
for obj in prom_objects: for obj in prom_objects:
@ -151,8 +149,6 @@ class PrometheusServicesLogger:
if self.mock_testing: if self.mock_testing:
self.mock_testing_failure_calls += 1 self.mock_testing_failure_calls += 1
print(f"payload call type: {payload.call_type}")
if payload.service.value in self.payload_to_prometheus_map: if payload.service.value in self.payload_to_prometheus_map:
prom_objects = self.payload_to_prometheus_map[payload.service.value] prom_objects = self.payload_to_prometheus_map[payload.service.value]
for obj in prom_objects: for obj in prom_objects:
@ -170,8 +166,6 @@ class PrometheusServicesLogger:
if self.mock_testing: if self.mock_testing:
self.mock_testing_success_calls += 1 self.mock_testing_success_calls += 1
print(f"payload call type: {payload.call_type}")
if payload.service.value in self.payload_to_prometheus_map: if payload.service.value in self.payload_to_prometheus_map:
prom_objects = self.payload_to_prometheus_map[payload.service.value] prom_objects = self.payload_to_prometheus_map[payload.service.value]
for obj in prom_objects: for obj in prom_objects:
@ -193,8 +187,6 @@ class PrometheusServicesLogger:
if self.mock_testing: if self.mock_testing:
self.mock_testing_failure_calls += 1 self.mock_testing_failure_calls += 1
print(f"payload call type: {payload.call_type}")
if payload.service.value in self.payload_to_prometheus_map: if payload.service.value in self.payload_to_prometheus_map:
prom_objects = self.payload_to_prometheus_map[payload.service.value] prom_objects = self.payload_to_prometheus_map[payload.service.value]
for obj in prom_objects: for obj in prom_objects:

View file

@ -21,10 +21,10 @@ model_list:
api_version: "2023-07-01-preview" api_version: "2023-07-01-preview"
stream_timeout: 0.001 stream_timeout: 0.001
model_name: azure-gpt-3.5 model_name: azure-gpt-3.5
# - model_name: text-embedding-ada-002 - model_name: text-embedding-ada-002
# litellm_params: litellm_params:
# model: text-embedding-ada-002 model: text-embedding-ada-002
# api_key: os.environ/OPENAI_API_KEY api_key: os.environ/OPENAI_API_KEY
- model_name: gpt-instruct - model_name: gpt-instruct
litellm_params: litellm_params:
model: text-completion-openai/gpt-3.5-turbo-instruct model: text-completion-openai/gpt-3.5-turbo-instruct
@ -45,6 +45,9 @@ litellm_settings:
success_callback: ["prometheus"] success_callback: ["prometheus"]
failure_callback: ["prometheus"] failure_callback: ["prometheus"]
service_callback: ["prometheus_system"] service_callback: ["prometheus_system"]
cache: True
cache_params:
type: "redis"
general_settings: general_settings:

View file

@ -578,8 +578,10 @@ def test_gemini_pro_function_calling():
model="gemini-pro", messages=messages, tools=tools, tool_choice="auto" model="gemini-pro", messages=messages, tools=tools, tool_choice="auto"
) )
print(f"completion: {completion}") print(f"completion: {completion}")
assert completion.choices[0].message.content is None if hasattr(completion.choices[0].message, "tool_calls") and isinstance(
assert len(completion.choices[0].message.tool_calls) == 1 completion.choices[0].message.tool_calls, list
):
assert len(completion.choices[0].message.tool_calls) == 1
try: try:
load_vertex_ai_credentials() load_vertex_ai_credentials()
tools = [ tools = [

View file

@ -178,32 +178,61 @@ def test_caching_with_default_ttl():
pytest.fail(f"Error occurred: {e}") pytest.fail(f"Error occurred: {e}")
def test_caching_with_cache_controls(): @pytest.mark.parametrize(
"sync_flag",
[True, False],
)
@pytest.mark.asyncio
async def test_caching_with_cache_controls(sync_flag):
try: try:
litellm.set_verbose = True litellm.set_verbose = True
litellm.cache = Cache() litellm.cache = Cache()
message = [{"role": "user", "content": f"Hey, how's it going? {uuid.uuid4()}"}] message = [{"role": "user", "content": f"Hey, how's it going? {uuid.uuid4()}"}]
## TTL = 0 if sync_flag:
response1 = completion( ## TTL = 0
model="gpt-3.5-turbo", messages=messages, cache={"ttl": 0} response1 = completion(
) model="gpt-3.5-turbo", messages=messages, cache={"ttl": 0}
response2 = completion( )
model="gpt-3.5-turbo", messages=messages, cache={"s-maxage": 10} response2 = completion(
) model="gpt-3.5-turbo", messages=messages, cache={"s-maxage": 10}
print(f"response1: {response1}") )
print(f"response2: {response2}")
assert response2["id"] != response1["id"] assert response2["id"] != response1["id"]
else:
## TTL = 0
response1 = await litellm.acompletion(
model="gpt-3.5-turbo", messages=messages, cache={"ttl": 0}
)
await asyncio.sleep(10)
response2 = await litellm.acompletion(
model="gpt-3.5-turbo", messages=messages, cache={"s-maxage": 10}
)
assert response2["id"] != response1["id"]
message = [{"role": "user", "content": f"Hey, how's it going? {uuid.uuid4()}"}] message = [{"role": "user", "content": f"Hey, how's it going? {uuid.uuid4()}"}]
## TTL = 5 ## TTL = 5
response1 = completion( if sync_flag:
model="gpt-3.5-turbo", messages=messages, cache={"ttl": 5} response1 = completion(
) model="gpt-3.5-turbo", messages=messages, cache={"ttl": 5}
response2 = completion( )
model="gpt-3.5-turbo", messages=messages, cache={"s-maxage": 5} response2 = completion(
) model="gpt-3.5-turbo", messages=messages, cache={"s-maxage": 5}
print(f"response1: {response1}") )
print(f"response2: {response2}") print(f"response1: {response1}")
assert response2["id"] == response1["id"] print(f"response2: {response2}")
assert response2["id"] == response1["id"]
else:
response1 = await litellm.acompletion(
model="gpt-3.5-turbo", messages=messages, cache={"ttl": 25}
)
await asyncio.sleep(10)
response2 = await litellm.acompletion(
model="gpt-3.5-turbo", messages=messages, cache={"s-maxage": 25}
)
print(f"response1: {response1}")
print(f"response2: {response2}")
assert response2["id"] == response1["id"]
except Exception as e: except Exception as e:
print(f"error occurred: {traceback.format_exc()}") print(f"error occurred: {traceback.format_exc()}")
pytest.fail(f"Error occurred: {e}") pytest.fail(f"Error occurred: {e}")
@ -1111,6 +1140,7 @@ async def test_cache_control_overrides():
"content": "hello who are you" + unique_num, "content": "hello who are you" + unique_num,
} }
], ],
caching=True,
) )
print(response1) print(response1)
@ -1125,6 +1155,55 @@ async def test_cache_control_overrides():
"content": "hello who are you" + unique_num, "content": "hello who are you" + unique_num,
} }
], ],
caching=True,
cache={"no-cache": True},
)
print(response2)
assert response1.id != response2.id
def test_sync_cache_control_overrides():
# we use the cache controls to ensure there is no cache hit on this test
litellm.cache = Cache(
type="redis",
host=os.environ["REDIS_HOST"],
port=os.environ["REDIS_PORT"],
password=os.environ["REDIS_PASSWORD"],
)
print("Testing cache override")
litellm.set_verbose = True
import uuid
unique_num = str(uuid.uuid4())
start_time = time.time()
response1 = litellm.completion(
model="gpt-3.5-turbo",
messages=[
{
"role": "user",
"content": "hello who are you" + unique_num,
}
],
caching=True,
)
print(response1)
time.sleep(2)
response2 = litellm.completion(
model="gpt-3.5-turbo",
messages=[
{
"role": "user",
"content": "hello who are you" + unique_num,
}
],
caching=True,
cache={"no-cache": True}, cache={"no-cache": True},
) )

View file

@ -2716,23 +2716,22 @@ def client(original_function):
# [OPTIONAL] CHECK CACHE # [OPTIONAL] CHECK CACHE
print_verbose( print_verbose(
f"kwargs[caching]: {kwargs.get('caching', False)}; litellm.cache: {litellm.cache}" f"SYNC kwargs[caching]: {kwargs.get('caching', False)}; litellm.cache: {litellm.cache}; kwargs.get('cache')['no-cache']: {kwargs.get('cache', {}).get('no-cache', False)}"
) )
# if caching is false or cache["no-cache"]==True, don't run this # if caching is false or cache["no-cache"]==True, don't run this
if ( if (
( (
( (
kwargs.get("caching", None) is None (
and kwargs.get("cache", None) is None kwargs.get("caching", None) is None
and litellm.cache is not None and litellm.cache is not None
) )
or kwargs.get("caching", False) == True or kwargs.get("caching", False) == True
or (
kwargs.get("cache", None) is not None
and kwargs.get("cache", {}).get("no-cache", False) != True
) )
and kwargs.get("cache", {}).get("no-cache", False) != True
) )
and kwargs.get("aembedding", False) != True and kwargs.get("aembedding", False) != True
and kwargs.get("atext_completion", False) != True
and kwargs.get("acompletion", False) != True and kwargs.get("acompletion", False) != True
and kwargs.get("aimg_generation", False) != True and kwargs.get("aimg_generation", False) != True
and kwargs.get("atranscription", False) != True and kwargs.get("atranscription", False) != True
@ -3011,24 +3010,17 @@ def client(original_function):
) )
# [OPTIONAL] CHECK CACHE # [OPTIONAL] CHECK CACHE
print_verbose(f"litellm.cache: {litellm.cache}")
print_verbose( print_verbose(
f"kwargs[caching]: {kwargs.get('caching', False)}; litellm.cache: {litellm.cache}" f"ASYNC kwargs[caching]: {kwargs.get('caching', False)}; litellm.cache: {litellm.cache}; kwargs.get('cache'): {kwargs.get('cache', None)}"
) )
# if caching is false, don't run this # if caching is false, don't run this
final_embedding_cached_response = None final_embedding_cached_response = None
if ( if (
( (kwargs.get("caching", None) is None and litellm.cache is not None)
kwargs.get("caching", None) is None
and kwargs.get("cache", None) is None
and litellm.cache is not None
)
or kwargs.get("caching", False) == True or kwargs.get("caching", False) == True
or ( ) and (
kwargs.get("cache", None) is not None kwargs.get("cache", {}).get("no-cache", False) != True
and kwargs.get("cache").get("no-cache", False) != True
)
): # allow users to control returning cached responses from the completion function ): # allow users to control returning cached responses from the completion function
# checking cache # checking cache
print_verbose("INSIDE CHECKING CACHE") print_verbose("INSIDE CHECKING CACHE")
@ -3074,7 +3066,6 @@ def client(original_function):
preset_cache_key # for streaming calls, we need to pass the preset_cache_key preset_cache_key # for streaming calls, we need to pass the preset_cache_key
) )
cached_result = litellm.cache.get_cache(*args, **kwargs) cached_result = litellm.cache.get_cache(*args, **kwargs)
if cached_result is not None and not isinstance( if cached_result is not None and not isinstance(
cached_result, list cached_result, list
): ):