feat(caching.py): enable caching on provider-specific optional params

Closes https://github.com/BerriAI/litellm/issues/5049
This commit is contained in:
Krrish Dholakia 2024-08-05 11:18:59 -07:00
parent cd94c3adc1
commit 3c4c78a71f
7 changed files with 172 additions and 74 deletions

View file

@ -146,6 +146,9 @@ return_response_headers: bool = (
) )
################## ##################
logging: bool = True logging: bool = True
enable_caching_on_optional_params: bool = (
False # feature-flag for caching on optional params - e.g. 'top_k'
)
caching: bool = ( caching: bool = (
False # Not used anymore, will be removed in next MAJOR release - https://github.com/BerriAI/litellm/discussions/648 False # Not used anymore, will be removed in next MAJOR release - https://github.com/BerriAI/litellm/discussions/648
) )

View file

@ -23,6 +23,7 @@ import litellm
from litellm._logging import verbose_logger from litellm._logging import verbose_logger
from litellm.litellm_core_utils.core_helpers import _get_parent_otel_span_from_kwargs from litellm.litellm_core_utils.core_helpers import _get_parent_otel_span_from_kwargs
from litellm.types.services import ServiceLoggerPayload, ServiceTypes from litellm.types.services import ServiceLoggerPayload, ServiceTypes
from litellm.types.utils import all_litellm_params
def print_verbose(print_statement): def print_verbose(print_statement):
@ -1838,6 +1839,7 @@ class Cache:
"seed", "seed",
"tools", "tools",
"tool_choice", "tool_choice",
"stream",
] ]
embedding_only_kwargs = [ embedding_only_kwargs = [
"input", "input",
@ -1851,9 +1853,9 @@ class Cache:
combined_kwargs = ( combined_kwargs = (
completion_kwargs + embedding_only_kwargs + transcription_only_kwargs completion_kwargs + embedding_only_kwargs + transcription_only_kwargs
) )
for param in combined_kwargs: litellm_param_kwargs = all_litellm_params
# ignore litellm params here for param in kwargs:
if param in kwargs: if param in combined_kwargs:
# check if param == model and model_group is passed in, then override model with model_group # check if param == model and model_group is passed in, then override model with model_group
if param == "model": if param == "model":
model_group = None model_group = None
@ -1897,6 +1899,17 @@ class Cache:
continue # ignore None params continue # ignore None params
param_value = kwargs[param] param_value = kwargs[param]
cache_key += f"{str(param)}: {str(param_value)}" cache_key += f"{str(param)}: {str(param_value)}"
elif (
param not in litellm_param_kwargs
): # check if user passed in optional param - e.g. top_k
if (
litellm.enable_caching_on_optional_params is True
): # feature flagged for now
if kwargs[param] is None:
continue # ignore None params
param_value = kwargs[param]
cache_key += f"{str(param)}: {str(param_value)}"
print_verbose(f"\nCreated cache key: {cache_key}") print_verbose(f"\nCreated cache key: {cache_key}")
# Use hashlib to create a sha256 hash of the cache key # Use hashlib to create a sha256 hash of the cache key
hash_object = hashlib.sha256(cache_key.encode()) hash_object = hashlib.sha256(cache_key.encode())
@ -2101,9 +2114,7 @@ class Cache:
try: try:
cache_list = [] cache_list = []
for idx, i in enumerate(kwargs["input"]): for idx, i in enumerate(kwargs["input"]):
preset_cache_key = litellm.cache.get_cache_key( preset_cache_key = self.get_cache_key(*args, **{**kwargs, "input": i})
*args, **{**kwargs, "input": i}
)
kwargs["cache_key"] = preset_cache_key kwargs["cache_key"] = preset_cache_key
embedding_response = result.data[idx] embedding_response = result.data[idx]
cache_key, cached_data, kwargs = self._add_cache_logic( cache_key, cached_data, kwargs = self._add_cache_logic(

View file

@ -125,7 +125,11 @@ from .llms.vertex_ai_partner import VertexAIPartnerModels
from .llms.vertex_httpx import VertexLLM from .llms.vertex_httpx import VertexLLM
from .llms.watsonx import IBMWatsonXAI from .llms.watsonx import IBMWatsonXAI
from .types.llms.openai import HttpxBinaryResponseContent from .types.llms.openai import HttpxBinaryResponseContent
from .types.utils import AdapterCompletionStreamWrapper, ChatCompletionMessageToolCall from .types.utils import (
AdapterCompletionStreamWrapper,
ChatCompletionMessageToolCall,
all_litellm_params,
)
encoding = tiktoken.get_encoding("cl100k_base") encoding = tiktoken.get_encoding("cl100k_base")
from litellm.utils import ( from litellm.utils import (
@ -744,64 +748,9 @@ def completion(
"top_logprobs", "top_logprobs",
"extra_headers", "extra_headers",
] ]
litellm_params = [ litellm_params = (
"metadata", all_litellm_params # use the external var., used in creating cache key as well.
"tags", )
"acompletion",
"atext_completion",
"text_completion",
"caching",
"mock_response",
"api_key",
"api_version",
"api_base",
"force_timeout",
"logger_fn",
"verbose",
"custom_llm_provider",
"litellm_logging_obj",
"litellm_call_id",
"use_client",
"id",
"fallbacks",
"azure",
"headers",
"model_list",
"num_retries",
"context_window_fallback_dict",
"retry_policy",
"roles",
"final_prompt_value",
"bos_token",
"eos_token",
"request_timeout",
"complete_response",
"self",
"client",
"rpm",
"tpm",
"max_parallel_requests",
"input_cost_per_token",
"output_cost_per_token",
"input_cost_per_second",
"output_cost_per_second",
"hf_model_name",
"model_info",
"proxy_server_request",
"preset_cache_key",
"caching_groups",
"ttl",
"cache",
"no-log",
"base_model",
"stream_timeout",
"supports_system_message",
"region_name",
"allowed_model_region",
"model_config",
"fastest_response",
"cooldown_time",
]
default_params = openai_params + litellm_params default_params = openai_params + litellm_params
non_default_params = { non_default_params = {

Binary file not shown.

View file

@ -207,11 +207,17 @@ async def test_caching_with_cache_controls(sync_flag):
else: else:
## TTL = 0 ## TTL = 0
response1 = await litellm.acompletion( response1 = await litellm.acompletion(
model="gpt-3.5-turbo", messages=messages, cache={"ttl": 0} model="gpt-3.5-turbo",
messages=messages,
cache={"ttl": 0},
mock_response="Hello world",
) )
await asyncio.sleep(10) await asyncio.sleep(10)
response2 = await litellm.acompletion( response2 = await litellm.acompletion(
model="gpt-3.5-turbo", messages=messages, cache={"s-maxage": 10} model="gpt-3.5-turbo",
messages=messages,
cache={"s-maxage": 10},
mock_response="Hello world",
) )
assert response2["id"] != response1["id"] assert response2["id"] != response1["id"]
@ -220,21 +226,33 @@ async def test_caching_with_cache_controls(sync_flag):
## TTL = 5 ## TTL = 5
if sync_flag: if sync_flag:
response1 = completion( response1 = completion(
model="gpt-3.5-turbo", messages=messages, cache={"ttl": 5} model="gpt-3.5-turbo",
messages=messages,
cache={"ttl": 5},
mock_response="Hello world",
) )
response2 = completion( response2 = completion(
model="gpt-3.5-turbo", messages=messages, cache={"s-maxage": 5} model="gpt-3.5-turbo",
messages=messages,
cache={"s-maxage": 5},
mock_response="Hello world",
) )
print(f"response1: {response1}") print(f"response1: {response1}")
print(f"response2: {response2}") print(f"response2: {response2}")
assert response2["id"] == response1["id"] assert response2["id"] == response1["id"]
else: else:
response1 = await litellm.acompletion( response1 = await litellm.acompletion(
model="gpt-3.5-turbo", messages=messages, cache={"ttl": 25} model="gpt-3.5-turbo",
messages=messages,
cache={"ttl": 25},
mock_response="Hello world",
) )
await asyncio.sleep(10) await asyncio.sleep(10)
response2 = await litellm.acompletion( response2 = await litellm.acompletion(
model="gpt-3.5-turbo", messages=messages, cache={"s-maxage": 25} model="gpt-3.5-turbo",
messages=messages,
cache={"s-maxage": 25},
mock_response="Hello world",
) )
print(f"response1: {response1}") print(f"response1: {response1}")
print(f"response2: {response2}") print(f"response2: {response2}")
@ -282,6 +300,61 @@ def test_caching_with_models_v2():
# test_caching_with_models_v2() # test_caching_with_models_v2()
def test_caching_with_optional_params():
litellm.enable_caching_on_optional_params = True
messages = [
{"role": "user", "content": "who is ishaan CTO of litellm from litellm 2023"}
]
litellm.cache = Cache()
print("test2 for caching")
litellm.set_verbose = True
response1 = completion(
model="gpt-3.5-turbo",
messages=messages,
top_k=10,
caching=True,
mock_response="Hello: {}".format(uuid.uuid4()),
)
response2 = completion(
model="gpt-3.5-turbo",
messages=messages,
top_k=10,
caching=True,
mock_response="Hello: {}".format(uuid.uuid4()),
)
response3 = completion(
model="gpt-3.5-turbo",
messages=messages,
top_k=9,
caching=True,
mock_response="Hello: {}".format(uuid.uuid4()),
)
print(f"response1: {response1}")
print(f"response2: {response2}")
print(f"response3: {response3}")
litellm.cache = None
litellm.success_callback = []
litellm._async_success_callback = []
if (
response3["choices"][0]["message"]["content"]
== response2["choices"][0]["message"]["content"]
):
# if models are different, it should not return cached response
print(f"response2: {response2}")
print(f"response3: {response3}")
pytest.fail(f"Error occurred:")
if (
response1["choices"][0]["message"]["content"]
!= response2["choices"][0]["message"]["content"]
):
print(f"response1: {response1}")
print(f"response2: {response2}")
pytest.fail(f"Error occurred:")
litellm.enable_caching_on_optional_params = False
embedding_large_text = ( embedding_large_text = (
""" """
small text small text
@ -1347,7 +1420,7 @@ def test_get_cache_key():
"litellm_logging_obj": {}, "litellm_logging_obj": {},
} }
) )
cache_key_str = "model: gpt-3.5-turbomessages: [{'role': 'user', 'content': 'write a one sentence poem about: 7510'}]temperature: 0.2max_tokens: 40" cache_key_str = "model: gpt-3.5-turbomessages: [{'role': 'user', 'content': 'write a one sentence poem about: 7510'}]max_tokens: 40temperature: 0.2stream: True"
hash_object = hashlib.sha256(cache_key_str.encode()) hash_object = hashlib.sha256(cache_key_str.encode())
# Hexadecimal representation of the hash # Hexadecimal representation of the hash
hash_hex = hash_object.hexdigest() hash_hex = hash_object.hexdigest()

View file

@ -1052,6 +1052,68 @@ class ResponseFormatChunk(TypedDict, total=False):
response_schema: dict response_schema: dict
all_litellm_params = [
"metadata",
"tags",
"acompletion",
"atext_completion",
"text_completion",
"caching",
"mock_response",
"api_key",
"api_version",
"api_base",
"force_timeout",
"logger_fn",
"verbose",
"custom_llm_provider",
"litellm_logging_obj",
"litellm_call_id",
"use_client",
"id",
"fallbacks",
"azure",
"headers",
"model_list",
"num_retries",
"context_window_fallback_dict",
"retry_policy",
"roles",
"final_prompt_value",
"bos_token",
"eos_token",
"request_timeout",
"complete_response",
"self",
"client",
"rpm",
"tpm",
"max_parallel_requests",
"input_cost_per_token",
"output_cost_per_token",
"input_cost_per_second",
"output_cost_per_second",
"hf_model_name",
"model_info",
"proxy_server_request",
"preset_cache_key",
"caching_groups",
"ttl",
"cache",
"no-log",
"base_model",
"stream_timeout",
"supports_system_message",
"region_name",
"allowed_model_region",
"model_config",
"fastest_response",
"cooldown_time",
"cache_key",
"max_retries",
]
class LoggedLiteLLMParams(TypedDict, total=False): class LoggedLiteLLMParams(TypedDict, total=False):
force_timeout: Optional[float] force_timeout: Optional[float]
custom_llm_provider: Optional[str] custom_llm_provider: Optional[str]

View file

@ -1084,7 +1084,7 @@ def client(original_function):
and str(original_function.__name__) and str(original_function.__name__)
in litellm.cache.supported_call_types in litellm.cache.supported_call_types
): ):
print_verbose(f"Checking Cache") print_verbose("Checking Cache")
if call_type == CallTypes.aembedding.value and isinstance( if call_type == CallTypes.aembedding.value and isinstance(
kwargs["input"], list kwargs["input"], list
): ):