mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-25 18:54:30 +00:00
caching with model names
This commit is contained in:
parent
694a8ad90c
commit
d0ba3ba2e5
4 changed files with 44 additions and 9 deletions
|
@ -20,6 +20,7 @@ vertex_location: Optional[str] = None
|
||||||
hugging_api_token: Optional[str] = None
|
hugging_api_token: Optional[str] = None
|
||||||
togetherai_api_key: Optional[str] = None
|
togetherai_api_key: Optional[str] = None
|
||||||
caching = False
|
caching = False
|
||||||
|
caching_with_models = False # if you want the caching key to be model + prompt
|
||||||
model_cost = {
|
model_cost = {
|
||||||
"gpt-3.5-turbo": {
|
"gpt-3.5-turbo": {
|
||||||
"max_tokens": 4000,
|
"max_tokens": 4000,
|
||||||
|
|
|
@ -12,13 +12,13 @@ import pytest
|
||||||
import litellm
|
import litellm
|
||||||
from litellm import embedding, completion
|
from litellm import embedding, completion
|
||||||
|
|
||||||
litellm.caching = True
|
|
||||||
messages = [{"role": "user", "content": "who is ishaan Github? "}]
|
messages = [{"role": "user", "content": "who is ishaan Github? "}]
|
||||||
|
|
||||||
|
|
||||||
# test if response cached
|
# test if response cached
|
||||||
def test_caching():
|
def test_caching():
|
||||||
try:
|
try:
|
||||||
|
litellm.caching = True
|
||||||
response1 = completion(model="gpt-3.5-turbo", messages=messages)
|
response1 = completion(model="gpt-3.5-turbo", messages=messages)
|
||||||
response2 = completion(model="gpt-3.5-turbo", messages=messages)
|
response2 = completion(model="gpt-3.5-turbo", messages=messages)
|
||||||
print(f"response1: {response1}")
|
print(f"response1: {response1}")
|
||||||
|
@ -32,3 +32,21 @@ def test_caching():
|
||||||
litellm.caching = False
|
litellm.caching = False
|
||||||
print(f"error occurred: {traceback.format_exc()}")
|
print(f"error occurred: {traceback.format_exc()}")
|
||||||
pytest.fail(f"Error occurred: {e}")
|
pytest.fail(f"Error occurred: {e}")
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
def test_caching_with_models():
|
||||||
|
litellm.caching_with_models = True
|
||||||
|
response2 = completion(model="gpt-3.5-turbo", messages=messages)
|
||||||
|
response3 = completion(model="command-nightly", messages=messages)
|
||||||
|
print(f"response2: {response2}")
|
||||||
|
print(f"response3: {response3}")
|
||||||
|
litellm.caching_with_models = False
|
||||||
|
if response3 == response2:
|
||||||
|
# if models are different, it should not return cached response
|
||||||
|
print(f"response2: {response2}")
|
||||||
|
print(f"response3: {response3}")
|
||||||
|
pytest.fail(f"Error occurred: {e}")
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -263,6 +263,15 @@ def client(original_function):
|
||||||
if (
|
if (
|
||||||
prompt != None and prompt in local_cache
|
prompt != None and prompt in local_cache
|
||||||
): # check if messages / prompt exists
|
): # check if messages / prompt exists
|
||||||
|
if litellm.caching_with_models:
|
||||||
|
# if caching with model names is enabled, key is prompt + model name
|
||||||
|
if (
|
||||||
|
"model" in kwargs
|
||||||
|
and kwargs["model"] in local_cache[prompt]["models"]
|
||||||
|
):
|
||||||
|
cache_key = prompt + kwargs["model"]
|
||||||
|
return local_cache[cache_key]
|
||||||
|
else: # caching only with prompts
|
||||||
result = local_cache[prompt]
|
result = local_cache[prompt]
|
||||||
return result
|
return result
|
||||||
else:
|
else:
|
||||||
|
@ -273,6 +282,14 @@ def client(original_function):
|
||||||
def add_cache(result, *args, **kwargs):
|
def add_cache(result, *args, **kwargs):
|
||||||
try: # never block execution
|
try: # never block execution
|
||||||
prompt = get_prompt(*args, **kwargs)
|
prompt = get_prompt(*args, **kwargs)
|
||||||
|
if litellm.caching_with_models: # caching with model + prompt
|
||||||
|
if (
|
||||||
|
"model" in kwargs
|
||||||
|
and kwargs["model"] in local_cache[prompt]["models"]
|
||||||
|
):
|
||||||
|
cache_key = prompt + kwargs["model"]
|
||||||
|
local_cache[cache_key] = result
|
||||||
|
else: # caching based only on prompts
|
||||||
local_cache[prompt] = result
|
local_cache[prompt] = result
|
||||||
except:
|
except:
|
||||||
pass
|
pass
|
||||||
|
@ -284,10 +301,9 @@ def client(original_function):
|
||||||
function_setup(*args, **kwargs)
|
function_setup(*args, **kwargs)
|
||||||
## MODEL CALL
|
## MODEL CALL
|
||||||
start_time = datetime.datetime.now()
|
start_time = datetime.datetime.now()
|
||||||
if (
|
if (litellm.caching or litellm.caching_with_models) and (
|
||||||
litellm.caching
|
cached_result := check_cache(*args, **kwargs)
|
||||||
and (cached_result := check_cache(*args, **kwargs)) is not None
|
) is not None:
|
||||||
):
|
|
||||||
result = cached_result
|
result = cached_result
|
||||||
else:
|
else:
|
||||||
result = original_function(*args, **kwargs)
|
result = original_function(*args, **kwargs)
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue