Improve code formatting and allow configurable litellm config path via environment variable.

This commit is contained in:
coconut49 2023-10-18 13:30:53 +08:00
parent a9ebf1b6ab
commit 52fdfe5819
No known key found for this signature in database
2 changed files with 16 additions and 14 deletions

View file

@ -9,7 +9,7 @@ import operator
config_filename = "litellm.secrets.toml" config_filename = "litellm.secrets.toml"
# Using appdirs to determine user-specific config path # Using appdirs to determine user-specific config path
config_dir = appdirs.user_config_dir("litellm") config_dir = appdirs.user_config_dir("litellm")
user_config_path = os.path.join(config_dir, config_filename) user_config_path = os.getenv("LITELLM_CONFIG_PATH", os.path.join(config_dir, config_filename))
load_dotenv() load_dotenv()
from importlib import resources from importlib import resources

View file

@ -18,7 +18,8 @@ except ImportError:
import subprocess import subprocess
import sys import sys
subprocess.check_call([sys.executable, "-m", "pip", "install", "uvicorn", "fastapi", "tomli", "appdirs", "tomli-w", "backoff"]) subprocess.check_call(
[sys.executable, "-m", "pip", "install", "uvicorn", "fastapi", "tomli", "appdirs", "tomli-w", "backoff"])
import uvicorn import uvicorn
import fastapi import fastapi
import tomli as tomllib import tomli as tomllib
@ -105,7 +106,7 @@ model_router = litellm.Router()
config_filename = "litellm.secrets.toml" config_filename = "litellm.secrets.toml"
config_dir = os.getcwd() config_dir = os.getcwd()
config_dir = appdirs.user_config_dir("litellm") config_dir = appdirs.user_config_dir("litellm")
user_config_path = os.path.join(config_dir, config_filename) user_config_path = os.getenv("LITELLM_CONFIG_PATH", os.path.join(config_dir, config_filename))
log_file = 'api_log.json' log_file = 'api_log.json'
@ -300,6 +301,7 @@ def initialize(model, alias, api_base, api_version, debug, temperature, max_toke
user_telemetry = telemetry user_telemetry = telemetry
usage_telemetry(feature="local_proxy_server") usage_telemetry(feature="local_proxy_server")
def track_cost_callback( def track_cost_callback(
kwargs, # kwargs to completion kwargs, # kwargs to completion
completion_response, # response from completion completion_response, # response from completion