forked from phoenix/litellm-mirror
add caching with chromDB - not a dependency
This commit is contained in:
parent
09fcd88799
commit
d80f847fde
6 changed files with 113 additions and 2 deletions
|
@ -12,10 +12,11 @@ jobs:
|
|||
name: Install Dependencies
|
||||
command: |
|
||||
python -m pip install --upgrade pip
|
||||
python -m pip install -r requirements.txt
|
||||
python -m pip install -r .circleci/requirements.txt
|
||||
pip install infisical
|
||||
pip install pytest
|
||||
pip install openai[datalib]
|
||||
pip install chromadb
|
||||
|
||||
# Run pytest and generate JUnit XML report
|
||||
- run:
|
||||
|
|
5
.circleci/requirements.txt
Normal file
5
.circleci/requirements.txt
Normal file
|
@ -0,0 +1,5 @@
|
|||
# used by CI/CD testing
|
||||
openai
|
||||
python-dotenv
|
||||
openai
|
||||
tiktoken
|
|
@ -104,6 +104,13 @@ model_list = open_ai_chat_completion_models + open_ai_text_completion_models + c
|
|||
open_ai_embedding_models = [
|
||||
'text-embedding-ada-002'
|
||||
]
|
||||
|
||||
####### Caching #####################
|
||||
cache = False # don't cache by default
|
||||
cache_collection = None
|
||||
cache_similarity_threshold = 1.0 # don't cache by default
|
||||
|
||||
|
||||
from .timeout import timeout
|
||||
from .utils import client, logging, exception_type, get_optional_params, modify_integration, token_counter, cost_per_token, completion_cost
|
||||
from .main import * # Import all the symbols from main.py
|
||||
|
|
|
@ -7,7 +7,7 @@ import litellm
|
|||
from litellm import client, logging, exception_type, timeout, get_optional_params
|
||||
import tiktoken
|
||||
encoding = tiktoken.get_encoding("cl100k_base")
|
||||
from litellm.utils import get_secret, install_and_import, CustomStreamWrapper
|
||||
from litellm.utils import get_secret, install_and_import, CustomStreamWrapper, add_cache, get_cache
|
||||
####### ENVIRONMENT VARIABLES ###################
|
||||
dotenv.load_dotenv() # Loading env variables using dotenv
|
||||
new_response = {
|
||||
|
@ -48,6 +48,10 @@ def completion(
|
|||
):
|
||||
try:
|
||||
global new_response
|
||||
if litellm.cache:
|
||||
cache_result = get_cache(messages)
|
||||
if cache_result != None:
|
||||
return cache_result
|
||||
model_response = deepcopy(new_response) # deep copy the default response format so we can mutate it and it's thread-safe.
|
||||
# check if user passed in any of the OpenAI optional params
|
||||
optional_params = get_optional_params(
|
||||
|
@ -405,6 +409,8 @@ def completion(
|
|||
logging(model=model, input=messages, azure=azure, logger_fn=logger_fn)
|
||||
args = locals()
|
||||
raise ValueError(f"No valid completion model args passed in - {args}")
|
||||
if litellm.cache:
|
||||
add_cache(messages, response)
|
||||
return response
|
||||
except Exception as e:
|
||||
## LOGGING
|
||||
|
|
43
litellm/tests/test_cache.py
Normal file
43
litellm/tests/test_cache.py
Normal file
|
@ -0,0 +1,43 @@
|
|||
import sys, os
|
||||
import traceback
|
||||
from dotenv import load_dotenv
|
||||
load_dotenv()
|
||||
import os
|
||||
sys.path.insert(0, os.path.abspath('../..')) # Adds the parent directory to the system path
|
||||
import pytest
|
||||
import litellm
|
||||
from litellm import embedding, completion
|
||||
|
||||
# set cache to True
|
||||
litellm.cache = True
|
||||
litellm.cache_similarity_threshold = 0.5
|
||||
|
||||
user_message = "Hello, whats the weather in San Francisco??"
|
||||
messages = [{ "content": user_message,"role": "user"}]
|
||||
|
||||
def test_completion_gpt():
|
||||
try:
|
||||
# in this test make the same call twice, measure the response time
|
||||
# the 2nd response time should be less than half of the first, ensuring that the cache is working
|
||||
import time
|
||||
start = time.time()
|
||||
response = completion(model="gpt-4", messages=messages)
|
||||
end = time.time()
|
||||
first_call_time = end-start
|
||||
print(f"first call: {first_call_time}")
|
||||
|
||||
start = time.time()
|
||||
response = completion(model="gpt-4", messages=messages)
|
||||
end = time.time()
|
||||
second_call_time = end-start
|
||||
print(f"second call: {second_call_time}")
|
||||
|
||||
if second_call_time > first_call_time/2:
|
||||
# the 2nd call should be less than half of the first call
|
||||
pytest.fail(f"Cache is not working")
|
||||
# Add any assertions here to check the response
|
||||
print(response)
|
||||
except Exception as e:
|
||||
pytest.fail(f"Error occurred: {e}")
|
||||
|
||||
|
|
@ -702,3 +702,52 @@ class CustomStreamWrapper:
|
|||
completion_obj["content"] = chunk.text
|
||||
# return this for all models
|
||||
return {"choices": [{"delta": completion_obj}]}
|
||||
|
||||
|
||||
############# Caching Implementation v0 using chromaDB ############################
|
||||
cache_collection = None
|
||||
def make_collection():
|
||||
global cache_collection
|
||||
import chromadb
|
||||
client = chromadb.Client()
|
||||
cache_collection = client.create_collection("llm_responses")
|
||||
|
||||
def message_to_user_question(messages):
|
||||
user_question = ""
|
||||
for message in messages:
|
||||
if message['role'] == 'user':
|
||||
user_question += message["content"]
|
||||
return user_question
|
||||
|
||||
|
||||
def add_cache(messages, model_response):
|
||||
global cache_collection
|
||||
user_question = message_to_user_question(messages)
|
||||
cache_collection.add(
|
||||
documents=[user_question],
|
||||
metadatas=[{"model_response": str(model_response)}],
|
||||
ids = [ str(uuid.uuid4())]
|
||||
)
|
||||
return
|
||||
|
||||
def get_cache(messages):
|
||||
try:
|
||||
global cache_collection
|
||||
if cache_collection == None:
|
||||
make_collection()
|
||||
user_question = message_to_user_question(messages)
|
||||
results = cache_collection.query(
|
||||
query_texts=[user_question],
|
||||
n_results=1
|
||||
)
|
||||
distance = results['distances'][0][0]
|
||||
sim = (1 - distance)
|
||||
if sim >= litellm.cache_similarity_threshold:
|
||||
# return response
|
||||
print("got cache hit!")
|
||||
return dict(results['metadatas'][0][0])
|
||||
else:
|
||||
# no hit
|
||||
return None
|
||||
except:
|
||||
return None
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue