diff --git a/cookbook/liteLLM_ChromaDB_Cache.ipynb b/cookbook/liteLLM_ChromaDB_Cache.ipynb new file mode 100644 index 000000000..2650a7263 --- /dev/null +++ b/cookbook/liteLLM_ChromaDB_Cache.ipynb @@ -0,0 +1,340 @@ +{ + "nbformat": 4, + "nbformat_minor": 0, + "metadata": { + "colab": { + "provenance": [] + }, + "kernelspec": { + "name": "python3", + "display_name": "Python 3" + }, + "language_info": { + "name": "python" + } + }, + "cells": [ + { + "cell_type": "markdown", + "source": [ + "## This is a tutorial on how to build a Cache for GPT-4, chatGPT, Claude, Palm, Llama2\n", + "\n", + "In this notebook we:\n", + "- use chromaDB to define add_cache(), get_cache(). We cache responses from the LLM\n", + "- use liteLLM for calling `completion()` with GPT-4, chatGPT, Claude, llama2" + ], + "metadata": { + "id": "fqqYwS3jzN_t" + } + }, + { + "cell_type": "code", + "source": [ + "!pip install litellm\n", + "!pip install -Uq chromadb" + ], + "metadata": { + "id": "yQWPyKaEvl7c" + }, + "execution_count": 1, + "outputs": [] + }, + { + "cell_type": "markdown", + "source": [ + "## Init ChromaDB collection\n" + ], + "metadata": { + "id": "oUVLNjt0pNUw" + } + }, + { + "cell_type": "code", + "source": [ + "import chromadb\n", + "# Global cache collection instance\n", + "cache_collection = None\n", + "\n", + "# Initialize the cache collection\n", + "def make_collection():\n", + " global cache_collection\n", + " client = chromadb.Client()\n", + " cache_collection = client.create_collection(\"llm_responses\")" + ], + "metadata": { + "id": "iyrAj4tjpMph" + }, + "execution_count": 10, + "outputs": [] + }, + { + "cell_type": "markdown", + "source": [ + "## Add to Cache Function\n", + "We extract the user question and use chromaDB to embed it. For each question we store the model response as `metadata`\n", + "\n", + "`add_cache()` args\n", + "* `messages` - Expect this to be in the chatGPT messages format\n", + "* `model_response` - Response from LLM\n" + ], + "metadata": { + "id": "mr8ArGpWpZqi" + } + }, + { + "cell_type": "code", + "source": [ + "import uuid\n", + "\n", + "# Add a response to the cache\n", + "def add_cache(messages, model_response):\n", + " global cache_collection\n", + " if cache_collection is None:\n", + " make_collection()\n", + "\n", + " user_question = message_to_user_question(messages)\n", + "\n", + " # Add the user question and model response to the cache\n", + " cache_collection.add(\n", + " documents=[user_question],\n", + " metadatas=[{\"model_response\": str(model_response)}],\n", + " ids=[str(uuid.uuid4())]\n", + " )\n", + " return\n", + "\n", + "\n", + "# HELPER: Extract user's question from messages\n", + "def message_to_user_question(messages):\n", + " user_question = \"\"\n", + " for message in messages:\n", + " if message['role'] == 'user':\n", + " user_question += message[\"content\"]\n", + " return user_question" + ], + "metadata": { + "id": "9Yr9jrPspTl8" + }, + "execution_count": 11, + "outputs": [] + }, + { + "cell_type": "markdown", + "source": [ + "## Get Cache Function\n", + "Given a user question, we check chromaDB for any embeddings with\n", + "`similarity > similarity_threshold`. Return the corresponding model_response if there is a match i.e `cache_hit`\n", + "\n", + "`get_cache()` args\n", + "* `messages` - Expect this to be in the chatGPT messages format\n", + "* `similarity_threshold` - Define a similarity_threshold on a scale of 0-1\n", + "0 -> everything is cache hit, 0.5 (50% similar), 1-> only return cache hits" + ], + "metadata": { + "id": "vpPjoHpNpxd0" + } + }, + { + "cell_type": "code", + "source": [ + "# Retrieve a response from the cache if similarity is above the threshold\n", + "def get_cache(messages, similarity_threshold):\n", + " try:\n", + " global cache_collection\n", + " if cache_collection is None:\n", + " make_collection()\n", + "\n", + " user_question = message_to_user_question(messages)\n", + "\n", + " # Query the cache for the user question\n", + " results = cache_collection.query(\n", + " query_texts=[user_question],\n", + " n_results=1\n", + " )\n", + "\n", + " if len(results['distances'][0]) == 0:\n", + " return None # Cache is empty\n", + "\n", + " distance = results['distances'][0][0]\n", + " sim = (1 - distance)\n", + "\n", + " if sim >= similarity_threshold:\n", + " return results['metadatas'][0][0][\"model_response\"] # Return cached response\n", + " else:\n", + " return None # No cache hit\n", + " except Exception as e:\n", + " print(\"Error in get cache\", e)\n", + " raise e\n" + ], + "metadata": { + "id": "SJaz-Mpnj7jd" + }, + "execution_count": 12, + "outputs": [] + }, + { + "cell_type": "markdown", + "source": [ + "## Using liteLLM completion()\n", + "We use liteLLM completion to call our LLM APIs. LiteLLM allows the same Input/Output format for Azure OpenAI, chatGPT,\n", + "* Basic usage - `litellm.completion(model, messages)`.\n", + "\n", + "Use OpenAI, Claude, Anthropic, Replicate models. See supported models here: https://litellm.readthedocs.io/en/latest/supported/\n", + "\n" + ], + "metadata": { + "id": "8bM5GI9hqYPK" + } + }, + { + "cell_type": "code", + "source": [ + "import litellm, os, random\n", + "os.environ[\"OPENAI_API_KEY\"] = \"\" # @param\n", + "os.environ[\"REPLICATE_API_TOKEN\"] = \"\" #@param\n", + "\n", + "models = [\"gpt-4\", \"replicate/llama-2-70b-chat:2c1608e18606fad2812020dc541930f2d0495ce32eee50074220b87300bc16e1\"]\n", + "\n", + "def completion_with_cache(messages, similarity_threshold):\n", + " # check cache before calling model, return if there is a hit\n", + " cache_result = get_cache(messages, similarity_threshold)\n", + "\n", + " if cache_result != None:\n", + " return cache_result\n", + "\n", + " # randomly pick llama2, GPT-4\n", + " random_model_idx = random.randint(0, 1)\n", + " model = models[random_model_idx]\n", + " # use litellm to make completion request\n", + " print(f\"using model {model}\")\n", + " model_response = litellm.completion(model, messages)\n", + "\n", + " # add the user question + model response to cache\n", + " add_cache(messages, model_response)\n", + "\n", + " return model_response" + ], + "metadata": { + "id": "r3hW2whOkAEj" + }, + "execution_count": 13, + "outputs": [] + }, + { + "cell_type": "markdown", + "source": [ + "## Testing + Running Cache" + ], + "metadata": { + "id": "tTkYOpo0rbJO" + } + }, + { + "cell_type": "code", + "source": [ + "import os\n", + "import time\n", + "import matplotlib.pyplot as plt\n", + "\n", + "\n", + "# List of example user messages\n", + "user_messages = [\n", + " \"Hello, what's the weather in San Francisco??\",\n", + " \"what's the weather in San Francisco??\"\n", + " \"Can you tell me about the latest news?\",\n", + " \"What's the capital of France?\",\n", + " \"How does photosynthesis work?\",\n", + " \"capital of france?\",\n", + " \"tell me a joke\",\n", + " \"tell me a joke right now\"\n", + " \"How do I bake a chocolate cake?\",\n", + " \"What are the benefits of exercise?\",\n", + " \"Tell me a joke!\",\n", + " # Add more questions here\n", + "]\n", + "\n", + "similarity_threshold = 0.5 # Adjust as needed\n", + "\n", + "### Testing / Measuring\n", + "cached_responses = 0\n", + "model_responses = 0\n", + "\n", + "for user_message in user_messages:\n", + " messages = [{\"content\": user_message, \"role\": \"user\"}]\n", + "\n", + " start = time.time()\n", + " response = completion_with_cache(messages=messages, similarity_threshold=similarity_threshold)\n", + " end = time.time()\n", + " response_time = end - start\n", + "\n", + " if response_time < 1: # Assuming cached responses come in less than 1s\n", + " cached_responses += 1\n", + " else:\n", + " model_responses += 1\n", + " print(f\"got response for {user_message}\")\n", + "\n", + "# Plotting\n", + "response_types = [\"Cached\", \"Model\"]\n", + "response_counts = [cached_responses, model_responses]\n", + "\n", + "fig, ax = plt.subplots()\n", + "ax.bar(response_types, response_counts)\n", + "ax.set_ylabel(\"Number of Responses\")\n", + "ax.set_title(\"Cached vs Model API Responses\")\n", + "plt.show()\n", + "\n", + "print(f\"Cached Responses: {cached_responses}\")\n", + "print(f\"Model Responses: {model_responses}\")" + ], + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/", + "height": 747 + }, + "id": "i650yqJfkokZ", + "outputId": "efd14d6f-500e-4e52-969f-974a2a2ac15a" + }, + "execution_count": 14, + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "using model gpt-4\n", + "got response for Hello, what's the weather in San Francisco??\n", + "got response for what's the weather in San Francisco??Can you tell me about the latest news?\n", + "using model gpt-4\n", + "got response for What's the capital of France?\n", + "using model gpt-4\n", + "got response for How does photosynthesis work?\n", + "got response for capital of france?\n", + "using model replicate/llama-2-70b-chat:2c1608e18606fad2812020dc541930f2d0495ce32eee50074220b87300bc16e1\n", + "got response for tell me a joke\n", + "using model replicate/llama-2-70b-chat:2c1608e18606fad2812020dc541930f2d0495ce32eee50074220b87300bc16e1\n", + "got response for tell me a joke right nowHow do I bake a chocolate cake?\n", + "using model gpt-4\n", + "got response for What are the benefits of exercise?\n", + "got response for Tell me a joke!\n" + ] + }, + { + "output_type": "display_data", + "data": { + "text/plain": [ + "
" + ], + "image/png": "\n" + }, + "metadata": {} + }, + { + "output_type": "stream", + "name": "stdout", + "text": [ + "Cached Responses: 3\n", + "Model Responses: 6\n" + ] + } + ] + } + ] +} \ No newline at end of file diff --git a/litellm/tests/test_cache.py b/litellm/tests/test_cache.py index 05004d1a6..d97d19142 100644 --- a/litellm/tests/test_cache.py +++ b/litellm/tests/test_cache.py @@ -1,44 +1,44 @@ -import sys, os -import traceback -from dotenv import load_dotenv -load_dotenv() -import os +# import sys, os +# import traceback +# from dotenv import load_dotenv +# load_dotenv() +# import os -sys.path.insert(0, os.path.abspath('../..')) # Adds the parent directory to the system path -import pytest -import litellm +# sys.path.insert(0, os.path.abspath('../..')) # Adds the parent directory to the system path +# import pytest +# import litellm -# set cache to True -litellm.cache = True -litellm.cache_similarity_threshold = 0.5 +# # set cache to True +# litellm.cache = True +# litellm.cache_similarity_threshold = 0.5 -user_message = "Hello, whats the weather in San Francisco??" -messages = [{ "content": user_message,"role": "user"}] +# user_message = "Hello, whats the weather in San Francisco??" +# messages = [{ "content": user_message,"role": "user"}] -def test_completion_with_cache_gpt4(): - try: - # in this test make the same call twice, measure the response time - # the 2nd response time should be less than half of the first, ensuring that the cache is working - import time - start = time.time() - print(litellm.cache) - response = litellm.completion(model="gpt-4", messages=messages) - end = time.time() - first_call_time = end-start - print(f"first call: {first_call_time}") +# def test_completion_with_cache_gpt4(): +# try: +# # in this test make the same call twice, measure the response time +# # the 2nd response time should be less than half of the first, ensuring that the cache is working +# import time +# start = time.time() +# print(litellm.cache) +# response = litellm.completion(model="gpt-4", messages=messages) +# end = time.time() +# first_call_time = end-start +# print(f"first call: {first_call_time}") - start = time.time() - response = litellm.completion(model="gpt-4", messages=messages) - end = time.time() - second_call_time = end-start - print(f"second call: {second_call_time}") +# start = time.time() +# response = litellm.completion(model="gpt-4", messages=messages) +# end = time.time() +# second_call_time = end-start +# print(f"second call: {second_call_time}") - if second_call_time > 1: - # the 2nd call should be less than 1s - pytest.fail(f"Cache is not working") - # Add any assertions here to check the response - print(response) - except Exception as e: - pytest.fail(f"Error occurred: {e}") +# if second_call_time > 1: +# # the 2nd call should be less than 1s +# pytest.fail(f"Cache is not working") +# # Add any assertions here to check the response +# print(response) +# except Exception as e: +# pytest.fail(f"Error occurred: {e}") -litellm.cache = False \ No newline at end of file +# litellm.cache = False \ No newline at end of file