add caching with chromDB - not a dependency

This commit is contained in:
ishaan-jaff 2023-08-10 13:19:48 -07:00
parent 09fcd88799
commit d80f847fde
6 changed files with 113 additions and 2 deletions

View file

@ -12,10 +12,11 @@ jobs:
name: Install Dependencies
command: |
python -m pip install --upgrade pip
python -m pip install -r requirements.txt
python -m pip install -r .circleci/requirements.txt
pip install infisical
pip install pytest
pip install openai[datalib]
pip install chromadb
# Run pytest and generate JUnit XML report
- run:

View file

@ -0,0 +1,5 @@
# used by CI/CD testing
openai
python-dotenv
openai
tiktoken

View file

@ -104,6 +104,13 @@ model_list = open_ai_chat_completion_models + open_ai_text_completion_models + c
open_ai_embedding_models = [
'text-embedding-ada-002'
]
####### Caching #####################
cache = False # don't cache by default
cache_collection = None
cache_similarity_threshold = 1.0 # don't cache by default
from .timeout import timeout
from .utils import client, logging, exception_type, get_optional_params, modify_integration, token_counter, cost_per_token, completion_cost
from .main import * # Import all the symbols from main.py

View file

@ -7,7 +7,7 @@ import litellm
from litellm import client, logging, exception_type, timeout, get_optional_params
import tiktoken
encoding = tiktoken.get_encoding("cl100k_base")
from litellm.utils import get_secret, install_and_import, CustomStreamWrapper
from litellm.utils import get_secret, install_and_import, CustomStreamWrapper, add_cache, get_cache
####### ENVIRONMENT VARIABLES ###################
dotenv.load_dotenv() # Loading env variables using dotenv
new_response = {
@ -48,6 +48,10 @@ def completion(
):
try:
global new_response
if litellm.cache:
cache_result = get_cache(messages)
if cache_result != None:
return cache_result
model_response = deepcopy(new_response) # deep copy the default response format so we can mutate it and it's thread-safe.
# check if user passed in any of the OpenAI optional params
optional_params = get_optional_params(
@ -405,6 +409,8 @@ def completion(
logging(model=model, input=messages, azure=azure, logger_fn=logger_fn)
args = locals()
raise ValueError(f"No valid completion model args passed in - {args}")
if litellm.cache:
add_cache(messages, response)
return response
except Exception as e:
## LOGGING

View file

@ -0,0 +1,43 @@
import sys, os
import traceback
from dotenv import load_dotenv
load_dotenv()
import os
sys.path.insert(0, os.path.abspath('../..')) # Adds the parent directory to the system path
import pytest
import litellm
from litellm import embedding, completion
# set cache to True
litellm.cache = True
litellm.cache_similarity_threshold = 0.5
user_message = "Hello, whats the weather in San Francisco??"
messages = [{ "content": user_message,"role": "user"}]
def test_completion_gpt():
try:
# in this test make the same call twice, measure the response time
# the 2nd response time should be less than half of the first, ensuring that the cache is working
import time
start = time.time()
response = completion(model="gpt-4", messages=messages)
end = time.time()
first_call_time = end-start
print(f"first call: {first_call_time}")
start = time.time()
response = completion(model="gpt-4", messages=messages)
end = time.time()
second_call_time = end-start
print(f"second call: {second_call_time}")
if second_call_time > first_call_time/2:
# the 2nd call should be less than half of the first call
pytest.fail(f"Cache is not working")
# Add any assertions here to check the response
print(response)
except Exception as e:
pytest.fail(f"Error occurred: {e}")

View file

@ -702,3 +702,52 @@ class CustomStreamWrapper:
completion_obj["content"] = chunk.text
# return this for all models
return {"choices": [{"delta": completion_obj}]}
############# Caching Implementation v0 using chromaDB ############################
cache_collection = None
def make_collection():
global cache_collection
import chromadb
client = chromadb.Client()
cache_collection = client.create_collection("llm_responses")
def message_to_user_question(messages):
user_question = ""
for message in messages:
if message['role'] == 'user':
user_question += message["content"]
return user_question
def add_cache(messages, model_response):
global cache_collection
user_question = message_to_user_question(messages)
cache_collection.add(
documents=[user_question],
metadatas=[{"model_response": str(model_response)}],
ids = [ str(uuid.uuid4())]
)
return
def get_cache(messages):
try:
global cache_collection
if cache_collection == None:
make_collection()
user_question = message_to_user_question(messages)
results = cache_collection.query(
query_texts=[user_question],
n_results=1
)
distance = results['distances'][0][0]
sim = (1 - distance)
if sim >= litellm.cache_similarity_threshold:
# return response
print("got cache hit!")
return dict(results['metadatas'][0][0])
else:
# no hit
return None
except:
return None