caching with model names

This commit is contained in:
ishaan-jaff 2023-08-18 13:48:32 -07:00
parent 694a8ad90c
commit d0ba3ba2e5
4 changed files with 44 additions and 9 deletions

View file

@ -20,6 +20,7 @@ vertex_location: Optional[str] = None
hugging_api_token: Optional[str] = None hugging_api_token: Optional[str] = None
togetherai_api_key: Optional[str] = None togetherai_api_key: Optional[str] = None
caching = False caching = False
caching_with_models = False # if you want the caching key to be model + prompt
model_cost = { model_cost = {
"gpt-3.5-turbo": { "gpt-3.5-turbo": {
"max_tokens": 4000, "max_tokens": 4000,

View file

@ -50,4 +50,4 @@ try:
except: except:
print(f"error occurred: {traceback.format_exc()}") print(f"error occurred: {traceback.format_exc()}")
pass pass
os.environ["OPENAI_API_KEY"] = str(temp_key) # this passes linting#5 os.environ["OPENAI_API_KEY"] = str(temp_key) # this passes linting#5

View file

@ -12,13 +12,13 @@ import pytest
import litellm import litellm
from litellm import embedding, completion from litellm import embedding, completion
litellm.caching = True
messages = [{"role": "user", "content": "who is ishaan Github? "}] messages = [{"role": "user", "content": "who is ishaan Github? "}]
# test if response cached # test if response cached
def test_caching(): def test_caching():
try: try:
litellm.caching = True
response1 = completion(model="gpt-3.5-turbo", messages=messages) response1 = completion(model="gpt-3.5-turbo", messages=messages)
response2 = completion(model="gpt-3.5-turbo", messages=messages) response2 = completion(model="gpt-3.5-turbo", messages=messages)
print(f"response1: {response1}") print(f"response1: {response1}")
@ -32,3 +32,21 @@ def test_caching():
litellm.caching = False litellm.caching = False
print(f"error occurred: {traceback.format_exc()}") print(f"error occurred: {traceback.format_exc()}")
pytest.fail(f"Error occurred: {e}") pytest.fail(f"Error occurred: {e}")
def test_caching_with_models():
litellm.caching_with_models = True
response2 = completion(model="gpt-3.5-turbo", messages=messages)
response3 = completion(model="command-nightly", messages=messages)
print(f"response2: {response2}")
print(f"response3: {response3}")
litellm.caching_with_models = False
if response3 == response2:
# if models are different, it should not return cached response
print(f"response2: {response2}")
print(f"response3: {response3}")
pytest.fail(f"Error occurred: {e}")

View file

@ -263,8 +263,17 @@ def client(original_function):
if ( if (
prompt != None and prompt in local_cache prompt != None and prompt in local_cache
): # check if messages / prompt exists ): # check if messages / prompt exists
result = local_cache[prompt] if litellm.caching_with_models:
return result # if caching with model names is enabled, key is prompt + model name
if (
"model" in kwargs
and kwargs["model"] in local_cache[prompt]["models"]
):
cache_key = prompt + kwargs["model"]
return local_cache[cache_key]
else: # caching only with prompts
result = local_cache[prompt]
return result
else: else:
return None return None
except: except:
@ -273,7 +282,15 @@ def client(original_function):
def add_cache(result, *args, **kwargs): def add_cache(result, *args, **kwargs):
try: # never block execution try: # never block execution
prompt = get_prompt(*args, **kwargs) prompt = get_prompt(*args, **kwargs)
local_cache[prompt] = result if litellm.caching_with_models: # caching with model + prompt
if (
"model" in kwargs
and kwargs["model"] in local_cache[prompt]["models"]
):
cache_key = prompt + kwargs["model"]
local_cache[cache_key] = result
else: # caching based only on prompts
local_cache[prompt] = result
except: except:
pass pass
@ -284,10 +301,9 @@ def client(original_function):
function_setup(*args, **kwargs) function_setup(*args, **kwargs)
## MODEL CALL ## MODEL CALL
start_time = datetime.datetime.now() start_time = datetime.datetime.now()
if ( if (litellm.caching or litellm.caching_with_models) and (
litellm.caching cached_result := check_cache(*args, **kwargs)
and (cached_result := check_cache(*args, **kwargs)) is not None ) is not None:
):
result = cached_result result = cached_result
else: else:
result = original_function(*args, **kwargs) result = original_function(*args, **kwargs)