From d80f847fdeb949b59a698a6c1a9d4edf00c8e979 Mon Sep 17 00:00:00 2001 From: ishaan-jaff Date: Thu, 10 Aug 2023 13:19:48 -0700 Subject: [PATCH] add caching with chromDB - not a dependency --- .circleci/config.yml | 3 ++- .circleci/requirements.txt | 5 ++++ litellm/__init__.py | 7 ++++++ litellm/main.py | 8 +++++- litellm/tests/test_cache.py | 43 ++++++++++++++++++++++++++++++++ litellm/utils.py | 49 +++++++++++++++++++++++++++++++++++++ 6 files changed, 113 insertions(+), 2 deletions(-) create mode 100644 .circleci/requirements.txt create mode 100644 litellm/tests/test_cache.py diff --git a/.circleci/config.yml b/.circleci/config.yml index 397031de7..74c4b4893 100644 --- a/.circleci/config.yml +++ b/.circleci/config.yml @@ -12,10 +12,11 @@ jobs: name: Install Dependencies command: | python -m pip install --upgrade pip - python -m pip install -r requirements.txt + python -m pip install -r .circleci/requirements.txt pip install infisical pip install pytest pip install openai[datalib] + pip install chromadb # Run pytest and generate JUnit XML report - run: diff --git a/.circleci/requirements.txt b/.circleci/requirements.txt new file mode 100644 index 000000000..56f796b35 --- /dev/null +++ b/.circleci/requirements.txt @@ -0,0 +1,5 @@ +# used by CI/CD testing +openai +python-dotenv +openai +tiktoken \ No newline at end of file diff --git a/litellm/__init__.py b/litellm/__init__.py index 7f0299162..9a4d86a6c 100644 --- a/litellm/__init__.py +++ b/litellm/__init__.py @@ -104,6 +104,13 @@ model_list = open_ai_chat_completion_models + open_ai_text_completion_models + c open_ai_embedding_models = [ 'text-embedding-ada-002' ] + +####### Caching ##################### +cache = False # don't cache by default +cache_collection = None +cache_similarity_threshold = 1.0 # don't cache by default + + from .timeout import timeout from .utils import client, logging, exception_type, get_optional_params, modify_integration, token_counter, cost_per_token, completion_cost from .main import * # Import all the symbols from main.py diff --git a/litellm/main.py b/litellm/main.py index 56613c8c8..af88aa411 100644 --- a/litellm/main.py +++ b/litellm/main.py @@ -7,7 +7,7 @@ import litellm from litellm import client, logging, exception_type, timeout, get_optional_params import tiktoken encoding = tiktoken.get_encoding("cl100k_base") -from litellm.utils import get_secret, install_and_import, CustomStreamWrapper +from litellm.utils import get_secret, install_and_import, CustomStreamWrapper, add_cache, get_cache ####### ENVIRONMENT VARIABLES ################### dotenv.load_dotenv() # Loading env variables using dotenv new_response = { @@ -48,6 +48,10 @@ def completion( ): try: global new_response + if litellm.cache: + cache_result = get_cache(messages) + if cache_result != None: + return cache_result model_response = deepcopy(new_response) # deep copy the default response format so we can mutate it and it's thread-safe. # check if user passed in any of the OpenAI optional params optional_params = get_optional_params( @@ -405,6 +409,8 @@ def completion( logging(model=model, input=messages, azure=azure, logger_fn=logger_fn) args = locals() raise ValueError(f"No valid completion model args passed in - {args}") + if litellm.cache: + add_cache(messages, response) return response except Exception as e: ## LOGGING diff --git a/litellm/tests/test_cache.py b/litellm/tests/test_cache.py new file mode 100644 index 000000000..db3aad623 --- /dev/null +++ b/litellm/tests/test_cache.py @@ -0,0 +1,43 @@ +import sys, os +import traceback +from dotenv import load_dotenv +load_dotenv() +import os +sys.path.insert(0, os.path.abspath('../..')) # Adds the parent directory to the system path +import pytest +import litellm +from litellm import embedding, completion + +# set cache to True +litellm.cache = True +litellm.cache_similarity_threshold = 0.5 + +user_message = "Hello, whats the weather in San Francisco??" +messages = [{ "content": user_message,"role": "user"}] + +def test_completion_gpt(): + try: + # in this test make the same call twice, measure the response time + # the 2nd response time should be less than half of the first, ensuring that the cache is working + import time + start = time.time() + response = completion(model="gpt-4", messages=messages) + end = time.time() + first_call_time = end-start + print(f"first call: {first_call_time}") + + start = time.time() + response = completion(model="gpt-4", messages=messages) + end = time.time() + second_call_time = end-start + print(f"second call: {second_call_time}") + + if second_call_time > first_call_time/2: + # the 2nd call should be less than half of the first call + pytest.fail(f"Cache is not working") + # Add any assertions here to check the response + print(response) + except Exception as e: + pytest.fail(f"Error occurred: {e}") + + diff --git a/litellm/utils.py b/litellm/utils.py index 85be206c5..58f5347b4 100644 --- a/litellm/utils.py +++ b/litellm/utils.py @@ -702,3 +702,52 @@ class CustomStreamWrapper: completion_obj["content"] = chunk.text # return this for all models return {"choices": [{"delta": completion_obj}]} + + +############# Caching Implementation v0 using chromaDB ############################ +cache_collection = None +def make_collection(): + global cache_collection + import chromadb + client = chromadb.Client() + cache_collection = client.create_collection("llm_responses") + +def message_to_user_question(messages): + user_question = "" + for message in messages: + if message['role'] == 'user': + user_question += message["content"] + return user_question + + +def add_cache(messages, model_response): + global cache_collection + user_question = message_to_user_question(messages) + cache_collection.add( + documents=[user_question], + metadatas=[{"model_response": str(model_response)}], + ids = [ str(uuid.uuid4())] + ) + return + +def get_cache(messages): + try: + global cache_collection + if cache_collection == None: + make_collection() + user_question = message_to_user_question(messages) + results = cache_collection.query( + query_texts=[user_question], + n_results=1 + ) + distance = results['distances'][0][0] + sim = (1 - distance) + if sim >= litellm.cache_similarity_threshold: + # return response + print("got cache hit!") + return dict(results['metadatas'][0][0]) + else: + # no hit + return None + except: + return None