diff --git a/litellm/__init__.py b/litellm/__init__.py index 6dc678b3e..22255eb34 100644 --- a/litellm/__init__.py +++ b/litellm/__init__.py @@ -146,6 +146,9 @@ return_response_headers: bool = ( ) ################## logging: bool = True +enable_caching_on_optional_params: bool = ( + False # feature-flag for caching on optional params - e.g. 'top_k' +) caching: bool = ( False # Not used anymore, will be removed in next MAJOR release - https://github.com/BerriAI/litellm/discussions/648 ) diff --git a/litellm/caching.py b/litellm/caching.py index c23c1641b..ab62c3440 100644 --- a/litellm/caching.py +++ b/litellm/caching.py @@ -23,6 +23,7 @@ import litellm from litellm._logging import verbose_logger from litellm.litellm_core_utils.core_helpers import _get_parent_otel_span_from_kwargs from litellm.types.services import ServiceLoggerPayload, ServiceTypes +from litellm.types.utils import all_litellm_params def print_verbose(print_statement): @@ -1838,6 +1839,7 @@ class Cache: "seed", "tools", "tool_choice", + "stream", ] embedding_only_kwargs = [ "input", @@ -1851,9 +1853,9 @@ class Cache: combined_kwargs = ( completion_kwargs + embedding_only_kwargs + transcription_only_kwargs ) - for param in combined_kwargs: - # ignore litellm params here - if param in kwargs: + litellm_param_kwargs = all_litellm_params + for param in kwargs: + if param in combined_kwargs: # check if param == model and model_group is passed in, then override model with model_group if param == "model": model_group = None @@ -1897,6 +1899,17 @@ class Cache: continue # ignore None params param_value = kwargs[param] cache_key += f"{str(param)}: {str(param_value)}" + elif ( + param not in litellm_param_kwargs + ): # check if user passed in optional param - e.g. top_k + if ( + litellm.enable_caching_on_optional_params is True + ): # feature flagged for now + if kwargs[param] is None: + continue # ignore None params + param_value = kwargs[param] + cache_key += f"{str(param)}: {str(param_value)}" + print_verbose(f"\nCreated cache key: {cache_key}") # Use hashlib to create a sha256 hash of the cache key hash_object = hashlib.sha256(cache_key.encode()) @@ -2101,9 +2114,7 @@ class Cache: try: cache_list = [] for idx, i in enumerate(kwargs["input"]): - preset_cache_key = litellm.cache.get_cache_key( - *args, **{**kwargs, "input": i} - ) + preset_cache_key = self.get_cache_key(*args, **{**kwargs, "input": i}) kwargs["cache_key"] = preset_cache_key embedding_response = result.data[idx] cache_key, cached_data, kwargs = self._add_cache_logic( diff --git a/litellm/main.py b/litellm/main.py index f0eb00ecd..fd1adc15b 100644 --- a/litellm/main.py +++ b/litellm/main.py @@ -125,7 +125,11 @@ from .llms.vertex_ai_partner import VertexAIPartnerModels from .llms.vertex_httpx import VertexLLM from .llms.watsonx import IBMWatsonXAI from .types.llms.openai import HttpxBinaryResponseContent -from .types.utils import AdapterCompletionStreamWrapper, ChatCompletionMessageToolCall +from .types.utils import ( + AdapterCompletionStreamWrapper, + ChatCompletionMessageToolCall, + all_litellm_params, +) encoding = tiktoken.get_encoding("cl100k_base") from litellm.utils import ( @@ -744,64 +748,9 @@ def completion( "top_logprobs", "extra_headers", ] - litellm_params = [ - "metadata", - "tags", - "acompletion", - "atext_completion", - "text_completion", - "caching", - "mock_response", - "api_key", - "api_version", - "api_base", - "force_timeout", - "logger_fn", - "verbose", - "custom_llm_provider", - "litellm_logging_obj", - "litellm_call_id", - "use_client", - "id", - "fallbacks", - "azure", - "headers", - "model_list", - "num_retries", - "context_window_fallback_dict", - "retry_policy", - "roles", - "final_prompt_value", - "bos_token", - "eos_token", - "request_timeout", - "complete_response", - "self", - "client", - "rpm", - "tpm", - "max_parallel_requests", - "input_cost_per_token", - "output_cost_per_token", - "input_cost_per_second", - "output_cost_per_second", - "hf_model_name", - "model_info", - "proxy_server_request", - "preset_cache_key", - "caching_groups", - "ttl", - "cache", - "no-log", - "base_model", - "stream_timeout", - "supports_system_message", - "region_name", - "allowed_model_region", - "model_config", - "fastest_response", - "cooldown_time", - ] + litellm_params = ( + all_litellm_params # use the external var., used in creating cache key as well. + ) default_params = openai_params + litellm_params non_default_params = { @@ -5205,7 +5154,7 @@ def stream_chunk_builder( response["choices"][0]["message"]["function_call"][ "arguments" ] = combined_arguments - + content_chunks = [ chunk for chunk in chunks diff --git a/litellm/tests/.litellm_cache/cache.db b/litellm/tests/.litellm_cache/cache.db new file mode 100644 index 000000000..409957649 Binary files /dev/null and b/litellm/tests/.litellm_cache/cache.db differ diff --git a/litellm/tests/test_caching.py b/litellm/tests/test_caching.py index a4a70a535..b08f0039c 100644 --- a/litellm/tests/test_caching.py +++ b/litellm/tests/test_caching.py @@ -207,11 +207,17 @@ async def test_caching_with_cache_controls(sync_flag): else: ## TTL = 0 response1 = await litellm.acompletion( - model="gpt-3.5-turbo", messages=messages, cache={"ttl": 0} + model="gpt-3.5-turbo", + messages=messages, + cache={"ttl": 0}, + mock_response="Hello world", ) await asyncio.sleep(10) response2 = await litellm.acompletion( - model="gpt-3.5-turbo", messages=messages, cache={"s-maxage": 10} + model="gpt-3.5-turbo", + messages=messages, + cache={"s-maxage": 10}, + mock_response="Hello world", ) assert response2["id"] != response1["id"] @@ -220,21 +226,33 @@ async def test_caching_with_cache_controls(sync_flag): ## TTL = 5 if sync_flag: response1 = completion( - model="gpt-3.5-turbo", messages=messages, cache={"ttl": 5} + model="gpt-3.5-turbo", + messages=messages, + cache={"ttl": 5}, + mock_response="Hello world", ) response2 = completion( - model="gpt-3.5-turbo", messages=messages, cache={"s-maxage": 5} + model="gpt-3.5-turbo", + messages=messages, + cache={"s-maxage": 5}, + mock_response="Hello world", ) print(f"response1: {response1}") print(f"response2: {response2}") assert response2["id"] == response1["id"] else: response1 = await litellm.acompletion( - model="gpt-3.5-turbo", messages=messages, cache={"ttl": 25} + model="gpt-3.5-turbo", + messages=messages, + cache={"ttl": 25}, + mock_response="Hello world", ) await asyncio.sleep(10) response2 = await litellm.acompletion( - model="gpt-3.5-turbo", messages=messages, cache={"s-maxage": 25} + model="gpt-3.5-turbo", + messages=messages, + cache={"s-maxage": 25}, + mock_response="Hello world", ) print(f"response1: {response1}") print(f"response2: {response2}") @@ -282,6 +300,61 @@ def test_caching_with_models_v2(): # test_caching_with_models_v2() + +def test_caching_with_optional_params(): + litellm.enable_caching_on_optional_params = True + messages = [ + {"role": "user", "content": "who is ishaan CTO of litellm from litellm 2023"} + ] + litellm.cache = Cache() + print("test2 for caching") + litellm.set_verbose = True + + response1 = completion( + model="gpt-3.5-turbo", + messages=messages, + top_k=10, + caching=True, + mock_response="Hello: {}".format(uuid.uuid4()), + ) + response2 = completion( + model="gpt-3.5-turbo", + messages=messages, + top_k=10, + caching=True, + mock_response="Hello: {}".format(uuid.uuid4()), + ) + response3 = completion( + model="gpt-3.5-turbo", + messages=messages, + top_k=9, + caching=True, + mock_response="Hello: {}".format(uuid.uuid4()), + ) + print(f"response1: {response1}") + print(f"response2: {response2}") + print(f"response3: {response3}") + litellm.cache = None + litellm.success_callback = [] + litellm._async_success_callback = [] + if ( + response3["choices"][0]["message"]["content"] + == response2["choices"][0]["message"]["content"] + ): + # if models are different, it should not return cached response + print(f"response2: {response2}") + print(f"response3: {response3}") + pytest.fail(f"Error occurred:") + if ( + response1["choices"][0]["message"]["content"] + != response2["choices"][0]["message"]["content"] + ): + print(f"response1: {response1}") + print(f"response2: {response2}") + pytest.fail(f"Error occurred:") + litellm.enable_caching_on_optional_params = False + + embedding_large_text = ( """ small text @@ -1347,7 +1420,7 @@ def test_get_cache_key(): "litellm_logging_obj": {}, } ) - cache_key_str = "model: gpt-3.5-turbomessages: [{'role': 'user', 'content': 'write a one sentence poem about: 7510'}]temperature: 0.2max_tokens: 40" + cache_key_str = "model: gpt-3.5-turbomessages: [{'role': 'user', 'content': 'write a one sentence poem about: 7510'}]max_tokens: 40temperature: 0.2stream: True" hash_object = hashlib.sha256(cache_key_str.encode()) # Hexadecimal representation of the hash hash_hex = hash_object.hexdigest() diff --git a/litellm/types/utils.py b/litellm/types/utils.py index 481f762ee..7f734482c 100644 --- a/litellm/types/utils.py +++ b/litellm/types/utils.py @@ -1052,6 +1052,68 @@ class ResponseFormatChunk(TypedDict, total=False): response_schema: dict +all_litellm_params = [ + "metadata", + "tags", + "acompletion", + "atext_completion", + "text_completion", + "caching", + "mock_response", + "api_key", + "api_version", + "api_base", + "force_timeout", + "logger_fn", + "verbose", + "custom_llm_provider", + "litellm_logging_obj", + "litellm_call_id", + "use_client", + "id", + "fallbacks", + "azure", + "headers", + "model_list", + "num_retries", + "context_window_fallback_dict", + "retry_policy", + "roles", + "final_prompt_value", + "bos_token", + "eos_token", + "request_timeout", + "complete_response", + "self", + "client", + "rpm", + "tpm", + "max_parallel_requests", + "input_cost_per_token", + "output_cost_per_token", + "input_cost_per_second", + "output_cost_per_second", + "hf_model_name", + "model_info", + "proxy_server_request", + "preset_cache_key", + "caching_groups", + "ttl", + "cache", + "no-log", + "base_model", + "stream_timeout", + "supports_system_message", + "region_name", + "allowed_model_region", + "model_config", + "fastest_response", + "cooldown_time", + "cache_key", + "max_retries", +] + + class LoggedLiteLLMParams(TypedDict, total=False): force_timeout: Optional[float] custom_llm_provider: Optional[str] diff --git a/litellm/utils.py b/litellm/utils.py index 825caf326..5948543bd 100644 --- a/litellm/utils.py +++ b/litellm/utils.py @@ -1084,7 +1084,7 @@ def client(original_function): and str(original_function.__name__) in litellm.cache.supported_call_types ): - print_verbose(f"Checking Cache") + print_verbose("Checking Cache") if call_type == CallTypes.aembedding.value and isinstance( kwargs["input"], list ):