forked from phoenix/litellm-mirror
fix liteLLM proxy
This commit is contained in:
parent
ce4ec195a3
commit
bdfcff3078
3 changed files with 149 additions and 100 deletions
|
@ -1,12 +1,20 @@
|
||||||
from flask import Flask, request, jsonify, abort
|
from flask import Flask, request, jsonify, abort, Response
|
||||||
from flask_cors import CORS
|
from flask_cors import CORS
|
||||||
import traceback
|
import traceback
|
||||||
import litellm
|
import litellm
|
||||||
|
|
||||||
from litellm import completion
|
from litellm import completion
|
||||||
|
import openai
|
||||||
|
from utils import handle_error, get_cache, add_cache
|
||||||
import os, dotenv
|
import os, dotenv
|
||||||
|
import logging
|
||||||
|
import json
|
||||||
dotenv.load_dotenv()
|
dotenv.load_dotenv()
|
||||||
|
|
||||||
|
# TODO: set your keys in .env or here:
|
||||||
|
# os.environ["OPENAI_API_KEY"] = "" # set your openai key here
|
||||||
|
# see supported models / keys here: https://litellm.readthedocs.io/en/latest/supported/
|
||||||
|
|
||||||
######### LOGGING ###################
|
######### LOGGING ###################
|
||||||
# log your data to slack, supabase
|
# log your data to slack, supabase
|
||||||
litellm.success_callback=["slack", "supabase"] # set .env SLACK_API_TOKEN, SLACK_API_SECRET, SLACK_API_CHANNEL, SUPABASE
|
litellm.success_callback=["slack", "supabase"] # set .env SLACK_API_TOKEN, SLACK_API_SECRET, SLACK_API_CHANNEL, SUPABASE
|
||||||
|
@ -22,16 +30,25 @@ CORS(app)
|
||||||
def index():
|
def index():
|
||||||
return 'received!', 200
|
return 'received!', 200
|
||||||
|
|
||||||
|
def data_generator(response):
|
||||||
|
for chunk in response:
|
||||||
|
yield f"data: {json.dumps(chunk)}\n\n"
|
||||||
|
|
||||||
@app.route('/chat/completions', methods=["POST"])
|
@app.route('/chat/completions', methods=["POST"])
|
||||||
def api_completion():
|
def api_completion():
|
||||||
data = request.json
|
data = request.json
|
||||||
|
if data.get('stream') == "True":
|
||||||
|
data['stream'] = True # convert to boolean
|
||||||
try:
|
try:
|
||||||
# pass in data to completion function, unpack data
|
# pass in data to completion function, unpack data
|
||||||
response = completion(**data)
|
response = completion(**data)
|
||||||
|
if 'stream' in data and data['stream'] == True: # use generate_responses to stream responses
|
||||||
|
return Response(data_generator(response), mimetype='text/event-stream')
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
# call handle_error function
|
# call handle_error function
|
||||||
|
print(f"got error{e}")
|
||||||
return handle_error(data)
|
return handle_error(data)
|
||||||
return response, 200
|
return response, 200 # non streaming responses
|
||||||
|
|
||||||
@app.route('/get_models', methods=["POST"])
|
@app.route('/get_models', methods=["POST"])
|
||||||
def get_models():
|
def get_models():
|
||||||
|
@ -48,45 +65,11 @@ if __name__ == "__main__":
|
||||||
|
|
||||||
############### Advanced ##########################
|
############### Advanced ##########################
|
||||||
|
|
||||||
################ ERROR HANDLING #####################
|
|
||||||
# implement model fallbacks, cooldowns, and retries
|
|
||||||
# if a model fails assume it was rate limited and let it cooldown for 60s
|
|
||||||
def handle_error(data):
|
|
||||||
import time
|
|
||||||
# retry completion() request with fallback models
|
|
||||||
response = None
|
|
||||||
start_time = time.time()
|
|
||||||
rate_limited_models = set()
|
|
||||||
model_expiration_times = {}
|
|
||||||
fallback_strategy=['gpt-3.5-turbo', 'command-nightly', 'claude-2']
|
|
||||||
while response == None and time.time() - start_time < 45: # retry for 45s
|
|
||||||
for model in fallback_strategy:
|
|
||||||
try:
|
|
||||||
if model in rate_limited_models: # check if model is currently cooling down
|
|
||||||
if model_expiration_times.get(model) and time.time() >= model_expiration_times[model]:
|
|
||||||
rate_limited_models.remove(model) # check if it's been 60s of cool down and remove model
|
|
||||||
else:
|
|
||||||
continue # skip model
|
|
||||||
print(f"calling model {model}")
|
|
||||||
response = completion(**data)
|
|
||||||
if response != None:
|
|
||||||
return response
|
|
||||||
except Exception as e:
|
|
||||||
rate_limited_models.add(model)
|
|
||||||
model_expiration_times[model] = time.time() + 60 # cool down this selected model
|
|
||||||
pass
|
|
||||||
return response
|
|
||||||
|
|
||||||
|
|
||||||
########### Pricing is tracked in Supabase ############
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
############ Caching ###################################
|
############ Caching ###################################
|
||||||
# make a new endpoint with caching
|
# make a new endpoint with caching
|
||||||
# This Cache is built using ChromaDB
|
# This Cache is built using ChromaDB
|
||||||
# it has two functions add_cache() and get_cache()
|
# it has two functions add_cache() and get_cache()
|
||||||
@app.route('/chat/completions', methods=["POST"])
|
@app.route('/chat/completions_with_cache', methods=["POST"])
|
||||||
def api_completion_with_cache():
|
def api_completion_with_cache():
|
||||||
data = request.json
|
data = request.json
|
||||||
try:
|
try:
|
||||||
|
@ -100,66 +83,4 @@ def api_completion_with_cache():
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
# call handle_error function
|
# call handle_error function
|
||||||
return handle_error(data)
|
return handle_error(data)
|
||||||
return response, 200
|
return response, 200
|
||||||
|
|
||||||
import uuid
|
|
||||||
cache_collection = None
|
|
||||||
# Add a response to the cache
|
|
||||||
def add_cache(messages, model_response):
|
|
||||||
global cache_collection
|
|
||||||
if cache_collection is None:
|
|
||||||
make_collection()
|
|
||||||
|
|
||||||
user_question = message_to_user_question(messages)
|
|
||||||
|
|
||||||
# Add the user question and model response to the cache
|
|
||||||
cache_collection.add(
|
|
||||||
documents=[user_question],
|
|
||||||
metadatas=[{"model_response": str(model_response)}],
|
|
||||||
ids=[str(uuid.uuid4())]
|
|
||||||
)
|
|
||||||
return
|
|
||||||
|
|
||||||
# Retrieve a response from the cache if similarity is above the threshold
|
|
||||||
def get_cache(messages, similarity_threshold):
|
|
||||||
try:
|
|
||||||
global cache_collection
|
|
||||||
if cache_collection is None:
|
|
||||||
make_collection()
|
|
||||||
|
|
||||||
user_question = message_to_user_question(messages)
|
|
||||||
|
|
||||||
# Query the cache for the user question
|
|
||||||
results = cache_collection.query(
|
|
||||||
query_texts=[user_question],
|
|
||||||
n_results=1
|
|
||||||
)
|
|
||||||
|
|
||||||
if len(results['distances'][0]) == 0:
|
|
||||||
return None # Cache is empty
|
|
||||||
|
|
||||||
distance = results['distances'][0][0]
|
|
||||||
sim = (1 - distance)
|
|
||||||
|
|
||||||
if sim >= similarity_threshold:
|
|
||||||
return results['metadatas'][0][0]["model_response"] # Return cached response
|
|
||||||
else:
|
|
||||||
return None # No cache hit
|
|
||||||
except Exception as e:
|
|
||||||
print("Error in get cache", e)
|
|
||||||
raise e
|
|
||||||
|
|
||||||
# Initialize the cache collection
|
|
||||||
def make_collection():
|
|
||||||
import chromadb
|
|
||||||
global cache_collection
|
|
||||||
client = chromadb.Client()
|
|
||||||
cache_collection = client.create_collection("llm_responses")
|
|
||||||
|
|
||||||
# HELPER: Extract user's question from messages
|
|
||||||
def message_to_user_question(messages):
|
|
||||||
user_question = ""
|
|
||||||
for message in messages:
|
|
||||||
if message['role'] == 'user':
|
|
||||||
user_question += message["content"]
|
|
||||||
return user_question
|
|
21
cookbook/proxy-server/test_proxy_stream.py
Normal file
21
cookbook/proxy-server/test_proxy_stream.py
Normal file
|
@ -0,0 +1,21 @@
|
||||||
|
import openai
|
||||||
|
import os
|
||||||
|
|
||||||
|
os.environ["OPENAI_API_KEY"] = ""
|
||||||
|
|
||||||
|
openai.api_key = os.environ["OPENAI_API_KEY"]
|
||||||
|
openai.api_base ="http://localhost:5000"
|
||||||
|
|
||||||
|
messages = [
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": "write a 1 pg essay in liteLLM"
|
||||||
|
}
|
||||||
|
]
|
||||||
|
|
||||||
|
response = openai.ChatCompletion.create(model="gpt-3.5-turbo", messages=messages, stream=True)
|
||||||
|
print("got response", response)
|
||||||
|
# response is a generator
|
||||||
|
|
||||||
|
for chunk in response:
|
||||||
|
print(chunk)
|
107
cookbook/proxy-server/utils.py
Normal file
107
cookbook/proxy-server/utils.py
Normal file
|
@ -0,0 +1,107 @@
|
||||||
|
|
||||||
|
from litellm import completion
|
||||||
|
import os, dotenv
|
||||||
|
import json
|
||||||
|
dotenv.load_dotenv()
|
||||||
|
############### Advanced ##########################
|
||||||
|
|
||||||
|
########### streaming ############################
|
||||||
|
def generate_responses(response):
|
||||||
|
for chunk in response:
|
||||||
|
yield json.dumps({"response": chunk}) + "\n"
|
||||||
|
|
||||||
|
################ ERROR HANDLING #####################
|
||||||
|
# implement model fallbacks, cooldowns, and retries
|
||||||
|
# if a model fails assume it was rate limited and let it cooldown for 60s
|
||||||
|
def handle_error(data):
|
||||||
|
import time
|
||||||
|
# retry completion() request with fallback models
|
||||||
|
response = None
|
||||||
|
start_time = time.time()
|
||||||
|
rate_limited_models = set()
|
||||||
|
model_expiration_times = {}
|
||||||
|
fallback_strategy=['gpt-3.5-turbo', 'command-nightly', 'claude-2']
|
||||||
|
while response == None and time.time() - start_time < 45: # retry for 45s
|
||||||
|
for model in fallback_strategy:
|
||||||
|
try:
|
||||||
|
if model in rate_limited_models: # check if model is currently cooling down
|
||||||
|
if model_expiration_times.get(model) and time.time() >= model_expiration_times[model]:
|
||||||
|
rate_limited_models.remove(model) # check if it's been 60s of cool down and remove model
|
||||||
|
else:
|
||||||
|
continue # skip model
|
||||||
|
print(f"calling model {model}")
|
||||||
|
response = completion(**data)
|
||||||
|
if response != None:
|
||||||
|
return response
|
||||||
|
except Exception as e:
|
||||||
|
rate_limited_models.add(model)
|
||||||
|
model_expiration_times[model] = time.time() + 60 # cool down this selected model
|
||||||
|
pass
|
||||||
|
return response
|
||||||
|
|
||||||
|
|
||||||
|
########### Pricing is tracked in Supabase ############
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
import uuid
|
||||||
|
cache_collection = None
|
||||||
|
# Add a response to the cache
|
||||||
|
def add_cache(messages, model_response):
|
||||||
|
global cache_collection
|
||||||
|
if cache_collection is None:
|
||||||
|
make_collection()
|
||||||
|
|
||||||
|
user_question = message_to_user_question(messages)
|
||||||
|
|
||||||
|
# Add the user question and model response to the cache
|
||||||
|
cache_collection.add(
|
||||||
|
documents=[user_question],
|
||||||
|
metadatas=[{"model_response": str(model_response)}],
|
||||||
|
ids=[str(uuid.uuid4())]
|
||||||
|
)
|
||||||
|
return
|
||||||
|
|
||||||
|
# Retrieve a response from the cache if similarity is above the threshold
|
||||||
|
def get_cache(messages, similarity_threshold):
|
||||||
|
try:
|
||||||
|
global cache_collection
|
||||||
|
if cache_collection is None:
|
||||||
|
make_collection()
|
||||||
|
|
||||||
|
user_question = message_to_user_question(messages)
|
||||||
|
|
||||||
|
# Query the cache for the user question
|
||||||
|
results = cache_collection.query(
|
||||||
|
query_texts=[user_question],
|
||||||
|
n_results=1
|
||||||
|
)
|
||||||
|
|
||||||
|
if len(results['distances'][0]) == 0:
|
||||||
|
return None # Cache is empty
|
||||||
|
|
||||||
|
distance = results['distances'][0][0]
|
||||||
|
sim = (1 - distance)
|
||||||
|
|
||||||
|
if sim >= similarity_threshold:
|
||||||
|
return results['metadatas'][0][0]["model_response"] # Return cached response
|
||||||
|
else:
|
||||||
|
return None # No cache hit
|
||||||
|
except Exception as e:
|
||||||
|
print("Error in get cache", e)
|
||||||
|
raise e
|
||||||
|
|
||||||
|
# Initialize the cache collection
|
||||||
|
def make_collection():
|
||||||
|
import chromadb
|
||||||
|
global cache_collection
|
||||||
|
client = chromadb.Client()
|
||||||
|
cache_collection = client.create_collection("llm_responses")
|
||||||
|
|
||||||
|
# HELPER: Extract user's question from messages
|
||||||
|
def message_to_user_question(messages):
|
||||||
|
user_question = ""
|
||||||
|
for message in messages:
|
||||||
|
if message['role'] == 'user':
|
||||||
|
user_question += message["content"]
|
||||||
|
return user_question
|
Loading…
Add table
Add a link
Reference in a new issue