mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-25 10:44:24 +00:00
refactor: add black formatting
This commit is contained in:
parent
b87d630b0a
commit
4905929de3
156 changed files with 19723 additions and 10869 deletions
|
@ -4,9 +4,10 @@ from flask_cors import CORS
|
|||
import traceback
|
||||
import litellm
|
||||
from util import handle_error
|
||||
from litellm import completion
|
||||
import os, dotenv, time
|
||||
from litellm import completion
|
||||
import os, dotenv, time
|
||||
import json
|
||||
|
||||
dotenv.load_dotenv()
|
||||
|
||||
# TODO: set your keys in .env or here:
|
||||
|
@ -19,57 +20,72 @@ verbose = True
|
|||
|
||||
# litellm.caching_with_models = True # CACHING: caching_with_models Keys in the cache are messages + model. - to learn more: https://docs.litellm.ai/docs/caching/
|
||||
######### PROMPT LOGGING ##########
|
||||
os.environ["PROMPTLAYER_API_KEY"] = "" # set your promptlayer key here - https://promptlayer.com/
|
||||
os.environ[
|
||||
"PROMPTLAYER_API_KEY"
|
||||
] = "" # set your promptlayer key here - https://promptlayer.com/
|
||||
|
||||
# set callbacks
|
||||
litellm.success_callback = ["promptlayer"]
|
||||
############ HELPER FUNCTIONS ###################################
|
||||
|
||||
|
||||
def print_verbose(print_statement):
|
||||
if verbose:
|
||||
print(print_statement)
|
||||
|
||||
|
||||
app = Flask(__name__)
|
||||
CORS(app)
|
||||
|
||||
@app.route('/')
|
||||
|
||||
@app.route("/")
|
||||
def index():
|
||||
return 'received!', 200
|
||||
return "received!", 200
|
||||
|
||||
|
||||
def data_generator(response):
|
||||
for chunk in response:
|
||||
yield f"data: {json.dumps(chunk)}\n\n"
|
||||
|
||||
@app.route('/chat/completions', methods=["POST"])
|
||||
|
||||
@app.route("/chat/completions", methods=["POST"])
|
||||
def api_completion():
|
||||
data = request.json
|
||||
start_time = time.time()
|
||||
if data.get('stream') == "True":
|
||||
data['stream'] = True # convert to boolean
|
||||
start_time = time.time()
|
||||
if data.get("stream") == "True":
|
||||
data["stream"] = True # convert to boolean
|
||||
try:
|
||||
if "prompt" not in data:
|
||||
raise ValueError("data needs to have prompt")
|
||||
data["model"] = "togethercomputer/CodeLlama-34b-Instruct" # by default use Together AI's CodeLlama model - https://api.together.xyz/playground/chat?model=togethercomputer%2FCodeLlama-34b-Instruct
|
||||
data[
|
||||
"model"
|
||||
] = "togethercomputer/CodeLlama-34b-Instruct" # by default use Together AI's CodeLlama model - https://api.together.xyz/playground/chat?model=togethercomputer%2FCodeLlama-34b-Instruct
|
||||
# COMPLETION CALL
|
||||
system_prompt = "Only respond to questions about code. Say 'I don't know' to anything outside of that."
|
||||
messages = [{"role": "system", "content": system_prompt}, {"role": "user", "content": data.pop("prompt")}]
|
||||
messages = [
|
||||
{"role": "system", "content": system_prompt},
|
||||
{"role": "user", "content": data.pop("prompt")},
|
||||
]
|
||||
data["messages"] = messages
|
||||
print(f"data: {data}")
|
||||
response = completion(**data)
|
||||
## LOG SUCCESS
|
||||
end_time = time.time()
|
||||
if 'stream' in data and data['stream'] == True: # use generate_responses to stream responses
|
||||
return Response(data_generator(response), mimetype='text/event-stream')
|
||||
end_time = time.time()
|
||||
if (
|
||||
"stream" in data and data["stream"] == True
|
||||
): # use generate_responses to stream responses
|
||||
return Response(data_generator(response), mimetype="text/event-stream")
|
||||
except Exception as e:
|
||||
# call handle_error function
|
||||
print_verbose(f"Got Error api_completion(): {traceback.format_exc()}")
|
||||
## LOG FAILURE
|
||||
end_time = time.time()
|
||||
end_time = time.time()
|
||||
traceback_exception = traceback.format_exc()
|
||||
return handle_error(data=data)
|
||||
return response
|
||||
|
||||
@app.route('/get_models', methods=["POST"])
|
||||
|
||||
@app.route("/get_models", methods=["POST"])
|
||||
def get_models():
|
||||
try:
|
||||
return litellm.model_list
|
||||
|
@ -78,7 +94,8 @@ def get_models():
|
|||
response = {"error": str(e)}
|
||||
return response, 200
|
||||
|
||||
if __name__ == "__main__":
|
||||
from waitress import serve
|
||||
serve(app, host="0.0.0.0", port=4000, threads=500)
|
||||
|
||||
if __name__ == "__main__":
|
||||
from waitress import serve
|
||||
|
||||
serve(app, host="0.0.0.0", port=4000, threads=500)
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue