feat(proxy_server): adds create-proxy feature

This commit is contained in:
Krrish Dholakia 2023-10-12 18:24:09 -07:00
parent 3da89a58ae
commit b28c055896
11 changed files with 246 additions and 124 deletions

View file

@ -1,9 +1,10 @@
import sys, os, platform
import sys, os, platform, time, copy
import threading
import shutil, random, traceback
sys.path.insert(
0, os.path.abspath("../..")
) # Adds the parent directory to the system path
# sys.path.insert(
# 0, os.path.abspath("../..")
# ) # Adds the parent directory to the system path - for litellm local dev
try:
import uvicorn
@ -76,12 +77,10 @@ user_max_tokens = None
user_temperature = None
user_telemetry = False
user_config = None
config_filename = "litellm.secrets.toml"
pkg_config_filename = "template.secrets.toml"
# Using appdirs to determine user-specific config path
config_dir = appdirs.user_config_dir("litellm")
config_filename = "secrets.toml"
config_dir = os.getcwd()
user_config_path = os.path.join(config_dir, config_filename)
log_file = 'api_log.json'
#### HELPER FUNCTIONS ####
def print_verbose(print_statement):
global user_debug
@ -98,15 +97,6 @@ def usage_telemetry(): # helps us know if people are using this feature. Set `li
def load_config():
try:
global user_config, user_api_base, user_max_tokens, user_temperature, user_model
if not os.path.exists(user_config_path):
# If user's config doesn't exist, copy the default config from the package
here = os.path.abspath(os.path.dirname(__file__))
parent_dir = os.path.dirname(here)
default_config_path = os.path.join(parent_dir, pkg_config_filename)
# Ensure the user-specific directory exists
os.makedirs(config_dir, exist_ok=True)
# Copying the file using shutil.copy
shutil.copy(default_config_path, user_config_path)
# As the .env file is typically much simpler in structure, we use load_dotenv here directly
with open(user_config_path, "rb") as f:
user_config = tomllib.load(f)
@ -133,11 +123,8 @@ def load_config():
## load model config - to set this run `litellm --config`
model_config = None
if user_model == "local":
model_config = user_config["local_model"]
elif user_model == "hosted":
model_config = user_config["hosted_model"]
litellm.max_budget = model_config.get("max_budget", None) # check if user set a budget for hosted model - e.g. gpt-4
if user_model in user_config["model"]:
model_config = user_config["model"][user_model]
print_verbose(f"user_config: {user_config}")
print_verbose(f"model_config: {model_config}")
@ -317,7 +304,55 @@ def track_cost_callback(
except:
pass
litellm.success_callback = [track_cost_callback]
def logger(
kwargs, # kwargs to completion
completion_response=None, # response from completion
start_time=None,
end_time=None # start/end time
):
log_event_type = kwargs['log_event_type']
print(f"REACHES LOGGER: {log_event_type}")
try:
if log_event_type == 'pre_api_call':
inference_params = copy.deepcopy(kwargs)
timestamp = inference_params.pop('start_time')
dt_key = timestamp.strftime("%Y%m%d%H%M%S%f")[:23]
log_data = {
dt_key: {
'pre_api_call': inference_params
}
}
try:
with open(log_file, 'r') as f:
existing_data = json.load(f)
except FileNotFoundError:
existing_data = {}
existing_data.update(log_data)
with open(log_file, 'w') as f:
json.dump(existing_data, f, indent=2)
elif log_event_type == 'post_api_call':
print(f"post api call kwargs: {kwargs}")
if "stream" not in kwargs["optional_params"] or kwargs["optional_params"]["stream"] is False or kwargs.get("complete_streaming_response", False):
inference_params = copy.deepcopy(kwargs)
timestamp = inference_params.pop('start_time')
dt_key = timestamp.strftime("%Y%m%d%H%M%S%f")[:23]
with open(log_file, 'r') as f:
existing_data = json.load(f)
existing_data[dt_key]['post_api_call'] = inference_params
with open(log_file, 'w') as f:
json.dump(existing_data, f, indent=2)
except:
traceback.print_exc()
litellm.input_callback = [logger]
litellm.success_callback = [logger]
litellm.failure_callback = [logger]
def litellm_completion(data, type):
try: