fix liteLLM proxy

This commit is contained in:
ishaan-jaff 2023-08-16 18:30:08 -07:00
parent ce4ec195a3
commit bdfcff3078
3 changed files with 149 additions and 100 deletions

View file

@ -1,12 +1,20 @@
from flask import Flask, request, jsonify, abort
from flask import Flask, request, jsonify, abort, Response
from flask_cors import CORS
import traceback
import litellm
from litellm import completion
import openai
from utils import handle_error, get_cache, add_cache
import os, dotenv
import logging
import json
dotenv.load_dotenv()
# TODO: set your keys in .env or here:
# os.environ["OPENAI_API_KEY"] = "" # set your openai key here
# see supported models / keys here: https://litellm.readthedocs.io/en/latest/supported/
######### LOGGING ###################
# log your data to slack, supabase
litellm.success_callback=["slack", "supabase"] # set .env SLACK_API_TOKEN, SLACK_API_SECRET, SLACK_API_CHANNEL, SUPABASE
@ -22,16 +30,25 @@ CORS(app)
def index():
return 'received!', 200
def data_generator(response):
for chunk in response:
yield f"data: {json.dumps(chunk)}\n\n"
@app.route('/chat/completions', methods=["POST"])
def api_completion():
data = request.json
if data.get('stream') == "True":
data['stream'] = True # convert to boolean
try:
# pass in data to completion function, unpack data
response = completion(**data)
if 'stream' in data and data['stream'] == True: # use generate_responses to stream responses
return Response(data_generator(response), mimetype='text/event-stream')
except Exception as e:
# call handle_error function
print(f"got error{e}")
return handle_error(data)
return response, 200
return response, 200 # non streaming responses
@app.route('/get_models', methods=["POST"])
def get_models():
@ -48,45 +65,11 @@ if __name__ == "__main__":
############### Advanced ##########################
################ ERROR HANDLING #####################
# implement model fallbacks, cooldowns, and retries
# if a model fails assume it was rate limited and let it cooldown for 60s
def handle_error(data):
import time
# retry completion() request with fallback models
response = None
start_time = time.time()
rate_limited_models = set()
model_expiration_times = {}
fallback_strategy=['gpt-3.5-turbo', 'command-nightly', 'claude-2']
while response == None and time.time() - start_time < 45: # retry for 45s
for model in fallback_strategy:
try:
if model in rate_limited_models: # check if model is currently cooling down
if model_expiration_times.get(model) and time.time() >= model_expiration_times[model]:
rate_limited_models.remove(model) # check if it's been 60s of cool down and remove model
else:
continue # skip model
print(f"calling model {model}")
response = completion(**data)
if response != None:
return response
except Exception as e:
rate_limited_models.add(model)
model_expiration_times[model] = time.time() + 60 # cool down this selected model
pass
return response
########### Pricing is tracked in Supabase ############
############ Caching ###################################
# make a new endpoint with caching
# This Cache is built using ChromaDB
# it has two functions add_cache() and get_cache()
@app.route('/chat/completions', methods=["POST"])
@app.route('/chat/completions_with_cache', methods=["POST"])
def api_completion_with_cache():
data = request.json
try:
@ -100,66 +83,4 @@ def api_completion_with_cache():
except Exception as e:
# call handle_error function
return handle_error(data)
return response, 200
import uuid
cache_collection = None
# Add a response to the cache
def add_cache(messages, model_response):
global cache_collection
if cache_collection is None:
make_collection()
user_question = message_to_user_question(messages)
# Add the user question and model response to the cache
cache_collection.add(
documents=[user_question],
metadatas=[{"model_response": str(model_response)}],
ids=[str(uuid.uuid4())]
)
return
# Retrieve a response from the cache if similarity is above the threshold
def get_cache(messages, similarity_threshold):
try:
global cache_collection
if cache_collection is None:
make_collection()
user_question = message_to_user_question(messages)
# Query the cache for the user question
results = cache_collection.query(
query_texts=[user_question],
n_results=1
)
if len(results['distances'][0]) == 0:
return None # Cache is empty
distance = results['distances'][0][0]
sim = (1 - distance)
if sim >= similarity_threshold:
return results['metadatas'][0][0]["model_response"] # Return cached response
else:
return None # No cache hit
except Exception as e:
print("Error in get cache", e)
raise e
# Initialize the cache collection
def make_collection():
import chromadb
global cache_collection
client = chromadb.Client()
cache_collection = client.create_collection("llm_responses")
# HELPER: Extract user's question from messages
def message_to_user_question(messages):
user_question = ""
for message in messages:
if message['role'] == 'user':
user_question += message["content"]
return user_question
return response, 200

View file

@ -0,0 +1,21 @@
import openai
import os
os.environ["OPENAI_API_KEY"] = ""
openai.api_key = os.environ["OPENAI_API_KEY"]
openai.api_base ="http://localhost:5000"
messages = [
{
"role": "user",
"content": "write a 1 pg essay in liteLLM"
}
]
response = openai.ChatCompletion.create(model="gpt-3.5-turbo", messages=messages, stream=True)
print("got response", response)
# response is a generator
for chunk in response:
print(chunk)

View file

@ -0,0 +1,107 @@
from litellm import completion
import os, dotenv
import json
dotenv.load_dotenv()
############### Advanced ##########################
########### streaming ############################
def generate_responses(response):
for chunk in response:
yield json.dumps({"response": chunk}) + "\n"
################ ERROR HANDLING #####################
# implement model fallbacks, cooldowns, and retries
# if a model fails assume it was rate limited and let it cooldown for 60s
def handle_error(data):
import time
# retry completion() request with fallback models
response = None
start_time = time.time()
rate_limited_models = set()
model_expiration_times = {}
fallback_strategy=['gpt-3.5-turbo', 'command-nightly', 'claude-2']
while response == None and time.time() - start_time < 45: # retry for 45s
for model in fallback_strategy:
try:
if model in rate_limited_models: # check if model is currently cooling down
if model_expiration_times.get(model) and time.time() >= model_expiration_times[model]:
rate_limited_models.remove(model) # check if it's been 60s of cool down and remove model
else:
continue # skip model
print(f"calling model {model}")
response = completion(**data)
if response != None:
return response
except Exception as e:
rate_limited_models.add(model)
model_expiration_times[model] = time.time() + 60 # cool down this selected model
pass
return response
########### Pricing is tracked in Supabase ############
import uuid
cache_collection = None
# Add a response to the cache
def add_cache(messages, model_response):
global cache_collection
if cache_collection is None:
make_collection()
user_question = message_to_user_question(messages)
# Add the user question and model response to the cache
cache_collection.add(
documents=[user_question],
metadatas=[{"model_response": str(model_response)}],
ids=[str(uuid.uuid4())]
)
return
# Retrieve a response from the cache if similarity is above the threshold
def get_cache(messages, similarity_threshold):
try:
global cache_collection
if cache_collection is None:
make_collection()
user_question = message_to_user_question(messages)
# Query the cache for the user question
results = cache_collection.query(
query_texts=[user_question],
n_results=1
)
if len(results['distances'][0]) == 0:
return None # Cache is empty
distance = results['distances'][0][0]
sim = (1 - distance)
if sim >= similarity_threshold:
return results['metadatas'][0][0]["model_response"] # Return cached response
else:
return None # No cache hit
except Exception as e:
print("Error in get cache", e)
raise e
# Initialize the cache collection
def make_collection():
import chromadb
global cache_collection
client = chromadb.Client()
cache_collection = client.create_collection("llm_responses")
# HELPER: Extract user's question from messages
def message_to_user_question(messages):
user_question = ""
for message in messages:
if message['role'] == 'user':
user_question += message["content"]
return user_question