forked from phoenix/litellm-mirror
Improve code formatting and allow configurable litellm config path via environment variable.
This commit is contained in:
parent
a9ebf1b6ab
commit
52fdfe5819
2 changed files with 16 additions and 14 deletions
|
@ -9,7 +9,7 @@ import operator
|
|||
config_filename = "litellm.secrets.toml"
|
||||
# Using appdirs to determine user-specific config path
|
||||
config_dir = appdirs.user_config_dir("litellm")
|
||||
user_config_path = os.path.join(config_dir, config_filename)
|
||||
user_config_path = os.getenv("LITELLM_CONFIG_PATH", os.path.join(config_dir, config_filename))
|
||||
|
||||
load_dotenv()
|
||||
from importlib import resources
|
||||
|
|
|
@ -18,7 +18,8 @@ except ImportError:
|
|||
import subprocess
|
||||
import sys
|
||||
|
||||
subprocess.check_call([sys.executable, "-m", "pip", "install", "uvicorn", "fastapi", "tomli", "appdirs", "tomli-w", "backoff"])
|
||||
subprocess.check_call(
|
||||
[sys.executable, "-m", "pip", "install", "uvicorn", "fastapi", "tomli", "appdirs", "tomli-w", "backoff"])
|
||||
import uvicorn
|
||||
import fastapi
|
||||
import tomli as tomllib
|
||||
|
@ -26,9 +27,9 @@ except ImportError:
|
|||
import tomli_w
|
||||
|
||||
try:
|
||||
from .llm import litellm_completion
|
||||
from .llm import litellm_completion
|
||||
except ImportError as e:
|
||||
from llm import litellm_completion # type: ignore
|
||||
from llm import litellm_completion # type: ignore
|
||||
|
||||
import random
|
||||
|
||||
|
@ -105,7 +106,7 @@ model_router = litellm.Router()
|
|||
config_filename = "litellm.secrets.toml"
|
||||
config_dir = os.getcwd()
|
||||
config_dir = appdirs.user_config_dir("litellm")
|
||||
user_config_path = os.path.join(config_dir, config_filename)
|
||||
user_config_path = os.getenv("LITELLM_CONFIG_PATH", os.path.join(config_dir, config_filename))
|
||||
log_file = 'api_log.json'
|
||||
|
||||
|
||||
|
@ -184,7 +185,7 @@ def save_params_to_config(data: dict):
|
|||
|
||||
|
||||
def load_config():
|
||||
try:
|
||||
try:
|
||||
global user_config, user_api_base, user_max_tokens, user_temperature, user_model, local_logging
|
||||
# As the .env file is typically much simpler in structure, we use load_dotenv here directly
|
||||
with open(user_config_path, "rb") as f:
|
||||
|
@ -199,9 +200,9 @@ def load_config():
|
|||
litellm.add_function_to_prompt = user_config["general"].get("add_function_to_prompt",
|
||||
True) # by default add function to prompt if unsupported by provider
|
||||
litellm.drop_params = user_config["general"].get("drop_params",
|
||||
True) # by default drop params if unsupported by provider
|
||||
True) # by default drop params if unsupported by provider
|
||||
litellm.model_fallbacks = user_config["general"].get("fallbacks",
|
||||
None) # fallback models in case initial completion call fails
|
||||
None) # fallback models in case initial completion call fails
|
||||
default_model = user_config["general"].get("default_model", None) # route all requests to this model.
|
||||
|
||||
local_logging = user_config["general"].get("local_logging", True)
|
||||
|
@ -215,10 +216,10 @@ def load_config():
|
|||
if user_model in user_config["model"]:
|
||||
model_config = user_config["model"][user_model]
|
||||
model_list = []
|
||||
for model in user_config["model"]:
|
||||
for model in user_config["model"]:
|
||||
if "model_list" in user_config["model"][model]:
|
||||
model_list.extend(user_config["model"][model]["model_list"])
|
||||
if len(model_list) > 0:
|
||||
if len(model_list) > 0:
|
||||
model_router.set_model_list(model_list=model_list)
|
||||
|
||||
print_verbose(f"user_config: {user_config}")
|
||||
|
@ -254,7 +255,7 @@ def load_config():
|
|||
},
|
||||
final_prompt_value=model_prompt_template.get("MODEL_POST_PROMPT", ""),
|
||||
)
|
||||
except:
|
||||
except:
|
||||
pass
|
||||
|
||||
|
||||
|
@ -271,8 +272,8 @@ def initialize(model, alias, api_base, api_version, debug, temperature, max_toke
|
|||
if api_base: # model-specific param
|
||||
user_api_base = api_base
|
||||
dynamic_config[user_model]["api_base"] = api_base
|
||||
if api_version:
|
||||
os.environ["AZURE_API_VERSION"] = api_version # set this for azure - litellm can read this from the env
|
||||
if api_version:
|
||||
os.environ["AZURE_API_VERSION"] = api_version # set this for azure - litellm can read this from the env
|
||||
if max_tokens: # model-specific param
|
||||
user_max_tokens = max_tokens
|
||||
dynamic_config[user_model]["max_tokens"] = max_tokens
|
||||
|
@ -290,7 +291,7 @@ def initialize(model, alias, api_base, api_version, debug, temperature, max_toke
|
|||
if max_budget: # litellm-specific param
|
||||
litellm.max_budget = max_budget
|
||||
dynamic_config["general"]["max_budget"] = max_budget
|
||||
if debug: # litellm-specific param
|
||||
if debug: # litellm-specific param
|
||||
litellm.set_verbose = True
|
||||
if save:
|
||||
save_params_to_config(dynamic_config)
|
||||
|
@ -300,6 +301,7 @@ def initialize(model, alias, api_base, api_version, debug, temperature, max_toke
|
|||
user_telemetry = telemetry
|
||||
usage_telemetry(feature="local_proxy_server")
|
||||
|
||||
|
||||
def track_cost_callback(
|
||||
kwargs, # kwargs to completion
|
||||
completion_response, # response from completion
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue