Improve code formatting and allow configurable litellm config path via environment variable.

This commit is contained in:
coconut49 2023-10-18 13:30:53 +08:00
parent a9ebf1b6ab
commit 52fdfe5819
No known key found for this signature in database
2 changed files with 16 additions and 14 deletions

View file

@ -9,7 +9,7 @@ import operator
config_filename = "litellm.secrets.toml" config_filename = "litellm.secrets.toml"
# Using appdirs to determine user-specific config path # Using appdirs to determine user-specific config path
config_dir = appdirs.user_config_dir("litellm") config_dir = appdirs.user_config_dir("litellm")
user_config_path = os.path.join(config_dir, config_filename) user_config_path = os.getenv("LITELLM_CONFIG_PATH", os.path.join(config_dir, config_filename))
load_dotenv() load_dotenv()
from importlib import resources from importlib import resources

View file

@ -18,7 +18,8 @@ except ImportError:
import subprocess import subprocess
import sys import sys
subprocess.check_call([sys.executable, "-m", "pip", "install", "uvicorn", "fastapi", "tomli", "appdirs", "tomli-w", "backoff"]) subprocess.check_call(
[sys.executable, "-m", "pip", "install", "uvicorn", "fastapi", "tomli", "appdirs", "tomli-w", "backoff"])
import uvicorn import uvicorn
import fastapi import fastapi
import tomli as tomllib import tomli as tomllib
@ -26,9 +27,9 @@ except ImportError:
import tomli_w import tomli_w
try: try:
from .llm import litellm_completion from .llm import litellm_completion
except ImportError as e: except ImportError as e:
from llm import litellm_completion # type: ignore from llm import litellm_completion # type: ignore
import random import random
@ -105,7 +106,7 @@ model_router = litellm.Router()
config_filename = "litellm.secrets.toml" config_filename = "litellm.secrets.toml"
config_dir = os.getcwd() config_dir = os.getcwd()
config_dir = appdirs.user_config_dir("litellm") config_dir = appdirs.user_config_dir("litellm")
user_config_path = os.path.join(config_dir, config_filename) user_config_path = os.getenv("LITELLM_CONFIG_PATH", os.path.join(config_dir, config_filename))
log_file = 'api_log.json' log_file = 'api_log.json'
@ -184,7 +185,7 @@ def save_params_to_config(data: dict):
def load_config(): def load_config():
try: try:
global user_config, user_api_base, user_max_tokens, user_temperature, user_model, local_logging global user_config, user_api_base, user_max_tokens, user_temperature, user_model, local_logging
# As the .env file is typically much simpler in structure, we use load_dotenv here directly # As the .env file is typically much simpler in structure, we use load_dotenv here directly
with open(user_config_path, "rb") as f: with open(user_config_path, "rb") as f:
@ -199,9 +200,9 @@ def load_config():
litellm.add_function_to_prompt = user_config["general"].get("add_function_to_prompt", litellm.add_function_to_prompt = user_config["general"].get("add_function_to_prompt",
True) # by default add function to prompt if unsupported by provider True) # by default add function to prompt if unsupported by provider
litellm.drop_params = user_config["general"].get("drop_params", litellm.drop_params = user_config["general"].get("drop_params",
True) # by default drop params if unsupported by provider True) # by default drop params if unsupported by provider
litellm.model_fallbacks = user_config["general"].get("fallbacks", litellm.model_fallbacks = user_config["general"].get("fallbacks",
None) # fallback models in case initial completion call fails None) # fallback models in case initial completion call fails
default_model = user_config["general"].get("default_model", None) # route all requests to this model. default_model = user_config["general"].get("default_model", None) # route all requests to this model.
local_logging = user_config["general"].get("local_logging", True) local_logging = user_config["general"].get("local_logging", True)
@ -215,10 +216,10 @@ def load_config():
if user_model in user_config["model"]: if user_model in user_config["model"]:
model_config = user_config["model"][user_model] model_config = user_config["model"][user_model]
model_list = [] model_list = []
for model in user_config["model"]: for model in user_config["model"]:
if "model_list" in user_config["model"][model]: if "model_list" in user_config["model"][model]:
model_list.extend(user_config["model"][model]["model_list"]) model_list.extend(user_config["model"][model]["model_list"])
if len(model_list) > 0: if len(model_list) > 0:
model_router.set_model_list(model_list=model_list) model_router.set_model_list(model_list=model_list)
print_verbose(f"user_config: {user_config}") print_verbose(f"user_config: {user_config}")
@ -254,7 +255,7 @@ def load_config():
}, },
final_prompt_value=model_prompt_template.get("MODEL_POST_PROMPT", ""), final_prompt_value=model_prompt_template.get("MODEL_POST_PROMPT", ""),
) )
except: except:
pass pass
@ -271,8 +272,8 @@ def initialize(model, alias, api_base, api_version, debug, temperature, max_toke
if api_base: # model-specific param if api_base: # model-specific param
user_api_base = api_base user_api_base = api_base
dynamic_config[user_model]["api_base"] = api_base dynamic_config[user_model]["api_base"] = api_base
if api_version: if api_version:
os.environ["AZURE_API_VERSION"] = api_version # set this for azure - litellm can read this from the env os.environ["AZURE_API_VERSION"] = api_version # set this for azure - litellm can read this from the env
if max_tokens: # model-specific param if max_tokens: # model-specific param
user_max_tokens = max_tokens user_max_tokens = max_tokens
dynamic_config[user_model]["max_tokens"] = max_tokens dynamic_config[user_model]["max_tokens"] = max_tokens
@ -290,7 +291,7 @@ def initialize(model, alias, api_base, api_version, debug, temperature, max_toke
if max_budget: # litellm-specific param if max_budget: # litellm-specific param
litellm.max_budget = max_budget litellm.max_budget = max_budget
dynamic_config["general"]["max_budget"] = max_budget dynamic_config["general"]["max_budget"] = max_budget
if debug: # litellm-specific param if debug: # litellm-specific param
litellm.set_verbose = True litellm.set_verbose = True
if save: if save:
save_params_to_config(dynamic_config) save_params_to_config(dynamic_config)
@ -300,6 +301,7 @@ def initialize(model, alias, api_base, api_version, debug, temperature, max_toke
user_telemetry = telemetry user_telemetry = telemetry
usage_telemetry(feature="local_proxy_server") usage_telemetry(feature="local_proxy_server")
def track_cost_callback( def track_cost_callback(
kwargs, # kwargs to completion kwargs, # kwargs to completion
completion_response, # response from completion completion_response, # response from completion