forked from phoenix/litellm-mirror
feat(caching.py): enable caching on provider-specific optional params
Closes https://github.com/BerriAI/litellm/issues/5049
This commit is contained in:
parent
cd94c3adc1
commit
3c4c78a71f
7 changed files with 172 additions and 74 deletions
|
@ -146,6 +146,9 @@ return_response_headers: bool = (
|
|||
)
|
||||
##################
|
||||
logging: bool = True
|
||||
enable_caching_on_optional_params: bool = (
|
||||
False # feature-flag for caching on optional params - e.g. 'top_k'
|
||||
)
|
||||
caching: bool = (
|
||||
False # Not used anymore, will be removed in next MAJOR release - https://github.com/BerriAI/litellm/discussions/648
|
||||
)
|
||||
|
|
|
@ -23,6 +23,7 @@ import litellm
|
|||
from litellm._logging import verbose_logger
|
||||
from litellm.litellm_core_utils.core_helpers import _get_parent_otel_span_from_kwargs
|
||||
from litellm.types.services import ServiceLoggerPayload, ServiceTypes
|
||||
from litellm.types.utils import all_litellm_params
|
||||
|
||||
|
||||
def print_verbose(print_statement):
|
||||
|
@ -1838,6 +1839,7 @@ class Cache:
|
|||
"seed",
|
||||
"tools",
|
||||
"tool_choice",
|
||||
"stream",
|
||||
]
|
||||
embedding_only_kwargs = [
|
||||
"input",
|
||||
|
@ -1851,9 +1853,9 @@ class Cache:
|
|||
combined_kwargs = (
|
||||
completion_kwargs + embedding_only_kwargs + transcription_only_kwargs
|
||||
)
|
||||
for param in combined_kwargs:
|
||||
# ignore litellm params here
|
||||
if param in kwargs:
|
||||
litellm_param_kwargs = all_litellm_params
|
||||
for param in kwargs:
|
||||
if param in combined_kwargs:
|
||||
# check if param == model and model_group is passed in, then override model with model_group
|
||||
if param == "model":
|
||||
model_group = None
|
||||
|
@ -1897,6 +1899,17 @@ class Cache:
|
|||
continue # ignore None params
|
||||
param_value = kwargs[param]
|
||||
cache_key += f"{str(param)}: {str(param_value)}"
|
||||
elif (
|
||||
param not in litellm_param_kwargs
|
||||
): # check if user passed in optional param - e.g. top_k
|
||||
if (
|
||||
litellm.enable_caching_on_optional_params is True
|
||||
): # feature flagged for now
|
||||
if kwargs[param] is None:
|
||||
continue # ignore None params
|
||||
param_value = kwargs[param]
|
||||
cache_key += f"{str(param)}: {str(param_value)}"
|
||||
|
||||
print_verbose(f"\nCreated cache key: {cache_key}")
|
||||
# Use hashlib to create a sha256 hash of the cache key
|
||||
hash_object = hashlib.sha256(cache_key.encode())
|
||||
|
@ -2101,9 +2114,7 @@ class Cache:
|
|||
try:
|
||||
cache_list = []
|
||||
for idx, i in enumerate(kwargs["input"]):
|
||||
preset_cache_key = litellm.cache.get_cache_key(
|
||||
*args, **{**kwargs, "input": i}
|
||||
)
|
||||
preset_cache_key = self.get_cache_key(*args, **{**kwargs, "input": i})
|
||||
kwargs["cache_key"] = preset_cache_key
|
||||
embedding_response = result.data[idx]
|
||||
cache_key, cached_data, kwargs = self._add_cache_logic(
|
||||
|
|
|
@ -125,7 +125,11 @@ from .llms.vertex_ai_partner import VertexAIPartnerModels
|
|||
from .llms.vertex_httpx import VertexLLM
|
||||
from .llms.watsonx import IBMWatsonXAI
|
||||
from .types.llms.openai import HttpxBinaryResponseContent
|
||||
from .types.utils import AdapterCompletionStreamWrapper, ChatCompletionMessageToolCall
|
||||
from .types.utils import (
|
||||
AdapterCompletionStreamWrapper,
|
||||
ChatCompletionMessageToolCall,
|
||||
all_litellm_params,
|
||||
)
|
||||
|
||||
encoding = tiktoken.get_encoding("cl100k_base")
|
||||
from litellm.utils import (
|
||||
|
@ -744,64 +748,9 @@ def completion(
|
|||
"top_logprobs",
|
||||
"extra_headers",
|
||||
]
|
||||
litellm_params = [
|
||||
"metadata",
|
||||
"tags",
|
||||
"acompletion",
|
||||
"atext_completion",
|
||||
"text_completion",
|
||||
"caching",
|
||||
"mock_response",
|
||||
"api_key",
|
||||
"api_version",
|
||||
"api_base",
|
||||
"force_timeout",
|
||||
"logger_fn",
|
||||
"verbose",
|
||||
"custom_llm_provider",
|
||||
"litellm_logging_obj",
|
||||
"litellm_call_id",
|
||||
"use_client",
|
||||
"id",
|
||||
"fallbacks",
|
||||
"azure",
|
||||
"headers",
|
||||
"model_list",
|
||||
"num_retries",
|
||||
"context_window_fallback_dict",
|
||||
"retry_policy",
|
||||
"roles",
|
||||
"final_prompt_value",
|
||||
"bos_token",
|
||||
"eos_token",
|
||||
"request_timeout",
|
||||
"complete_response",
|
||||
"self",
|
||||
"client",
|
||||
"rpm",
|
||||
"tpm",
|
||||
"max_parallel_requests",
|
||||
"input_cost_per_token",
|
||||
"output_cost_per_token",
|
||||
"input_cost_per_second",
|
||||
"output_cost_per_second",
|
||||
"hf_model_name",
|
||||
"model_info",
|
||||
"proxy_server_request",
|
||||
"preset_cache_key",
|
||||
"caching_groups",
|
||||
"ttl",
|
||||
"cache",
|
||||
"no-log",
|
||||
"base_model",
|
||||
"stream_timeout",
|
||||
"supports_system_message",
|
||||
"region_name",
|
||||
"allowed_model_region",
|
||||
"model_config",
|
||||
"fastest_response",
|
||||
"cooldown_time",
|
||||
]
|
||||
litellm_params = (
|
||||
all_litellm_params # use the external var., used in creating cache key as well.
|
||||
)
|
||||
|
||||
default_params = openai_params + litellm_params
|
||||
non_default_params = {
|
||||
|
@ -5205,7 +5154,7 @@ def stream_chunk_builder(
|
|||
response["choices"][0]["message"]["function_call"][
|
||||
"arguments"
|
||||
] = combined_arguments
|
||||
|
||||
|
||||
content_chunks = [
|
||||
chunk
|
||||
for chunk in chunks
|
||||
|
|
BIN
litellm/tests/.litellm_cache/cache.db
Normal file
BIN
litellm/tests/.litellm_cache/cache.db
Normal file
Binary file not shown.
|
@ -207,11 +207,17 @@ async def test_caching_with_cache_controls(sync_flag):
|
|||
else:
|
||||
## TTL = 0
|
||||
response1 = await litellm.acompletion(
|
||||
model="gpt-3.5-turbo", messages=messages, cache={"ttl": 0}
|
||||
model="gpt-3.5-turbo",
|
||||
messages=messages,
|
||||
cache={"ttl": 0},
|
||||
mock_response="Hello world",
|
||||
)
|
||||
await asyncio.sleep(10)
|
||||
response2 = await litellm.acompletion(
|
||||
model="gpt-3.5-turbo", messages=messages, cache={"s-maxage": 10}
|
||||
model="gpt-3.5-turbo",
|
||||
messages=messages,
|
||||
cache={"s-maxage": 10},
|
||||
mock_response="Hello world",
|
||||
)
|
||||
|
||||
assert response2["id"] != response1["id"]
|
||||
|
@ -220,21 +226,33 @@ async def test_caching_with_cache_controls(sync_flag):
|
|||
## TTL = 5
|
||||
if sync_flag:
|
||||
response1 = completion(
|
||||
model="gpt-3.5-turbo", messages=messages, cache={"ttl": 5}
|
||||
model="gpt-3.5-turbo",
|
||||
messages=messages,
|
||||
cache={"ttl": 5},
|
||||
mock_response="Hello world",
|
||||
)
|
||||
response2 = completion(
|
||||
model="gpt-3.5-turbo", messages=messages, cache={"s-maxage": 5}
|
||||
model="gpt-3.5-turbo",
|
||||
messages=messages,
|
||||
cache={"s-maxage": 5},
|
||||
mock_response="Hello world",
|
||||
)
|
||||
print(f"response1: {response1}")
|
||||
print(f"response2: {response2}")
|
||||
assert response2["id"] == response1["id"]
|
||||
else:
|
||||
response1 = await litellm.acompletion(
|
||||
model="gpt-3.5-turbo", messages=messages, cache={"ttl": 25}
|
||||
model="gpt-3.5-turbo",
|
||||
messages=messages,
|
||||
cache={"ttl": 25},
|
||||
mock_response="Hello world",
|
||||
)
|
||||
await asyncio.sleep(10)
|
||||
response2 = await litellm.acompletion(
|
||||
model="gpt-3.5-turbo", messages=messages, cache={"s-maxage": 25}
|
||||
model="gpt-3.5-turbo",
|
||||
messages=messages,
|
||||
cache={"s-maxage": 25},
|
||||
mock_response="Hello world",
|
||||
)
|
||||
print(f"response1: {response1}")
|
||||
print(f"response2: {response2}")
|
||||
|
@ -282,6 +300,61 @@ def test_caching_with_models_v2():
|
|||
|
||||
# test_caching_with_models_v2()
|
||||
|
||||
|
||||
def test_caching_with_optional_params():
|
||||
litellm.enable_caching_on_optional_params = True
|
||||
messages = [
|
||||
{"role": "user", "content": "who is ishaan CTO of litellm from litellm 2023"}
|
||||
]
|
||||
litellm.cache = Cache()
|
||||
print("test2 for caching")
|
||||
litellm.set_verbose = True
|
||||
|
||||
response1 = completion(
|
||||
model="gpt-3.5-turbo",
|
||||
messages=messages,
|
||||
top_k=10,
|
||||
caching=True,
|
||||
mock_response="Hello: {}".format(uuid.uuid4()),
|
||||
)
|
||||
response2 = completion(
|
||||
model="gpt-3.5-turbo",
|
||||
messages=messages,
|
||||
top_k=10,
|
||||
caching=True,
|
||||
mock_response="Hello: {}".format(uuid.uuid4()),
|
||||
)
|
||||
response3 = completion(
|
||||
model="gpt-3.5-turbo",
|
||||
messages=messages,
|
||||
top_k=9,
|
||||
caching=True,
|
||||
mock_response="Hello: {}".format(uuid.uuid4()),
|
||||
)
|
||||
print(f"response1: {response1}")
|
||||
print(f"response2: {response2}")
|
||||
print(f"response3: {response3}")
|
||||
litellm.cache = None
|
||||
litellm.success_callback = []
|
||||
litellm._async_success_callback = []
|
||||
if (
|
||||
response3["choices"][0]["message"]["content"]
|
||||
== response2["choices"][0]["message"]["content"]
|
||||
):
|
||||
# if models are different, it should not return cached response
|
||||
print(f"response2: {response2}")
|
||||
print(f"response3: {response3}")
|
||||
pytest.fail(f"Error occurred:")
|
||||
if (
|
||||
response1["choices"][0]["message"]["content"]
|
||||
!= response2["choices"][0]["message"]["content"]
|
||||
):
|
||||
print(f"response1: {response1}")
|
||||
print(f"response2: {response2}")
|
||||
pytest.fail(f"Error occurred:")
|
||||
litellm.enable_caching_on_optional_params = False
|
||||
|
||||
|
||||
embedding_large_text = (
|
||||
"""
|
||||
small text
|
||||
|
@ -1347,7 +1420,7 @@ def test_get_cache_key():
|
|||
"litellm_logging_obj": {},
|
||||
}
|
||||
)
|
||||
cache_key_str = "model: gpt-3.5-turbomessages: [{'role': 'user', 'content': 'write a one sentence poem about: 7510'}]temperature: 0.2max_tokens: 40"
|
||||
cache_key_str = "model: gpt-3.5-turbomessages: [{'role': 'user', 'content': 'write a one sentence poem about: 7510'}]max_tokens: 40temperature: 0.2stream: True"
|
||||
hash_object = hashlib.sha256(cache_key_str.encode())
|
||||
# Hexadecimal representation of the hash
|
||||
hash_hex = hash_object.hexdigest()
|
||||
|
|
|
@ -1052,6 +1052,68 @@ class ResponseFormatChunk(TypedDict, total=False):
|
|||
response_schema: dict
|
||||
|
||||
|
||||
all_litellm_params = [
|
||||
"metadata",
|
||||
"tags",
|
||||
"acompletion",
|
||||
"atext_completion",
|
||||
"text_completion",
|
||||
"caching",
|
||||
"mock_response",
|
||||
"api_key",
|
||||
"api_version",
|
||||
"api_base",
|
||||
"force_timeout",
|
||||
"logger_fn",
|
||||
"verbose",
|
||||
"custom_llm_provider",
|
||||
"litellm_logging_obj",
|
||||
"litellm_call_id",
|
||||
"use_client",
|
||||
"id",
|
||||
"fallbacks",
|
||||
"azure",
|
||||
"headers",
|
||||
"model_list",
|
||||
"num_retries",
|
||||
"context_window_fallback_dict",
|
||||
"retry_policy",
|
||||
"roles",
|
||||
"final_prompt_value",
|
||||
"bos_token",
|
||||
"eos_token",
|
||||
"request_timeout",
|
||||
"complete_response",
|
||||
"self",
|
||||
"client",
|
||||
"rpm",
|
||||
"tpm",
|
||||
"max_parallel_requests",
|
||||
"input_cost_per_token",
|
||||
"output_cost_per_token",
|
||||
"input_cost_per_second",
|
||||
"output_cost_per_second",
|
||||
"hf_model_name",
|
||||
"model_info",
|
||||
"proxy_server_request",
|
||||
"preset_cache_key",
|
||||
"caching_groups",
|
||||
"ttl",
|
||||
"cache",
|
||||
"no-log",
|
||||
"base_model",
|
||||
"stream_timeout",
|
||||
"supports_system_message",
|
||||
"region_name",
|
||||
"allowed_model_region",
|
||||
"model_config",
|
||||
"fastest_response",
|
||||
"cooldown_time",
|
||||
"cache_key",
|
||||
"max_retries",
|
||||
]
|
||||
|
||||
|
||||
class LoggedLiteLLMParams(TypedDict, total=False):
|
||||
force_timeout: Optional[float]
|
||||
custom_llm_provider: Optional[str]
|
||||
|
|
|
@ -1084,7 +1084,7 @@ def client(original_function):
|
|||
and str(original_function.__name__)
|
||||
in litellm.cache.supported_call_types
|
||||
):
|
||||
print_verbose(f"Checking Cache")
|
||||
print_verbose("Checking Cache")
|
||||
if call_type == CallTypes.aembedding.value and isinstance(
|
||||
kwargs["input"], list
|
||||
):
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue