Improve code formatting and allow configurable litellm config path via environment variable.

This commit is contained in:
coconut49 2023-10-18 13:30:53 +08:00
parent a9ebf1b6ab
commit 52fdfe5819
No known key found for this signature in database
2 changed files with 16 additions and 14 deletions

View file

@ -9,7 +9,7 @@ import operator
config_filename = "litellm.secrets.toml"
# Using appdirs to determine user-specific config path
config_dir = appdirs.user_config_dir("litellm")
user_config_path = os.path.join(config_dir, config_filename)
user_config_path = os.getenv("LITELLM_CONFIG_PATH", os.path.join(config_dir, config_filename))
load_dotenv()
from importlib import resources

View file

@ -18,7 +18,8 @@ except ImportError:
import subprocess
import sys
subprocess.check_call([sys.executable, "-m", "pip", "install", "uvicorn", "fastapi", "tomli", "appdirs", "tomli-w", "backoff"])
subprocess.check_call(
[sys.executable, "-m", "pip", "install", "uvicorn", "fastapi", "tomli", "appdirs", "tomli-w", "backoff"])
import uvicorn
import fastapi
import tomli as tomllib
@ -26,9 +27,9 @@ except ImportError:
import tomli_w
try:
from .llm import litellm_completion
from .llm import litellm_completion
except ImportError as e:
from llm import litellm_completion # type: ignore
from llm import litellm_completion # type: ignore
import random
@ -105,7 +106,7 @@ model_router = litellm.Router()
config_filename = "litellm.secrets.toml"
config_dir = os.getcwd()
config_dir = appdirs.user_config_dir("litellm")
user_config_path = os.path.join(config_dir, config_filename)
user_config_path = os.getenv("LITELLM_CONFIG_PATH", os.path.join(config_dir, config_filename))
log_file = 'api_log.json'
@ -184,7 +185,7 @@ def save_params_to_config(data: dict):
def load_config():
try:
try:
global user_config, user_api_base, user_max_tokens, user_temperature, user_model, local_logging
# As the .env file is typically much simpler in structure, we use load_dotenv here directly
with open(user_config_path, "rb") as f:
@ -199,9 +200,9 @@ def load_config():
litellm.add_function_to_prompt = user_config["general"].get("add_function_to_prompt",
True) # by default add function to prompt if unsupported by provider
litellm.drop_params = user_config["general"].get("drop_params",
True) # by default drop params if unsupported by provider
True) # by default drop params if unsupported by provider
litellm.model_fallbacks = user_config["general"].get("fallbacks",
None) # fallback models in case initial completion call fails
None) # fallback models in case initial completion call fails
default_model = user_config["general"].get("default_model", None) # route all requests to this model.
local_logging = user_config["general"].get("local_logging", True)
@ -215,10 +216,10 @@ def load_config():
if user_model in user_config["model"]:
model_config = user_config["model"][user_model]
model_list = []
for model in user_config["model"]:
for model in user_config["model"]:
if "model_list" in user_config["model"][model]:
model_list.extend(user_config["model"][model]["model_list"])
if len(model_list) > 0:
if len(model_list) > 0:
model_router.set_model_list(model_list=model_list)
print_verbose(f"user_config: {user_config}")
@ -254,7 +255,7 @@ def load_config():
},
final_prompt_value=model_prompt_template.get("MODEL_POST_PROMPT", ""),
)
except:
except:
pass
@ -271,8 +272,8 @@ def initialize(model, alias, api_base, api_version, debug, temperature, max_toke
if api_base: # model-specific param
user_api_base = api_base
dynamic_config[user_model]["api_base"] = api_base
if api_version:
os.environ["AZURE_API_VERSION"] = api_version # set this for azure - litellm can read this from the env
if api_version:
os.environ["AZURE_API_VERSION"] = api_version # set this for azure - litellm can read this from the env
if max_tokens: # model-specific param
user_max_tokens = max_tokens
dynamic_config[user_model]["max_tokens"] = max_tokens
@ -290,7 +291,7 @@ def initialize(model, alias, api_base, api_version, debug, temperature, max_toke
if max_budget: # litellm-specific param
litellm.max_budget = max_budget
dynamic_config["general"]["max_budget"] = max_budget
if debug: # litellm-specific param
if debug: # litellm-specific param
litellm.set_verbose = True
if save:
save_params_to_config(dynamic_config)
@ -300,6 +301,7 @@ def initialize(model, alias, api_base, api_version, debug, temperature, max_toke
user_telemetry = telemetry
usage_telemetry(feature="local_proxy_server")
def track_cost_callback(
kwargs, # kwargs to completion
completion_response, # response from completion