Refactor proxy_server.py for readability and code consistency

This commit is contained in:
coconut49 2023-10-17 23:48:55 +08:00
parent 266b3b82f5
commit 4414594e7d
No known key found for this signature in database

View file

@ -1,11 +1,11 @@
import sys, os, platform, time, copy import sys, os, platform, time, copy
import threading import threading
import shutil, random, traceback import shutil, random, traceback
sys.path.insert( sys.path.insert(
0, os.path.abspath("../..") 0, os.path.abspath("../..")
) # Adds the parent directory to the system path - for litellm local dev ) # Adds the parent directory to the system path - for litellm local dev
try: try:
import uvicorn import uvicorn
import fastapi import fastapi
@ -22,13 +22,14 @@ except ImportError:
import tomli as tomllib import tomli as tomllib
import appdirs import appdirs
import tomli_w import tomli_w
try: try:
from .llm import litellm_completion from .llm import litellm_completion
except ImportError as e: except ImportError as e:
from llm import litellm_completion from llm import litellm_completion
import random import random
list_of_messages = [ list_of_messages = [
"'The thing I wish you improved is...'", "'The thing I wish you improved is...'",
"'A feature I really want is...'", "'A feature I really want is...'",
@ -37,35 +38,36 @@ list_of_messages = [
"'I don't like how this works...'", "'I don't like how this works...'",
"'It would help me if you could add...'", "'It would help me if you could add...'",
"'This feature doesn't meet my needs because...'", "'This feature doesn't meet my needs because...'",
"'I get frustrated when the product...'", "'I get frustrated when the product...'",
] ]
def generate_feedback_box(): def generate_feedback_box():
box_width = 60 box_width = 60
# Select a random message # Select a random message
message = random.choice(list_of_messages) message = random.choice(list_of_messages)
print()
print('\033[1;37m' + '#' + '-' * box_width + '#\033[0m')
print('\033[1;37m' + '#' + ' ' * box_width + '#\033[0m')
print('\033[1;37m' + '# {:^59} #\033[0m'.format(message))
print('\033[1;37m' + '# {:^59} #\033[0m'.format('https://github.com/BerriAI/litellm/issues/new'))
print('\033[1;37m' + '#' + ' ' * box_width + '#\033[0m')
print('\033[1;37m' + '#' + '-' * box_width + '#\033[0m')
print()
print(' Thank you for using LiteLLM! - Krrish & Ishaan')
print()
print()
print()
print('\033[1;37m' + '#' + '-'*box_width + '#\033[0m')
print('\033[1;37m' + '#' + ' '*box_width + '#\033[0m')
print('\033[1;37m' + '# {:^59} #\033[0m'.format(message))
print('\033[1;37m' + '# {:^59} #\033[0m'.format('https://github.com/BerriAI/litellm/issues/new'))
print('\033[1;37m' + '#' + ' '*box_width + '#\033[0m')
print('\033[1;37m' + '#' + '-'*box_width + '#\033[0m')
print()
print(' Thank you for using LiteLLM! - Krrish & Ishaan')
print()
print()
generate_feedback_box() generate_feedback_box()
print() print()
print("\033[1;31mGive Feedback / Get Help: https://github.com/BerriAI/litellm/issues/new\033[0m") print("\033[1;31mGive Feedback / Get Help: https://github.com/BerriAI/litellm/issues/new\033[0m")
print() print()
print("\033[1;34mDocs: https://docs.litellm.ai/docs/proxy_server\033[0m") print("\033[1;34mDocs: https://docs.litellm.ai/docs/proxy_server\033[0m")
print() print()
import litellm import litellm
from fastapi import FastAPI, Request from fastapi import FastAPI, Request
@ -100,24 +102,29 @@ config_dir = os.getcwd()
config_dir = appdirs.user_config_dir("litellm") config_dir = appdirs.user_config_dir("litellm")
user_config_path = os.path.join(config_dir, config_filename) user_config_path = os.path.join(config_dir, config_filename)
log_file = 'api_log.json' log_file = 'api_log.json'
#### HELPER FUNCTIONS #### #### HELPER FUNCTIONS ####
def print_verbose(print_statement): def print_verbose(print_statement):
global user_debug global user_debug
if user_debug: if user_debug:
print(print_statement) print(print_statement)
def usage_telemetry(feature: str): # helps us know if people are using this feature. Set `litellm --telemetry False` to your cli call to turn this off
if user_telemetry: def usage_telemetry(
feature: str): # helps us know if people are using this feature. Set `litellm --telemetry False` to your cli call to turn this off
if user_telemetry:
data = { data = {
"feature": feature # "local_proxy_server" "feature": feature # "local_proxy_server"
} }
threading.Thread(target=litellm.utils.litellm_telemetry, args=(data,), daemon=True).start() threading.Thread(target=litellm.utils.litellm_telemetry, args=(data,), daemon=True).start()
def add_keys_to_config(key, value): def add_keys_to_config(key, value):
# Check if file exists # Check if file exists
if os.path.exists(user_config_path): if os.path.exists(user_config_path):
# Load existing file # Load existing file
with open(user_config_path, "rb") as f: with open(user_config_path, "rb") as f:
config = tomllib.load(f) config = tomllib.load(f)
else: else:
# File doesn't exist, create empty config # File doesn't exist, create empty config
@ -130,21 +137,22 @@ def add_keys_to_config(key, value):
with open(user_config_path, 'wb') as f: with open(user_config_path, 'wb') as f:
tomli_w.dump(config, f) tomli_w.dump(config, f)
def save_params_to_config(data: dict):
def save_params_to_config(data: dict):
# Check if file exists # Check if file exists
if os.path.exists(user_config_path): if os.path.exists(user_config_path):
# Load existing file # Load existing file
with open(user_config_path, "rb") as f: with open(user_config_path, "rb") as f:
config = tomllib.load(f) config = tomllib.load(f)
else: else:
# File doesn't exist, create empty config # File doesn't exist, create empty config
config = {} config = {}
config.setdefault('general', {}) config.setdefault('general', {})
## general config ## general config
general_settings = data["general"] general_settings = data["general"]
for key, value in general_settings.items(): for key, value in general_settings.items():
config["general"][key] = value config["general"][key] = value
@ -161,101 +169,104 @@ def save_params_to_config(data: dict):
# Write config to file # Write config to file
with open(user_config_path, 'wb') as f: with open(user_config_path, 'wb') as f:
tomli_w.dump(config, f) tomli_w.dump(config, f)
def load_config(): def load_config():
try: global user_config, user_api_base, user_max_tokens, user_temperature, user_model
global user_config, user_api_base, user_max_tokens, user_temperature, user_model # As the .env file is typically much simpler in structure, we use load_dotenv here directly
# As the .env file is typically much simpler in structure, we use load_dotenv here directly with open(user_config_path, "rb") as f:
with open(user_config_path, "rb") as f: user_config = tomllib.load(f)
user_config = tomllib.load(f)
## load keys ## load keys
if "keys" in user_config: if "keys" in user_config:
for key in user_config["keys"]: for key in user_config["keys"]:
os.environ[key] = user_config["keys"][key] # litellm can read keys from the environment os.environ[key] = user_config["keys"][key] # litellm can read keys from the environment
## settings ## settings
if "general" in user_config: if "general" in user_config:
litellm.add_function_to_prompt = user_config["general"].get("add_function_to_prompt", True) # by default add function to prompt if unsupported by provider litellm.add_function_to_prompt = user_config["general"].get("add_function_to_prompt",
litellm.drop_params = user_config["general"].get("drop_params", True) # by default drop params if unsupported by provider True) # by default add function to prompt if unsupported by provider
litellm.model_fallbacks = user_config["general"].get("fallbacks", None) # fallback models in case initial completion call fails litellm.drop_params = user_config["general"].get("drop_params",
default_model = user_config["general"].get("default_model", None) # route all requests to this model. True) # by default drop params if unsupported by provider
litellm.model_fallbacks = user_config["general"].get("fallbacks",
None) # fallback models in case initial completion call fails
default_model = user_config["general"].get("default_model", None) # route all requests to this model.
if user_model is None: # `litellm --model <model-name>`` > default_model. if user_model is None: # `litellm --model <model-name>`` > default_model.
user_model = default_model user_model = default_model
## load model config - to set this run `litellm --config` ## load model config - to set this run `litellm --config`
model_config = None model_config = None
if "model" in user_config: if "model" in user_config:
if user_model in user_config["model"]: if user_model in user_config["model"]:
model_config = user_config["model"][user_model] model_config = user_config["model"][user_model]
print_verbose(f"user_config: {user_config}")
print_verbose(f"model_config: {model_config}")
print_verbose(f"user_model: {user_model}")
if model_config is None:
return
user_max_tokens = model_config.get("max_tokens", None) print_verbose(f"user_config: {user_config}")
user_temperature = model_config.get("temperature", None) print_verbose(f"model_config: {model_config}")
user_api_base = model_config.get("api_base", None) print_verbose(f"user_model: {user_model}")
if model_config is None:
## custom prompt template return
if "prompt_template" in model_config:
model_prompt_template = model_config["prompt_template"]
if len(model_prompt_template.keys()) > 0: # if user has initialized this at all
litellm.register_prompt_template(
model=user_model,
initial_prompt_value=model_prompt_template.get("MODEL_PRE_PROMPT", ""),
roles={
"system": {
"pre_message": model_prompt_template.get("MODEL_SYSTEM_MESSAGE_START_TOKEN", ""),
"post_message": model_prompt_template.get("MODEL_SYSTEM_MESSAGE_END_TOKEN", ""),
},
"user": {
"pre_message": model_prompt_template.get("MODEL_USER_MESSAGE_START_TOKEN", ""),
"post_message": model_prompt_template.get("MODEL_USER_MESSAGE_END_TOKEN", ""),
},
"assistant": {
"pre_message": model_prompt_template.get("MODEL_ASSISTANT_MESSAGE_START_TOKEN", ""),
"post_message": model_prompt_template.get("MODEL_ASSISTANT_MESSAGE_END_TOKEN", ""),
}
},
final_prompt_value=model_prompt_template.get("MODEL_POST_PROMPT", ""),
)
except Exception as e:
pass
def initialize(model, alias, api_base, debug, temperature, max_tokens, max_budget, telemetry, drop_params, add_function_to_prompt, headers, save): user_max_tokens = model_config.get("max_tokens", None)
user_temperature = model_config.get("temperature", None)
user_api_base = model_config.get("api_base", None)
## custom prompt template
if "prompt_template" in model_config:
model_prompt_template = model_config["prompt_template"]
if len(model_prompt_template.keys()) > 0: # if user has initialized this at all
litellm.register_prompt_template(
model=user_model,
initial_prompt_value=model_prompt_template.get("MODEL_PRE_PROMPT", ""),
roles={
"system": {
"pre_message": model_prompt_template.get("MODEL_SYSTEM_MESSAGE_START_TOKEN", ""),
"post_message": model_prompt_template.get("MODEL_SYSTEM_MESSAGE_END_TOKEN", ""),
},
"user": {
"pre_message": model_prompt_template.get("MODEL_USER_MESSAGE_START_TOKEN", ""),
"post_message": model_prompt_template.get("MODEL_USER_MESSAGE_END_TOKEN", ""),
},
"assistant": {
"pre_message": model_prompt_template.get("MODEL_ASSISTANT_MESSAGE_START_TOKEN", ""),
"post_message": model_prompt_template.get("MODEL_ASSISTANT_MESSAGE_END_TOKEN", ""),
}
},
final_prompt_value=model_prompt_template.get("MODEL_POST_PROMPT", ""),
)
def initialize(model, alias, api_base, debug, temperature, max_tokens, max_budget, telemetry, drop_params,
add_function_to_prompt, headers, save):
global user_model, user_api_base, user_debug, user_max_tokens, user_temperature, user_telemetry, user_headers global user_model, user_api_base, user_debug, user_max_tokens, user_temperature, user_telemetry, user_headers
user_model = model user_model = model
user_debug = debug user_debug = debug
load_config() load_config()
dynamic_config = {"general": {}, user_model: {}} dynamic_config = {"general": {}, user_model: {}}
if headers: # model-specific param if headers: # model-specific param
user_headers = headers user_headers = headers
dynamic_config[user_model]["headers"] = headers dynamic_config[user_model]["headers"] = headers
if api_base: # model-specific param if api_base: # model-specific param
user_api_base = api_base user_api_base = api_base
dynamic_config[user_model]["api_base"] = api_base dynamic_config[user_model]["api_base"] = api_base
if max_tokens: # model-specific param if max_tokens: # model-specific param
user_max_tokens = max_tokens user_max_tokens = max_tokens
dynamic_config[user_model]["max_tokens"] = max_tokens dynamic_config[user_model]["max_tokens"] = max_tokens
if temperature: # model-specific param if temperature: # model-specific param
user_temperature = temperature user_temperature = temperature
dynamic_config[user_model]["temperature"] = temperature dynamic_config[user_model]["temperature"] = temperature
if alias: # model-specific param if alias: # model-specific param
dynamic_config[user_model]["alias"] = alias dynamic_config[user_model]["alias"] = alias
if drop_params == True: # litellm-specific param if drop_params == True: # litellm-specific param
litellm.drop_params = True litellm.drop_params = True
dynamic_config["general"]["drop_params"] = True dynamic_config["general"]["drop_params"] = True
if add_function_to_prompt == True: # litellm-specific param if add_function_to_prompt == True: # litellm-specific param
litellm.add_function_to_prompt = True litellm.add_function_to_prompt = True
dynamic_config["general"]["add_function_to_prompt"] = True dynamic_config["general"]["add_function_to_prompt"] = True
if max_budget: # litellm-specific param if max_budget: # litellm-specific param
litellm.max_budget = max_budget litellm.max_budget = max_budget
dynamic_config["general"]["max_budget"] = max_budget dynamic_config["general"]["max_budget"] = max_budget
if save: if save:
save_params_to_config(dynamic_config) save_params_to_config(dynamic_config)
with open(user_config_path) as f: with open(user_config_path) as f:
print(f.read()) print(f.read())
@ -263,6 +274,7 @@ def initialize(model, alias, api_base, debug, temperature, max_tokens, max_budge
user_telemetry = telemetry user_telemetry = telemetry
usage_telemetry(feature="local_proxy_server") usage_telemetry(feature="local_proxy_server")
def deploy_proxy(model, api_base, debug, temperature, max_tokens, telemetry, deploy): def deploy_proxy(model, api_base, debug, temperature, max_tokens, telemetry, deploy):
import requests import requests
# Load .env file # Load .env file
@ -293,8 +305,6 @@ def deploy_proxy(model, api_base, debug, temperature, max_tokens, telemetry, dep
files = {"file": open(".env", "rb")} files = {"file": open(".env", "rb")}
# print(files) # print(files)
response = requests.post(url, data=data, files=files) response = requests.post(url, data=data, files=files)
# print(response) # print(response)
# Check the status of the request # Check the status of the request
@ -309,10 +319,11 @@ def deploy_proxy(model, api_base, debug, temperature, max_tokens, telemetry, dep
return url return url
def track_cost_callback( def track_cost_callback(
kwargs, # kwargs to completion kwargs, # kwargs to completion
completion_response, # response from completion completion_response, # response from completion
start_time, end_time # start/end time start_time, end_time # start/end time
): ):
# track cost like this # track cost like this
# { # {
@ -330,12 +341,12 @@ def track_cost_callback(
# for streaming responses # for streaming responses
if "complete_streaming_response" in kwargs: if "complete_streaming_response" in kwargs:
# for tracking streaming cost we pass the "messages" and the output_text to litellm.completion_cost # for tracking streaming cost we pass the "messages" and the output_text to litellm.completion_cost
completion_response=kwargs["complete_streaming_response"] completion_response = kwargs["complete_streaming_response"]
input_text = kwargs["messages"] input_text = kwargs["messages"]
output_text = completion_response["choices"][0]["message"]["content"] output_text = completion_response["choices"][0]["message"]["content"]
response_cost = litellm.completion_cost( response_cost = litellm.completion_cost(
model = kwargs["model"], model=kwargs["model"],
messages = input_text, messages=input_text,
completion=output_text completion=output_text
) )
model = kwargs['model'] model = kwargs['model']
@ -353,7 +364,7 @@ def track_cost_callback(
with open("costs.json") as f: with open("costs.json") as f:
cost_data = json.load(f) cost_data = json.load(f)
except FileNotFoundError: except FileNotFoundError:
cost_data = {} cost_data = {}
import datetime import datetime
date = datetime.datetime.now().strftime("%b-%d-%Y") date = datetime.datetime.now().strftime("%b-%d-%Y")
if date not in cost_data: if date not in cost_data:
@ -374,47 +385,32 @@ def track_cost_callback(
except: except:
pass pass
def logger(
kwargs, # kwargs to completion
completion_response=None, # response from completion
start_time=None,
end_time=None # start/end time
):
log_event_type = kwargs['log_event_type']
try:
if log_event_type == 'pre_api_call':
inference_params = copy.deepcopy(kwargs)
timestamp = inference_params.pop('start_time')
dt_key = timestamp.strftime("%Y%m%d%H%M%S%f")[:23]
log_data = {
dt_key: {
'pre_api_call': inference_params
}
}
try:
with open(log_file, 'r') as f:
existing_data = json.load(f)
except FileNotFoundError:
existing_data = {}
existing_data.update(log_data)
def write_to_log():
with open(log_file, 'w') as f:
json.dump(existing_data, f, indent=2)
thread = threading.Thread(target=write_to_log, daemon=True) def logger(
thread.start() kwargs, # kwargs to completion
elif log_event_type == 'post_api_call': completion_response=None, # response from completion
if "stream" not in kwargs["optional_params"] or kwargs["optional_params"]["stream"] is False or kwargs.get("complete_streaming_response", False): start_time=None,
end_time=None # start/end time
):
log_event_type = kwargs['log_event_type']
try:
if log_event_type == 'pre_api_call':
inference_params = copy.deepcopy(kwargs) inference_params = copy.deepcopy(kwargs)
timestamp = inference_params.pop('start_time') timestamp = inference_params.pop('start_time')
dt_key = timestamp.strftime("%Y%m%d%H%M%S%f")[:23] dt_key = timestamp.strftime("%Y%m%d%H%M%S%f")[:23]
log_data = {
with open(log_file, 'r') as f: dt_key: {
existing_data = json.load(f) 'pre_api_call': inference_params
}
existing_data[dt_key]['post_api_call'] = inference_params }
try:
with open(log_file, 'r') as f:
existing_data = json.load(f)
except FileNotFoundError:
existing_data = {}
existing_data.update(log_data)
def write_to_log(): def write_to_log():
with open(log_file, 'w') as f: with open(log_file, 'w') as f:
@ -422,15 +418,35 @@ def logger(
thread = threading.Thread(target=write_to_log, daemon=True) thread = threading.Thread(target=write_to_log, daemon=True)
thread.start() thread.start()
except: elif log_event_type == 'post_api_call':
pass if "stream" not in kwargs["optional_params"] or kwargs["optional_params"]["stream"] is False or kwargs.get(
"complete_streaming_response", False):
inference_params = copy.deepcopy(kwargs)
timestamp = inference_params.pop('start_time')
dt_key = timestamp.strftime("%Y%m%d%H%M%S%f")[:23]
with open(log_file, 'r') as f:
existing_data = json.load(f)
existing_data[dt_key]['post_api_call'] = inference_params
def write_to_log():
with open(log_file, 'w') as f:
json.dump(existing_data, f, indent=2)
thread = threading.Thread(target=write_to_log, daemon=True)
thread.start()
except:
pass
litellm.input_callback = [logger] litellm.input_callback = [logger]
litellm.success_callback = [logger] litellm.success_callback = [logger]
litellm.failure_callback = [logger] litellm.failure_callback = [logger]
#### API ENDPOINTS #### #### API ENDPOINTS ####
@router.get("/models") # if project requires model list @router.get("/models") # if project requires model list
def model_list(): def model_list():
if user_model != None: if user_model != None:
return dict( return dict(
@ -440,19 +456,26 @@ def model_list():
else: else:
all_models = litellm.utils.get_valid_models() all_models = litellm.utils.get_valid_models()
return dict( return dict(
data = [{"id": model, "object": "model", "created": 1677610602, "owned_by": "openai"} for model in all_models], data=[{"id": model, "object": "model", "created": 1677610602, "owned_by": "openai"} for model in
all_models],
object="list", object="list",
) )
@router.post("/completions") @router.post("/completions")
async def completion(request: Request): async def completion(request: Request):
data = await request.json() data = await request.json()
return litellm_completion(data=data, type="completion", user_model=user_model, user_temperature=user_temperature, user_max_tokens=user_max_tokens, user_api_base=user_api_base, user_headers=user_headers, user_debug=user_debug) return litellm_completion(data=data, type="completion", user_model=user_model, user_temperature=user_temperature,
user_max_tokens=user_max_tokens, user_api_base=user_api_base, user_headers=user_headers,
user_debug=user_debug)
@router.post("/chat/completions") @router.post("/chat/completions")
async def chat_completion(request: Request): async def chat_completion(request: Request):
data = await request.json() data = await request.json()
response = litellm_completion(data, type="chat_completion", user_model=user_model, user_temperature=user_temperature, user_max_tokens=user_max_tokens, user_api_base=user_api_base, user_headers=user_headers, user_debug=user_debug) response = litellm_completion(data, type="chat_completion", user_model=user_model,
user_temperature=user_temperature, user_max_tokens=user_max_tokens,
user_api_base=user_api_base, user_headers=user_headers, user_debug=user_debug)
return response return response
@ -462,6 +485,7 @@ async def v1_completion(request: Request):
data = await request.json() data = await request.json()
return litellm_completion(data=data, type="completion") return litellm_completion(data=data, type="completion")
@router.post("/v1/chat/completions") @router.post("/v1/chat/completions")
async def v1_chat_completion(request: Request): async def v1_chat_completion(request: Request):
data = await request.json() data = await request.json()
@ -469,6 +493,7 @@ async def v1_chat_completion(request: Request):
response = litellm_completion(data, type="chat_completion") response = litellm_completion(data, type="chat_completion")
return response return response
def print_cost_logs(): def print_cost_logs():
with open('costs.json', 'r') as f: with open('costs.json', 'r') as f:
# print this in green # print this in green
@ -477,13 +502,16 @@ def print_cost_logs():
print("\033[0m") print("\033[0m")
return return
@router.get("/ollama_logs") @router.get("/ollama_logs")
async def retrieve_server_log(request: Request): async def retrieve_server_log(request: Request):
filepath = os.path.expanduser('~/.ollama/logs/server.log') filepath = os.path.expanduser('~/.ollama/logs/server.log')
return FileResponse(filepath) return FileResponse(filepath)
@router.get("/") @router.get("/")
async def home(request: Request): async def home(request: Request):
return "LiteLLM: RUNNING" return "LiteLLM: RUNNING"
app.include_router(router)
app.include_router(router)