From 52fdfe58195a514ca718df7760969cee610f82c7 Mon Sep 17 00:00:00 2001 From: coconut49 Date: Wed, 18 Oct 2023 13:30:53 +0800 Subject: [PATCH] Improve code formatting and allow configurable litellm config path via environment variable. --- litellm/proxy/proxy_cli.py | 2 +- litellm/proxy/proxy_server.py | 28 +++++++++++++++------------- 2 files changed, 16 insertions(+), 14 deletions(-) diff --git a/litellm/proxy/proxy_cli.py b/litellm/proxy/proxy_cli.py index f2b29bf5f..96e089caa 100644 --- a/litellm/proxy/proxy_cli.py +++ b/litellm/proxy/proxy_cli.py @@ -9,7 +9,7 @@ import operator config_filename = "litellm.secrets.toml" # Using appdirs to determine user-specific config path config_dir = appdirs.user_config_dir("litellm") -user_config_path = os.path.join(config_dir, config_filename) +user_config_path = os.getenv("LITELLM_CONFIG_PATH", os.path.join(config_dir, config_filename)) load_dotenv() from importlib import resources diff --git a/litellm/proxy/proxy_server.py b/litellm/proxy/proxy_server.py index 573646584..0afa6be1f 100644 --- a/litellm/proxy/proxy_server.py +++ b/litellm/proxy/proxy_server.py @@ -18,7 +18,8 @@ except ImportError: import subprocess import sys - subprocess.check_call([sys.executable, "-m", "pip", "install", "uvicorn", "fastapi", "tomli", "appdirs", "tomli-w", "backoff"]) + subprocess.check_call( + [sys.executable, "-m", "pip", "install", "uvicorn", "fastapi", "tomli", "appdirs", "tomli-w", "backoff"]) import uvicorn import fastapi import tomli as tomllib @@ -26,9 +27,9 @@ except ImportError: import tomli_w try: - from .llm import litellm_completion + from .llm import litellm_completion except ImportError as e: - from llm import litellm_completion # type: ignore + from llm import litellm_completion # type: ignore import random @@ -105,7 +106,7 @@ model_router = litellm.Router() config_filename = "litellm.secrets.toml" config_dir = os.getcwd() config_dir = appdirs.user_config_dir("litellm") -user_config_path = os.path.join(config_dir, config_filename) +user_config_path = os.getenv("LITELLM_CONFIG_PATH", os.path.join(config_dir, config_filename)) log_file = 'api_log.json' @@ -184,7 +185,7 @@ def save_params_to_config(data: dict): def load_config(): - try: + try: global user_config, user_api_base, user_max_tokens, user_temperature, user_model, local_logging # As the .env file is typically much simpler in structure, we use load_dotenv here directly with open(user_config_path, "rb") as f: @@ -199,9 +200,9 @@ def load_config(): litellm.add_function_to_prompt = user_config["general"].get("add_function_to_prompt", True) # by default add function to prompt if unsupported by provider litellm.drop_params = user_config["general"].get("drop_params", - True) # by default drop params if unsupported by provider + True) # by default drop params if unsupported by provider litellm.model_fallbacks = user_config["general"].get("fallbacks", - None) # fallback models in case initial completion call fails + None) # fallback models in case initial completion call fails default_model = user_config["general"].get("default_model", None) # route all requests to this model. local_logging = user_config["general"].get("local_logging", True) @@ -215,10 +216,10 @@ def load_config(): if user_model in user_config["model"]: model_config = user_config["model"][user_model] model_list = [] - for model in user_config["model"]: + for model in user_config["model"]: if "model_list" in user_config["model"][model]: model_list.extend(user_config["model"][model]["model_list"]) - if len(model_list) > 0: + if len(model_list) > 0: model_router.set_model_list(model_list=model_list) print_verbose(f"user_config: {user_config}") @@ -254,7 +255,7 @@ def load_config(): }, final_prompt_value=model_prompt_template.get("MODEL_POST_PROMPT", ""), ) - except: + except: pass @@ -271,8 +272,8 @@ def initialize(model, alias, api_base, api_version, debug, temperature, max_toke if api_base: # model-specific param user_api_base = api_base dynamic_config[user_model]["api_base"] = api_base - if api_version: - os.environ["AZURE_API_VERSION"] = api_version # set this for azure - litellm can read this from the env + if api_version: + os.environ["AZURE_API_VERSION"] = api_version # set this for azure - litellm can read this from the env if max_tokens: # model-specific param user_max_tokens = max_tokens dynamic_config[user_model]["max_tokens"] = max_tokens @@ -290,7 +291,7 @@ def initialize(model, alias, api_base, api_version, debug, temperature, max_toke if max_budget: # litellm-specific param litellm.max_budget = max_budget dynamic_config["general"]["max_budget"] = max_budget - if debug: # litellm-specific param + if debug: # litellm-specific param litellm.set_verbose = True if save: save_params_to_config(dynamic_config) @@ -300,6 +301,7 @@ def initialize(model, alias, api_base, api_version, debug, temperature, max_toke user_telemetry = telemetry usage_telemetry(feature="local_proxy_server") + def track_cost_callback( kwargs, # kwargs to completion completion_response, # response from completion