mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-26 03:04:13 +00:00
feat(proxy_server): adds create-proxy feature
This commit is contained in:
parent
3da89a58ae
commit
b28c055896
11 changed files with 246 additions and 124 deletions
|
@ -1,9 +1,10 @@
|
|||
import sys, os, platform
|
||||
import sys, os, platform, time, copy
|
||||
import threading
|
||||
import shutil, random, traceback
|
||||
sys.path.insert(
|
||||
0, os.path.abspath("../..")
|
||||
) # Adds the parent directory to the system path
|
||||
# sys.path.insert(
|
||||
# 0, os.path.abspath("../..")
|
||||
# ) # Adds the parent directory to the system path - for litellm local dev
|
||||
|
||||
|
||||
try:
|
||||
import uvicorn
|
||||
|
@ -76,12 +77,10 @@ user_max_tokens = None
|
|||
user_temperature = None
|
||||
user_telemetry = False
|
||||
user_config = None
|
||||
config_filename = "litellm.secrets.toml"
|
||||
pkg_config_filename = "template.secrets.toml"
|
||||
# Using appdirs to determine user-specific config path
|
||||
config_dir = appdirs.user_config_dir("litellm")
|
||||
config_filename = "secrets.toml"
|
||||
config_dir = os.getcwd()
|
||||
user_config_path = os.path.join(config_dir, config_filename)
|
||||
|
||||
log_file = 'api_log.json'
|
||||
#### HELPER FUNCTIONS ####
|
||||
def print_verbose(print_statement):
|
||||
global user_debug
|
||||
|
@ -98,15 +97,6 @@ def usage_telemetry(): # helps us know if people are using this feature. Set `li
|
|||
def load_config():
|
||||
try:
|
||||
global user_config, user_api_base, user_max_tokens, user_temperature, user_model
|
||||
if not os.path.exists(user_config_path):
|
||||
# If user's config doesn't exist, copy the default config from the package
|
||||
here = os.path.abspath(os.path.dirname(__file__))
|
||||
parent_dir = os.path.dirname(here)
|
||||
default_config_path = os.path.join(parent_dir, pkg_config_filename)
|
||||
# Ensure the user-specific directory exists
|
||||
os.makedirs(config_dir, exist_ok=True)
|
||||
# Copying the file using shutil.copy
|
||||
shutil.copy(default_config_path, user_config_path)
|
||||
# As the .env file is typically much simpler in structure, we use load_dotenv here directly
|
||||
with open(user_config_path, "rb") as f:
|
||||
user_config = tomllib.load(f)
|
||||
|
@ -133,11 +123,8 @@ def load_config():
|
|||
|
||||
## load model config - to set this run `litellm --config`
|
||||
model_config = None
|
||||
if user_model == "local":
|
||||
model_config = user_config["local_model"]
|
||||
elif user_model == "hosted":
|
||||
model_config = user_config["hosted_model"]
|
||||
litellm.max_budget = model_config.get("max_budget", None) # check if user set a budget for hosted model - e.g. gpt-4
|
||||
if user_model in user_config["model"]:
|
||||
model_config = user_config["model"][user_model]
|
||||
|
||||
print_verbose(f"user_config: {user_config}")
|
||||
print_verbose(f"model_config: {model_config}")
|
||||
|
@ -317,7 +304,55 @@ def track_cost_callback(
|
|||
except:
|
||||
pass
|
||||
|
||||
litellm.success_callback = [track_cost_callback]
|
||||
def logger(
|
||||
kwargs, # kwargs to completion
|
||||
completion_response=None, # response from completion
|
||||
start_time=None,
|
||||
end_time=None # start/end time
|
||||
):
|
||||
log_event_type = kwargs['log_event_type']
|
||||
print(f"REACHES LOGGER: {log_event_type}")
|
||||
try:
|
||||
if log_event_type == 'pre_api_call':
|
||||
inference_params = copy.deepcopy(kwargs)
|
||||
timestamp = inference_params.pop('start_time')
|
||||
dt_key = timestamp.strftime("%Y%m%d%H%M%S%f")[:23]
|
||||
log_data = {
|
||||
dt_key: {
|
||||
'pre_api_call': inference_params
|
||||
}
|
||||
}
|
||||
|
||||
try:
|
||||
with open(log_file, 'r') as f:
|
||||
existing_data = json.load(f)
|
||||
except FileNotFoundError:
|
||||
existing_data = {}
|
||||
|
||||
existing_data.update(log_data)
|
||||
|
||||
with open(log_file, 'w') as f:
|
||||
json.dump(existing_data, f, indent=2)
|
||||
elif log_event_type == 'post_api_call':
|
||||
print(f"post api call kwargs: {kwargs}")
|
||||
if "stream" not in kwargs["optional_params"] or kwargs["optional_params"]["stream"] is False or kwargs.get("complete_streaming_response", False):
|
||||
inference_params = copy.deepcopy(kwargs)
|
||||
timestamp = inference_params.pop('start_time')
|
||||
dt_key = timestamp.strftime("%Y%m%d%H%M%S%f")[:23]
|
||||
|
||||
with open(log_file, 'r') as f:
|
||||
existing_data = json.load(f)
|
||||
|
||||
existing_data[dt_key]['post_api_call'] = inference_params
|
||||
|
||||
with open(log_file, 'w') as f:
|
||||
json.dump(existing_data, f, indent=2)
|
||||
except:
|
||||
traceback.print_exc()
|
||||
|
||||
litellm.input_callback = [logger]
|
||||
litellm.success_callback = [logger]
|
||||
litellm.failure_callback = [logger]
|
||||
|
||||
def litellm_completion(data, type):
|
||||
try:
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue