forked from phoenix/litellm-mirror
feat(caching.py): enable caching on provider-specific optional params
Closes https://github.com/BerriAI/litellm/issues/5049
This commit is contained in:
parent
cd94c3adc1
commit
3c4c78a71f
7 changed files with 172 additions and 74 deletions
|
@ -146,6 +146,9 @@ return_response_headers: bool = (
|
||||||
)
|
)
|
||||||
##################
|
##################
|
||||||
logging: bool = True
|
logging: bool = True
|
||||||
|
enable_caching_on_optional_params: bool = (
|
||||||
|
False # feature-flag for caching on optional params - e.g. 'top_k'
|
||||||
|
)
|
||||||
caching: bool = (
|
caching: bool = (
|
||||||
False # Not used anymore, will be removed in next MAJOR release - https://github.com/BerriAI/litellm/discussions/648
|
False # Not used anymore, will be removed in next MAJOR release - https://github.com/BerriAI/litellm/discussions/648
|
||||||
)
|
)
|
||||||
|
|
|
@ -23,6 +23,7 @@ import litellm
|
||||||
from litellm._logging import verbose_logger
|
from litellm._logging import verbose_logger
|
||||||
from litellm.litellm_core_utils.core_helpers import _get_parent_otel_span_from_kwargs
|
from litellm.litellm_core_utils.core_helpers import _get_parent_otel_span_from_kwargs
|
||||||
from litellm.types.services import ServiceLoggerPayload, ServiceTypes
|
from litellm.types.services import ServiceLoggerPayload, ServiceTypes
|
||||||
|
from litellm.types.utils import all_litellm_params
|
||||||
|
|
||||||
|
|
||||||
def print_verbose(print_statement):
|
def print_verbose(print_statement):
|
||||||
|
@ -1838,6 +1839,7 @@ class Cache:
|
||||||
"seed",
|
"seed",
|
||||||
"tools",
|
"tools",
|
||||||
"tool_choice",
|
"tool_choice",
|
||||||
|
"stream",
|
||||||
]
|
]
|
||||||
embedding_only_kwargs = [
|
embedding_only_kwargs = [
|
||||||
"input",
|
"input",
|
||||||
|
@ -1851,9 +1853,9 @@ class Cache:
|
||||||
combined_kwargs = (
|
combined_kwargs = (
|
||||||
completion_kwargs + embedding_only_kwargs + transcription_only_kwargs
|
completion_kwargs + embedding_only_kwargs + transcription_only_kwargs
|
||||||
)
|
)
|
||||||
for param in combined_kwargs:
|
litellm_param_kwargs = all_litellm_params
|
||||||
# ignore litellm params here
|
for param in kwargs:
|
||||||
if param in kwargs:
|
if param in combined_kwargs:
|
||||||
# check if param == model and model_group is passed in, then override model with model_group
|
# check if param == model and model_group is passed in, then override model with model_group
|
||||||
if param == "model":
|
if param == "model":
|
||||||
model_group = None
|
model_group = None
|
||||||
|
@ -1897,6 +1899,17 @@ class Cache:
|
||||||
continue # ignore None params
|
continue # ignore None params
|
||||||
param_value = kwargs[param]
|
param_value = kwargs[param]
|
||||||
cache_key += f"{str(param)}: {str(param_value)}"
|
cache_key += f"{str(param)}: {str(param_value)}"
|
||||||
|
elif (
|
||||||
|
param not in litellm_param_kwargs
|
||||||
|
): # check if user passed in optional param - e.g. top_k
|
||||||
|
if (
|
||||||
|
litellm.enable_caching_on_optional_params is True
|
||||||
|
): # feature flagged for now
|
||||||
|
if kwargs[param] is None:
|
||||||
|
continue # ignore None params
|
||||||
|
param_value = kwargs[param]
|
||||||
|
cache_key += f"{str(param)}: {str(param_value)}"
|
||||||
|
|
||||||
print_verbose(f"\nCreated cache key: {cache_key}")
|
print_verbose(f"\nCreated cache key: {cache_key}")
|
||||||
# Use hashlib to create a sha256 hash of the cache key
|
# Use hashlib to create a sha256 hash of the cache key
|
||||||
hash_object = hashlib.sha256(cache_key.encode())
|
hash_object = hashlib.sha256(cache_key.encode())
|
||||||
|
@ -2101,9 +2114,7 @@ class Cache:
|
||||||
try:
|
try:
|
||||||
cache_list = []
|
cache_list = []
|
||||||
for idx, i in enumerate(kwargs["input"]):
|
for idx, i in enumerate(kwargs["input"]):
|
||||||
preset_cache_key = litellm.cache.get_cache_key(
|
preset_cache_key = self.get_cache_key(*args, **{**kwargs, "input": i})
|
||||||
*args, **{**kwargs, "input": i}
|
|
||||||
)
|
|
||||||
kwargs["cache_key"] = preset_cache_key
|
kwargs["cache_key"] = preset_cache_key
|
||||||
embedding_response = result.data[idx]
|
embedding_response = result.data[idx]
|
||||||
cache_key, cached_data, kwargs = self._add_cache_logic(
|
cache_key, cached_data, kwargs = self._add_cache_logic(
|
||||||
|
|
|
@ -125,7 +125,11 @@ from .llms.vertex_ai_partner import VertexAIPartnerModels
|
||||||
from .llms.vertex_httpx import VertexLLM
|
from .llms.vertex_httpx import VertexLLM
|
||||||
from .llms.watsonx import IBMWatsonXAI
|
from .llms.watsonx import IBMWatsonXAI
|
||||||
from .types.llms.openai import HttpxBinaryResponseContent
|
from .types.llms.openai import HttpxBinaryResponseContent
|
||||||
from .types.utils import AdapterCompletionStreamWrapper, ChatCompletionMessageToolCall
|
from .types.utils import (
|
||||||
|
AdapterCompletionStreamWrapper,
|
||||||
|
ChatCompletionMessageToolCall,
|
||||||
|
all_litellm_params,
|
||||||
|
)
|
||||||
|
|
||||||
encoding = tiktoken.get_encoding("cl100k_base")
|
encoding = tiktoken.get_encoding("cl100k_base")
|
||||||
from litellm.utils import (
|
from litellm.utils import (
|
||||||
|
@ -744,64 +748,9 @@ def completion(
|
||||||
"top_logprobs",
|
"top_logprobs",
|
||||||
"extra_headers",
|
"extra_headers",
|
||||||
]
|
]
|
||||||
litellm_params = [
|
litellm_params = (
|
||||||
"metadata",
|
all_litellm_params # use the external var., used in creating cache key as well.
|
||||||
"tags",
|
)
|
||||||
"acompletion",
|
|
||||||
"atext_completion",
|
|
||||||
"text_completion",
|
|
||||||
"caching",
|
|
||||||
"mock_response",
|
|
||||||
"api_key",
|
|
||||||
"api_version",
|
|
||||||
"api_base",
|
|
||||||
"force_timeout",
|
|
||||||
"logger_fn",
|
|
||||||
"verbose",
|
|
||||||
"custom_llm_provider",
|
|
||||||
"litellm_logging_obj",
|
|
||||||
"litellm_call_id",
|
|
||||||
"use_client",
|
|
||||||
"id",
|
|
||||||
"fallbacks",
|
|
||||||
"azure",
|
|
||||||
"headers",
|
|
||||||
"model_list",
|
|
||||||
"num_retries",
|
|
||||||
"context_window_fallback_dict",
|
|
||||||
"retry_policy",
|
|
||||||
"roles",
|
|
||||||
"final_prompt_value",
|
|
||||||
"bos_token",
|
|
||||||
"eos_token",
|
|
||||||
"request_timeout",
|
|
||||||
"complete_response",
|
|
||||||
"self",
|
|
||||||
"client",
|
|
||||||
"rpm",
|
|
||||||
"tpm",
|
|
||||||
"max_parallel_requests",
|
|
||||||
"input_cost_per_token",
|
|
||||||
"output_cost_per_token",
|
|
||||||
"input_cost_per_second",
|
|
||||||
"output_cost_per_second",
|
|
||||||
"hf_model_name",
|
|
||||||
"model_info",
|
|
||||||
"proxy_server_request",
|
|
||||||
"preset_cache_key",
|
|
||||||
"caching_groups",
|
|
||||||
"ttl",
|
|
||||||
"cache",
|
|
||||||
"no-log",
|
|
||||||
"base_model",
|
|
||||||
"stream_timeout",
|
|
||||||
"supports_system_message",
|
|
||||||
"region_name",
|
|
||||||
"allowed_model_region",
|
|
||||||
"model_config",
|
|
||||||
"fastest_response",
|
|
||||||
"cooldown_time",
|
|
||||||
]
|
|
||||||
|
|
||||||
default_params = openai_params + litellm_params
|
default_params = openai_params + litellm_params
|
||||||
non_default_params = {
|
non_default_params = {
|
||||||
|
@ -5205,7 +5154,7 @@ def stream_chunk_builder(
|
||||||
response["choices"][0]["message"]["function_call"][
|
response["choices"][0]["message"]["function_call"][
|
||||||
"arguments"
|
"arguments"
|
||||||
] = combined_arguments
|
] = combined_arguments
|
||||||
|
|
||||||
content_chunks = [
|
content_chunks = [
|
||||||
chunk
|
chunk
|
||||||
for chunk in chunks
|
for chunk in chunks
|
||||||
|
|
BIN
litellm/tests/.litellm_cache/cache.db
Normal file
BIN
litellm/tests/.litellm_cache/cache.db
Normal file
Binary file not shown.
|
@ -207,11 +207,17 @@ async def test_caching_with_cache_controls(sync_flag):
|
||||||
else:
|
else:
|
||||||
## TTL = 0
|
## TTL = 0
|
||||||
response1 = await litellm.acompletion(
|
response1 = await litellm.acompletion(
|
||||||
model="gpt-3.5-turbo", messages=messages, cache={"ttl": 0}
|
model="gpt-3.5-turbo",
|
||||||
|
messages=messages,
|
||||||
|
cache={"ttl": 0},
|
||||||
|
mock_response="Hello world",
|
||||||
)
|
)
|
||||||
await asyncio.sleep(10)
|
await asyncio.sleep(10)
|
||||||
response2 = await litellm.acompletion(
|
response2 = await litellm.acompletion(
|
||||||
model="gpt-3.5-turbo", messages=messages, cache={"s-maxage": 10}
|
model="gpt-3.5-turbo",
|
||||||
|
messages=messages,
|
||||||
|
cache={"s-maxage": 10},
|
||||||
|
mock_response="Hello world",
|
||||||
)
|
)
|
||||||
|
|
||||||
assert response2["id"] != response1["id"]
|
assert response2["id"] != response1["id"]
|
||||||
|
@ -220,21 +226,33 @@ async def test_caching_with_cache_controls(sync_flag):
|
||||||
## TTL = 5
|
## TTL = 5
|
||||||
if sync_flag:
|
if sync_flag:
|
||||||
response1 = completion(
|
response1 = completion(
|
||||||
model="gpt-3.5-turbo", messages=messages, cache={"ttl": 5}
|
model="gpt-3.5-turbo",
|
||||||
|
messages=messages,
|
||||||
|
cache={"ttl": 5},
|
||||||
|
mock_response="Hello world",
|
||||||
)
|
)
|
||||||
response2 = completion(
|
response2 = completion(
|
||||||
model="gpt-3.5-turbo", messages=messages, cache={"s-maxage": 5}
|
model="gpt-3.5-turbo",
|
||||||
|
messages=messages,
|
||||||
|
cache={"s-maxage": 5},
|
||||||
|
mock_response="Hello world",
|
||||||
)
|
)
|
||||||
print(f"response1: {response1}")
|
print(f"response1: {response1}")
|
||||||
print(f"response2: {response2}")
|
print(f"response2: {response2}")
|
||||||
assert response2["id"] == response1["id"]
|
assert response2["id"] == response1["id"]
|
||||||
else:
|
else:
|
||||||
response1 = await litellm.acompletion(
|
response1 = await litellm.acompletion(
|
||||||
model="gpt-3.5-turbo", messages=messages, cache={"ttl": 25}
|
model="gpt-3.5-turbo",
|
||||||
|
messages=messages,
|
||||||
|
cache={"ttl": 25},
|
||||||
|
mock_response="Hello world",
|
||||||
)
|
)
|
||||||
await asyncio.sleep(10)
|
await asyncio.sleep(10)
|
||||||
response2 = await litellm.acompletion(
|
response2 = await litellm.acompletion(
|
||||||
model="gpt-3.5-turbo", messages=messages, cache={"s-maxage": 25}
|
model="gpt-3.5-turbo",
|
||||||
|
messages=messages,
|
||||||
|
cache={"s-maxage": 25},
|
||||||
|
mock_response="Hello world",
|
||||||
)
|
)
|
||||||
print(f"response1: {response1}")
|
print(f"response1: {response1}")
|
||||||
print(f"response2: {response2}")
|
print(f"response2: {response2}")
|
||||||
|
@ -282,6 +300,61 @@ def test_caching_with_models_v2():
|
||||||
|
|
||||||
# test_caching_with_models_v2()
|
# test_caching_with_models_v2()
|
||||||
|
|
||||||
|
|
||||||
|
def test_caching_with_optional_params():
|
||||||
|
litellm.enable_caching_on_optional_params = True
|
||||||
|
messages = [
|
||||||
|
{"role": "user", "content": "who is ishaan CTO of litellm from litellm 2023"}
|
||||||
|
]
|
||||||
|
litellm.cache = Cache()
|
||||||
|
print("test2 for caching")
|
||||||
|
litellm.set_verbose = True
|
||||||
|
|
||||||
|
response1 = completion(
|
||||||
|
model="gpt-3.5-turbo",
|
||||||
|
messages=messages,
|
||||||
|
top_k=10,
|
||||||
|
caching=True,
|
||||||
|
mock_response="Hello: {}".format(uuid.uuid4()),
|
||||||
|
)
|
||||||
|
response2 = completion(
|
||||||
|
model="gpt-3.5-turbo",
|
||||||
|
messages=messages,
|
||||||
|
top_k=10,
|
||||||
|
caching=True,
|
||||||
|
mock_response="Hello: {}".format(uuid.uuid4()),
|
||||||
|
)
|
||||||
|
response3 = completion(
|
||||||
|
model="gpt-3.5-turbo",
|
||||||
|
messages=messages,
|
||||||
|
top_k=9,
|
||||||
|
caching=True,
|
||||||
|
mock_response="Hello: {}".format(uuid.uuid4()),
|
||||||
|
)
|
||||||
|
print(f"response1: {response1}")
|
||||||
|
print(f"response2: {response2}")
|
||||||
|
print(f"response3: {response3}")
|
||||||
|
litellm.cache = None
|
||||||
|
litellm.success_callback = []
|
||||||
|
litellm._async_success_callback = []
|
||||||
|
if (
|
||||||
|
response3["choices"][0]["message"]["content"]
|
||||||
|
== response2["choices"][0]["message"]["content"]
|
||||||
|
):
|
||||||
|
# if models are different, it should not return cached response
|
||||||
|
print(f"response2: {response2}")
|
||||||
|
print(f"response3: {response3}")
|
||||||
|
pytest.fail(f"Error occurred:")
|
||||||
|
if (
|
||||||
|
response1["choices"][0]["message"]["content"]
|
||||||
|
!= response2["choices"][0]["message"]["content"]
|
||||||
|
):
|
||||||
|
print(f"response1: {response1}")
|
||||||
|
print(f"response2: {response2}")
|
||||||
|
pytest.fail(f"Error occurred:")
|
||||||
|
litellm.enable_caching_on_optional_params = False
|
||||||
|
|
||||||
|
|
||||||
embedding_large_text = (
|
embedding_large_text = (
|
||||||
"""
|
"""
|
||||||
small text
|
small text
|
||||||
|
@ -1347,7 +1420,7 @@ def test_get_cache_key():
|
||||||
"litellm_logging_obj": {},
|
"litellm_logging_obj": {},
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
cache_key_str = "model: gpt-3.5-turbomessages: [{'role': 'user', 'content': 'write a one sentence poem about: 7510'}]temperature: 0.2max_tokens: 40"
|
cache_key_str = "model: gpt-3.5-turbomessages: [{'role': 'user', 'content': 'write a one sentence poem about: 7510'}]max_tokens: 40temperature: 0.2stream: True"
|
||||||
hash_object = hashlib.sha256(cache_key_str.encode())
|
hash_object = hashlib.sha256(cache_key_str.encode())
|
||||||
# Hexadecimal representation of the hash
|
# Hexadecimal representation of the hash
|
||||||
hash_hex = hash_object.hexdigest()
|
hash_hex = hash_object.hexdigest()
|
||||||
|
|
|
@ -1052,6 +1052,68 @@ class ResponseFormatChunk(TypedDict, total=False):
|
||||||
response_schema: dict
|
response_schema: dict
|
||||||
|
|
||||||
|
|
||||||
|
all_litellm_params = [
|
||||||
|
"metadata",
|
||||||
|
"tags",
|
||||||
|
"acompletion",
|
||||||
|
"atext_completion",
|
||||||
|
"text_completion",
|
||||||
|
"caching",
|
||||||
|
"mock_response",
|
||||||
|
"api_key",
|
||||||
|
"api_version",
|
||||||
|
"api_base",
|
||||||
|
"force_timeout",
|
||||||
|
"logger_fn",
|
||||||
|
"verbose",
|
||||||
|
"custom_llm_provider",
|
||||||
|
"litellm_logging_obj",
|
||||||
|
"litellm_call_id",
|
||||||
|
"use_client",
|
||||||
|
"id",
|
||||||
|
"fallbacks",
|
||||||
|
"azure",
|
||||||
|
"headers",
|
||||||
|
"model_list",
|
||||||
|
"num_retries",
|
||||||
|
"context_window_fallback_dict",
|
||||||
|
"retry_policy",
|
||||||
|
"roles",
|
||||||
|
"final_prompt_value",
|
||||||
|
"bos_token",
|
||||||
|
"eos_token",
|
||||||
|
"request_timeout",
|
||||||
|
"complete_response",
|
||||||
|
"self",
|
||||||
|
"client",
|
||||||
|
"rpm",
|
||||||
|
"tpm",
|
||||||
|
"max_parallel_requests",
|
||||||
|
"input_cost_per_token",
|
||||||
|
"output_cost_per_token",
|
||||||
|
"input_cost_per_second",
|
||||||
|
"output_cost_per_second",
|
||||||
|
"hf_model_name",
|
||||||
|
"model_info",
|
||||||
|
"proxy_server_request",
|
||||||
|
"preset_cache_key",
|
||||||
|
"caching_groups",
|
||||||
|
"ttl",
|
||||||
|
"cache",
|
||||||
|
"no-log",
|
||||||
|
"base_model",
|
||||||
|
"stream_timeout",
|
||||||
|
"supports_system_message",
|
||||||
|
"region_name",
|
||||||
|
"allowed_model_region",
|
||||||
|
"model_config",
|
||||||
|
"fastest_response",
|
||||||
|
"cooldown_time",
|
||||||
|
"cache_key",
|
||||||
|
"max_retries",
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
class LoggedLiteLLMParams(TypedDict, total=False):
|
class LoggedLiteLLMParams(TypedDict, total=False):
|
||||||
force_timeout: Optional[float]
|
force_timeout: Optional[float]
|
||||||
custom_llm_provider: Optional[str]
|
custom_llm_provider: Optional[str]
|
||||||
|
|
|
@ -1084,7 +1084,7 @@ def client(original_function):
|
||||||
and str(original_function.__name__)
|
and str(original_function.__name__)
|
||||||
in litellm.cache.supported_call_types
|
in litellm.cache.supported_call_types
|
||||||
):
|
):
|
||||||
print_verbose(f"Checking Cache")
|
print_verbose("Checking Cache")
|
||||||
if call_type == CallTypes.aembedding.value and isinstance(
|
if call_type == CallTypes.aembedding.value and isinstance(
|
||||||
kwargs["input"], list
|
kwargs["input"], list
|
||||||
):
|
):
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue