forked from phoenix/litellm-mirror
add caching with chromDB - not a dependency
This commit is contained in:
parent
09fcd88799
commit
d80f847fde
6 changed files with 113 additions and 2 deletions
|
@ -12,10 +12,11 @@ jobs:
|
||||||
name: Install Dependencies
|
name: Install Dependencies
|
||||||
command: |
|
command: |
|
||||||
python -m pip install --upgrade pip
|
python -m pip install --upgrade pip
|
||||||
python -m pip install -r requirements.txt
|
python -m pip install -r .circleci/requirements.txt
|
||||||
pip install infisical
|
pip install infisical
|
||||||
pip install pytest
|
pip install pytest
|
||||||
pip install openai[datalib]
|
pip install openai[datalib]
|
||||||
|
pip install chromadb
|
||||||
|
|
||||||
# Run pytest and generate JUnit XML report
|
# Run pytest and generate JUnit XML report
|
||||||
- run:
|
- run:
|
||||||
|
|
5
.circleci/requirements.txt
Normal file
5
.circleci/requirements.txt
Normal file
|
@ -0,0 +1,5 @@
|
||||||
|
# used by CI/CD testing
|
||||||
|
openai
|
||||||
|
python-dotenv
|
||||||
|
openai
|
||||||
|
tiktoken
|
|
@ -104,6 +104,13 @@ model_list = open_ai_chat_completion_models + open_ai_text_completion_models + c
|
||||||
open_ai_embedding_models = [
|
open_ai_embedding_models = [
|
||||||
'text-embedding-ada-002'
|
'text-embedding-ada-002'
|
||||||
]
|
]
|
||||||
|
|
||||||
|
####### Caching #####################
|
||||||
|
cache = False # don't cache by default
|
||||||
|
cache_collection = None
|
||||||
|
cache_similarity_threshold = 1.0 # don't cache by default
|
||||||
|
|
||||||
|
|
||||||
from .timeout import timeout
|
from .timeout import timeout
|
||||||
from .utils import client, logging, exception_type, get_optional_params, modify_integration, token_counter, cost_per_token, completion_cost
|
from .utils import client, logging, exception_type, get_optional_params, modify_integration, token_counter, cost_per_token, completion_cost
|
||||||
from .main import * # Import all the symbols from main.py
|
from .main import * # Import all the symbols from main.py
|
||||||
|
|
|
@ -7,7 +7,7 @@ import litellm
|
||||||
from litellm import client, logging, exception_type, timeout, get_optional_params
|
from litellm import client, logging, exception_type, timeout, get_optional_params
|
||||||
import tiktoken
|
import tiktoken
|
||||||
encoding = tiktoken.get_encoding("cl100k_base")
|
encoding = tiktoken.get_encoding("cl100k_base")
|
||||||
from litellm.utils import get_secret, install_and_import, CustomStreamWrapper
|
from litellm.utils import get_secret, install_and_import, CustomStreamWrapper, add_cache, get_cache
|
||||||
####### ENVIRONMENT VARIABLES ###################
|
####### ENVIRONMENT VARIABLES ###################
|
||||||
dotenv.load_dotenv() # Loading env variables using dotenv
|
dotenv.load_dotenv() # Loading env variables using dotenv
|
||||||
new_response = {
|
new_response = {
|
||||||
|
@ -48,6 +48,10 @@ def completion(
|
||||||
):
|
):
|
||||||
try:
|
try:
|
||||||
global new_response
|
global new_response
|
||||||
|
if litellm.cache:
|
||||||
|
cache_result = get_cache(messages)
|
||||||
|
if cache_result != None:
|
||||||
|
return cache_result
|
||||||
model_response = deepcopy(new_response) # deep copy the default response format so we can mutate it and it's thread-safe.
|
model_response = deepcopy(new_response) # deep copy the default response format so we can mutate it and it's thread-safe.
|
||||||
# check if user passed in any of the OpenAI optional params
|
# check if user passed in any of the OpenAI optional params
|
||||||
optional_params = get_optional_params(
|
optional_params = get_optional_params(
|
||||||
|
@ -405,6 +409,8 @@ def completion(
|
||||||
logging(model=model, input=messages, azure=azure, logger_fn=logger_fn)
|
logging(model=model, input=messages, azure=azure, logger_fn=logger_fn)
|
||||||
args = locals()
|
args = locals()
|
||||||
raise ValueError(f"No valid completion model args passed in - {args}")
|
raise ValueError(f"No valid completion model args passed in - {args}")
|
||||||
|
if litellm.cache:
|
||||||
|
add_cache(messages, response)
|
||||||
return response
|
return response
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
## LOGGING
|
## LOGGING
|
||||||
|
|
43
litellm/tests/test_cache.py
Normal file
43
litellm/tests/test_cache.py
Normal file
|
@ -0,0 +1,43 @@
|
||||||
|
import sys, os
|
||||||
|
import traceback
|
||||||
|
from dotenv import load_dotenv
|
||||||
|
load_dotenv()
|
||||||
|
import os
|
||||||
|
sys.path.insert(0, os.path.abspath('../..')) # Adds the parent directory to the system path
|
||||||
|
import pytest
|
||||||
|
import litellm
|
||||||
|
from litellm import embedding, completion
|
||||||
|
|
||||||
|
# set cache to True
|
||||||
|
litellm.cache = True
|
||||||
|
litellm.cache_similarity_threshold = 0.5
|
||||||
|
|
||||||
|
user_message = "Hello, whats the weather in San Francisco??"
|
||||||
|
messages = [{ "content": user_message,"role": "user"}]
|
||||||
|
|
||||||
|
def test_completion_gpt():
|
||||||
|
try:
|
||||||
|
# in this test make the same call twice, measure the response time
|
||||||
|
# the 2nd response time should be less than half of the first, ensuring that the cache is working
|
||||||
|
import time
|
||||||
|
start = time.time()
|
||||||
|
response = completion(model="gpt-4", messages=messages)
|
||||||
|
end = time.time()
|
||||||
|
first_call_time = end-start
|
||||||
|
print(f"first call: {first_call_time}")
|
||||||
|
|
||||||
|
start = time.time()
|
||||||
|
response = completion(model="gpt-4", messages=messages)
|
||||||
|
end = time.time()
|
||||||
|
second_call_time = end-start
|
||||||
|
print(f"second call: {second_call_time}")
|
||||||
|
|
||||||
|
if second_call_time > first_call_time/2:
|
||||||
|
# the 2nd call should be less than half of the first call
|
||||||
|
pytest.fail(f"Cache is not working")
|
||||||
|
# Add any assertions here to check the response
|
||||||
|
print(response)
|
||||||
|
except Exception as e:
|
||||||
|
pytest.fail(f"Error occurred: {e}")
|
||||||
|
|
||||||
|
|
|
@ -702,3 +702,52 @@ class CustomStreamWrapper:
|
||||||
completion_obj["content"] = chunk.text
|
completion_obj["content"] = chunk.text
|
||||||
# return this for all models
|
# return this for all models
|
||||||
return {"choices": [{"delta": completion_obj}]}
|
return {"choices": [{"delta": completion_obj}]}
|
||||||
|
|
||||||
|
|
||||||
|
############# Caching Implementation v0 using chromaDB ############################
|
||||||
|
cache_collection = None
|
||||||
|
def make_collection():
|
||||||
|
global cache_collection
|
||||||
|
import chromadb
|
||||||
|
client = chromadb.Client()
|
||||||
|
cache_collection = client.create_collection("llm_responses")
|
||||||
|
|
||||||
|
def message_to_user_question(messages):
|
||||||
|
user_question = ""
|
||||||
|
for message in messages:
|
||||||
|
if message['role'] == 'user':
|
||||||
|
user_question += message["content"]
|
||||||
|
return user_question
|
||||||
|
|
||||||
|
|
||||||
|
def add_cache(messages, model_response):
|
||||||
|
global cache_collection
|
||||||
|
user_question = message_to_user_question(messages)
|
||||||
|
cache_collection.add(
|
||||||
|
documents=[user_question],
|
||||||
|
metadatas=[{"model_response": str(model_response)}],
|
||||||
|
ids = [ str(uuid.uuid4())]
|
||||||
|
)
|
||||||
|
return
|
||||||
|
|
||||||
|
def get_cache(messages):
|
||||||
|
try:
|
||||||
|
global cache_collection
|
||||||
|
if cache_collection == None:
|
||||||
|
make_collection()
|
||||||
|
user_question = message_to_user_question(messages)
|
||||||
|
results = cache_collection.query(
|
||||||
|
query_texts=[user_question],
|
||||||
|
n_results=1
|
||||||
|
)
|
||||||
|
distance = results['distances'][0][0]
|
||||||
|
sim = (1 - distance)
|
||||||
|
if sim >= litellm.cache_similarity_threshold:
|
||||||
|
# return response
|
||||||
|
print("got cache hit!")
|
||||||
|
return dict(results['metadatas'][0][0])
|
||||||
|
else:
|
||||||
|
# no hit
|
||||||
|
return None
|
||||||
|
except:
|
||||||
|
return None
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue